Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- hf_demo.py +162 -82
- hf_demo_test.ipynb +188 -125
- utils/train_util.py +2 -1
hf_demo.py
CHANGED
@@ -16,70 +16,93 @@ pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
|
|
16 |
from inference import get_lora_network, inference, get_validation_dataloader
|
17 |
lora_map = {
|
18 |
"None": "None",
|
19 |
-
"Andre Derain": "andre-derain_subset1",
|
20 |
-
"Vincent van Gogh": "van_gogh_subset1",
|
21 |
-
"Andy Warhol": "andy_subset1",
|
22 |
"Walter Battiss": "walter-battiss_subset2",
|
23 |
-
"Camille Corot": "camille-corot_subset1",
|
24 |
-
"Claude Monet": "monet_subset2",
|
25 |
-
"Pablo Picasso": "picasso_subset1",
|
26 |
"Jackson Pollock": "jackson-pollock_subset1",
|
27 |
-
"Gerhard Richter": "gerhard-richter_subset1",
|
28 |
"M.C. Escher": "m.c.-escher_subset1",
|
29 |
"Albert Gleizes": "albert-gleizes_subset1",
|
30 |
-
"Hokusai": "katsushika-hokusai_subset1",
|
31 |
"Wassily Kandinsky": "kandinsky_subset1",
|
32 |
-
"Gustav Klimt": "klimt_subset3",
|
33 |
"Roy Lichtenstein": "roy-lichtenstein_subset1",
|
34 |
-
"Henri Matisse": "henri-matisse_subset1",
|
35 |
"Joan Miro": "joan-miro_subset2",
|
36 |
}
|
|
|
37 |
@spaces.GPU
|
38 |
-
def
|
39 |
adapter_path = lora_map[adapter_choice]
|
40 |
if adapter_path not in [None, "None"]:
|
41 |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
42 |
style_prompt="sks art"
|
43 |
else:
|
44 |
style_prompt=None
|
45 |
-
prompts = [prompt]
|
46 |
infer_loader = get_validation_dataloader(prompts,num_workers=0)
|
47 |
network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
|
48 |
|
49 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
50 |
-
height=512, width=512, scales=[
|
51 |
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
52 |
start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
|
53 |
-
from_scratch=True, device=device, weight_dtype=dtype)[0][1.0]
|
54 |
return pred_images
|
|
|
55 |
@spaces.GPU
|
56 |
-
def
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
60 |
-
height=512, width=512, scales=[0
|
61 |
-
save_dir=None, seed=seed,steps=
|
62 |
-
start_noise
|
63 |
-
from_scratch=
|
64 |
return pred_images
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
|
85 |
|
@@ -92,62 +115,119 @@ with block:
|
|
92 |
gr.Markdown("(More features in development...)")
|
93 |
with gr.Row():
|
94 |
text = gr.Textbox(
|
95 |
-
label="Enter your prompt",
|
96 |
max_lines=2,
|
97 |
-
placeholder="Enter your prompt",
|
98 |
-
container=
|
99 |
value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
|
100 |
)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
label="
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
|
113 |
-
|
114 |
-
with gr.Row(elem_id="advanced-options"):
|
115 |
adapter_choice = gr.Dropdown(
|
116 |
label="Select Art Adapter",
|
117 |
-
choices=[
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
value="
|
|
|
124 |
)
|
125 |
-
# print(adapter_choice[0])
|
126 |
-
# lora_path = lora_map[adapter_choice.value]
|
127 |
-
# if lora_path is not None:
|
128 |
-
# lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
129 |
|
130 |
-
|
131 |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
seed = gr.Slider(
|
137 |
-
label="Seed",
|
138 |
-
minimum=0,
|
139 |
-
maximum=2147483647,
|
140 |
-
step=1,
|
141 |
-
randomize=True,
|
142 |
-
)
|
143 |
|
144 |
-
gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
|
145 |
-
advanced_button.click(
|
146 |
-
None,
|
147 |
-
[],
|
148 |
-
text,
|
149 |
-
)
|
150 |
|
|
|
|
|
151 |
|
|
|
|
|
152 |
|
153 |
-
block.launch()
|
|
|
16 |
from inference import get_lora_network, inference, get_validation_dataloader
|
17 |
lora_map = {
|
18 |
"None": "None",
|
19 |
+
"Andre Derain (fauvism)": "andre-derain_subset1",
|
20 |
+
"Vincent van Gogh (post impressionism)": "van_gogh_subset1",
|
21 |
+
"Andy Warhol (pop art)": "andy_subset1",
|
22 |
"Walter Battiss": "walter-battiss_subset2",
|
23 |
+
"Camille Corot (realism)": "camille-corot_subset1",
|
24 |
+
"Claude Monet (impressionism)": "monet_subset2",
|
25 |
+
"Pablo Picasso (cubism)": "picasso_subset1",
|
26 |
"Jackson Pollock": "jackson-pollock_subset1",
|
27 |
+
"Gerhard Richter (abstract expressionism)": "gerhard-richter_subset1",
|
28 |
"M.C. Escher": "m.c.-escher_subset1",
|
29 |
"Albert Gleizes": "albert-gleizes_subset1",
|
30 |
+
"Hokusai (ukiyo-e)": "katsushika-hokusai_subset1",
|
31 |
"Wassily Kandinsky": "kandinsky_subset1",
|
32 |
+
"Gustav Klimt (art nouveau)": "klimt_subset3",
|
33 |
"Roy Lichtenstein": "roy-lichtenstein_subset1",
|
34 |
+
"Henri Matisse (abstract expressionism)": "henri-matisse_subset1",
|
35 |
"Joan Miro": "joan-miro_subset2",
|
36 |
}
|
37 |
+
|
38 |
@spaces.GPU
|
39 |
+
def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):
|
40 |
adapter_path = lora_map[adapter_choice]
|
41 |
if adapter_path not in [None, "None"]:
|
42 |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
43 |
style_prompt="sks art"
|
44 |
else:
|
45 |
style_prompt=None
|
46 |
+
prompts = [prompt]
|
47 |
infer_loader = get_validation_dataloader(prompts,num_workers=0)
|
48 |
network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
|
49 |
|
50 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
51 |
+
height=512, width=512, scales=[adapter_scale],
|
52 |
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
53 |
start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
|
54 |
+
from_scratch=True, device=device, weight_dtype=dtype)[0][1.0][0]
|
55 |
return pred_images
|
56 |
+
|
57 |
@spaces.GPU
|
58 |
+
def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):
|
59 |
+
style_prompt=None
|
60 |
+
prompts = [prompt]
|
61 |
+
infer_loader = get_validation_dataloader(prompts,num_workers=0)
|
62 |
+
network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"]
|
63 |
+
|
64 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
65 |
+
height=512, width=512, scales=[0.0],
|
66 |
+
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
67 |
+
start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
|
68 |
+
from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]
|
69 |
return pred_images
|
70 |
|
71 |
+
|
72 |
+
@spaces.GPU
|
73 |
+
def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):
|
74 |
+
style_prompt=None
|
75 |
+
prompts = [prompt]
|
76 |
+
# convert np to pil
|
77 |
+
ref_image = [Image.fromarray(ref_image)]
|
78 |
+
network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"]
|
79 |
+
infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
|
80 |
+
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
81 |
+
height=512, width=512, scales=[0.0],
|
82 |
+
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
83 |
+
start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
|
84 |
+
from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]
|
85 |
+
return pred_images
|
86 |
+
|
87 |
+
@spaces.GPU
|
88 |
+
def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):
|
89 |
+
adapter_path = lora_map[adapter_choice]
|
90 |
+
if adapter_path not in [None, "None"]:
|
91 |
+
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
92 |
+
style_prompt="sks art"
|
93 |
+
else:
|
94 |
+
style_prompt=None
|
95 |
+
prompts = [prompt]
|
96 |
+
# convert np to pil
|
97 |
+
ref_image = [Image.fromarray(ref_image)]
|
98 |
+
network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
|
99 |
+
infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
|
100 |
+
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
101 |
+
height=512, width=512, scales=[adapter_scale],
|
102 |
+
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
103 |
+
start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
|
104 |
+
from_scratch=False, device=device, weight_dtype=dtype)[0][1.0][0]
|
105 |
+
return pred_images
|
106 |
|
107 |
|
108 |
|
|
|
115 |
gr.Markdown("(More features in development...)")
|
116 |
with gr.Row():
|
117 |
text = gr.Textbox(
|
118 |
+
label="Enter your prompt(long and detailed would be better):",
|
119 |
max_lines=2,
|
120 |
+
placeholder="Enter your prompt(long and detailed would be better)",
|
121 |
+
container=True,
|
122 |
value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
|
123 |
)
|
124 |
|
125 |
+
with gr.Tab('Generation'):
|
126 |
+
with gr.Row():
|
127 |
+
with gr.Column():
|
128 |
+
# gr.Markdown("## Art-Free Generation")
|
129 |
+
# gr.Markdown("Generate images from text prompts.")
|
130 |
+
|
131 |
+
gallery_gen_ori = gr.Image(
|
132 |
+
label="W/O Adapter",
|
133 |
+
show_label=True,
|
134 |
+
elem_id="gallery",
|
135 |
+
height="auto"
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
with gr.Column():
|
140 |
+
# gr.Markdown("## Art-Free Generation")
|
141 |
+
# gr.Markdown("Generate images from text prompts.")
|
142 |
+
gallery_gen_art = gr.Image(
|
143 |
+
label="W/ Adapter",
|
144 |
+
show_label=True,
|
145 |
+
elem_id="gallery",
|
146 |
+
height="auto"
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
with gr.Row():
|
151 |
+
btn_gen_ori = gr.Button("Art-Free Generate", scale=1)
|
152 |
+
btn_gen_art = gr.Button("Artistic Generate", scale=1)
|
153 |
+
|
154 |
+
|
155 |
+
with gr.Tab('Stylization'):
|
156 |
+
with gr.Row():
|
157 |
+
|
158 |
+
with gr.Column():
|
159 |
+
# gr.Markdown("## Art-Free Generation")
|
160 |
+
# gr.Markdown("Generate images from text prompts.")
|
161 |
+
|
162 |
+
gallery_stylization_ref = gr.Image(
|
163 |
+
label="Ref Image",
|
164 |
+
show_label=True,
|
165 |
+
elem_id="gallery",
|
166 |
+
height="auto",
|
167 |
+
scale=1,
|
168 |
+
)
|
169 |
+
with gr.Column(scale=2):
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column():
|
172 |
+
# gr.Markdown("## Art-Free Generation")
|
173 |
+
# gr.Markdown("Generate images from text prompts.")
|
174 |
+
|
175 |
+
gallery_stylization_ori = gr.Image(
|
176 |
+
label="W/O Adapter",
|
177 |
+
show_label=True,
|
178 |
+
elem_id="gallery",
|
179 |
+
height="auto",
|
180 |
+
scale=1,
|
181 |
+
)
|
182 |
+
|
183 |
+
|
184 |
+
with gr.Column():
|
185 |
+
# gr.Markdown("## Art-Free Generation")
|
186 |
+
# gr.Markdown("Generate images from text prompts.")
|
187 |
+
gallery_stylization_art = gr.Image(
|
188 |
+
label="W/ Adapter",
|
189 |
+
show_label=True,
|
190 |
+
elem_id="gallery",
|
191 |
+
height="auto",
|
192 |
+
scale=1,
|
193 |
+
)
|
194 |
+
start_timestep = gr.Slider(label="Adapter Timestep", minimum=0, maximum=1000, value=800, step=1)
|
195 |
+
with gr.Row():
|
196 |
+
btn_style_ori = gr.Button("Art-Free Stylization", scale=1)
|
197 |
+
btn_style_art = gr.Button("Artistic Stylization", scale=1)
|
198 |
|
199 |
|
200 |
+
with gr.Row():
|
201 |
+
# with gr.Column():
|
202 |
+
# samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1, scale=1)
|
203 |
+
scale = gr.Slider(
|
204 |
+
label="Guidance Scale", minimum=0, maximum=20, value=7.5, step=0.1
|
205 |
+
)
|
206 |
+
# with gr.Column():
|
|
|
|
|
|
|
|
|
207 |
adapter_choice = gr.Dropdown(
|
208 |
label="Select Art Adapter",
|
209 |
+
choices=[ "Andre Derain (fauvism)","Vincent van Gogh (post impressionism)","Andy Warhol (pop art)",
|
210 |
+
"Camille Corot (realism)", "Claude Monet (impressionism)", "Pablo Picasso (cubism)", "Gerhard Richter (abstract expressionism)",
|
211 |
+
"Hokusai (ukiyo-e)", "Gustav Klimt (art nouveau)", "Henri Matisse (abstract expressionism)",
|
212 |
+
"Walter Battiss", "Jackson Pollock", "M.C. Escher", "Albert Gleizes", "Wassily Kandinsky",
|
213 |
+
"Roy Lichtenstein", "Joan Miro"
|
214 |
+
],
|
215 |
+
value="Andre Derain (fauvism)",
|
216 |
+
scale=1
|
217 |
)
|
|
|
|
|
|
|
|
|
218 |
|
219 |
+
with gr.Row():
|
220 |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
|
221 |
+
adapter_scale = gr.Slider(label="Stylization Scale", minimum=0, maximum=1.5, value=1., step=0.1, scale=1)
|
222 |
+
|
223 |
+
with gr.Row():
|
224 |
+
seed = gr.Slider(label="Seed",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
+
gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori)
|
228 |
+
gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art)
|
229 |
|
230 |
+
gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori)
|
231 |
+
gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art)
|
232 |
|
233 |
+
block.launch(sharing=True)
|
hf_demo_test.ipynb
CHANGED
@@ -45,7 +45,9 @@
|
|
45 |
},
|
46 |
"outputs": [],
|
47 |
"source": [
|
48 |
-
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
|
|
|
|
|
49 |
]
|
50 |
},
|
51 |
{
|
@@ -70,7 +72,7 @@
|
|
70 |
{
|
71 |
"data": {
|
72 |
"application/vnd.jupyter.widget-view+json": {
|
73 |
-
"model_id": "
|
74 |
"version_major": 2,
|
75 |
"version_minor": 0
|
76 |
},
|
@@ -83,8 +85,8 @@
|
|
83 |
}
|
84 |
],
|
85 |
"source": [
|
86 |
-
"pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\"
|
87 |
-
"device
|
88 |
]
|
89 |
},
|
90 |
{
|
@@ -102,77 +104,105 @@
|
|
102 |
"from inference import get_lora_network, inference, get_validation_dataloader\n",
|
103 |
"lora_map = {\n",
|
104 |
" \"None\": \"None\",\n",
|
105 |
-
" \"Andre Derain\": \"andre-derain_subset1\",\n",
|
106 |
-
" \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
|
107 |
-
" \"Andy Warhol\": \"andy_subset1\",\n",
|
108 |
" \"Walter Battiss\": \"walter-battiss_subset2\",\n",
|
109 |
-
" \"Camille Corot\": \"camille-corot_subset1\",\n",
|
110 |
-
" \"Claude Monet\": \"monet_subset2\",\n",
|
111 |
-
" \"Pablo Picasso\": \"picasso_subset1\",\n",
|
112 |
" \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
|
113 |
-
" \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
|
114 |
" \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
|
115 |
" \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
|
116 |
-
" \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
|
117 |
" \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
|
118 |
-
" \"Gustav Klimt\": \"klimt_subset3\",\n",
|
119 |
" \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
|
120 |
-
" \"Henri Matisse\": \"henri-matisse_subset1\",\n",
|
121 |
" \"Joan Miro\": \"joan-miro_subset2\",\n",
|
122 |
"}\n",
|
123 |
"\n",
|
124 |
-
"
|
|
|
|
|
125 |
" adapter_path = lora_map[adapter_choice]\n",
|
126 |
" if adapter_path not in [None, \"None\"]:\n",
|
127 |
" adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
"\n",
|
129 |
-
" prompts = [prompt]*samples\n",
|
130 |
-
" infer_loader = get_validation_dataloader(prompts)\n",
|
131 |
-
" network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
|
132 |
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
133 |
-
" height=512, width=512, scales=[
|
134 |
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
135 |
-
" start_noise=-1, show=False, style_prompt
|
136 |
-
" from_scratch=True)[0][1.0]\n",
|
137 |
" return pred_images\n",
|
138 |
"\n",
|
139 |
-
"
|
140 |
-
"
|
141 |
-
"
|
|
|
|
|
|
|
|
|
142 |
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
143 |
-
" height=512, width=512, scales=[0
|
144 |
-
" save_dir=None, seed=seed,steps=
|
145 |
-
" start_noise
|
146 |
-
" from_scratch=
|
147 |
" return pred_images\n",
|
148 |
"\n",
|
149 |
-
"
|
150 |
-
"
|
151 |
-
"
|
152 |
-
"
|
153 |
-
"
|
154 |
-
"#
|
155 |
-
"
|
156 |
-
"
|
157 |
-
"
|
158 |
-
"
|
159 |
-
"
|
160 |
-
"
|
161 |
-
"
|
162 |
-
"
|
163 |
-
"
|
164 |
-
"
|
165 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
]
|
167 |
},
|
168 |
{
|
169 |
"cell_type": "code",
|
170 |
-
"execution_count":
|
171 |
"id": "aa33e9d104023847",
|
172 |
"metadata": {
|
173 |
"ExecuteTime": {
|
174 |
-
"end_time": "2024-12-
|
175 |
-
"start_time": "2024-12-
|
176 |
}
|
177 |
},
|
178 |
"outputs": [
|
@@ -180,9 +210,10 @@
|
|
180 |
"name": "stdout",
|
181 |
"output_type": "stream",
|
182 |
"text": [
|
183 |
-
"
|
184 |
-
"
|
185 |
-
"
|
|
|
186 |
"\n",
|
187 |
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
|
188 |
]
|
@@ -190,7 +221,7 @@
|
|
190 |
{
|
191 |
"data": {
|
192 |
"text/html": [
|
193 |
-
"<div><iframe src=\"https://
|
194 |
],
|
195 |
"text/plain": [
|
196 |
"<IPython.core.display.HTML object>"
|
@@ -203,103 +234,135 @@
|
|
203 |
"data": {
|
204 |
"text/plain": []
|
205 |
},
|
206 |
-
"execution_count":
|
207 |
"metadata": {},
|
208 |
"output_type": "execute_result"
|
209 |
-
},
|
210 |
-
{
|
211 |
-
"name": "stdout",
|
212 |
-
"output_type": "stream",
|
213 |
-
"text": [
|
214 |
-
"Train method: None\n",
|
215 |
-
"Rank: 1, Alpha: 1\n",
|
216 |
-
"create LoRA for U-Net: 0 modules.\n",
|
217 |
-
"save dir: None\n",
|
218 |
-
"['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
|
219 |
-
]
|
220 |
-
},
|
221 |
-
{
|
222 |
-
"name": "stderr",
|
223 |
-
"output_type": "stream",
|
224 |
-
"text": [
|
225 |
-
"/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
|
226 |
-
" return F.conv2d(input, weight, bias, self.stride,\n",
|
227 |
-
"\n",
|
228 |
-
"00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00, 6.90it/s]"
|
229 |
-
]
|
230 |
-
},
|
231 |
-
{
|
232 |
-
"name": "stdout",
|
233 |
-
"output_type": "stream",
|
234 |
-
"text": [
|
235 |
-
"Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
|
236 |
-
]
|
237 |
}
|
238 |
],
|
239 |
"source": [
|
240 |
"block = gr.Blocks()\n",
|
241 |
"# Direct infer\n",
|
|
|
242 |
"with block:\n",
|
243 |
" with gr.Group():\n",
|
244 |
" gr.Markdown(\" # Art-Free Diffusion Demo\")\n",
|
|
|
245 |
" with gr.Row():\n",
|
246 |
" text = gr.Textbox(\n",
|
247 |
-
" label=\"Enter your prompt
|
248 |
" max_lines=2,\n",
|
249 |
-
" placeholder=\"Enter your prompt\",\n",
|
250 |
-
" container=
|
251 |
" value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
|
252 |
" )\n",
|
253 |
-
" \n",
|
254 |
"\n",
|
255 |
-
"
|
256 |
-
"
|
257 |
-
"
|
258 |
-
"
|
259 |
-
"
|
260 |
-
" elem_id=\"gallery\",\n",
|
261 |
-
" columns=[2],\n",
|
262 |
-
" )\n",
|
263 |
"\n",
|
264 |
-
"
|
|
|
|
|
|
|
|
|
|
|
265 |
"\n",
|
266 |
-
" with gr.Row(elem_id=\"advanced-options\"):\n",
|
267 |
-
" adapter_choice = gr.Dropdown(\n",
|
268 |
-
" label=\"Choose adapter\",\n",
|
269 |
-
" choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
|
270 |
-
" \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
|
271 |
-
" \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
|
272 |
-
" \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
|
273 |
-
" \"Henri Matisse\", \"Joan Miro\"\n",
|
274 |
-
" ],\n",
|
275 |
-
" value=\"None\"\n",
|
276 |
-
" )\n",
|
277 |
-
" # print(adapter_choice[0])\n",
|
278 |
-
" # lora_path = lora_map[adapter_choice.value]\n",
|
279 |
-
" # if lora_path is not None:\n",
|
280 |
-
" # lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
281 |
"\n",
|
282 |
-
"
|
283 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
" scale = gr.Slider(\n",
|
285 |
-
" label=\"Guidance Scale\", minimum=0, maximum=
|
286 |
" )\n",
|
287 |
-
"
|
288 |
-
"
|
289 |
-
" label=\"
|
290 |
-
"
|
291 |
-
"
|
292 |
-
"
|
293 |
-
"
|
|
|
|
|
|
|
|
|
294 |
" )\n",
|
295 |
"\n",
|
296 |
-
" gr.
|
297 |
-
"
|
298 |
-
"
|
299 |
-
"
|
300 |
-
"
|
301 |
-
"
|
|
|
|
|
|
|
|
|
302 |
"\n",
|
|
|
|
|
303 |
"\n",
|
304 |
"block.launch(share=True)"
|
305 |
]
|
|
|
45 |
},
|
46 |
"outputs": [],
|
47 |
"source": [
|
48 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
|
49 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
50 |
+
"dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16"
|
51 |
]
|
52 |
},
|
53 |
{
|
|
|
72 |
{
|
73 |
"data": {
|
74 |
"application/vnd.jupyter.widget-view+json": {
|
75 |
+
"model_id": "acc42f294243439798e4d77d1a59296d",
|
76 |
"version_major": 2,
|
77 |
"version_minor": 0
|
78 |
},
|
|
|
85 |
}
|
86 |
],
|
87 |
"source": [
|
88 |
+
"pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",\n",
|
89 |
+
" torch_dtype=dtype).to(device)"
|
90 |
]
|
91 |
},
|
92 |
{
|
|
|
104 |
"from inference import get_lora_network, inference, get_validation_dataloader\n",
|
105 |
"lora_map = {\n",
|
106 |
" \"None\": \"None\",\n",
|
107 |
+
" \"Andre Derain (fauvism)\": \"andre-derain_subset1\",\n",
|
108 |
+
" \"Vincent van Gogh (post impressionism)\": \"van_gogh_subset1\",\n",
|
109 |
+
" \"Andy Warhol (pop art)\": \"andy_subset1\",\n",
|
110 |
" \"Walter Battiss\": \"walter-battiss_subset2\",\n",
|
111 |
+
" \"Camille Corot (realism)\": \"camille-corot_subset1\",\n",
|
112 |
+
" \"Claude Monet (impressionism)\": \"monet_subset2\",\n",
|
113 |
+
" \"Pablo Picasso (cubism)\": \"picasso_subset1\",\n",
|
114 |
" \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
|
115 |
+
" \"Gerhard Richter (abstract expressionism)\": \"gerhard-richter_subset1\",\n",
|
116 |
" \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
|
117 |
" \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
|
118 |
+
" \"Hokusai (ukiyo-e)\": \"katsushika-hokusai_subset1\",\n",
|
119 |
" \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
|
120 |
+
" \"Gustav Klimt (art nouveau)\": \"klimt_subset3\",\n",
|
121 |
" \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
|
122 |
+
" \"Henri Matisse (abstract expressionism)\": \"henri-matisse_subset1\",\n",
|
123 |
" \"Joan Miro\": \"joan-miro_subset2\",\n",
|
124 |
"}\n",
|
125 |
"\n",
|
126 |
+
"\n",
|
127 |
+
"\n",
|
128 |
+
"def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):\n",
|
129 |
" adapter_path = lora_map[adapter_choice]\n",
|
130 |
" if adapter_path not in [None, \"None\"]:\n",
|
131 |
" adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
132 |
+
" style_prompt=\"sks art\"\n",
|
133 |
+
" else:\n",
|
134 |
+
" style_prompt=None\n",
|
135 |
+
" prompts = [prompt]\n",
|
136 |
+
" infer_loader = get_validation_dataloader(prompts,num_workers=0)\n",
|
137 |
+
" network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)[\"network\"]\n",
|
138 |
"\n",
|
|
|
|
|
|
|
139 |
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
140 |
+
" height=512, width=512, scales=[adapter_scale],\n",
|
141 |
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
142 |
+
" start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,\n",
|
143 |
+
" from_scratch=True, device=device, weight_dtype=dtype)[0][1.0][0]\n",
|
144 |
" return pred_images\n",
|
145 |
"\n",
|
146 |
+
"\n",
|
147 |
+
"def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):\n",
|
148 |
+
" style_prompt=None\n",
|
149 |
+
" prompts = [prompt]\n",
|
150 |
+
" infer_loader = get_validation_dataloader(prompts,num_workers=0)\n",
|
151 |
+
" network = get_lora_network(pipe.unet, \"None\", weight_dtype=dtype)[\"network\"]\n",
|
152 |
+
"\n",
|
153 |
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
154 |
+
" height=512, width=512, scales=[0.0],\n",
|
155 |
+
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
156 |
+
" start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,\n",
|
157 |
+
" from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]\n",
|
158 |
" return pred_images\n",
|
159 |
"\n",
|
160 |
+
"\n",
|
161 |
+
"\n",
|
162 |
+
"def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):\n",
|
163 |
+
" style_prompt=None\n",
|
164 |
+
" prompts = [prompt]\n",
|
165 |
+
" # convert np to pil\n",
|
166 |
+
" ref_image = [Image.fromarray(ref_image)]\n",
|
167 |
+
" network = get_lora_network(pipe.unet, \"None\", weight_dtype=dtype)[\"network\"]\n",
|
168 |
+
" infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)\n",
|
169 |
+
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
170 |
+
" height=512, width=512, scales=[0.0],\n",
|
171 |
+
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
172 |
+
" start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,\n",
|
173 |
+
" from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]\n",
|
174 |
+
" return pred_images\n",
|
175 |
+
"\n",
|
176 |
+
"\n",
|
177 |
+
"def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):\n",
|
178 |
+
" adapter_path = lora_map[adapter_choice]\n",
|
179 |
+
" if adapter_path not in [None, \"None\"]:\n",
|
180 |
+
" adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
181 |
+
" style_prompt=\"sks art\"\n",
|
182 |
+
" else:\n",
|
183 |
+
" style_prompt=None\n",
|
184 |
+
" prompts = [prompt]\n",
|
185 |
+
" # convert np to pil\n",
|
186 |
+
" ref_image = [Image.fromarray(ref_image)]\n",
|
187 |
+
" network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)[\"network\"]\n",
|
188 |
+
" infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)\n",
|
189 |
+
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
190 |
+
" height=512, width=512, scales=[adapter_scale],\n",
|
191 |
+
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
192 |
+
" start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,\n",
|
193 |
+
" from_scratch=False, device=device, weight_dtype=dtype)[0][1.0][0]\n",
|
194 |
+
" return pred_images\n",
|
195 |
+
"\n"
|
196 |
]
|
197 |
},
|
198 |
{
|
199 |
"cell_type": "code",
|
200 |
+
"execution_count": 15,
|
201 |
"id": "aa33e9d104023847",
|
202 |
"metadata": {
|
203 |
"ExecuteTime": {
|
204 |
+
"end_time": "2024-12-10T02:56:13.419303Z",
|
205 |
+
"start_time": "2024-12-10T02:56:13.002796Z"
|
206 |
}
|
207 |
},
|
208 |
"outputs": [
|
|
|
210 |
"name": "stdout",
|
211 |
"output_type": "stream",
|
212 |
"text": [
|
213 |
+
"Running on local URL: http://127.0.0.1:7869\n",
|
214 |
+
"\n",
|
215 |
+
"Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB\n",
|
216 |
+
"Running on public URL: https://0fd0c028b349b76a72.gradio.live\n",
|
217 |
"\n",
|
218 |
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
|
219 |
]
|
|
|
221 |
{
|
222 |
"data": {
|
223 |
"text/html": [
|
224 |
+
"<div><iframe src=\"https://0fd0c028b349b76a72.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
225 |
],
|
226 |
"text/plain": [
|
227 |
"<IPython.core.display.HTML object>"
|
|
|
234 |
"data": {
|
235 |
"text/plain": []
|
236 |
},
|
237 |
+
"execution_count": 15,
|
238 |
"metadata": {},
|
239 |
"output_type": "execute_result"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
}
|
241 |
],
|
242 |
"source": [
|
243 |
"block = gr.Blocks()\n",
|
244 |
"# Direct infer\n",
|
245 |
+
"# Direct infer\n",
|
246 |
"with block:\n",
|
247 |
" with gr.Group():\n",
|
248 |
" gr.Markdown(\" # Art-Free Diffusion Demo\")\n",
|
249 |
+
" gr.Markdown(\"(More features in development...)\")\n",
|
250 |
" with gr.Row():\n",
|
251 |
" text = gr.Textbox(\n",
|
252 |
+
" label=\"Enter your prompt(long and detailed would be better):\",\n",
|
253 |
" max_lines=2,\n",
|
254 |
+
" placeholder=\"Enter your prompt(long and detailed would be better)\",\n",
|
255 |
+
" container=True,\n",
|
256 |
" value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
|
257 |
" )\n",
|
|
|
258 |
"\n",
|
259 |
+
" with gr.Tab('Generation'):\n",
|
260 |
+
" with gr.Row():\n",
|
261 |
+
" with gr.Column():\n",
|
262 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
263 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
|
|
|
|
|
|
264 |
"\n",
|
265 |
+
" gallery_gen_ori = gr.Image(\n",
|
266 |
+
" label=\"W/O Adapter\",\n",
|
267 |
+
" show_label=True,\n",
|
268 |
+
" elem_id=\"gallery\",\n",
|
269 |
+
" height=\"auto\"\n",
|
270 |
+
" )\n",
|
271 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
"\n",
|
273 |
+
" with gr.Column():\n",
|
274 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
275 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
276 |
+
" gallery_gen_art = gr.Image(\n",
|
277 |
+
" label=\"W/ Adapter\",\n",
|
278 |
+
" show_label=True,\n",
|
279 |
+
" elem_id=\"gallery\",\n",
|
280 |
+
" height=\"auto\"\n",
|
281 |
+
" )\n",
|
282 |
+
"\n",
|
283 |
+
"\n",
|
284 |
+
" with gr.Row():\n",
|
285 |
+
" btn_gen_ori = gr.Button(\"Art-Free Generate\", scale=1)\n",
|
286 |
+
" btn_gen_art = gr.Button(\"Artistic Generate\", scale=1)\n",
|
287 |
+
"\n",
|
288 |
+
"\n",
|
289 |
+
" with gr.Tab('Stylization'):\n",
|
290 |
+
" with gr.Row():\n",
|
291 |
+
"\n",
|
292 |
+
" with gr.Column():\n",
|
293 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
294 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
295 |
+
"\n",
|
296 |
+
" gallery_stylization_ref = gr.Image(\n",
|
297 |
+
" label=\"Ref Image\",\n",
|
298 |
+
" show_label=True,\n",
|
299 |
+
" elem_id=\"gallery\",\n",
|
300 |
+
" height=\"auto\",\n",
|
301 |
+
" scale=1,\n",
|
302 |
+
" )\n",
|
303 |
+
" with gr.Column(scale=2):\n",
|
304 |
+
" with gr.Row():\n",
|
305 |
+
" with gr.Column():\n",
|
306 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
307 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
308 |
+
" \n",
|
309 |
+
" gallery_stylization_ori = gr.Image(\n",
|
310 |
+
" label=\"W/O Adapter\",\n",
|
311 |
+
" show_label=True,\n",
|
312 |
+
" elem_id=\"gallery\",\n",
|
313 |
+
" height=\"auto\",\n",
|
314 |
+
" scale=1,\n",
|
315 |
+
" )\n",
|
316 |
+
" \n",
|
317 |
+
" \n",
|
318 |
+
" with gr.Column():\n",
|
319 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
320 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
321 |
+
" gallery_stylization_art = gr.Image(\n",
|
322 |
+
" label=\"W/ Adapter\",\n",
|
323 |
+
" show_label=True,\n",
|
324 |
+
" elem_id=\"gallery\",\n",
|
325 |
+
" height=\"auto\",\n",
|
326 |
+
" scale=1,\n",
|
327 |
+
" )\n",
|
328 |
+
" start_timestep = gr.Slider(label=\"Adapter Timestep\", minimum=0, maximum=1000, value=800, step=1)\n",
|
329 |
+
" with gr.Row():\n",
|
330 |
+
" btn_style_ori = gr.Button(\"Art-Free Stylization\", scale=1)\n",
|
331 |
+
" btn_style_art = gr.Button(\"Artistic Stylization\", scale=1)\n",
|
332 |
+
"\n",
|
333 |
+
"\n",
|
334 |
+
" with gr.Row():\n",
|
335 |
+
" # with gr.Column():\n",
|
336 |
+
" # samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1, scale=1)\n",
|
337 |
" scale = gr.Slider(\n",
|
338 |
+
" label=\"Guidance Scale\", minimum=0, maximum=20, value=7.5, step=0.1\n",
|
339 |
" )\n",
|
340 |
+
" # with gr.Column():\n",
|
341 |
+
" adapter_choice = gr.Dropdown(\n",
|
342 |
+
" label=\"Select Art Adapter\",\n",
|
343 |
+
" choices=[ \"Andre Derain (fauvism)\",\"Vincent van Gogh (post impressionism)\",\"Andy Warhol (pop art)\",\n",
|
344 |
+
" \"Camille Corot (realism)\", \"Claude Monet (impressionism)\", \"Pablo Picasso (cubism)\", \"Gerhard Richter (abstract expressionism)\",\n",
|
345 |
+
" \"Hokusai (ukiyo-e)\", \"Gustav Klimt (art nouveau)\", \"Henri Matisse (abstract expressionism)\",\n",
|
346 |
+
" \"Walter Battiss\", \"Jackson Pollock\", \"M.C. Escher\", \"Albert Gleizes\", \"Wassily Kandinsky\",\n",
|
347 |
+
" \"Roy Lichtenstein\", \"Joan Miro\"\n",
|
348 |
+
" ],\n",
|
349 |
+
" value=\"Andre Derain (fauvism)\",\n",
|
350 |
+
" scale=1\n",
|
351 |
" )\n",
|
352 |
"\n",
|
353 |
+
" with gr.Row():\n",
|
354 |
+
" steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
|
355 |
+
" adapter_scale = gr.Slider(label=\"Stylization Scale\", minimum=0, maximum=1.5, value=1., step=0.1, scale=1)\n",
|
356 |
+
"\n",
|
357 |
+
" with gr.Row():\n",
|
358 |
+
" seed = gr.Slider(label=\"Seed\",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1)\n",
|
359 |
+
"\n",
|
360 |
+
"\n",
|
361 |
+
" gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori)\n",
|
362 |
+
" gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art)\n",
|
363 |
"\n",
|
364 |
+
" gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori)\n",
|
365 |
+
" gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art)\n",
|
366 |
"\n",
|
367 |
"block.launch(share=True)"
|
368 |
]
|
utils/train_util.py
CHANGED
@@ -249,7 +249,8 @@ def get_noisy_image(
|
|
249 |
image = img
|
250 |
# im_orig = image
|
251 |
device = vae.device
|
252 |
-
|
|
|
253 |
|
254 |
init_latents = vae.encode(image).latent_dist.sample(None)
|
255 |
init_latents = vae.config.scaling_factor * init_latents
|
|
|
249 |
image = img
|
250 |
# im_orig = image
|
251 |
device = vae.device
|
252 |
+
weight_dtype = vae.dtype
|
253 |
+
image = image_processor.preprocess(image).to(device).to(weight_dtype).to(weight_dtype)
|
254 |
|
255 |
init_latents = vae.encode(image).latent_dist.sample(None)
|
256 |
init_latents = vae.config.scaling_factor * init_latents
|