adamelliotfields commited on
Commit
22a0476
·
verified ·
1 Parent(s): 2221c84

Tidy-up model loading

Browse files
Files changed (1) hide show
  1. generate.py +61 -55
generate.py CHANGED
@@ -56,8 +56,6 @@ class Loader:
56
  def __new__(cls):
57
  if cls._instance is None:
58
  cls._instance = super(Loader, cls).__new__(cls)
59
- cls._instance.cpu = torch.device("cpu")
60
- cls._instance.gpu = torch.device("cuda")
61
  cls._instance.gan = None
62
  cls._instance.pipe = None
63
  return cls._instance
@@ -66,7 +64,7 @@ class Loader:
66
  has_deepcache = hasattr(self.pipe, "deepcache")
67
 
68
  if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
69
- return self.pipe.deepcache
70
  if has_deepcache:
71
  self.pipe.deepcache.disable()
72
  else:
@@ -74,9 +72,8 @@ class Loader:
74
 
75
  self.pipe.deepcache.set_params(cache_interval=interval)
76
  self.pipe.deepcache.enable()
77
- return self.pipe.deepcache
78
 
79
- def _load_vae(self, model_name=None, taesd=False, dtype=None):
80
  vae_type = type(self.pipe.vae)
81
  is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
82
  is_tiny = issubclass(vae_type, AutoencoderTiny)
@@ -88,25 +85,24 @@ class Loader:
88
  self.pipe.vae = AutoencoderTiny.from_pretrained(
89
  pretrained_model_name_or_path="madebyollin/taesd",
90
  use_safetensors=True,
91
- torch_dtype=dtype,
92
- ).to(self.gpu)
93
- return self.pipe.vae
94
 
95
  if is_tiny and not taesd:
96
  print("Switching to KL VAE...")
 
 
 
 
 
 
97
  self.pipe.vae = torch.compile(
98
- fullgraph=True,
99
  mode="reduce-overhead",
100
- model=AutoencoderKL.from_pretrained(
101
- pretrained_model_name_or_path=model_name,
102
- use_safetensors=True,
103
- torch_dtype=dtype,
104
- subfolder="vae",
105
- ).to(self.gpu),
106
  )
107
- return self.pipe.vae
108
 
109
- def load(self, model, scheduler, karras, taesd, deepcache_interval, upscale, dtype=None):
110
  model_lower = model.lower()
111
 
112
  schedulers = {
@@ -131,13 +127,23 @@ class Loader:
131
  if scheduler in ["Euler a", "PNDM"]:
132
  del scheduler_kwargs["use_karras_sigmas"]
133
 
 
 
 
 
 
 
 
 
 
 
134
  pipe_kwargs = {
135
  "scheduler": schedulers[scheduler](**scheduler_kwargs),
136
  "pretrained_model_name_or_path": model_lower,
137
  "requires_safety_checker": False,
138
  "use_safetensors": True,
139
  "safety_checker": None,
140
- "torch_dtype": dtype,
141
  }
142
 
143
  # already loaded
@@ -150,6 +156,10 @@ class Loader:
150
  or self.pipe.scheduler.config.use_karras_sigmas == karras
151
  )
152
 
 
 
 
 
153
  if same_model:
154
  if not same_scheduler:
155
  print(f"Switching to {scheduler}...")
@@ -157,30 +167,23 @@ class Loader:
157
  print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
158
  if not same_scheduler or not same_karras:
159
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
160
-
161
- self._load_vae(model_lower, taesd, dtype)
162
  self._load_deepcache(interval=deepcache_interval)
163
  return self.pipe, self.gan
164
  else:
165
  print(f"Unloading {model_name.lower()}...")
166
  self.pipe = None
167
- torch.cuda.empty_cache()
168
-
169
- # no fp16 variant
170
- if not ZERO_GPU and model_lower not in [
171
- "sg161222/realistic_vision_v5.1_novae",
172
- "prompthero/openjourney-v4",
173
- "linaqruf/anything-v3-1",
174
- ]:
175
- pipe_kwargs["variant"] = "fp16"
176
 
177
  print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
178
- self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
 
 
 
179
  self.pipe.load_textual_inversion(
180
  pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
181
  tokens=list(EMBEDDINGS.values()),
182
  )
183
- self._load_vae(model_lower, taesd, dtype)
184
  self._load_deepcache(interval=deepcache_interval)
185
 
186
  if upscale and self.gan is None:
@@ -190,8 +193,8 @@ class Loader:
190
  if not upscale and self.gan is not None:
191
  print("Unloading fal/AuraSR-v2...")
192
  self.gan = None
193
- torch.cuda.empty_cache
194
 
 
195
  return self.pipe, self.gan
196
 
197
 
@@ -269,11 +272,11 @@ def generate(
269
  if seed is None or seed < 0:
270
  seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
271
 
272
- GPU = torch.device("cuda")
273
 
274
- TORCH_DTYPE = (
275
  torch.bfloat16
276
- if torch.cuda.is_available() and torch.cuda.get_device_properties(GPU).major >= 8
277
  else torch.float16
278
  )
279
 
@@ -293,18 +296,19 @@ def generate(
293
  taesd,
294
  deepcache_interval,
295
  upscale,
296
- TORCH_DTYPE,
 
297
  )
298
 
299
  # prompt embeds
300
  compel = Compel(
301
  textual_inversion_manager=DiffusersTextualInversionManager(pipe),
302
- dtype_for_device_getter=lambda _: TORCH_DTYPE,
303
  returned_embeddings_type=EMBEDDINGS_TYPE,
304
  truncate_long_prompts=truncate_prompts,
305
  text_encoder=pipe.text_encoder,
306
  tokenizer=pipe.tokenizer,
307
- device=GPU,
308
  )
309
 
310
  images = []
@@ -318,7 +322,7 @@ def generate(
318
 
319
  for i in range(num_images):
320
  # seeded generator for each iteration
321
- generator = torch.Generator(device=GPU).manual_seed(current_seed)
322
 
323
  try:
324
  all_positive_prompts = parse_prompt(positive_prompt)
@@ -333,22 +337,24 @@ def generate(
333
  raise Error("ParsingException: Invalid prompt")
334
 
335
  with token_merging(pipe, tome_ratio=tome_ratio):
336
- image = pipe(
337
- num_inference_steps=inference_steps,
338
- negative_prompt_embeds=neg_embeds,
339
- guidance_scale=guidance_scale,
340
- prompt_embeds=pos_embeds,
341
- generator=generator,
342
- height=height,
343
- width=width,
344
- ).images[0]
345
-
346
- if upscale:
347
- print("Upscaling image...")
348
- batch_size = 12 if ZERO_GPU else 4 # smaller batch to fit in 8GB
349
- image = gan.upscale_4x_overlapped(image, max_batch_size=batch_size)
350
-
351
- images.append((image, str(current_seed)))
 
 
352
 
353
  if increment_seed:
354
  current_seed += 1
 
56
  def __new__(cls):
57
  if cls._instance is None:
58
  cls._instance = super(Loader, cls).__new__(cls)
 
 
59
  cls._instance.gan = None
60
  cls._instance.pipe = None
61
  return cls._instance
 
64
  has_deepcache = hasattr(self.pipe, "deepcache")
65
 
66
  if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
67
+ return
68
  if has_deepcache:
69
  self.pipe.deepcache.disable()
70
  else:
 
72
 
73
  self.pipe.deepcache.set_params(cache_interval=interval)
74
  self.pipe.deepcache.enable()
 
75
 
76
+ def _load_vae(self, model_name=None, taesd=False, variant=None):
77
  vae_type = type(self.pipe.vae)
78
  is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
79
  is_tiny = issubclass(vae_type, AutoencoderTiny)
 
85
  self.pipe.vae = AutoencoderTiny.from_pretrained(
86
  pretrained_model_name_or_path="madebyollin/taesd",
87
  use_safetensors=True,
88
+ ).to(device=self.pipe.device)
89
+ return
 
90
 
91
  if is_tiny and not taesd:
92
  print("Switching to KL VAE...")
93
+ model = AutoencoderKL.from_pretrained(
94
+ pretrained_model_name_or_path=model_name,
95
+ use_safetensors=True,
96
+ subfolder="vae",
97
+ variant=variant,
98
+ ).to(device=self.pipe.device)
99
  self.pipe.vae = torch.compile(
 
100
  mode="reduce-overhead",
101
+ fullgraph=True,
102
+ model=model,
 
 
 
 
103
  )
 
104
 
105
+ def load(self, model, scheduler, karras, taesd, deepcache_interval, upscale, dtype, device):
106
  model_lower = model.lower()
107
 
108
  schedulers = {
 
127
  if scheduler in ["Euler a", "PNDM"]:
128
  del scheduler_kwargs["use_karras_sigmas"]
129
 
130
+ # no fp16 variant
131
+ if not ZERO_GPU and model_lower not in [
132
+ "sg161222/realistic_vision_v5.1_novae",
133
+ "prompthero/openjourney-v4",
134
+ "linaqruf/anything-v3-1",
135
+ ]:
136
+ variant = "fp16"
137
+ else:
138
+ variant = None
139
+
140
  pipe_kwargs = {
141
  "scheduler": schedulers[scheduler](**scheduler_kwargs),
142
  "pretrained_model_name_or_path": model_lower,
143
  "requires_safety_checker": False,
144
  "use_safetensors": True,
145
  "safety_checker": None,
146
+ "variant": variant,
147
  }
148
 
149
  # already loaded
 
156
  or self.pipe.scheduler.config.use_karras_sigmas == karras
157
  )
158
 
159
+ if upscale and not self.gan:
160
+ print("Loading fal/AuraSR-v2...")
161
+ self.gan = AuraSR.from_pretrained("fal/AuraSR-v2")
162
+
163
  if same_model:
164
  if not same_scheduler:
165
  print(f"Switching to {scheduler}...")
 
167
  print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
168
  if not same_scheduler or not same_karras:
169
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
170
+ self._load_vae(model_lower, taesd, variant)
 
171
  self._load_deepcache(interval=deepcache_interval)
172
  return self.pipe, self.gan
173
  else:
174
  print(f"Unloading {model_name.lower()}...")
175
  self.pipe = None
 
 
 
 
 
 
 
 
 
176
 
177
  print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
178
+ self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(
179
+ device=device,
180
+ dtype=dtype,
181
+ )
182
  self.pipe.load_textual_inversion(
183
  pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
184
  tokens=list(EMBEDDINGS.values()),
185
  )
186
+ self._load_vae(model_lower, taesd, variant)
187
  self._load_deepcache(interval=deepcache_interval)
188
 
189
  if upscale and self.gan is None:
 
193
  if not upscale and self.gan is not None:
194
  print("Unloading fal/AuraSR-v2...")
195
  self.gan = None
 
196
 
197
+ torch.cuda.empty_cache()
198
  return self.pipe, self.gan
199
 
200
 
 
272
  if seed is None or seed < 0:
273
  seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
274
 
275
+ DEVICE = torch.device("cuda")
276
 
277
+ DTYPE = (
278
  torch.bfloat16
279
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(DEVICE).major >= 8
280
  else torch.float16
281
  )
282
 
 
296
  taesd,
297
  deepcache_interval,
298
  upscale,
299
+ DTYPE,
300
+ DEVICE,
301
  )
302
 
303
  # prompt embeds
304
  compel = Compel(
305
  textual_inversion_manager=DiffusersTextualInversionManager(pipe),
306
+ dtype_for_device_getter=lambda _: DTYPE,
307
  returned_embeddings_type=EMBEDDINGS_TYPE,
308
  truncate_long_prompts=truncate_prompts,
309
  text_encoder=pipe.text_encoder,
310
  tokenizer=pipe.tokenizer,
311
+ device=pipe.device,
312
  )
313
 
314
  images = []
 
322
 
323
  for i in range(num_images):
324
  # seeded generator for each iteration
325
+ generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
326
 
327
  try:
328
  all_positive_prompts = parse_prompt(positive_prompt)
 
337
  raise Error("ParsingException: Invalid prompt")
338
 
339
  with token_merging(pipe, tome_ratio=tome_ratio):
340
+ try:
341
+ image = pipe(
342
+ num_inference_steps=inference_steps,
343
+ negative_prompt_embeds=neg_embeds,
344
+ guidance_scale=guidance_scale,
345
+ prompt_embeds=pos_embeds,
346
+ generator=generator,
347
+ height=height,
348
+ width=width,
349
+ ).images[0]
350
+ if upscale:
351
+ print("Upscaling image...")
352
+ batch_size = 12 if ZERO_GPU else 4 # smaller batch to fit in 8GB
353
+ image = gan.upscale_4x_overlapped(image, max_batch_size=batch_size)
354
+ images.append((image, str(current_seed)))
355
+ finally:
356
+ if not ZERO_GPU:
357
+ torch.cuda.empty_cache()
358
 
359
  if increment_seed:
360
  current_seed += 1