fffiloni commited on
Commit
669715f
·
verified ·
1 Parent(s): cb886e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -4
app.py CHANGED
@@ -6,14 +6,205 @@ from pathlib import Path
6
  stable_cascade_path = Path(__file__).parent / "third_party" / "StableCascade"
7
  sys.path.append(str(stable_cascade_path))
8
 
9
- import gradio as gr
 
 
 
 
 
 
10
  from inference.utils import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def infer(un, deux, trois):
13
- return "quatre"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  gr.Interface(
16
  fn = infer,
17
  inputs=[gr.Textbox(label="style description"), gr.Image(label="Ref Style File", type="filepath"), gr.Textbox(label="caption")],
18
- outputs=[gr.Textbox()]
19
  ).launch()
 
6
  stable_cascade_path = Path(__file__).parent / "third_party" / "StableCascade"
7
  sys.path.append(str(stable_cascade_path))
8
 
9
+ import yaml
10
+ import torch
11
+ from tqdm import tqdm
12
+ from accelerate.utils import set_module_tensor_to_device
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms as T
15
+ from lang_sam import LangSAM
16
  from inference.utils import *
17
+ from core.utils import load_or_fail
18
+ from train import WurstCoreC, WurstCoreB
19
+ from gdf_rbm import RBM
20
+ from stage_c_rbm import StageCRBM
21
+ from utils import WurstCoreCRBM
22
+ from gdf.schedulers import CosineSchedule
23
+ from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
24
+ from gdf.targets import EpsilonTarget
25
+
26
+ # Device configuration
27
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
+ print(device)
29
+
30
+ # Flag for low VRAM usage
31
+ low_vram = False
32
+
33
+ # Function definition for low VRAM usage
34
+ if low_vram:
35
+ def models_to(model, device="cpu", excepts=None):
36
+ """
37
+ Change the device of nn.Modules within a class, skipping specified attributes.
38
+ """
39
+ for attr_name in dir(model):
40
+ if attr_name.startswith('__') and attr_name.endswith('__'):
41
+ continue # skip special attributes
42
+
43
+ attr_value = getattr(model, attr_name, None)
44
+
45
+ if isinstance(attr_value, torch.nn.Module):
46
+ if excepts and attr_name in excepts:
47
+ print(f"Except '{attr_name}'")
48
+ continue
49
+ print(f"Change device of '{attr_name}' to {device}")
50
+ attr_value.to(device)
51
+
52
+ torch.cuda.empty_cache()
53
+
54
+ # Stage C model configuration
55
+ config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
56
+ with open(config_file, "r", encoding="utf-8") as file:
57
+ loaded_config = yaml.safe_load(file)
58
+
59
+ core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
60
+
61
+ # Stage B model configuration
62
+ config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
63
+ with open(config_file_b, "r", encoding="utf-8") as file:
64
+ config_file_b = yaml.safe_load(file)
65
+
66
+ core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
67
+
68
+ # Setup extras and models for Stage C
69
+ extras = core.setup_extras_pre()
70
+
71
+ gdf_rbm = RBM(
72
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
73
+ input_scaler=VPScaler(), target=EpsilonTarget(),
74
+ noise_cond=CosineTNoiseCond(),
75
+ loss_weight=AdaptiveLossWeight(),
76
+ )
77
+
78
+ sampling_configs = {
79
+ "cfg": 5,
80
+ "sampler": DDPMSampler(gdf_rbm),
81
+ "shift": 1,
82
+ "timesteps": 20
83
+ }
84
+
85
+ extras = core.Extras(
86
+ gdf=gdf_rbm,
87
+ sampling_configs=sampling_configs,
88
+ transforms=extras.transforms,
89
+ effnet_preprocess=extras.effnet_preprocess,
90
+ clip_preprocess=extras.clip_preprocess
91
+ )
92
+
93
+ models = core.setup_models(extras)
94
+ models.generator.eval().requires_grad_(False)
95
+
96
+ # Setup extras and models for Stage B
97
+ extras_b = core_b.setup_extras_pre()
98
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
99
+ models_b = WurstCoreB.Models(
100
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
101
+ )
102
+ models_b.generator.bfloat16().eval().requires_grad_(False)
103
+
104
+ # Off-load old generator (low VRAM mode)
105
+ if low_vram:
106
+ models.generator.to("cpu")
107
+ torch.cuda.empty_cache()
108
+
109
+ # Load and configure new generator
110
+ generator_rbm = StageCRBM()
111
+ for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
112
+ set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
113
 
114
+ generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
115
+ generator_rbm = core.load_model(generator_rbm, 'generator')
116
+
117
+ # Create models_rbm instance
118
+ models_rbm = core.Models(
119
+ effnet=models.effnet,
120
+ previewer=models.previewer,
121
+ generator=generator_rbm,
122
+ generator_ema=models.generator_ema,
123
+ tokenizer=models.tokenizer,
124
+ text_model=models.text_model,
125
+ image_model=models.image_model
126
+ )
127
+ models_rbm.generator.eval().requires_grad_(False)
128
+
129
+ def infer(style_description, ref_style_file, caption):
130
+
131
+ height=1024
132
+ width=1024
133
+ batch_size=1
134
+ output_file='output.png'
135
+
136
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
137
+
138
+ extras.sampling_configs['cfg'] = 4
139
+ extras.sampling_configs['shift'] = 2
140
+ extras.sampling_configs['timesteps'] = 20
141
+ extras.sampling_configs['t_start'] = 1.0
142
+
143
+ extras_b.sampling_configs['cfg'] = 1.1
144
+ extras_b.sampling_configs['shift'] = 1
145
+ extras_b.sampling_configs['timesteps'] = 10
146
+ extras_b.sampling_configs['t_start'] = 1.0
147
+
148
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
149
+
150
+ batch = {'captions': [caption] * batch_size}
151
+ batch['style'] = ref_style
152
+
153
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
154
+
155
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
156
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
157
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
158
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
159
+
160
+ if low_vram:
161
+ # The sampling process uses more vram, so we offload everything except two modules to the cpu.
162
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
163
+
164
+ # Stage C reverse process.
165
+ sampling_c = extras.gdf.sample(
166
+ models_rbm.generator, conditions, stage_c_latent_shape,
167
+ unconditions, device=device,
168
+ **extras.sampling_configs,
169
+ x0_style_forward=x0_style_forward,
170
+ apply_pushforward=False, tau_pushforward=8,
171
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
172
+ extras=extras, models=models_rbm,
173
+ lam_style=1, lam_txt_alignment=1.0,
174
+ use_ddim_sampler=True,
175
+ )
176
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
177
+ sampled_c = sampled_c
178
+
179
+ # Stage B reverse process.
180
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
181
+ conditions_b['effnet'] = sampled_c
182
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
183
+
184
+ sampling_b = extras_b.gdf.sample(
185
+ models_b.generator, conditions_b, stage_b_latent_shape,
186
+ unconditions_b, device=device, **extras_b.sampling_configs,
187
+ )
188
+ for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
189
+ sampled_b = sampled_b
190
+ sampled = models_b.stage_a.decode(sampled_b).float()
191
+
192
+ sampled = torch.cat([
193
+ torch.nn.functional.interpolate(ref_style.cpu(), size=height),
194
+ sampled.cpu(),
195
+ ],
196
+ dim=0)
197
+
198
+ # Save the sampled image to a file
199
+ sampled_image = T.ToPILImage()(sampled.squeeze(0)) # Convert tensor to PIL image
200
+ sampled_image.save(output_file) # Save the image
201
+
202
+ return output_file # Return the path to the saved image
203
+
204
+ import gradio as gr
205
 
206
  gr.Interface(
207
  fn = infer,
208
  inputs=[gr.Textbox(label="style description"), gr.Image(label="Ref Style File", type="filepath"), gr.Textbox(label="caption")],
209
+ outputs=[gr.Image()]
210
  ).launch()