tolgacangoz commited on
Commit
05fa96d
·
verified ·
1 Parent(s): 5eb4145

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. unet/matryoshka.py +7 -3
unet/matryoshka.py CHANGED
@@ -3059,6 +3059,7 @@ class MatryoshkaUNet2DConditionModel(
3059
  added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
3060
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
3061
  added_cond_kwargs["from_nested"] = from_nested
 
3062
 
3063
  if not from_nested:
3064
  encoder_hidden_states = self.process_encoder_hidden_states(
@@ -3507,6 +3508,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3507
  added_cond_kwargs = added_cond_kwargs or {}
3508
  added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention
3509
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
 
3510
 
3511
  if not self.config.nesting:
3512
  encoder_hidden_states = self.inner_unet.process_encoder_hidden_states(
@@ -3529,6 +3531,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3529
  added_cond_kwargs = added_cond_kwargs or {}
3530
  added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention
3531
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
 
3532
 
3533
  encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states(
3534
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
@@ -3603,7 +3606,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3603
  encoder_hidden_states=encoder_hidden_states[:bh],
3604
  attention_mask=attention_mask,
3605
  cross_attention_kwargs=cross_attention_kwargs,
3606
- encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,
3607
  **additional_residuals,
3608
  )
3609
  else:
@@ -4025,7 +4028,7 @@ class MatryoshkaPipeline(
4025
  # Retrieve the original scale by scaling back the LoRA layers
4026
  unscale_lora_layers(self.text_encoder, lora_scale)
4027
 
4028
- return prompt_embeds, negative_prompt_embeds
4029
 
4030
  def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
4031
  dtype = next(self.image_encoder.parameters()).dtype
@@ -4458,7 +4461,7 @@ class MatryoshkaPipeline(
4458
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
4459
  )
4460
 
4461
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
4462
  prompt,
4463
  device,
4464
  num_images_per_prompt,
@@ -4548,6 +4551,7 @@ class MatryoshkaPipeline(
4548
  timestep_cond=timestep_cond,
4549
  cross_attention_kwargs=self.cross_attention_kwargs,
4550
  added_cond_kwargs=added_cond_kwargs,
 
4551
  return_dict=False,
4552
  )[0]
4553
 
 
3059
  added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
3060
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
3061
  added_cond_kwargs["from_nested"] = from_nested
3062
+ added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
3063
 
3064
  if not from_nested:
3065
  encoder_hidden_states = self.process_encoder_hidden_states(
 
3508
  added_cond_kwargs = added_cond_kwargs or {}
3509
  added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention
3510
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
3511
+ added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
3512
 
3513
  if not self.config.nesting:
3514
  encoder_hidden_states = self.inner_unet.process_encoder_hidden_states(
 
3531
  added_cond_kwargs = added_cond_kwargs or {}
3532
  added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention
3533
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
3534
+ added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
3535
 
3536
  encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states(
3537
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
 
3606
  encoder_hidden_states=encoder_hidden_states[:bh],
3607
  attention_mask=attention_mask,
3608
  cross_attention_kwargs=cross_attention_kwargs,
3609
+ encoder_attention_mask=cond_mask_inner_unet[:bh] if cond_mask_inner_unet is not None else cond_mask_inner_unet,
3610
  **additional_residuals,
3611
  )
3612
  else:
 
4028
  # Retrieve the original scale by scaling back the LoRA layers
4029
  unscale_lora_layers(self.text_encoder, lora_scale)
4030
 
4031
+ return prompt_embeds, negative_prompt_embeds, attention_mask
4032
 
4033
  def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
4034
  dtype = next(self.image_encoder.parameters()).dtype
 
4461
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
4462
  )
4463
 
4464
+ prompt_embeds, negative_prompt_embeds, encoder_attention_mask = self.encode_prompt(
4465
  prompt,
4466
  device,
4467
  num_images_per_prompt,
 
4551
  timestep_cond=timestep_cond,
4552
  cross_attention_kwargs=self.cross_attention_kwargs,
4553
  added_cond_kwargs=added_cond_kwargs,
4554
+ encoder_attention_mask=encoder_attention_mask,
4555
  return_dict=False,
4556
  )[0]
4557