tolgacangoz
commited on
Upload matryoshka.py
Browse files- unet/matryoshka.py +114 -82
unet/matryoshka.py
CHANGED
@@ -664,9 +664,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
664 |
variance_noise = []
|
665 |
for m_o in model_output:
|
666 |
variance_noise.append(
|
667 |
-
randn_tensor(
|
668 |
-
m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype
|
669 |
-
)
|
670 |
)
|
671 |
else:
|
672 |
variance_noise = randn_tensor(
|
@@ -1897,6 +1895,8 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
|
|
1897 |
dim=1, keepdim=True
|
1898 |
)
|
1899 |
cond_emb = self.cond_emb(y)
|
|
|
|
|
1900 |
|
1901 |
if not masked_cross_attention:
|
1902 |
conditioning_mask = None
|
@@ -1905,11 +1905,8 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
|
|
1905 |
if micro is not None:
|
1906 |
temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype))
|
1907 |
temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype))
|
1908 |
-
if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
|
1909 |
-
|
1910 |
-
return cond_emb_micro, conditioning_mask, cond_emb
|
1911 |
-
else:
|
1912 |
-
return temb_micro_conditioning, conditioning_mask, None
|
1913 |
|
1914 |
return cond_emb, conditioning_mask, cond_emb
|
1915 |
|
@@ -3035,11 +3032,6 @@ class MatryoshkaUNet2DConditionModel(
|
|
3035 |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
3036 |
attention_mask = attention_mask.unsqueeze(1)
|
3037 |
|
3038 |
-
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
3039 |
-
if encoder_attention_mask is not None:
|
3040 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0
|
3041 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
3042 |
-
|
3043 |
# 0. center input if necessary
|
3044 |
if self.config.center_input_sample:
|
3045 |
sample = 2 * sample - 1.0
|
@@ -3074,6 +3066,11 @@ class MatryoshkaUNet2DConditionModel(
|
|
3074 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3075 |
)
|
3076 |
|
|
|
|
|
|
|
|
|
|
|
3077 |
if self.config.addition_embed_type == "image_hint":
|
3078 |
aug_emb, hint = aug_emb
|
3079 |
sample = torch.cat([sample, hint], dim=1)
|
@@ -3484,11 +3481,6 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3484 |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
3485 |
attention_mask = attention_mask.unsqueeze(1)
|
3486 |
|
3487 |
-
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
3488 |
-
if encoder_attention_mask is not None:
|
3489 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
3490 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
3491 |
-
|
3492 |
# 0. center input if necessary
|
3493 |
if self.config.center_input_sample:
|
3494 |
sample = 2 * sample - 1.0
|
@@ -3515,15 +3507,15 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3515 |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3516 |
)
|
3517 |
|
3518 |
-
aug_emb_inner_unet,
|
3519 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3520 |
)
|
3521 |
-
|
3522 |
-
aug_emb,
|
3523 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3524 |
)
|
3525 |
else:
|
3526 |
-
aug_emb,
|
3527 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3528 |
)
|
3529 |
|
@@ -3537,14 +3529,19 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3537 |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3538 |
)
|
3539 |
|
3540 |
-
aug_emb_inner_unet,
|
3541 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3542 |
)
|
3543 |
|
3544 |
-
aug_emb,
|
3545 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3546 |
)
|
3547 |
|
|
|
|
|
|
|
|
|
|
|
3548 |
if self.config.addition_embed_type == "image_hint":
|
3549 |
aug_emb, hint = aug_emb
|
3550 |
sample = torch.cat([sample, hint], dim=1)
|
@@ -3606,7 +3603,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3606 |
encoder_hidden_states=encoder_hidden_states[:bh],
|
3607 |
attention_mask=attention_mask,
|
3608 |
cross_attention_kwargs=cross_attention_kwargs,
|
3609 |
-
encoder_attention_mask=
|
3610 |
**additional_residuals,
|
3611 |
)
|
3612 |
else:
|
@@ -3626,7 +3623,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
|
|
3626 |
timestep,
|
3627 |
cond_emb=cond_emb,
|
3628 |
encoder_hidden_states=encoder_hidden_states,
|
3629 |
-
encoder_attention_mask=
|
3630 |
from_nested=True,
|
3631 |
)
|
3632 |
x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner
|
@@ -3914,9 +3911,6 @@ class MatryoshkaPipeline(
|
|
3914 |
|
3915 |
text_inputs = self.tokenizer(
|
3916 |
prompt,
|
3917 |
-
padding="max_length",
|
3918 |
-
max_length=self.tokenizer.model_max_length,
|
3919 |
-
truncation=True,
|
3920 |
return_tensors="pt",
|
3921 |
)
|
3922 |
text_input_ids = text_inputs.input_ids
|
@@ -3934,26 +3928,9 @@ class MatryoshkaPipeline(
|
|
3934 |
)
|
3935 |
|
3936 |
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
3937 |
-
|
3938 |
else:
|
3939 |
-
|
3940 |
-
|
3941 |
-
if clip_skip is None:
|
3942 |
-
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
3943 |
-
prompt_embeds = prompt_embeds[0]
|
3944 |
-
else:
|
3945 |
-
prompt_embeds = self.text_encoder(
|
3946 |
-
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
3947 |
-
)
|
3948 |
-
# Access the `hidden_states` first, that contains a tuple of
|
3949 |
-
# all the hidden states from the encoder layers. Then index into
|
3950 |
-
# the tuple to access the hidden states from the desired layer.
|
3951 |
-
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
3952 |
-
# We also need to apply the final LayerNorm here to not mess with the
|
3953 |
-
# representations. The `last_hidden_states` that we typically use for
|
3954 |
-
# obtaining the final prompt representations passes through the LayerNorm
|
3955 |
-
# layer.
|
3956 |
-
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
3957 |
|
3958 |
if self.text_encoder is not None:
|
3959 |
prompt_embeds_dtype = self.text_encoder.dtype
|
@@ -3962,13 +3939,6 @@ class MatryoshkaPipeline(
|
|
3962 |
else:
|
3963 |
prompt_embeds_dtype = prompt_embeds.dtype
|
3964 |
|
3965 |
-
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
3966 |
-
|
3967 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
3968 |
-
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
3969 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
3970 |
-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
3971 |
-
|
3972 |
# get unconditional embeddings for classifier free guidance
|
3973 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
3974 |
uncond_tokens: List[str]
|
@@ -3994,41 +3964,78 @@ class MatryoshkaPipeline(
|
|
3994 |
if isinstance(self, TextualInversionLoaderMixin):
|
3995 |
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
3996 |
|
3997 |
-
max_length = prompt_embeds.shape[1]
|
3998 |
uncond_input = self.tokenizer(
|
3999 |
uncond_tokens,
|
4000 |
-
padding="max_length",
|
4001 |
-
max_length=max_length,
|
4002 |
-
truncation=True,
|
4003 |
return_tensors="pt",
|
4004 |
)
|
|
|
4005 |
|
4006 |
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
4007 |
-
|
4008 |
else:
|
4009 |
-
|
4010 |
|
4011 |
-
|
4012 |
-
|
4013 |
-
attention_mask=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4014 |
)
|
4015 |
-
|
4016 |
|
4017 |
-
|
4018 |
-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
4019 |
-
seq_len = negative_prompt_embeds.shape[1]
|
4020 |
-
|
4021 |
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
4022 |
-
|
4023 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
4024 |
-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
4025 |
|
4026 |
if self.text_encoder is not None:
|
4027 |
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
4028 |
# Retrieve the original scale by scaling back the LoRA layers
|
4029 |
unscale_lora_layers(self.text_encoder, lora_scale)
|
4030 |
|
4031 |
-
|
|
|
|
|
4032 |
|
4033 |
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
4034 |
dtype = next(self.image_encoder.parameters()).dtype
|
@@ -4461,7 +4468,12 @@ class MatryoshkaPipeline(
|
|
4461 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
4462 |
)
|
4463 |
|
4464 |
-
|
|
|
|
|
|
|
|
|
|
|
4465 |
prompt,
|
4466 |
device,
|
4467 |
num_images_per_prompt,
|
@@ -4477,7 +4489,12 @@ class MatryoshkaPipeline(
|
|
4477 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
4478 |
# to avoid doing two forward passes
|
4479 |
if self.do_classifier_free_guidance:
|
4480 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
|
|
|
|
|
|
|
|
|
|
4481 |
|
4482 |
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
4483 |
image_embeds = self.prepare_ip_adapter_image_embeds(
|
@@ -4489,10 +4506,13 @@ class MatryoshkaPipeline(
|
|
4489 |
)
|
4490 |
|
4491 |
# 4. Prepare timesteps
|
4492 |
-
|
4493 |
-
|
4494 |
-
|
4495 |
-
|
|
|
|
|
|
|
4496 |
|
4497 |
# 5. Prepare latent variables
|
4498 |
num_channels_latents = self.unet.config.in_channels
|
@@ -4551,7 +4571,7 @@ class MatryoshkaPipeline(
|
|
4551 |
timestep_cond=timestep_cond,
|
4552 |
cross_attention_kwargs=self.cross_attention_kwargs,
|
4553 |
added_cond_kwargs=added_cond_kwargs,
|
4554 |
-
encoder_attention_mask=
|
4555 |
return_dict=False,
|
4556 |
)[0]
|
4557 |
|
@@ -4568,7 +4588,19 @@ class MatryoshkaPipeline(
|
|
4568 |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
4569 |
|
4570 |
# compute the previous noisy sample x_t -> x_t-1
|
4571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4572 |
|
4573 |
if callback_on_step_end is not None:
|
4574 |
callback_kwargs = {}
|
|
|
664 |
variance_noise = []
|
665 |
for m_o in model_output:
|
666 |
variance_noise.append(
|
667 |
+
randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype)
|
|
|
|
|
668 |
)
|
669 |
else:
|
670 |
variance_noise = randn_tensor(
|
|
|
1895 |
dim=1, keepdim=True
|
1896 |
)
|
1897 |
cond_emb = self.cond_emb(y)
|
1898 |
+
else:
|
1899 |
+
cond_emb = None
|
1900 |
|
1901 |
if not masked_cross_attention:
|
1902 |
conditioning_mask = None
|
|
|
1905 |
if micro is not None:
|
1906 |
temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype))
|
1907 |
temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype))
|
1908 |
+
# if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
|
1909 |
+
return temb_micro_conditioning, conditioning_mask, cond_emb
|
|
|
|
|
|
|
1910 |
|
1911 |
return cond_emb, conditioning_mask, cond_emb
|
1912 |
|
|
|
3032 |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
3033 |
attention_mask = attention_mask.unsqueeze(1)
|
3034 |
|
|
|
|
|
|
|
|
|
|
|
3035 |
# 0. center input if necessary
|
3036 |
if self.config.center_input_sample:
|
3037 |
sample = 2 * sample - 1.0
|
|
|
3066 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3067 |
)
|
3068 |
|
3069 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
3070 |
+
if encoder_attention_mask is not None:
|
3071 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0
|
3072 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
3073 |
+
|
3074 |
if self.config.addition_embed_type == "image_hint":
|
3075 |
aug_emb, hint = aug_emb
|
3076 |
sample = torch.cat([sample, hint], dim=1)
|
|
|
3481 |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
3482 |
attention_mask = attention_mask.unsqueeze(1)
|
3483 |
|
|
|
|
|
|
|
|
|
|
|
3484 |
# 0. center input if necessary
|
3485 |
if self.config.center_input_sample:
|
3486 |
sample = 2 * sample - 1.0
|
|
|
3507 |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3508 |
)
|
3509 |
|
3510 |
+
aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed(
|
3511 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3512 |
)
|
3513 |
+
added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
|
3514 |
+
aug_emb, __, _ = self.get_aug_embed(
|
3515 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3516 |
)
|
3517 |
else:
|
3518 |
+
aug_emb, cond_mask, _ = self.get_aug_embed(
|
3519 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3520 |
)
|
3521 |
|
|
|
3529 |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3530 |
)
|
3531 |
|
3532 |
+
aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed(
|
3533 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3534 |
)
|
3535 |
|
3536 |
+
aug_emb, __, _ = self.get_aug_embed(
|
3537 |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
3538 |
)
|
3539 |
|
3540 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
3541 |
+
if encoder_attention_mask is not None:
|
3542 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
3543 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
3544 |
+
|
3545 |
if self.config.addition_embed_type == "image_hint":
|
3546 |
aug_emb, hint = aug_emb
|
3547 |
sample = torch.cat([sample, hint], dim=1)
|
|
|
3603 |
encoder_hidden_states=encoder_hidden_states[:bh],
|
3604 |
attention_mask=attention_mask,
|
3605 |
cross_attention_kwargs=cross_attention_kwargs,
|
3606 |
+
encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,
|
3607 |
**additional_residuals,
|
3608 |
)
|
3609 |
else:
|
|
|
3623 |
timestep,
|
3624 |
cond_emb=cond_emb,
|
3625 |
encoder_hidden_states=encoder_hidden_states,
|
3626 |
+
encoder_attention_mask=cond_mask,
|
3627 |
from_nested=True,
|
3628 |
)
|
3629 |
x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner
|
|
|
3911 |
|
3912 |
text_inputs = self.tokenizer(
|
3913 |
prompt,
|
|
|
|
|
|
|
3914 |
return_tensors="pt",
|
3915 |
)
|
3916 |
text_input_ids = text_inputs.input_ids
|
|
|
3928 |
)
|
3929 |
|
3930 |
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
3931 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
3932 |
else:
|
3933 |
+
prompt_attention_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3934 |
|
3935 |
if self.text_encoder is not None:
|
3936 |
prompt_embeds_dtype = self.text_encoder.dtype
|
|
|
3939 |
else:
|
3940 |
prompt_embeds_dtype = prompt_embeds.dtype
|
3941 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3942 |
# get unconditional embeddings for classifier free guidance
|
3943 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
3944 |
uncond_tokens: List[str]
|
|
|
3964 |
if isinstance(self, TextualInversionLoaderMixin):
|
3965 |
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
3966 |
|
|
|
3967 |
uncond_input = self.tokenizer(
|
3968 |
uncond_tokens,
|
|
|
|
|
|
|
3969 |
return_tensors="pt",
|
3970 |
)
|
3971 |
+
uncond_input_ids = uncond_input.input_ids
|
3972 |
|
3973 |
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
3974 |
+
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
|
3975 |
else:
|
3976 |
+
negative_prompt_attention_mask = None
|
3977 |
|
3978 |
+
if not do_classifier_free_guidance:
|
3979 |
+
if clip_skip is None:
|
3980 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
3981 |
+
prompt_embeds = prompt_embeds[0]
|
3982 |
+
else:
|
3983 |
+
prompt_embeds = self.text_encoder(
|
3984 |
+
text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True
|
3985 |
+
)
|
3986 |
+
# Access the `hidden_states` first, that contains a tuple of
|
3987 |
+
# all the hidden states from the encoder layers. Then index into
|
3988 |
+
# the tuple to access the hidden states from the desired layer.
|
3989 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
3990 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
3991 |
+
# representations. The `last_hidden_states` that we typically use for
|
3992 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
3993 |
+
# layer.
|
3994 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
3995 |
+
else:
|
3996 |
+
max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0]))
|
3997 |
+
if len(text_input_ids[0]) < max_len:
|
3998 |
+
text_input_ids = torch.cat(
|
3999 |
+
[text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)],
|
4000 |
+
dim=1,
|
4001 |
+
)
|
4002 |
+
prompt_attention_mask = torch.cat(
|
4003 |
+
[
|
4004 |
+
prompt_attention_mask,
|
4005 |
+
torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long),
|
4006 |
+
],
|
4007 |
+
dim=1,
|
4008 |
+
)
|
4009 |
+
elif len(uncond_input_ids[0]) < max_len:
|
4010 |
+
uncond_input_ids = torch.cat(
|
4011 |
+
[uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)],
|
4012 |
+
dim=1,
|
4013 |
+
)
|
4014 |
+
negative_prompt_attention_mask = torch.cat(
|
4015 |
+
[
|
4016 |
+
negative_prompt_attention_mask,
|
4017 |
+
torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long),
|
4018 |
+
],
|
4019 |
+
dim=1,
|
4020 |
+
)
|
4021 |
+
cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0)
|
4022 |
+
cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
4023 |
+
prompt_embeds = self.text_encoder(
|
4024 |
+
cfg_input_ids.to(device),
|
4025 |
+
attention_mask=cfg_attention_mask,
|
4026 |
)
|
4027 |
+
prompt_embeds = prompt_embeds[0]
|
4028 |
|
4029 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4030 |
|
4031 |
if self.text_encoder is not None:
|
4032 |
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
4033 |
# Retrieve the original scale by scaling back the LoRA layers
|
4034 |
unscale_lora_layers(self.text_encoder, lora_scale)
|
4035 |
|
4036 |
+
if not do_classifier_free_guidance:
|
4037 |
+
return prompt_embeds, None, prompt_attention_mask, None
|
4038 |
+
return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask
|
4039 |
|
4040 |
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
4041 |
dtype = next(self.image_encoder.parameters()).dtype
|
|
|
4468 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
4469 |
)
|
4470 |
|
4471 |
+
(
|
4472 |
+
prompt_embeds,
|
4473 |
+
negative_prompt_embeds,
|
4474 |
+
prompt_attention_mask,
|
4475 |
+
negative_prompt_attention_mask,
|
4476 |
+
) = self.encode_prompt(
|
4477 |
prompt,
|
4478 |
device,
|
4479 |
num_images_per_prompt,
|
|
|
4489 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
4490 |
# to avoid doing two forward passes
|
4491 |
if self.do_classifier_free_guidance:
|
4492 |
+
prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)])
|
4493 |
+
attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
4494 |
+
else:
|
4495 |
+
attention_masks = prompt_attention_mask
|
4496 |
+
|
4497 |
+
prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1)
|
4498 |
|
4499 |
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
4500 |
image_embeds = self.prepare_ip_adapter_image_embeds(
|
|
|
4506 |
)
|
4507 |
|
4508 |
# 4. Prepare timesteps
|
4509 |
+
if isinstance(self.scheduler, MatryoshkaDDIMScheduler):
|
4510 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
4511 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
4512 |
+
)
|
4513 |
+
timesteps = timesteps[:-1] # is this correct???
|
4514 |
+
else:
|
4515 |
+
timesteps = self.scheduler.timesteps
|
4516 |
|
4517 |
# 5. Prepare latent variables
|
4518 |
num_channels_latents = self.unet.config.in_channels
|
|
|
4571 |
timestep_cond=timestep_cond,
|
4572 |
cross_attention_kwargs=self.cross_attention_kwargs,
|
4573 |
added_cond_kwargs=added_cond_kwargs,
|
4574 |
+
encoder_attention_mask=attention_masks,
|
4575 |
return_dict=False,
|
4576 |
)[0]
|
4577 |
|
|
|
4588 |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
4589 |
|
4590 |
# compute the previous noisy sample x_t -> x_t-1
|
4591 |
+
if self.scheduler.scales is not None and not isinstance(self.scheduler, MatryoshkaDDIMScheduler):
|
4592 |
+
latents[0] = self.scheduler.step(
|
4593 |
+
noise_pred[0], t, latents[0], **extra_step_kwargs, return_dict=False
|
4594 |
+
)[0]
|
4595 |
+
latents[1] = self.scheduler.inner_scheduler.step(
|
4596 |
+
noise_pred[1], t, latents[1], **extra_step_kwargs, return_dict=False
|
4597 |
+
)[0]
|
4598 |
+
if len(latents) > 2:
|
4599 |
+
latents[2] = self.scheduler.inner_scheduler.inner_scheduler.step(
|
4600 |
+
noise_pred[2], t, latents[2], **extra_step_kwargs, return_dict=False
|
4601 |
+
)[0]
|
4602 |
+
else:
|
4603 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
4604 |
|
4605 |
if callback_on_step_end is not None:
|
4606 |
callback_kwargs = {}
|