multimodalart HF staff Linoy Tsaban commited on
Commit
345d7b4
·
1 Parent(s): 45e73ca

Update pipeline_semantic_stable_diffusion_img2img_solver.py (#9)

Browse files

- Update pipeline_semantic_stable_diffusion_img2img_solver.py (4065064f7aab311c2f9705f66c0b0aa7669cfdac)
- Update app.py (24b22ad95d5e32ad4374733cf2b0acb5c0e13f26)


Co-authored-by: Linoy Tsaban <[email protected]>

app.py CHANGED
@@ -35,9 +35,9 @@ def caption_image(input_image):
35
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
  return generated_caption, generated_caption
37
 
38
- def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
39
  latents = wts[-1].expand(1, -1, -1, -1)
40
- img, attention_store = pipe(
41
  prompt=prompt_tar,
42
  init_latents=latents,
43
  guidance_scale=cfg_scale_tar,
@@ -45,10 +45,10 @@ def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, e
45
  # num_inference_steps=steps,
46
  # use_ddpm=True,
47
  # wts=wts.value,
48
- attention_store = attention_store,
49
  zs=zs,
50
  )
51
- return img.images[0], attention_store
52
 
53
 
54
  def reconstruct(
@@ -59,6 +59,7 @@ def reconstruct(
59
  wts,
60
  zs,
61
  attention_store,
 
62
  do_reconstruction,
63
  reconstruction,
64
  reconstruct_button,
@@ -79,8 +80,8 @@ def reconstruct(
79
  ): # if image caption was not changed, run actual reconstruction
80
  tar_prompt = ""
81
  latents = wts[-1].expand(1, -1, -1, -1)
82
- reconstruction, attention_store = sample(
83
- zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
84
  )
85
  do_reconstruction = False
86
  return (
@@ -130,7 +131,7 @@ def load_and_invert(
130
  ## SEGA ##
131
 
132
  def edit(input_image,
133
- wts, zs, attention_store,
134
  tar_prompt,
135
  image_caption,
136
  steps,
@@ -197,27 +198,27 @@ def edit(input_image,
197
  )
198
 
199
  latnets = wts[-1].expand(1, -1, -1, -1)
200
- sega_out, attention_store = pipe(prompt=tar_prompt,
201
  init_latents=latnets,
202
  guidance_scale = tar_cfg_scale,
203
  # num_images_per_prompt=1,
204
  # num_inference_steps=steps,
205
  # use_ddpm=True,
206
  # wts=wts.value,
207
- zs=zs, attention_store=attention_store, **editing_args)
208
 
209
- return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
210
 
211
 
212
  else: # if sega concepts were not added, performs regular ddpm sampling
213
 
214
  if do_reconstruction: # if ddpm sampling wasn't computed
215
- pure_ddpm_img, attention_store = sample(zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
216
  reconstruction = pure_ddpm_img
217
  do_reconstruction = False
218
- return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
219
 
220
- return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
221
 
222
 
223
  def randomize_seed_fn(seed, is_random):
@@ -461,6 +462,7 @@ with gr.Blocks(css="style.css") as demo:
461
  wts = gr.State()
462
  zs = gr.State()
463
  attention_store=gr.State()
 
464
  reconstruction = gr.State()
465
  do_inversion = gr.State(value=True)
466
  do_reconstruction = gr.State(value=True)
@@ -697,6 +699,7 @@ with gr.Blocks(css="style.css") as demo:
697
  fn=edit,
698
  inputs=[input_image,
699
  wts, zs, attention_store,
 
700
  tar_prompt,
701
  image_caption,
702
  steps,
@@ -716,7 +719,7 @@ with gr.Blocks(css="style.css") as demo:
716
 
717
 
718
  ],
719
- outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, do_inversion, share_btn_container])
720
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
721
 
722
 
 
35
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
  return generated_caption, generated_caption
37
 
38
+ def sample(zs, wts, attention_store, text_cross_attention_maps, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
39
  latents = wts[-1].expand(1, -1, -1, -1)
40
+ img, attention_store, text_cross_attention_maps = pipe(
41
  prompt=prompt_tar,
42
  init_latents=latents,
43
  guidance_scale=cfg_scale_tar,
 
45
  # num_inference_steps=steps,
46
  # use_ddpm=True,
47
  # wts=wts.value,
48
+ attention_store = attention_store, text_cross_attention_maps=text_cross_attention_maps,
49
  zs=zs,
50
  )
51
+ return img.images[0], attention_store, text_cross_attention_maps
52
 
53
 
54
  def reconstruct(
 
59
  wts,
60
  zs,
61
  attention_store,
62
+ text_cross_attention_maps,
63
  do_reconstruction,
64
  reconstruction,
65
  reconstruct_button,
 
80
  ): # if image caption was not changed, run actual reconstruction
81
  tar_prompt = ""
82
  latents = wts[-1].expand(1, -1, -1, -1)
83
+ reconstruction, attention_store, text_cross_attention_maps = sample(
84
+ zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps,prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
85
  )
86
  do_reconstruction = False
87
  return (
 
131
  ## SEGA ##
132
 
133
  def edit(input_image,
134
+ wts, zs, attention_store, text_cross_attention_maps,
135
  tar_prompt,
136
  image_caption,
137
  steps,
 
198
  )
199
 
200
  latnets = wts[-1].expand(1, -1, -1, -1)
201
+ sega_out, attention_store, text_cross_attention_maps = pipe(prompt=tar_prompt,
202
  init_latents=latnets,
203
  guidance_scale = tar_cfg_scale,
204
  # num_images_per_prompt=1,
205
  # num_inference_steps=steps,
206
  # use_ddpm=True,
207
  # wts=wts.value,
208
+ zs=zs, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, **editing_args)
209
 
210
+ return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
211
 
212
 
213
  else: # if sega concepts were not added, performs regular ddpm sampling
214
 
215
  if do_reconstruction: # if ddpm sampling wasn't computed
216
+ pure_ddpm_img, attention_store, text_cross_attention_maps = sample(zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
217
  reconstruction = pure_ddpm_img
218
  do_reconstruction = False
219
+ return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
220
 
221
+ return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
222
 
223
 
224
  def randomize_seed_fn(seed, is_random):
 
462
  wts = gr.State()
463
  zs = gr.State()
464
  attention_store=gr.State()
465
+ text_cross_attention_maps = gr.State()
466
  reconstruction = gr.State()
467
  do_inversion = gr.State(value=True)
468
  do_reconstruction = gr.State(value=True)
 
699
  fn=edit,
700
  inputs=[input_image,
701
  wts, zs, attention_store,
702
+ text_cross_attention_maps,
703
  tar_prompt,
704
  image_caption,
705
  steps,
 
719
 
720
 
721
  ],
722
+ outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, text_cross_attention_maps, do_inversion, share_btn_container])
723
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
724
 
725
 
pipeline_semantic_stable_diffusion_img2img_solver.py CHANGED
@@ -500,6 +500,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
500
  use_cross_attn_mask: bool = False,
501
  # Attention store (just for visualization purposes)
502
  attention_store = None,
 
503
  attn_store_steps: Optional[List[int]] = [],
504
  store_averaged_over_steps: bool = True,
505
  use_intersect_mask: bool = False,
@@ -755,10 +756,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
755
  # For classifier free guidance, we need to do two forward passes.
756
  # Here we concatenate the unconditional and text embeddings into a single batch
757
  # to avoid doing two forward passes
758
- self.text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
759
  if enable_edit_guidance:
760
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
761
- self.text_cross_attention_maps += \
762
  ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
763
  else:
764
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
@@ -920,11 +921,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
920
  if use_cross_attn_mask:
921
  out = attention_store.aggregate_attention(
922
  attention_maps=attention_store.step_store,
923
- prompts=self.text_cross_attention_maps,
924
  res=16,
925
  from_where=["up", "down"],
926
  is_cross=True,
927
- select=self.text_cross_attention_maps.index(editing_prompt[c]),
928
  )
929
  attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
930
 
@@ -1105,7 +1106,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1105
  if not return_dict:
1106
  return (image, has_nsfw_concept), attention_store
1107
 
1108
- return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store
1109
 
1110
  def encode_text(self, prompts):
1111
  text_inputs = self.tokenizer(
 
500
  use_cross_attn_mask: bool = False,
501
  # Attention store (just for visualization purposes)
502
  attention_store = None,
503
+ text_cross_attention_maps = None,
504
  attn_store_steps: Optional[List[int]] = [],
505
  store_averaged_over_steps: bool = True,
506
  use_intersect_mask: bool = False,
 
756
  # For classifier free guidance, we need to do two forward passes.
757
  # Here we concatenate the unconditional and text embeddings into a single batch
758
  # to avoid doing two forward passes
759
+ text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
760
  if enable_edit_guidance:
761
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
762
+ text_cross_attention_maps += \
763
  ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
764
  else:
765
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
 
921
  if use_cross_attn_mask:
922
  out = attention_store.aggregate_attention(
923
  attention_maps=attention_store.step_store,
924
+ prompts=text_cross_attention_maps,
925
  res=16,
926
  from_where=["up", "down"],
927
  is_cross=True,
928
+ select=text_cross_attention_maps.index(editing_prompt[c]),
929
  )
930
  attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
931
 
 
1106
  if not return_dict:
1107
  return (image, has_nsfw_concept), attention_store
1108
 
1109
+ return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store, text_cross_attention_maps
1110
 
1111
  def encode_text(self, prompts):
1112
  text_inputs = self.tokenizer(