tolgacangoz
commited on
Upload matryoshka.py
Browse files- unet/matryoshka.py +7 -3
unet/matryoshka.py
CHANGED
@@ -3059,6 +3059,7 @@ class MatryoshkaUNet2DConditionModel(
|
|
3059 |
added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
|
3060 |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
|
3061 |
added_cond_kwargs["from_nested"] = from_nested
|
|
|
3062 |
|
3063 |
if not from_nested:
|
3064 |
encoder_hidden_states = self.process_encoder_hidden_states(
|
@@ -3507,6 +3508,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3507 |
added_cond_kwargs = added_cond_kwargs or {}
|
3508 |
added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention
|
3509 |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
|
|
|
3510 |
|
3511 |
if not self.config.nesting:
|
3512 |
encoder_hidden_states = self.inner_unet.process_encoder_hidden_states(
|
@@ -3529,6 +3531,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3529 |
added_cond_kwargs = added_cond_kwargs or {}
|
3530 |
added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention
|
3531 |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
|
|
|
3532 |
|
3533 |
encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states(
|
3534 |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
@@ -3603,7 +3606,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3603 |
encoder_hidden_states=encoder_hidden_states[:bh],
|
3604 |
attention_mask=attention_mask,
|
3605 |
cross_attention_kwargs=cross_attention_kwargs,
|
3606 |
-
encoder_attention_mask=
|
3607 |
**additional_residuals,
|
3608 |
)
|
3609 |
else:
|
@@ -4025,7 +4028,7 @@ class MatryoshkaPipeline(
|
|
4025 |
# Retrieve the original scale by scaling back the LoRA layers
|
4026 |
unscale_lora_layers(self.text_encoder, lora_scale)
|
4027 |
|
4028 |
-
return prompt_embeds, negative_prompt_embeds
|
4029 |
|
4030 |
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
4031 |
dtype = next(self.image_encoder.parameters()).dtype
|
@@ -4458,7 +4461,7 @@ class MatryoshkaPipeline(
|
|
4458 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
4459 |
)
|
4460 |
|
4461 |
-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
4462 |
prompt,
|
4463 |
device,
|
4464 |
num_images_per_prompt,
|
@@ -4548,6 +4551,7 @@ class MatryoshkaPipeline(
|
|
4548 |
timestep_cond=timestep_cond,
|
4549 |
cross_attention_kwargs=self.cross_attention_kwargs,
|
4550 |
added_cond_kwargs=added_cond_kwargs,
|
|
|
4551 |
return_dict=False,
|
4552 |
)[0]
|
4553 |
|
|
|
3059 |
added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
|
3060 |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
|
3061 |
added_cond_kwargs["from_nested"] = from_nested
|
3062 |
+
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
|
3063 |
|
3064 |
if not from_nested:
|
3065 |
encoder_hidden_states = self.process_encoder_hidden_states(
|
|
|
3508 |
added_cond_kwargs = added_cond_kwargs or {}
|
3509 |
added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention
|
3510 |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
|
3511 |
+
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
|
3512 |
|
3513 |
if not self.config.nesting:
|
3514 |
encoder_hidden_states = self.inner_unet.process_encoder_hidden_states(
|
|
|
3531 |
added_cond_kwargs = added_cond_kwargs or {}
|
3532 |
added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention
|
3533 |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
|
3534 |
+
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
|
3535 |
|
3536 |
encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states(
|
3537 |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
|
|
3606 |
encoder_hidden_states=encoder_hidden_states[:bh],
|
3607 |
attention_mask=attention_mask,
|
3608 |
cross_attention_kwargs=cross_attention_kwargs,
|
3609 |
+
encoder_attention_mask=cond_mask_inner_unet[:bh] if cond_mask_inner_unet is not None else cond_mask_inner_unet,
|
3610 |
**additional_residuals,
|
3611 |
)
|
3612 |
else:
|
|
|
4028 |
# Retrieve the original scale by scaling back the LoRA layers
|
4029 |
unscale_lora_layers(self.text_encoder, lora_scale)
|
4030 |
|
4031 |
+
return prompt_embeds, negative_prompt_embeds, attention_mask
|
4032 |
|
4033 |
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
4034 |
dtype = next(self.image_encoder.parameters()).dtype
|
|
|
4461 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
4462 |
)
|
4463 |
|
4464 |
+
prompt_embeds, negative_prompt_embeds, encoder_attention_mask = self.encode_prompt(
|
4465 |
prompt,
|
4466 |
device,
|
4467 |
num_images_per_prompt,
|
|
|
4551 |
timestep_cond=timestep_cond,
|
4552 |
cross_attention_kwargs=self.cross_attention_kwargs,
|
4553 |
added_cond_kwargs=added_cond_kwargs,
|
4554 |
+
encoder_attention_mask=encoder_attention_mask,
|
4555 |
return_dict=False,
|
4556 |
)[0]
|
4557 |
|