rhfeiyang commited on
Commit
e04ccbd
·
1 Parent(s): d4fbe32
Files changed (1) hide show
  1. inference.py +3 -2
inference.py CHANGED
@@ -14,6 +14,7 @@ import sys
14
  import gc
15
  from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
16
 
 
17
  # import train_util
18
 
19
  from utils.train_util import get_noisy_image, encode_prompts
@@ -319,8 +320,8 @@ def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIP
319
  uncond_embeddings = uncond_embed.repeat(bcz, 1, 1)
320
  else:
321
  uncond_embeddings = uncond_embed
322
- style_text_embeddings = torch.cat([uncond_embeddings, style_embeddings])
323
- original_embeddings = torch.cat([uncond_embeddings, original_embeddings])
324
 
325
  generator = torch.manual_seed(single_seed) if single_seed is not None else None
326
  noise_scheduler.set_timesteps(steps)
 
14
  import gc
15
  from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
16
 
17
+ from hf_demo import dtype
18
  # import train_util
19
 
20
  from utils.train_util import get_noisy_image, encode_prompts
 
320
  uncond_embeddings = uncond_embed.repeat(bcz, 1, 1)
321
  else:
322
  uncond_embeddings = uncond_embed
323
+ style_text_embeddings = torch.cat([uncond_embeddings, style_embeddings], dtype=weight_dtype)
324
+ original_embeddings = torch.cat([uncond_embeddings, original_embeddings], dtype=weight_dtype)
325
 
326
  generator = torch.manual_seed(single_seed) if single_seed is not None else None
327
  noise_scheduler.set_timesteps(steps)