rhfeiyang commited on
Commit
c62a333
·
1 Parent(s): a7876f7
Files changed (1) hide show
  1. inference.py +3 -3
inference.py CHANGED
@@ -349,12 +349,12 @@ def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIP
349
  network.set_lora_slider(scale=current_scale)
350
  text_embedding = style_text_embeddings
351
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
352
- latent_model_input = torch.cat([latents] * 2)
353
 
354
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
355
  # predict the noise residual
356
  with network:
357
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embedding).sample
358
 
359
  # perform guidance
360
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
 
349
  network.set_lora_slider(scale=current_scale)
350
  text_embedding = style_text_embeddings
351
  # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
352
+ latent_model_input = torch.cat([latents] * 2).to(weight_dtype)
353
 
354
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t).to(weight_dtype)
355
  # predict the noise residual
356
  with network:
357
+ noise_pred = unet(latent_model_input, t , encoder_hidden_states=text_embedding).sample
358
 
359
  # perform guidance
360
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)