multimodalart HF staff commited on
Commit
2ed4418
·
verified ·
1 Parent(s): 463aefd

Add live previews (now for realz)

Browse files
Files changed (1) hide show
  1. live_preview_helpers.py +6 -5
live_preview_helpers.py CHANGED
@@ -59,6 +59,7 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
59
  return_dict: bool = True,
60
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
  max_sequence_length: int = 512,
 
62
  ):
63
  height = height or self.default_sample_size * self.vae_scale_factor
64
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -156,10 +157,10 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
156
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
157
  torch.cuda.empty_cache()
158
 
159
- # Final image
160
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
161
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
162
- image = self.vae.decode(latents, return_dict=False)[0]
163
  self.maybe_free_model_hooks()
164
  torch.cuda.empty_cache()
165
- return self.image_processor.postprocess(image, output_type=output_type)[0], 0
 
59
  return_dict: bool = True,
60
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
61
  max_sequence_length: int = 512,
62
+ good_vae: Optional[Any] = None,
63
  ):
64
  height = height or self.default_sample_size * self.vae_scale_factor
65
  width = width or self.default_sample_size * self.vae_scale_factor
 
157
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
158
  torch.cuda.empty_cache()
159
 
160
+ # Final image using good_vae
161
+ latents = self._unpack_latents(latents, height, width, good_vae.config.vae_scale_factor)
162
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
163
+ image = good_vae.decode(latents, return_dict=False)[0]
164
  self.maybe_free_model_hooks()
165
  torch.cuda.empty_cache()
166
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]