radames commited on
Commit
c7f8801
·
1 Parent(s): ec09a64

copy from diffusers

Browse files
Files changed (1) hide show
  1. latent_consistency_controlnet.py +20 -15
latent_consistency_controlnet.py CHANGED
@@ -25,7 +25,6 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
25
 
26
  from diffusers import (
27
  AutoencoderKL,
28
- AutoencoderTiny,
29
  ConfigMixin,
30
  DiffusionPipeline,
31
  SchedulerMixin,
@@ -50,6 +49,17 @@ import PIL.Image
50
 
51
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
54
  _optional_components = ["scheduler"]
55
 
@@ -276,22 +286,17 @@ class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
276
  )
277
 
278
  elif isinstance(generator, list):
279
- if isinstance(self.vae, AutoencoderTiny):
280
- init_latents = [
281
- self.vae.encode(image[i : i + 1]).latents
282
- for i in range(batch_size)
283
- ]
284
- else:
285
- init_latents = [
286
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i])
287
- for i in range(batch_size)
288
- ]
289
  init_latents = torch.cat(init_latents, dim=0)
290
  else:
291
- if isinstance(self.vae, AutoencoderTiny):
292
- init_latents = self.vae.encode(image).latents
293
- else:
294
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
295
 
296
  init_latents = self.vae.config.scaling_factor * init_latents
297
 
 
25
 
26
  from diffusers import (
27
  AutoencoderKL,
 
28
  ConfigMixin,
29
  DiffusionPipeline,
30
  SchedulerMixin,
 
49
 
50
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
 
52
+
53
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
54
+ def retrieve_latents(encoder_output, generator):
55
+ if hasattr(encoder_output, "latent_dist"):
56
+ return encoder_output.latent_dist.sample(generator)
57
+ elif hasattr(encoder_output, "latents"):
58
+ return encoder_output.latents
59
+ else:
60
+ raise AttributeError("Could not access latents of provided encoder_output")
61
+
62
+
63
  class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
64
  _optional_components = ["scheduler"]
65
 
 
286
  )
287
 
288
  elif isinstance(generator, list):
289
+ init_latents = [
290
+ retrieve_latents(
291
+ self.vae.encode(image[i : i + 1]), generator=generator[i]
292
+ )
293
+ for i in range(batch_size)
294
+ ]
 
 
 
 
295
  init_latents = torch.cat(init_latents, dim=0)
296
  else:
297
+ init_latents = retrieve_latents(
298
+ self.vae.encode(image), generator=generator
299
+ )
 
300
 
301
  init_latents = self.vae.config.scaling_factor * init_latents
302