rhfeiyang commited on
Commit
fb8d464
1 Parent(s): d7b9e64
Files changed (3) hide show
  1. hf_demo.py +162 -82
  2. hf_demo_test.ipynb +188 -125
  3. utils/train_util.py +2 -1
hf_demo.py CHANGED
@@ -16,70 +16,93 @@ pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
16
  from inference import get_lora_network, inference, get_validation_dataloader
17
  lora_map = {
18
  "None": "None",
19
- "Andre Derain": "andre-derain_subset1",
20
- "Vincent van Gogh": "van_gogh_subset1",
21
- "Andy Warhol": "andy_subset1",
22
  "Walter Battiss": "walter-battiss_subset2",
23
- "Camille Corot": "camille-corot_subset1",
24
- "Claude Monet": "monet_subset2",
25
- "Pablo Picasso": "picasso_subset1",
26
  "Jackson Pollock": "jackson-pollock_subset1",
27
- "Gerhard Richter": "gerhard-richter_subset1",
28
  "M.C. Escher": "m.c.-escher_subset1",
29
  "Albert Gleizes": "albert-gleizes_subset1",
30
- "Hokusai": "katsushika-hokusai_subset1",
31
  "Wassily Kandinsky": "kandinsky_subset1",
32
- "Gustav Klimt": "klimt_subset3",
33
  "Roy Lichtenstein": "roy-lichtenstein_subset1",
34
- "Henri Matisse": "henri-matisse_subset1",
35
  "Joan Miro": "joan-miro_subset2",
36
  }
 
37
  @spaces.GPU
38
- def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
39
  adapter_path = lora_map[adapter_choice]
40
  if adapter_path not in [None, "None"]:
41
  adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
42
  style_prompt="sks art"
43
  else:
44
  style_prompt=None
45
- prompts = [prompt]*samples
46
  infer_loader = get_validation_dataloader(prompts,num_workers=0)
47
  network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
48
 
49
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
50
- height=512, width=512, scales=[1.0],
51
  save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
52
  start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
53
- from_scratch=True, device=device, weight_dtype=dtype)[0][1.0]
54
  return pred_images
 
55
  @spaces.GPU
56
- def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
57
- infer_loader = get_validation_dataloader(prompts, image,num_workers=0)
58
- network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
 
 
 
59
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
60
- height=512, width=512, scales=[0.,1.],
61
- save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
62
- start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
63
- from_scratch=False, device=device, weight_dtype=dtype)[0][1.0]
64
  return pred_images
65
 
66
- # def infer(prompt, samples, steps, scale, seed):
67
- # generator = torch.Generator(device=device).manual_seed(seed)
68
- # images_list = pipe( # type: ignore
69
- # [prompt] * samples,
70
- # num_inference_steps=steps,
71
- # guidance_scale=scale,
72
- # generator=generator,
73
- # )
74
- # images = []
75
- # safe_image = Image.open(r"data/unsafe.png")
76
- # print(images_list)
77
- # for i, image in enumerate(images_list["images"]): # type: ignore
78
- # if images_list["nsfw_content_detected"][i]: # type: ignore
79
- # images.append(safe_image)
80
- # else:
81
- # images.append(image)
82
- # return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
 
@@ -92,62 +115,119 @@ with block:
92
  gr.Markdown("(More features in development...)")
93
  with gr.Row():
94
  text = gr.Textbox(
95
- label="Enter your prompt",
96
  max_lines=2,
97
- placeholder="Enter your prompt",
98
- container=False,
99
  value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
- btn = gr.Button("Run", scale=0)
105
- gallery = gr.Gallery(
106
- label="Generated images",
107
- show_label=False,
108
- elem_id="gallery",
109
- columns=[1],
110
- )
111
-
112
- advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
113
-
114
- with gr.Row(elem_id="advanced-options"):
115
  adapter_choice = gr.Dropdown(
116
  label="Select Art Adapter",
117
- choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss",
118
- "Camille Corot", "Claude Monet", "Pablo Picasso",
119
- "Jackson Pollock", "Gerhard Richter", "M.C. Escher",
120
- "Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein",
121
- "Henri Matisse", "Joan Miro"
122
- ],
123
- value="None"
 
124
  )
125
- # print(adapter_choice[0])
126
- # lora_path = lora_map[adapter_choice.value]
127
- # if lora_path is not None:
128
- # lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
129
 
130
- samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
131
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
132
- scale = gr.Slider(
133
- label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
134
- )
135
- print(scale)
136
- seed = gr.Slider(
137
- label="Seed",
138
- minimum=0,
139
- maximum=2147483647,
140
- step=1,
141
- randomize=True,
142
- )
143
 
144
- gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
145
- advanced_button.click(
146
- None,
147
- [],
148
- text,
149
- )
150
 
 
 
151
 
 
 
152
 
153
- block.launch()
 
16
  from inference import get_lora_network, inference, get_validation_dataloader
17
  lora_map = {
18
  "None": "None",
19
+ "Andre Derain (fauvism)": "andre-derain_subset1",
20
+ "Vincent van Gogh (post impressionism)": "van_gogh_subset1",
21
+ "Andy Warhol (pop art)": "andy_subset1",
22
  "Walter Battiss": "walter-battiss_subset2",
23
+ "Camille Corot (realism)": "camille-corot_subset1",
24
+ "Claude Monet (impressionism)": "monet_subset2",
25
+ "Pablo Picasso (cubism)": "picasso_subset1",
26
  "Jackson Pollock": "jackson-pollock_subset1",
27
+ "Gerhard Richter (abstract expressionism)": "gerhard-richter_subset1",
28
  "M.C. Escher": "m.c.-escher_subset1",
29
  "Albert Gleizes": "albert-gleizes_subset1",
30
+ "Hokusai (ukiyo-e)": "katsushika-hokusai_subset1",
31
  "Wassily Kandinsky": "kandinsky_subset1",
32
+ "Gustav Klimt (art nouveau)": "klimt_subset3",
33
  "Roy Lichtenstein": "roy-lichtenstein_subset1",
34
+ "Henri Matisse (abstract expressionism)": "henri-matisse_subset1",
35
  "Joan Miro": "joan-miro_subset2",
36
  }
37
+
38
  @spaces.GPU
39
+ def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):
40
  adapter_path = lora_map[adapter_choice]
41
  if adapter_path not in [None, "None"]:
42
  adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
43
  style_prompt="sks art"
44
  else:
45
  style_prompt=None
46
+ prompts = [prompt]
47
  infer_loader = get_validation_dataloader(prompts,num_workers=0)
48
  network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
49
 
50
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
51
+ height=512, width=512, scales=[adapter_scale],
52
  save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
53
  start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
54
+ from_scratch=True, device=device, weight_dtype=dtype)[0][1.0][0]
55
  return pred_images
56
+
57
  @spaces.GPU
58
+ def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):
59
+ style_prompt=None
60
+ prompts = [prompt]
61
+ infer_loader = get_validation_dataloader(prompts,num_workers=0)
62
+ network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"]
63
+
64
  pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
65
+ height=512, width=512, scales=[0.0],
66
+ save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
67
+ start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
68
+ from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]
69
  return pred_images
70
 
71
+
72
+ @spaces.GPU
73
+ def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):
74
+ style_prompt=None
75
+ prompts = [prompt]
76
+ # convert np to pil
77
+ ref_image = [Image.fromarray(ref_image)]
78
+ network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"]
79
+ infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
80
+ pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
81
+ height=512, width=512, scales=[0.0],
82
+ save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
83
+ start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
84
+ from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]
85
+ return pred_images
86
+
87
+ @spaces.GPU
88
+ def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):
89
+ adapter_path = lora_map[adapter_choice]
90
+ if adapter_path not in [None, "None"]:
91
+ adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
92
+ style_prompt="sks art"
93
+ else:
94
+ style_prompt=None
95
+ prompts = [prompt]
96
+ # convert np to pil
97
+ ref_image = [Image.fromarray(ref_image)]
98
+ network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
99
+ infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
100
+ pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
101
+ height=512, width=512, scales=[adapter_scale],
102
+ save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
103
+ start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
104
+ from_scratch=False, device=device, weight_dtype=dtype)[0][1.0][0]
105
+ return pred_images
106
 
107
 
108
 
 
115
  gr.Markdown("(More features in development...)")
116
  with gr.Row():
117
  text = gr.Textbox(
118
+ label="Enter your prompt(long and detailed would be better):",
119
  max_lines=2,
120
+ placeholder="Enter your prompt(long and detailed would be better)",
121
+ container=True,
122
  value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
123
  )
124
 
125
+ with gr.Tab('Generation'):
126
+ with gr.Row():
127
+ with gr.Column():
128
+ # gr.Markdown("## Art-Free Generation")
129
+ # gr.Markdown("Generate images from text prompts.")
130
+
131
+ gallery_gen_ori = gr.Image(
132
+ label="W/O Adapter",
133
+ show_label=True,
134
+ elem_id="gallery",
135
+ height="auto"
136
+ )
137
+
138
+
139
+ with gr.Column():
140
+ # gr.Markdown("## Art-Free Generation")
141
+ # gr.Markdown("Generate images from text prompts.")
142
+ gallery_gen_art = gr.Image(
143
+ label="W/ Adapter",
144
+ show_label=True,
145
+ elem_id="gallery",
146
+ height="auto"
147
+ )
148
+
149
+
150
+ with gr.Row():
151
+ btn_gen_ori = gr.Button("Art-Free Generate", scale=1)
152
+ btn_gen_art = gr.Button("Artistic Generate", scale=1)
153
+
154
+
155
+ with gr.Tab('Stylization'):
156
+ with gr.Row():
157
+
158
+ with gr.Column():
159
+ # gr.Markdown("## Art-Free Generation")
160
+ # gr.Markdown("Generate images from text prompts.")
161
+
162
+ gallery_stylization_ref = gr.Image(
163
+ label="Ref Image",
164
+ show_label=True,
165
+ elem_id="gallery",
166
+ height="auto",
167
+ scale=1,
168
+ )
169
+ with gr.Column(scale=2):
170
+ with gr.Row():
171
+ with gr.Column():
172
+ # gr.Markdown("## Art-Free Generation")
173
+ # gr.Markdown("Generate images from text prompts.")
174
+
175
+ gallery_stylization_ori = gr.Image(
176
+ label="W/O Adapter",
177
+ show_label=True,
178
+ elem_id="gallery",
179
+ height="auto",
180
+ scale=1,
181
+ )
182
+
183
+
184
+ with gr.Column():
185
+ # gr.Markdown("## Art-Free Generation")
186
+ # gr.Markdown("Generate images from text prompts.")
187
+ gallery_stylization_art = gr.Image(
188
+ label="W/ Adapter",
189
+ show_label=True,
190
+ elem_id="gallery",
191
+ height="auto",
192
+ scale=1,
193
+ )
194
+ start_timestep = gr.Slider(label="Adapter Timestep", minimum=0, maximum=1000, value=800, step=1)
195
+ with gr.Row():
196
+ btn_style_ori = gr.Button("Art-Free Stylization", scale=1)
197
+ btn_style_art = gr.Button("Artistic Stylization", scale=1)
198
 
199
 
200
+ with gr.Row():
201
+ # with gr.Column():
202
+ # samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1, scale=1)
203
+ scale = gr.Slider(
204
+ label="Guidance Scale", minimum=0, maximum=20, value=7.5, step=0.1
205
+ )
206
+ # with gr.Column():
 
 
 
 
207
  adapter_choice = gr.Dropdown(
208
  label="Select Art Adapter",
209
+ choices=[ "Andre Derain (fauvism)","Vincent van Gogh (post impressionism)","Andy Warhol (pop art)",
210
+ "Camille Corot (realism)", "Claude Monet (impressionism)", "Pablo Picasso (cubism)", "Gerhard Richter (abstract expressionism)",
211
+ "Hokusai (ukiyo-e)", "Gustav Klimt (art nouveau)", "Henri Matisse (abstract expressionism)",
212
+ "Walter Battiss", "Jackson Pollock", "M.C. Escher", "Albert Gleizes", "Wassily Kandinsky",
213
+ "Roy Lichtenstein", "Joan Miro"
214
+ ],
215
+ value="Andre Derain (fauvism)",
216
+ scale=1
217
  )
 
 
 
 
218
 
219
+ with gr.Row():
220
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
221
+ adapter_scale = gr.Slider(label="Stylization Scale", minimum=0, maximum=1.5, value=1., step=0.1, scale=1)
222
+
223
+ with gr.Row():
224
+ seed = gr.Slider(label="Seed",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1)
 
 
 
 
 
 
 
225
 
 
 
 
 
 
 
226
 
227
+ gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori)
228
+ gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art)
229
 
230
+ gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori)
231
+ gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art)
232
 
233
+ block.launch(sharing=True)
hf_demo_test.ipynb CHANGED
@@ -45,7 +45,9 @@
45
  },
46
  "outputs": [],
47
  "source": [
48
- "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
 
 
49
  ]
50
  },
51
  {
@@ -70,7 +72,7 @@
70
  {
71
  "data": {
72
  "application/vnd.jupyter.widget-view+json": {
73
- "model_id": "9df8347307674ba8afb0250e23109aa1",
74
  "version_major": 2,
75
  "version_minor": 0
76
  },
@@ -83,8 +85,8 @@
83
  }
84
  ],
85
  "source": [
86
- "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n",
87
- "device = \"cuda\""
88
  ]
89
  },
90
  {
@@ -102,77 +104,105 @@
102
  "from inference import get_lora_network, inference, get_validation_dataloader\n",
103
  "lora_map = {\n",
104
  " \"None\": \"None\",\n",
105
- " \"Andre Derain\": \"andre-derain_subset1\",\n",
106
- " \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
107
- " \"Andy Warhol\": \"andy_subset1\",\n",
108
  " \"Walter Battiss\": \"walter-battiss_subset2\",\n",
109
- " \"Camille Corot\": \"camille-corot_subset1\",\n",
110
- " \"Claude Monet\": \"monet_subset2\",\n",
111
- " \"Pablo Picasso\": \"picasso_subset1\",\n",
112
  " \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
113
- " \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
114
  " \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
115
  " \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
116
- " \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
117
  " \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
118
- " \"Gustav Klimt\": \"klimt_subset3\",\n",
119
  " \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
120
- " \"Henri Matisse\": \"henri-matisse_subset1\",\n",
121
  " \"Joan Miro\": \"joan-miro_subset2\",\n",
122
  "}\n",
123
  "\n",
124
- "def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n",
 
 
125
  " adapter_path = lora_map[adapter_choice]\n",
126
  " if adapter_path not in [None, \"None\"]:\n",
127
  " adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
 
 
 
 
 
 
128
  "\n",
129
- " prompts = [prompt]*samples\n",
130
- " infer_loader = get_validation_dataloader(prompts)\n",
131
- " network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
132
  " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
133
- " height=512, width=512, scales=[1.0],\n",
134
  " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
135
- " start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n",
136
- " from_scratch=True)[0][1.0]\n",
137
  " return pred_images\n",
138
  "\n",
139
- "def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n",
140
- " infer_loader = get_validation_dataloader(prompts, image)\n",
141
- " network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n",
 
 
 
 
142
  " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
143
- " height=512, width=512, scales=[0.,1.],\n",
144
- " save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n",
145
- " start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n",
146
- " from_scratch=False)\n",
147
  " return pred_images\n",
148
  "\n",
149
- "# def infer(prompt, samples, steps, scale, seed):\n",
150
- "# generator = torch.Generator(device=device).manual_seed(seed)\n",
151
- "# images_list = pipe( # type: ignore\n",
152
- "# [prompt] * samples,\n",
153
- "# num_inference_steps=steps,\n",
154
- "# guidance_scale=scale,\n",
155
- "# generator=generator,\n",
156
- "# )\n",
157
- "# images = []\n",
158
- "# safe_image = Image.open(r\"data/unsafe.png\")\n",
159
- "# print(images_list)\n",
160
- "# for i, image in enumerate(images_list[\"images\"]): # type: ignore\n",
161
- "# if images_list[\"nsfw_content_detected\"][i]: # type: ignore\n",
162
- "# images.append(safe_image)\n",
163
- "# else:\n",
164
- "# images.append(image)\n",
165
- "# return images\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  ]
167
  },
168
  {
169
  "cell_type": "code",
170
- "execution_count": 6,
171
  "id": "aa33e9d104023847",
172
  "metadata": {
173
  "ExecuteTime": {
174
- "end_time": "2024-12-09T12:09:39.339583Z",
175
- "start_time": "2024-12-09T12:09:38.953936Z"
176
  }
177
  },
178
  "outputs": [
@@ -180,9 +210,10 @@
180
  "name": "stdout",
181
  "output_type": "stream",
182
  "text": [
183
- "<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n",
184
- "Running on local URL: http://127.0.0.1:7876\n",
185
- "Running on public URL: https://be7cce8fec75395c82.gradio.live\n",
 
186
  "\n",
187
  "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
188
  ]
@@ -190,7 +221,7 @@
190
  {
191
  "data": {
192
  "text/html": [
193
- "<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
194
  ],
195
  "text/plain": [
196
  "<IPython.core.display.HTML object>"
@@ -203,103 +234,135 @@
203
  "data": {
204
  "text/plain": []
205
  },
206
- "execution_count": 6,
207
  "metadata": {},
208
  "output_type": "execute_result"
209
- },
210
- {
211
- "name": "stdout",
212
- "output_type": "stream",
213
- "text": [
214
- "Train method: None\n",
215
- "Rank: 1, Alpha: 1\n",
216
- "create LoRA for U-Net: 0 modules.\n",
217
- "save dir: None\n",
218
- "['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
219
- ]
220
- },
221
- {
222
- "name": "stderr",
223
- "output_type": "stream",
224
- "text": [
225
- "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
226
- " return F.conv2d(input, weight, bias, self.stride,\n",
227
- "\n",
228
- "00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00, 6.90it/s]"
229
- ]
230
- },
231
- {
232
- "name": "stdout",
233
- "output_type": "stream",
234
- "text": [
235
- "Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
236
- ]
237
  }
238
  ],
239
  "source": [
240
  "block = gr.Blocks()\n",
241
  "# Direct infer\n",
 
242
  "with block:\n",
243
  " with gr.Group():\n",
244
  " gr.Markdown(\" # Art-Free Diffusion Demo\")\n",
 
245
  " with gr.Row():\n",
246
  " text = gr.Textbox(\n",
247
- " label=\"Enter your prompt\",\n",
248
  " max_lines=2,\n",
249
- " placeholder=\"Enter your prompt\",\n",
250
- " container=False,\n",
251
  " value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
252
  " )\n",
253
- " \n",
254
  "\n",
255
- " \n",
256
- " btn = gr.Button(\"Run\", scale=0)\n",
257
- " gallery = gr.Gallery(\n",
258
- " label=\"Generated images\",\n",
259
- " show_label=False,\n",
260
- " elem_id=\"gallery\",\n",
261
- " columns=[2],\n",
262
- " )\n",
263
  "\n",
264
- " advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n",
 
 
 
 
 
265
  "\n",
266
- " with gr.Row(elem_id=\"advanced-options\"):\n",
267
- " adapter_choice = gr.Dropdown(\n",
268
- " label=\"Choose adapter\",\n",
269
- " choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
270
- " \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
271
- " \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
272
- " \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
273
- " \"Henri Matisse\", \"Joan Miro\"\n",
274
- " ],\n",
275
- " value=\"None\"\n",
276
- " )\n",
277
- " # print(adapter_choice[0])\n",
278
- " # lora_path = lora_map[adapter_choice.value]\n",
279
- " # if lora_path is not None:\n",
280
- " # lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
281
  "\n",
282
- " samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n",
283
- " steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  " scale = gr.Slider(\n",
285
- " label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n",
286
  " )\n",
287
- " print(scale)\n",
288
- " seed = gr.Slider(\n",
289
- " label=\"Seed\",\n",
290
- " minimum=0,\n",
291
- " maximum=2147483647,\n",
292
- " step=1,\n",
293
- " randomize=True,\n",
 
 
 
 
294
  " )\n",
295
  "\n",
296
- " gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n",
297
- " advanced_button.click(\n",
298
- " None,\n",
299
- " [],\n",
300
- " text,\n",
301
- " )\n",
 
 
 
 
302
  "\n",
 
 
303
  "\n",
304
  "block.launch(share=True)"
305
  ]
 
45
  },
46
  "outputs": [],
47
  "source": [
48
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
49
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
50
+ "dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16"
51
  ]
52
  },
53
  {
 
72
  {
73
  "data": {
74
  "application/vnd.jupyter.widget-view+json": {
75
+ "model_id": "acc42f294243439798e4d77d1a59296d",
76
  "version_major": 2,
77
  "version_minor": 0
78
  },
 
85
  }
86
  ],
87
  "source": [
88
+ "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",\n",
89
+ " torch_dtype=dtype).to(device)"
90
  ]
91
  },
92
  {
 
104
  "from inference import get_lora_network, inference, get_validation_dataloader\n",
105
  "lora_map = {\n",
106
  " \"None\": \"None\",\n",
107
+ " \"Andre Derain (fauvism)\": \"andre-derain_subset1\",\n",
108
+ " \"Vincent van Gogh (post impressionism)\": \"van_gogh_subset1\",\n",
109
+ " \"Andy Warhol (pop art)\": \"andy_subset1\",\n",
110
  " \"Walter Battiss\": \"walter-battiss_subset2\",\n",
111
+ " \"Camille Corot (realism)\": \"camille-corot_subset1\",\n",
112
+ " \"Claude Monet (impressionism)\": \"monet_subset2\",\n",
113
+ " \"Pablo Picasso (cubism)\": \"picasso_subset1\",\n",
114
  " \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
115
+ " \"Gerhard Richter (abstract expressionism)\": \"gerhard-richter_subset1\",\n",
116
  " \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
117
  " \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
118
+ " \"Hokusai (ukiyo-e)\": \"katsushika-hokusai_subset1\",\n",
119
  " \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
120
+ " \"Gustav Klimt (art nouveau)\": \"klimt_subset3\",\n",
121
  " \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
122
+ " \"Henri Matisse (abstract expressionism)\": \"henri-matisse_subset1\",\n",
123
  " \"Joan Miro\": \"joan-miro_subset2\",\n",
124
  "}\n",
125
  "\n",
126
+ "\n",
127
+ "\n",
128
+ "def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):\n",
129
  " adapter_path = lora_map[adapter_choice]\n",
130
  " if adapter_path not in [None, \"None\"]:\n",
131
  " adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
132
+ " style_prompt=\"sks art\"\n",
133
+ " else:\n",
134
+ " style_prompt=None\n",
135
+ " prompts = [prompt]\n",
136
+ " infer_loader = get_validation_dataloader(prompts,num_workers=0)\n",
137
+ " network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)[\"network\"]\n",
138
  "\n",
 
 
 
139
  " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
140
+ " height=512, width=512, scales=[adapter_scale],\n",
141
  " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
142
+ " start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,\n",
143
+ " from_scratch=True, device=device, weight_dtype=dtype)[0][1.0][0]\n",
144
  " return pred_images\n",
145
  "\n",
146
+ "\n",
147
+ "def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):\n",
148
+ " style_prompt=None\n",
149
+ " prompts = [prompt]\n",
150
+ " infer_loader = get_validation_dataloader(prompts,num_workers=0)\n",
151
+ " network = get_lora_network(pipe.unet, \"None\", weight_dtype=dtype)[\"network\"]\n",
152
+ "\n",
153
  " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
154
+ " height=512, width=512, scales=[0.0],\n",
155
+ " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
156
+ " start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,\n",
157
+ " from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]\n",
158
  " return pred_images\n",
159
  "\n",
160
+ "\n",
161
+ "\n",
162
+ "def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):\n",
163
+ " style_prompt=None\n",
164
+ " prompts = [prompt]\n",
165
+ " # convert np to pil\n",
166
+ " ref_image = [Image.fromarray(ref_image)]\n",
167
+ " network = get_lora_network(pipe.unet, \"None\", weight_dtype=dtype)[\"network\"]\n",
168
+ " infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)\n",
169
+ " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
170
+ " height=512, width=512, scales=[0.0],\n",
171
+ " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
172
+ " start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,\n",
173
+ " from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]\n",
174
+ " return pred_images\n",
175
+ "\n",
176
+ "\n",
177
+ "def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):\n",
178
+ " adapter_path = lora_map[adapter_choice]\n",
179
+ " if adapter_path not in [None, \"None\"]:\n",
180
+ " adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
181
+ " style_prompt=\"sks art\"\n",
182
+ " else:\n",
183
+ " style_prompt=None\n",
184
+ " prompts = [prompt]\n",
185
+ " # convert np to pil\n",
186
+ " ref_image = [Image.fromarray(ref_image)]\n",
187
+ " network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)[\"network\"]\n",
188
+ " infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)\n",
189
+ " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
190
+ " height=512, width=512, scales=[adapter_scale],\n",
191
+ " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
192
+ " start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,\n",
193
+ " from_scratch=False, device=device, weight_dtype=dtype)[0][1.0][0]\n",
194
+ " return pred_images\n",
195
+ "\n"
196
  ]
197
  },
198
  {
199
  "cell_type": "code",
200
+ "execution_count": 15,
201
  "id": "aa33e9d104023847",
202
  "metadata": {
203
  "ExecuteTime": {
204
+ "end_time": "2024-12-10T02:56:13.419303Z",
205
+ "start_time": "2024-12-10T02:56:13.002796Z"
206
  }
207
  },
208
  "outputs": [
 
210
  "name": "stdout",
211
  "output_type": "stream",
212
  "text": [
213
+ "Running on local URL: http://127.0.0.1:7869\n",
214
+ "\n",
215
+ "Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB\n",
216
+ "Running on public URL: https://0fd0c028b349b76a72.gradio.live\n",
217
  "\n",
218
  "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
219
  ]
 
221
  {
222
  "data": {
223
  "text/html": [
224
+ "<div><iframe src=\"https://0fd0c028b349b76a72.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
225
  ],
226
  "text/plain": [
227
  "<IPython.core.display.HTML object>"
 
234
  "data": {
235
  "text/plain": []
236
  },
237
+ "execution_count": 15,
238
  "metadata": {},
239
  "output_type": "execute_result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  }
241
  ],
242
  "source": [
243
  "block = gr.Blocks()\n",
244
  "# Direct infer\n",
245
+ "# Direct infer\n",
246
  "with block:\n",
247
  " with gr.Group():\n",
248
  " gr.Markdown(\" # Art-Free Diffusion Demo\")\n",
249
+ " gr.Markdown(\"(More features in development...)\")\n",
250
  " with gr.Row():\n",
251
  " text = gr.Textbox(\n",
252
+ " label=\"Enter your prompt(long and detailed would be better):\",\n",
253
  " max_lines=2,\n",
254
+ " placeholder=\"Enter your prompt(long and detailed would be better)\",\n",
255
+ " container=True,\n",
256
  " value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
257
  " )\n",
 
258
  "\n",
259
+ " with gr.Tab('Generation'):\n",
260
+ " with gr.Row():\n",
261
+ " with gr.Column():\n",
262
+ " # gr.Markdown(\"## Art-Free Generation\")\n",
263
+ " # gr.Markdown(\"Generate images from text prompts.\")\n",
 
 
 
264
  "\n",
265
+ " gallery_gen_ori = gr.Image(\n",
266
+ " label=\"W/O Adapter\",\n",
267
+ " show_label=True,\n",
268
+ " elem_id=\"gallery\",\n",
269
+ " height=\"auto\"\n",
270
+ " )\n",
271
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  "\n",
273
+ " with gr.Column():\n",
274
+ " # gr.Markdown(\"## Art-Free Generation\")\n",
275
+ " # gr.Markdown(\"Generate images from text prompts.\")\n",
276
+ " gallery_gen_art = gr.Image(\n",
277
+ " label=\"W/ Adapter\",\n",
278
+ " show_label=True,\n",
279
+ " elem_id=\"gallery\",\n",
280
+ " height=\"auto\"\n",
281
+ " )\n",
282
+ "\n",
283
+ "\n",
284
+ " with gr.Row():\n",
285
+ " btn_gen_ori = gr.Button(\"Art-Free Generate\", scale=1)\n",
286
+ " btn_gen_art = gr.Button(\"Artistic Generate\", scale=1)\n",
287
+ "\n",
288
+ "\n",
289
+ " with gr.Tab('Stylization'):\n",
290
+ " with gr.Row():\n",
291
+ "\n",
292
+ " with gr.Column():\n",
293
+ " # gr.Markdown(\"## Art-Free Generation\")\n",
294
+ " # gr.Markdown(\"Generate images from text prompts.\")\n",
295
+ "\n",
296
+ " gallery_stylization_ref = gr.Image(\n",
297
+ " label=\"Ref Image\",\n",
298
+ " show_label=True,\n",
299
+ " elem_id=\"gallery\",\n",
300
+ " height=\"auto\",\n",
301
+ " scale=1,\n",
302
+ " )\n",
303
+ " with gr.Column(scale=2):\n",
304
+ " with gr.Row():\n",
305
+ " with gr.Column():\n",
306
+ " # gr.Markdown(\"## Art-Free Generation\")\n",
307
+ " # gr.Markdown(\"Generate images from text prompts.\")\n",
308
+ " \n",
309
+ " gallery_stylization_ori = gr.Image(\n",
310
+ " label=\"W/O Adapter\",\n",
311
+ " show_label=True,\n",
312
+ " elem_id=\"gallery\",\n",
313
+ " height=\"auto\",\n",
314
+ " scale=1,\n",
315
+ " )\n",
316
+ " \n",
317
+ " \n",
318
+ " with gr.Column():\n",
319
+ " # gr.Markdown(\"## Art-Free Generation\")\n",
320
+ " # gr.Markdown(\"Generate images from text prompts.\")\n",
321
+ " gallery_stylization_art = gr.Image(\n",
322
+ " label=\"W/ Adapter\",\n",
323
+ " show_label=True,\n",
324
+ " elem_id=\"gallery\",\n",
325
+ " height=\"auto\",\n",
326
+ " scale=1,\n",
327
+ " )\n",
328
+ " start_timestep = gr.Slider(label=\"Adapter Timestep\", minimum=0, maximum=1000, value=800, step=1)\n",
329
+ " with gr.Row():\n",
330
+ " btn_style_ori = gr.Button(\"Art-Free Stylization\", scale=1)\n",
331
+ " btn_style_art = gr.Button(\"Artistic Stylization\", scale=1)\n",
332
+ "\n",
333
+ "\n",
334
+ " with gr.Row():\n",
335
+ " # with gr.Column():\n",
336
+ " # samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1, scale=1)\n",
337
  " scale = gr.Slider(\n",
338
+ " label=\"Guidance Scale\", minimum=0, maximum=20, value=7.5, step=0.1\n",
339
  " )\n",
340
+ " # with gr.Column():\n",
341
+ " adapter_choice = gr.Dropdown(\n",
342
+ " label=\"Select Art Adapter\",\n",
343
+ " choices=[ \"Andre Derain (fauvism)\",\"Vincent van Gogh (post impressionism)\",\"Andy Warhol (pop art)\",\n",
344
+ " \"Camille Corot (realism)\", \"Claude Monet (impressionism)\", \"Pablo Picasso (cubism)\", \"Gerhard Richter (abstract expressionism)\",\n",
345
+ " \"Hokusai (ukiyo-e)\", \"Gustav Klimt (art nouveau)\", \"Henri Matisse (abstract expressionism)\",\n",
346
+ " \"Walter Battiss\", \"Jackson Pollock\", \"M.C. Escher\", \"Albert Gleizes\", \"Wassily Kandinsky\",\n",
347
+ " \"Roy Lichtenstein\", \"Joan Miro\"\n",
348
+ " ],\n",
349
+ " value=\"Andre Derain (fauvism)\",\n",
350
+ " scale=1\n",
351
  " )\n",
352
  "\n",
353
+ " with gr.Row():\n",
354
+ " steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
355
+ " adapter_scale = gr.Slider(label=\"Stylization Scale\", minimum=0, maximum=1.5, value=1., step=0.1, scale=1)\n",
356
+ "\n",
357
+ " with gr.Row():\n",
358
+ " seed = gr.Slider(label=\"Seed\",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1)\n",
359
+ "\n",
360
+ "\n",
361
+ " gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori)\n",
362
+ " gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art)\n",
363
  "\n",
364
+ " gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori)\n",
365
+ " gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art)\n",
366
  "\n",
367
  "block.launch(share=True)"
368
  ]
utils/train_util.py CHANGED
@@ -249,7 +249,8 @@ def get_noisy_image(
249
  image = img
250
  # im_orig = image
251
  device = vae.device
252
- image = image_processor.preprocess(image).to(device)
 
253
 
254
  init_latents = vae.encode(image).latent_dist.sample(None)
255
  init_latents = vae.config.scaling_factor * init_latents
 
249
  image = img
250
  # im_orig = image
251
  device = vae.device
252
+ weight_dtype = vae.dtype
253
+ image = image_processor.preprocess(image).to(device).to(weight_dtype).to(weight_dtype)
254
 
255
  init_latents = vae.encode(image).latent_dist.sample(None)
256
  init_latents = vae.config.scaling_factor * init_latents