Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Tidy-up model loading
Browse files- generate.py +61 -55
generate.py
CHANGED
@@ -56,8 +56,6 @@ class Loader:
|
|
56 |
def __new__(cls):
|
57 |
if cls._instance is None:
|
58 |
cls._instance = super(Loader, cls).__new__(cls)
|
59 |
-
cls._instance.cpu = torch.device("cpu")
|
60 |
-
cls._instance.gpu = torch.device("cuda")
|
61 |
cls._instance.gan = None
|
62 |
cls._instance.pipe = None
|
63 |
return cls._instance
|
@@ -66,7 +64,7 @@ class Loader:
|
|
66 |
has_deepcache = hasattr(self.pipe, "deepcache")
|
67 |
|
68 |
if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
|
69 |
-
return
|
70 |
if has_deepcache:
|
71 |
self.pipe.deepcache.disable()
|
72 |
else:
|
@@ -74,9 +72,8 @@ class Loader:
|
|
74 |
|
75 |
self.pipe.deepcache.set_params(cache_interval=interval)
|
76 |
self.pipe.deepcache.enable()
|
77 |
-
return self.pipe.deepcache
|
78 |
|
79 |
-
def _load_vae(self, model_name=None, taesd=False,
|
80 |
vae_type = type(self.pipe.vae)
|
81 |
is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
|
82 |
is_tiny = issubclass(vae_type, AutoencoderTiny)
|
@@ -88,25 +85,24 @@ class Loader:
|
|
88 |
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
89 |
pretrained_model_name_or_path="madebyollin/taesd",
|
90 |
use_safetensors=True,
|
91 |
-
|
92 |
-
|
93 |
-
return self.pipe.vae
|
94 |
|
95 |
if is_tiny and not taesd:
|
96 |
print("Switching to KL VAE...")
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
self.pipe.vae = torch.compile(
|
98 |
-
fullgraph=True,
|
99 |
mode="reduce-overhead",
|
100 |
-
|
101 |
-
|
102 |
-
use_safetensors=True,
|
103 |
-
torch_dtype=dtype,
|
104 |
-
subfolder="vae",
|
105 |
-
).to(self.gpu),
|
106 |
)
|
107 |
-
return self.pipe.vae
|
108 |
|
109 |
-
def load(self, model, scheduler, karras, taesd, deepcache_interval, upscale, dtype
|
110 |
model_lower = model.lower()
|
111 |
|
112 |
schedulers = {
|
@@ -131,13 +127,23 @@ class Loader:
|
|
131 |
if scheduler in ["Euler a", "PNDM"]:
|
132 |
del scheduler_kwargs["use_karras_sigmas"]
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
pipe_kwargs = {
|
135 |
"scheduler": schedulers[scheduler](**scheduler_kwargs),
|
136 |
"pretrained_model_name_or_path": model_lower,
|
137 |
"requires_safety_checker": False,
|
138 |
"use_safetensors": True,
|
139 |
"safety_checker": None,
|
140 |
-
"
|
141 |
}
|
142 |
|
143 |
# already loaded
|
@@ -150,6 +156,10 @@ class Loader:
|
|
150 |
or self.pipe.scheduler.config.use_karras_sigmas == karras
|
151 |
)
|
152 |
|
|
|
|
|
|
|
|
|
153 |
if same_model:
|
154 |
if not same_scheduler:
|
155 |
print(f"Switching to {scheduler}...")
|
@@ -157,30 +167,23 @@ class Loader:
|
|
157 |
print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
|
158 |
if not same_scheduler or not same_karras:
|
159 |
self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
|
160 |
-
|
161 |
-
self._load_vae(model_lower, taesd, dtype)
|
162 |
self._load_deepcache(interval=deepcache_interval)
|
163 |
return self.pipe, self.gan
|
164 |
else:
|
165 |
print(f"Unloading {model_name.lower()}...")
|
166 |
self.pipe = None
|
167 |
-
torch.cuda.empty_cache()
|
168 |
-
|
169 |
-
# no fp16 variant
|
170 |
-
if not ZERO_GPU and model_lower not in [
|
171 |
-
"sg161222/realistic_vision_v5.1_novae",
|
172 |
-
"prompthero/openjourney-v4",
|
173 |
-
"linaqruf/anything-v3-1",
|
174 |
-
]:
|
175 |
-
pipe_kwargs["variant"] = "fp16"
|
176 |
|
177 |
print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
|
178 |
-
self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(
|
|
|
|
|
|
|
179 |
self.pipe.load_textual_inversion(
|
180 |
pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
|
181 |
tokens=list(EMBEDDINGS.values()),
|
182 |
)
|
183 |
-
self._load_vae(model_lower, taesd,
|
184 |
self._load_deepcache(interval=deepcache_interval)
|
185 |
|
186 |
if upscale and self.gan is None:
|
@@ -190,8 +193,8 @@ class Loader:
|
|
190 |
if not upscale and self.gan is not None:
|
191 |
print("Unloading fal/AuraSR-v2...")
|
192 |
self.gan = None
|
193 |
-
torch.cuda.empty_cache
|
194 |
|
|
|
195 |
return self.pipe, self.gan
|
196 |
|
197 |
|
@@ -269,11 +272,11 @@ def generate(
|
|
269 |
if seed is None or seed < 0:
|
270 |
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
271 |
|
272 |
-
|
273 |
|
274 |
-
|
275 |
torch.bfloat16
|
276 |
-
if torch.cuda.is_available() and torch.cuda.get_device_properties(
|
277 |
else torch.float16
|
278 |
)
|
279 |
|
@@ -293,18 +296,19 @@ def generate(
|
|
293 |
taesd,
|
294 |
deepcache_interval,
|
295 |
upscale,
|
296 |
-
|
|
|
297 |
)
|
298 |
|
299 |
# prompt embeds
|
300 |
compel = Compel(
|
301 |
textual_inversion_manager=DiffusersTextualInversionManager(pipe),
|
302 |
-
dtype_for_device_getter=lambda _:
|
303 |
returned_embeddings_type=EMBEDDINGS_TYPE,
|
304 |
truncate_long_prompts=truncate_prompts,
|
305 |
text_encoder=pipe.text_encoder,
|
306 |
tokenizer=pipe.tokenizer,
|
307 |
-
device=
|
308 |
)
|
309 |
|
310 |
images = []
|
@@ -318,7 +322,7 @@ def generate(
|
|
318 |
|
319 |
for i in range(num_images):
|
320 |
# seeded generator for each iteration
|
321 |
-
generator = torch.Generator(device=
|
322 |
|
323 |
try:
|
324 |
all_positive_prompts = parse_prompt(positive_prompt)
|
@@ -333,22 +337,24 @@ def generate(
|
|
333 |
raise Error("ParsingException: Invalid prompt")
|
334 |
|
335 |
with token_merging(pipe, tome_ratio=tome_ratio):
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
|
|
352 |
|
353 |
if increment_seed:
|
354 |
current_seed += 1
|
|
|
56 |
def __new__(cls):
|
57 |
if cls._instance is None:
|
58 |
cls._instance = super(Loader, cls).__new__(cls)
|
|
|
|
|
59 |
cls._instance.gan = None
|
60 |
cls._instance.pipe = None
|
61 |
return cls._instance
|
|
|
64 |
has_deepcache = hasattr(self.pipe, "deepcache")
|
65 |
|
66 |
if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
|
67 |
+
return
|
68 |
if has_deepcache:
|
69 |
self.pipe.deepcache.disable()
|
70 |
else:
|
|
|
72 |
|
73 |
self.pipe.deepcache.set_params(cache_interval=interval)
|
74 |
self.pipe.deepcache.enable()
|
|
|
75 |
|
76 |
+
def _load_vae(self, model_name=None, taesd=False, variant=None):
|
77 |
vae_type = type(self.pipe.vae)
|
78 |
is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
|
79 |
is_tiny = issubclass(vae_type, AutoencoderTiny)
|
|
|
85 |
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
86 |
pretrained_model_name_or_path="madebyollin/taesd",
|
87 |
use_safetensors=True,
|
88 |
+
).to(device=self.pipe.device)
|
89 |
+
return
|
|
|
90 |
|
91 |
if is_tiny and not taesd:
|
92 |
print("Switching to KL VAE...")
|
93 |
+
model = AutoencoderKL.from_pretrained(
|
94 |
+
pretrained_model_name_or_path=model_name,
|
95 |
+
use_safetensors=True,
|
96 |
+
subfolder="vae",
|
97 |
+
variant=variant,
|
98 |
+
).to(device=self.pipe.device)
|
99 |
self.pipe.vae = torch.compile(
|
|
|
100 |
mode="reduce-overhead",
|
101 |
+
fullgraph=True,
|
102 |
+
model=model,
|
|
|
|
|
|
|
|
|
103 |
)
|
|
|
104 |
|
105 |
+
def load(self, model, scheduler, karras, taesd, deepcache_interval, upscale, dtype, device):
|
106 |
model_lower = model.lower()
|
107 |
|
108 |
schedulers = {
|
|
|
127 |
if scheduler in ["Euler a", "PNDM"]:
|
128 |
del scheduler_kwargs["use_karras_sigmas"]
|
129 |
|
130 |
+
# no fp16 variant
|
131 |
+
if not ZERO_GPU and model_lower not in [
|
132 |
+
"sg161222/realistic_vision_v5.1_novae",
|
133 |
+
"prompthero/openjourney-v4",
|
134 |
+
"linaqruf/anything-v3-1",
|
135 |
+
]:
|
136 |
+
variant = "fp16"
|
137 |
+
else:
|
138 |
+
variant = None
|
139 |
+
|
140 |
pipe_kwargs = {
|
141 |
"scheduler": schedulers[scheduler](**scheduler_kwargs),
|
142 |
"pretrained_model_name_or_path": model_lower,
|
143 |
"requires_safety_checker": False,
|
144 |
"use_safetensors": True,
|
145 |
"safety_checker": None,
|
146 |
+
"variant": variant,
|
147 |
}
|
148 |
|
149 |
# already loaded
|
|
|
156 |
or self.pipe.scheduler.config.use_karras_sigmas == karras
|
157 |
)
|
158 |
|
159 |
+
if upscale and not self.gan:
|
160 |
+
print("Loading fal/AuraSR-v2...")
|
161 |
+
self.gan = AuraSR.from_pretrained("fal/AuraSR-v2")
|
162 |
+
|
163 |
if same_model:
|
164 |
if not same_scheduler:
|
165 |
print(f"Switching to {scheduler}...")
|
|
|
167 |
print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
|
168 |
if not same_scheduler or not same_karras:
|
169 |
self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
|
170 |
+
self._load_vae(model_lower, taesd, variant)
|
|
|
171 |
self._load_deepcache(interval=deepcache_interval)
|
172 |
return self.pipe, self.gan
|
173 |
else:
|
174 |
print(f"Unloading {model_name.lower()}...")
|
175 |
self.pipe = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
|
178 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(
|
179 |
+
device=device,
|
180 |
+
dtype=dtype,
|
181 |
+
)
|
182 |
self.pipe.load_textual_inversion(
|
183 |
pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
|
184 |
tokens=list(EMBEDDINGS.values()),
|
185 |
)
|
186 |
+
self._load_vae(model_lower, taesd, variant)
|
187 |
self._load_deepcache(interval=deepcache_interval)
|
188 |
|
189 |
if upscale and self.gan is None:
|
|
|
193 |
if not upscale and self.gan is not None:
|
194 |
print("Unloading fal/AuraSR-v2...")
|
195 |
self.gan = None
|
|
|
196 |
|
197 |
+
torch.cuda.empty_cache()
|
198 |
return self.pipe, self.gan
|
199 |
|
200 |
|
|
|
272 |
if seed is None or seed < 0:
|
273 |
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
274 |
|
275 |
+
DEVICE = torch.device("cuda")
|
276 |
|
277 |
+
DTYPE = (
|
278 |
torch.bfloat16
|
279 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(DEVICE).major >= 8
|
280 |
else torch.float16
|
281 |
)
|
282 |
|
|
|
296 |
taesd,
|
297 |
deepcache_interval,
|
298 |
upscale,
|
299 |
+
DTYPE,
|
300 |
+
DEVICE,
|
301 |
)
|
302 |
|
303 |
# prompt embeds
|
304 |
compel = Compel(
|
305 |
textual_inversion_manager=DiffusersTextualInversionManager(pipe),
|
306 |
+
dtype_for_device_getter=lambda _: DTYPE,
|
307 |
returned_embeddings_type=EMBEDDINGS_TYPE,
|
308 |
truncate_long_prompts=truncate_prompts,
|
309 |
text_encoder=pipe.text_encoder,
|
310 |
tokenizer=pipe.tokenizer,
|
311 |
+
device=pipe.device,
|
312 |
)
|
313 |
|
314 |
images = []
|
|
|
322 |
|
323 |
for i in range(num_images):
|
324 |
# seeded generator for each iteration
|
325 |
+
generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
|
326 |
|
327 |
try:
|
328 |
all_positive_prompts = parse_prompt(positive_prompt)
|
|
|
337 |
raise Error("ParsingException: Invalid prompt")
|
338 |
|
339 |
with token_merging(pipe, tome_ratio=tome_ratio):
|
340 |
+
try:
|
341 |
+
image = pipe(
|
342 |
+
num_inference_steps=inference_steps,
|
343 |
+
negative_prompt_embeds=neg_embeds,
|
344 |
+
guidance_scale=guidance_scale,
|
345 |
+
prompt_embeds=pos_embeds,
|
346 |
+
generator=generator,
|
347 |
+
height=height,
|
348 |
+
width=width,
|
349 |
+
).images[0]
|
350 |
+
if upscale:
|
351 |
+
print("Upscaling image...")
|
352 |
+
batch_size = 12 if ZERO_GPU else 4 # smaller batch to fit in 8GB
|
353 |
+
image = gan.upscale_4x_overlapped(image, max_batch_size=batch_size)
|
354 |
+
images.append((image, str(current_seed)))
|
355 |
+
finally:
|
356 |
+
if not ZERO_GPU:
|
357 |
+
torch.cuda.empty_cache()
|
358 |
|
359 |
if increment_seed:
|
360 |
current_seed += 1
|