rhfeiyang commited on
Commit
d7b9e64
·
1 Parent(s): d9e6174
Files changed (1) hide show
  1. hf_demo.py +5 -2
hf_demo.py CHANGED
@@ -39,14 +39,17 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
39
  adapter_path = lora_map[adapter_choice]
40
  if adapter_path not in [None, "None"]:
41
  adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
42
-
 
 
43
  prompts = [prompt]*samples
44
  infer_loader = get_validation_dataloader(prompts,num_workers=0)
45
  network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
 
46
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
47
  height=512, width=512, scales=[1.0],
48
  save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
49
- start_noise=-1, show=False, style_prompt="sks art", no_load=True,
50
  from_scratch=True, device=device, weight_dtype=dtype)[0][1.0]
51
  return pred_images
52
  @spaces.GPU
 
39
  adapter_path = lora_map[adapter_choice]
40
  if adapter_path not in [None, "None"]:
41
  adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
42
+ style_prompt="sks art"
43
+ else:
44
+ style_prompt=None
45
  prompts = [prompt]*samples
46
  infer_loader = get_validation_dataloader(prompts,num_workers=0)
47
  network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
48
+
49
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
50
  height=512, width=512, scales=[1.0],
51
  save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
52
+ start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
53
  from_scratch=True, device=device, weight_dtype=dtype)[0][1.0]
54
  return pred_images
55
  @spaces.GPU