Tony Lian commited on
Commit
61ac46b
·
1 Parent(s): 7b0a7ad

Add batched single object generation

Browse files
Files changed (5) hide show
  1. app.py +4 -4
  2. generation.py +53 -24
  3. models/models.py +19 -17
  4. models/pipelines.py +47 -29
  5. models/sam.py +50 -29
app.py CHANGED
@@ -109,7 +109,7 @@ def get_ours_image(response, seed, num_inference_steps=20, dpm_scheduler=True, u
109
  spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
110
  fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
111
  gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
112
- so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt
113
  )
114
  images = [image_np]
115
  if show_so_imgs:
@@ -201,7 +201,7 @@ html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to
201
  <p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
202
  <p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
203
  <p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.</p>
204
- <p>4. The diffusion model only runs 20 steps by default. You can make it run 50 steps to get higher quality images (or tweak frozen steps/guidance steps for better guidance and coherence).</p>
205
  <p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (Currently we are using a T4, and you can add a A10G to make it 5x faster) {duplicate_html}</p>
206
  <br/>
207
  <p>Implementation note: In this demo, we replace the attention manipulation in our layout-guided Stable Diffusion described in our paper with GLIGEN due to much faster inference speed (<b>FlashAttention supported, no backprop needed</b> during inference). Compared to vanilla GLIGEN, we have better coherence. Other parts of text-to-image pipeline, including single object generation and SAM, remain the same. The settings and examples in the prompt are simplified in this demo.</p>
@@ -237,12 +237,12 @@ with gr.Blocks(
237
  with gr.Column(scale=1):
238
  response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
239
  overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
 
240
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
241
  with gr.Accordion("Advanced options (play around for better generation)", open=False):
242
  frozen_step_ratio = gr.Slider(0, 1, value=0.4, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
243
  gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
244
- num_inference_steps = gr.Slider(1, 50, value=20, step=1, label="Number of inference steps")
245
- dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend 50 or even more inference steps)", show_label=False, value=True)
246
  use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision", show_label=False, value=True)
247
  fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
248
  fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
 
109
  spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
110
  fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
111
  gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
112
+ so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt, so_batch_size=8
113
  )
114
  images = [image_np]
115
  if show_so_imgs:
 
201
  <p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
202
  <p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
203
  <p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.</p>
204
+ <p>4. The diffusion model only runs 50 steps by default in this demo. You can make it run more/fewer steps to get higher quality images or faster generation (or tweak frozen steps/guidance steps for better guidance and coherence).</p>
205
  <p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (Currently we are using a T4, and you can add a A10G to make it 5x faster) {duplicate_html}</p>
206
  <br/>
207
  <p>Implementation note: In this demo, we replace the attention manipulation in our layout-guided Stable Diffusion described in our paper with GLIGEN due to much faster inference speed (<b>FlashAttention supported, no backprop needed</b> during inference). Compared to vanilla GLIGEN, we have better coherence. Other parts of text-to-image pipeline, including single object generation and SAM, remain the same. The settings and examples in the prompt are simplified in this demo.</p>
 
237
  with gr.Column(scale=1):
238
  response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
239
  overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
240
+ num_inference_steps = gr.Slider(1, 250, value=50, step=1, label="Number of denoising steps (set to 20 to trade quality for faster generation)")
241
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
242
  with gr.Accordion("Advanced options (play around for better generation)", open=False):
243
  frozen_step_ratio = gr.Slider(0, 1, value=0.4, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
244
  gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
245
+ dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
 
246
  use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision", show_label=False, value=True)
247
  fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
248
  fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
generation.py CHANGED
@@ -1,6 +1,7 @@
1
  version = "v3.0"
2
 
3
  import torch
 
4
  import models
5
  import utils
6
  from models import pipelines, sam
@@ -21,7 +22,6 @@ H, W = height // 8, width // 8 # size of the latent
21
  guidance_scale = 7.5 # Scale for classifier-free guidance
22
 
23
  # batch size that is not 1 is not supported
24
- so_batch_size = 1
25
  overall_batch_size = 1
26
 
27
  # discourage masks with confidence below
@@ -33,41 +33,70 @@ discourage_mask_below_coarse_iou = 0.25
33
  run_ind = None
34
 
35
 
36
- def generate_single_object_with_box(prompt, box, phrase, word, input_latents, input_embeddings,
37
  sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
38
- verbose=False, scheduler_key=None, visualize=True):
 
39
 
40
- bboxes, phrases, words = [box], [phrase], [word]
41
 
42
- latents, single_object_images, single_object_pil_images_box_ann, latents_all = pipelines.generate_gligen(
43
- model_dict, input_latents, input_embeddings, num_inference_steps, bboxes, phrases, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
44
- guidance_scale=guidance_scale, return_saved_cross_attn=False,
45
- return_box_vis=True, save_all_latents=True, scheduler_key=scheduler_key
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- mask_selected, conf_score_selected = sam.sam_refine_box(sam_input_image=single_object_images[0], box=box, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
49
 
 
 
 
 
 
 
50
  mask_selected_tensor = torch.tensor(mask_selected)
51
 
52
- return latents_all, mask_selected_tensor, single_object_pil_images_box_ann[0]
 
 
53
 
54
  def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_list, so_input_embeddings, verbose=False, **kwargs):
55
- latents_all_list, mask_tensor_list, so_img_list = [], [], []
56
 
57
  if not so_prompt_phrase_word_box_list:
58
  return latents_all_list, mask_tensor_list
59
 
60
- so_uncond_embeddings, so_cond_embeddings = so_input_embeddings
61
 
62
- for idx, ((prompt, phrase, word, box), input_latents) in enumerate(zip(so_prompt_phrase_word_box_list, input_latents_list)):
63
- so_current_cond_embeddings = so_cond_embeddings[idx:idx+1]
64
- so_current_text_embeddings = torch.cat([so_uncond_embeddings, so_current_cond_embeddings], dim=0)
65
- so_current_input_embeddings = so_current_text_embeddings, so_uncond_embeddings, so_current_cond_embeddings
66
-
67
- latents_all, mask_tensor, so_img = generate_single_object_with_box(prompt, box, phrase, word, input_latents, input_embeddings=so_current_input_embeddings, verbose=verbose, **kwargs)
68
- latents_all_list.append(latents_all)
69
- mask_tensor_list.append(mask_tensor)
70
- so_img_list.append(so_img)
71
 
72
  return latents_all_list, mask_tensor_list, so_img_list
73
 
@@ -77,7 +106,7 @@ def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_li
77
  def run(
78
  spec, bg_seed = 1, overall_prompt_override="", fg_seed_start = 20, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta = 0.3, num_inference_steps = 20,
79
  so_center_box = False, fg_blending_ratio = 0.1, scheduler_key='dpm_scheduler', so_negative_prompt = DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt = DEFAULT_OVERALL_NEGATIVE_PROMPT, so_horizontal_center_only = True,
80
- align_with_overall_bboxes = False, horizontal_shift_only = True, use_autocast = False
81
  ):
82
  """
83
  so_center_box: using centered box in single object generation
@@ -130,7 +159,7 @@ def run(
130
  latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
131
  so_prompt_phrase_word_box_list, input_latents_list,
132
  gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
133
- sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose
134
  )
135
 
136
 
 
1
  version = "v3.0"
2
 
3
  import torch
4
+ import numpy as np
5
  import models
6
  import utils
7
  from models import pipelines, sam
 
22
  guidance_scale = 7.5 # Scale for classifier-free guidance
23
 
24
  # batch size that is not 1 is not supported
 
25
  overall_batch_size = 1
26
 
27
  # discourage masks with confidence below
 
33
  run_ind = None
34
 
35
 
36
+ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings,
37
  sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
38
+ verbose=False, scheduler_key=None, visualize=True, batch_size=None):
39
+ # batch_size=None: does not limit the batch size (pass all input together)
40
 
41
+ # prompts and words are not used since we don't have cross-attention control in this function
42
 
43
+ input_latents = torch.cat(input_latents_list, dim=0)
44
+
45
+ # We need to "unsqueeze" to tell that we have only one box and phrase in each batch item
46
+ bboxes, phrases = [[item] for item in bboxes], [[item] for item in phrases]
47
+
48
+ input_len = len(bboxes)
49
+ assert len(bboxes) == len(phrases), f"{len(bboxes)} != {len(phrases)}"
50
+
51
+ if batch_size is None:
52
+ batch_size = input_len
53
+
54
+ run_times = int(np.ceil(input_len / batch_size))
55
+ single_object_images, single_object_pil_images_box_ann, latents_all = [], [], []
56
+ for batch_idx in range(run_times):
57
+ input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
58
+ bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
59
+ input_embeddings_batch = input_embeddings[0], input_embeddings[1][batch_idx * batch_size:(batch_idx + 1) * batch_size]
60
+
61
+ _, single_object_images_batch, single_object_pil_images_box_ann_batch, latents_all_batch = pipelines.generate_gligen(
62
+ model_dict, input_latents_batch, input_embeddings_batch, num_inference_steps, bboxes_batch, phrases_batch, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
63
+ guidance_scale=guidance_scale, return_saved_cross_attn=False,
64
+ return_box_vis=True, save_all_latents=True, batched_condition=True, scheduler_key=scheduler_key
65
+ )
66
+
67
+ single_object_images.append(single_object_images_batch)
68
+ single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
69
+ latents_all.append(latents_all_batch)
70
 
71
+ single_object_images, single_object_pil_images_box_ann, latents_all = np.concatenate(single_object_images, axis=0), sum(single_object_pil_images_box_ann, []), torch.cat(latents_all, dim=1)
72
 
73
+ mask_selected, conf_score_selected = sam.sam_refine_boxes(sam_input_images=single_object_images, boxes=bboxes, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
74
+
75
+ # mask_selected: List[List[Array of shape (64, 64)]]
76
+
77
+ mask_selected = np.array(mask_selected)[:, 0]
78
+
79
  mask_selected_tensor = torch.tensor(mask_selected)
80
 
81
+ latents_all = latents_all.transpose(0,1)[:,:,None,...]
82
+
83
+ return latents_all, mask_selected_tensor, single_object_pil_images_box_ann
84
 
85
  def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_list, so_input_embeddings, verbose=False, **kwargs):
86
+ latents_all_list, mask_tensor_list = [], []
87
 
88
  if not so_prompt_phrase_word_box_list:
89
  return latents_all_list, mask_tensor_list
90
 
91
+ prompts, bboxes, phrases, words = [], [], [], []
92
 
93
+ for prompt, phrase, word, box in so_prompt_phrase_word_box_list:
94
+ prompts.append(prompt)
95
+ bboxes.append(box)
96
+ phrases.append(phrase)
97
+ words.append(word)
98
+
99
+ latents_all_list, mask_tensor_list, so_img_list = generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings=so_input_embeddings, verbose=verbose, **kwargs)
 
 
100
 
101
  return latents_all_list, mask_tensor_list, so_img_list
102
 
 
106
  def run(
107
  spec, bg_seed = 1, overall_prompt_override="", fg_seed_start = 20, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta = 0.3, num_inference_steps = 20,
108
  so_center_box = False, fg_blending_ratio = 0.1, scheduler_key='dpm_scheduler', so_negative_prompt = DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt = DEFAULT_OVERALL_NEGATIVE_PROMPT, so_horizontal_center_only = True,
109
+ align_with_overall_bboxes = False, horizontal_shift_only = True, use_autocast = False, so_batch_size = None
110
  ):
111
  """
112
  so_center_box: using centered box in single object generation
 
159
  latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
160
  so_prompt_phrase_word_box_list, input_latents_list,
161
  gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
162
+ sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose, batch_size=so_batch_size
163
  )
164
 
165
 
models/models.py CHANGED
@@ -75,20 +75,22 @@ def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_
75
  return text_embeddings
76
  return text_embeddings, uncond_embeddings, cond_embeddings
77
 
78
- def attn_list_to_tensor(cross_attention_probs):
79
- # timestep, CrossAttnBlock, Transformer2DModel, 1xBasicTransformerBlock
80
-
81
- num_cross_attn_block = len(cross_attention_probs[0])
82
- cross_attention_probs_all = []
83
-
84
- for i in range(num_cross_attn_block):
85
- # cross_attention_probs_timestep[i]: Transformer2DModel
86
- # 1xBasicTransformerBlock is skipped
87
- cross_attention_probs_current = []
88
- for cross_attention_probs_timestep in cross_attention_probs:
89
- cross_attention_probs_current.append(torch.stack([item for item in cross_attention_probs_timestep[i]], dim=0))
90
-
91
- cross_attention_probs_current = torch.stack(cross_attention_probs_current, dim=0)
92
- cross_attention_probs_all.append(cross_attention_probs_current)
93
-
94
- return cross_attention_probs_all
 
 
 
75
  return text_embeddings
76
  return text_embeddings, uncond_embeddings, cond_embeddings
77
 
78
+ def process_input_embeddings(input_embeddings):
79
+ assert isinstance(input_embeddings, (tuple, list))
80
+ if len(input_embeddings) == 3:
81
+ # input_embeddings: text_embeddings, uncond_embeddings, cond_embeddings
82
+ # Assume `uncond_embeddings` is full (has batch size the same as cond_embeddings)
83
+ _, uncond_embeddings, cond_embeddings = input_embeddings
84
+ assert uncond_embeddings.shape[0] == cond_embeddings.shape[0], f"{uncond_embeddings.shape[0]} != {cond_embeddings.shape[0]}"
85
+ return input_embeddings
86
+ elif len(input_embeddings) == 2:
87
+ # input_embeddings: uncond_embeddings, cond_embeddings
88
+ # uncond_embeddings may have only one item
89
+ uncond_embeddings, cond_embeddings = input_embeddings
90
+ if uncond_embeddings.shape[0] == 1:
91
+ uncond_embeddings = uncond_embeddings.expand(cond_embeddings.shape)
92
+ # We follow the convention: negative (unconditional) prompt comes first
93
+ text_embeddings = torch.cat((uncond_embeddings, cond_embeddings), dim=0)
94
+ return text_embeddings, uncond_embeddings, cond_embeddings
95
+ else:
96
+ raise ValueError(f"input_embeddings length: {len(input_embeddings)}")
models/pipelines.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  import gc
6
  import numpy as np
7
  from .attention import GatedSelfAttentionDense
8
- from .models import torch_device
9
 
10
  @torch.no_grad()
11
  def encode(model_dict, image, generator):
@@ -88,17 +88,56 @@ def gligen_enable_fuser(unet, enabled=True):
88
  if isinstance(module, GatedSelfAttentionDense):
89
  module.enabled = enabled
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @torch.no_grad()
92
  def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5,
93
  frozen_steps=20, frozen_mask=None,
94
  return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
95
  offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
96
- return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler'):
97
  """
98
  The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
99
  """
100
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
101
- text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
 
102
 
103
  if latents.dim() == 5:
104
  # latents_all from the input side, different from the latents_all to be saved
@@ -122,33 +161,12 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
122
  if frozen_mask is not None:
123
  frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
124
 
125
- batch_size = 1
126
-
127
  # 5.1 Prepare GLIGEN variables
128
- assert len(phrases) == len(bboxes)
129
- # assert batch_size == 1
130
- max_objs = 30
131
- _boxes = bboxes
132
 
133
- n_objs = min(len(_boxes), max_objs)
134
- boxes = torch.zeros(max_objs, 4, device=torch_device, dtype=dtype)
135
- phrase_embeddings = torch.zeros(max_objs, 768, device=torch_device, dtype=dtype)
136
- masks = torch.zeros(max_objs, device=torch_device, dtype=dtype)
137
-
138
- if n_objs > 0:
139
- boxes[:n_objs] = torch.tensor(_boxes[:n_objs])
140
- tokenizer_inputs = tokenizer(phrases, padding=True, return_tensors="pt").to(torch_device)
141
- _phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
142
- phrase_embeddings[:n_objs] = _phrase_embeddings[:n_objs]
143
- masks[:n_objs] = 1
144
-
145
- # Classifier-free guidance
146
- repeat_batch = batch_size * num_images_per_prompt * 2
147
-
148
- boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
149
- phrase_embeddings = phrase_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
150
- masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
151
- masks[:repeat_batch // 2] = 0
152
 
153
  if return_saved_cross_attn:
154
  saved_attns = []
@@ -215,7 +233,7 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
215
  if return_saved_cross_attn:
216
  ret.append(saved_attns)
217
  if return_box_vis:
218
- pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images]
219
  ret.append(pil_images)
220
  if save_all_latents:
221
  latents_all = torch.stack(latents_all, dim=0)
 
5
  import gc
6
  import numpy as np
7
  from .attention import GatedSelfAttentionDense
8
+ from .models import process_input_embeddings, torch_device
9
 
10
  @torch.no_grad()
11
  def encode(model_dict, image, generator):
 
88
  if isinstance(module, GatedSelfAttentionDense):
89
  module.enabled = enabled
90
 
91
+ def prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt):
92
+ batch_size = len(bboxes)
93
+
94
+ assert len(phrases) == len(bboxes)
95
+ max_objs = 30
96
+
97
+ n_objs = min(max([len(bboxes_item) for bboxes_item in bboxes]), max_objs)
98
+ boxes = torch.zeros((batch_size, max_objs, 4), device=torch_device, dtype=dtype)
99
+ phrase_embeddings = torch.zeros((batch_size, max_objs, 768), device=torch_device, dtype=dtype)
100
+ # masks is a 1D tensor deciding which of the enteries to be enabled
101
+ masks = torch.zeros((batch_size, max_objs), device=torch_device, dtype=dtype)
102
+
103
+ if n_objs > 0:
104
+ for idx, (bboxes_item, phrases_item) in enumerate(zip(bboxes, phrases)):
105
+ # the length of `bboxes_item` could be smaller than `n_objs` because n_objs takes the max of item length
106
+ bboxes_item = torch.tensor(bboxes_item[:n_objs])
107
+ boxes[idx, :bboxes_item.shape[0]] = bboxes_item
108
+
109
+ tokenizer_inputs = tokenizer(phrases_item[:n_objs], padding=True, return_tensors="pt").to(torch_device)
110
+ _phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
111
+ phrase_embeddings[idx, :_phrase_embeddings.shape[0]] = _phrase_embeddings
112
+ assert bboxes_item.shape[0] == _phrase_embeddings.shape[0], f"{bboxes_item.shape[0]} != {_phrase_embeddings.shape[0]}"
113
+
114
+ masks[idx, :bboxes_item.shape[0]] = 1
115
+
116
+ # Classifier-free guidance
117
+ repeat_times = num_images_per_prompt * 2
118
+ condition_len = batch_size * repeat_times
119
+
120
+ boxes = boxes.repeat(repeat_times, 1, 1)
121
+ phrase_embeddings = phrase_embeddings.repeat(repeat_times, 1, 1)
122
+ masks = masks.repeat(repeat_times, 1)
123
+ masks[:condition_len // 2] = 0
124
+
125
+ # print("shapes:", boxes.shape, phrase_embeddings.shape, masks.shape)
126
+
127
+ return boxes, phrase_embeddings, masks, condition_len
128
+
129
  @torch.no_grad()
130
  def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5,
131
  frozen_steps=20, frozen_mask=None,
132
  return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
133
  offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
134
+ return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False):
135
  """
136
  The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
137
  """
138
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
139
+
140
+ text_embeddings, _, cond_embeddings = process_input_embeddings(input_embeddings)
141
 
142
  if latents.dim() == 5:
143
  # latents_all from the input side, different from the latents_all to be saved
 
161
  if frozen_mask is not None:
162
  frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
163
 
 
 
164
  # 5.1 Prepare GLIGEN variables
165
+ if not batched_condition:
166
+ # Add batch dimension to bboxes and phrases
167
+ bboxes, phrases = [bboxes], [phrases]
 
168
 
169
+ boxes, phrase_embeddings, masks, condition_len = prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if return_saved_cross_attn:
172
  saved_attns = []
 
233
  if return_saved_cross_attn:
234
  ret.append(saved_attns)
235
  if return_box_vis:
236
+ pil_images = [utils.draw_box(Image.fromarray(image), bboxes_item, phrases_item) for image, bboxes_item, phrases_item in zip(images, bboxes, phrases)]
237
  ret.append(pil_images)
238
  if save_all_latents:
239
  latents_all = torch.stack(latents_all, dim=0)
models/sam.py CHANGED
@@ -2,6 +2,7 @@ import gc
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
  import torch
 
5
  from models import torch_device
6
  from transformers import SamModel, SamProcessor
7
  import utils
@@ -20,10 +21,18 @@ def load_sam():
20
 
21
  # Not fully backward compatible with the previous implementation
22
  # Reference: lmdv2/notebooks/gen_masked_latents_multi_object_ref_ca_loss_modular.ipynb
23
- def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_shape=None):
24
  """target_mask_shape: (h, w)"""
25
  sam_model, sam_processor = sam_model_dict['sam_model'], sam_model_dict['sam_processor']
26
 
 
 
 
 
 
 
 
 
27
  with torch.no_grad():
28
  with torch.autocast(torch_device):
29
  inputs = sam_processor(image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
@@ -31,18 +40,17 @@ def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_
31
  masks = sam_processor.image_processor.post_process_masks(
32
  outputs.pred_masks.cpu().float(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
33
  )
34
- conf_scores = outputs.iou_scores.to(device="cpu", dtype=torch.float32).numpy()[0,0]
35
  del inputs, outputs
36
-
37
- gc.collect()
38
- if torch_device == "cuda":
39
- torch.cuda.empty_cache()
40
-
41
- masks = masks[0][0].numpy()
42
 
43
- if target_mask_shape is not None:
44
- masks = np.array([cv2.resize(mask.astype(np.uint8) * 255, target_mask_shape[::-1], cv2.INTER_LINEAR).astype(bool) for mask in masks])
45
 
 
 
 
 
 
46
  return masks, conf_scores
47
 
48
  def sam_point_input(sam_model_dict, image, input_points, **kwargs):
@@ -154,26 +162,39 @@ def sam_refine_attn(sam_input_image, token_attn_np, model_dict, height, width, H
154
 
155
  return mask_selected, conf_score_selected
156
 
157
- def sam_refine_box(sam_input_image, box, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
 
 
 
 
158
  # (w, h)
159
- input_boxes = utils.scale_proportion(box, H=height, W=width)
160
- input_boxes = [input_boxes]
161
 
162
- masks, conf_scores = sam_box_input(model_dict, image=sam_input_image, input_boxes=input_boxes, target_mask_shape=(H, W))
163
 
164
- mask_binary = utils.proportion_to_mask(box, H, W, return_np=True)
165
- if verbose:
166
- # Also the box is the input for SAM
167
- plt.title("Binary mask from input box (for iou)")
168
- plt.imshow(mask_binary)
169
- plt.show()
170
 
171
- coarse_ious = get_iou_with_resize(mask_binary, masks, masks_shape=mask_binary.shape)
172
-
173
- mask_selected, conf_score_selected = select_mask(masks, conf_scores, coarse_ious=coarse_ious,
174
- rule="largest_over_conf",
175
- discourage_mask_below_confidence=discourage_mask_below_confidence,
176
- discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
177
- verbose=True)
178
-
179
- return mask_selected, conf_score_selected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
  import torch
5
+ import torch.nn.functional as F
6
  from models import torch_device
7
  from transformers import SamModel, SamProcessor
8
  import utils
 
21
 
22
  # Not fully backward compatible with the previous implementation
23
  # Reference: lmdv2/notebooks/gen_masked_latents_multi_object_ref_ca_loss_modular.ipynb
24
+ def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_shape=None, return_numpy=True):
25
  """target_mask_shape: (h, w)"""
26
  sam_model, sam_processor = sam_model_dict['sam_model'], sam_model_dict['sam_processor']
27
 
28
+ if input_boxes and isinstance(input_boxes[0], tuple):
29
+ # Convert tuple to list
30
+ input_boxes = [list(input_box) for input_box in input_boxes]
31
+
32
+ if input_boxes and input_boxes[0] and isinstance(input_boxes[0][0], tuple):
33
+ # Convert tuple to list
34
+ input_boxes = [[list(input_box) for input_box in input_boxes_item] for input_boxes_item in input_boxes]
35
+
36
  with torch.no_grad():
37
  with torch.autocast(torch_device):
38
  inputs = sam_processor(image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
 
40
  masks = sam_processor.image_processor.post_process_masks(
41
  outputs.pred_masks.cpu().float(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
42
  )
43
+ conf_scores = outputs.iou_scores.cpu().numpy()[0,0]
44
  del inputs, outputs
 
 
 
 
 
 
45
 
46
+ gc.collect()
47
+ torch.cuda.empty_cache()
48
 
49
+ if return_numpy:
50
+ masks = [F.interpolate(masks_item.type(torch.float), target_mask_shape, mode='bilinear').type(torch.bool).numpy() for masks_item in masks]
51
+ else:
52
+ masks = [F.interpolate(masks_item.type(torch.float), target_mask_shape, mode='bilinear').type(torch.bool) for masks_item in masks]
53
+
54
  return masks, conf_scores
55
 
56
  def sam_point_input(sam_model_dict, image, input_points, **kwargs):
 
162
 
163
  return mask_selected, conf_score_selected
164
 
165
+ def sam_refine_box(sam_input_image, box, *args, **kwargs):
166
+ sam_input_images, boxes = [sam_input_image], [box]
167
+ return sam_refine_boxes(sam_input_images, boxes, *args, **kwargs)
168
+
169
+ def sam_refine_boxes(sam_input_images, boxes, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
170
  # (w, h)
171
+ input_boxes = [[utils.scale_proportion(box, H=height, W=width) for box in boxes_item] for boxes_item in boxes]
 
172
 
173
+ masks, conf_scores = sam_box_input(model_dict, image=sam_input_images, input_boxes=input_boxes, target_mask_shape=(H, W))
174
 
175
+ mask_selected_batched_list, conf_score_selected_batched_list = [], []
 
 
 
 
 
176
 
177
+ for boxes_item, masks_item in zip(boxes, masks):
178
+ mask_selected_list, conf_score_selected_list = [], []
179
+ for box, three_masks in zip(boxes_item, masks_item):
180
+ mask_binary = utils.proportion_to_mask(box, H, W, return_np=True)
181
+ if verbose:
182
+ # Also the box is the input for SAM
183
+ plt.title("Binary mask from input box (for iou)")
184
+ plt.imshow(mask_binary)
185
+ plt.show()
186
+
187
+ coarse_ious = get_iou_with_resize(mask_binary, three_masks, masks_shape=mask_binary.shape)
188
+
189
+ mask_selected, conf_score_selected = select_mask(three_masks, conf_scores, coarse_ious=coarse_ious,
190
+ rule="largest_over_conf",
191
+ discourage_mask_below_confidence=discourage_mask_below_confidence,
192
+ discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
193
+ verbose=True)
194
+
195
+ mask_selected_list.append(mask_selected)
196
+ conf_score_selected_list.append(conf_score_selected)
197
+ mask_selected_batched_list.append(mask_selected_list)
198
+ conf_score_selected_batched_list.append(conf_score_selected_list)
199
+
200
+ return mask_selected_batched_list, conf_score_selected_batched_list