Yuanshi's picture
add all
6ed1db6
raw
history blame
13.1 kB
import torch
from typing import List, Union, Optional, Dict, Any, Callable
from diffusers.models.attention_processor import Attention, F
from .lora_controller import enable_lora
def attn_forward(
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
condition_latents: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
cond_rotary_emb: Optional[torch.Tensor] = None,
model_config: Optional[Dict[str, Any]] = {},
) -> torch.FloatTensor:
batch_size, _, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
with enable_lora(
(attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
):
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if condition_latents is not None:
cond_query = attn.to_q(condition_latents)
cond_key = attn.to_k(condition_latents)
cond_value = attn.to_v(condition_latents)
cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
1, 2
)
cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
1, 2
)
if attn.norm_q is not None:
cond_query = attn.norm_q(cond_query)
if attn.norm_k is not None:
cond_key = attn.norm_k(cond_key)
if cond_rotary_emb is not None:
cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
if condition_latents is not None:
query = torch.cat([query, cond_query], dim=2)
key = torch.cat([key, cond_key], dim=2)
value = torch.cat([value, cond_value], dim=2)
if not model_config.get("union_cond_attn", True):
# If we don't want to use the union condition attention, we need to mask the attention
# between the hidden states and the condition latents
attention_mask = torch.ones(
query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
)
condition_n = cond_query.shape[2]
attention_mask[-condition_n:, :-condition_n] = False
attention_mask[:-condition_n, -condition_n:] = False
if hasattr(attn, "c_factor"):
attention_mask = torch.zeros(
query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
)
condition_n = cond_query.shape[2]
bias = torch.log(attn.c_factor[0])
attention_mask[-condition_n:, :-condition_n] = bias
attention_mask[:-condition_n, -condition_n:] = bias
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
if condition_latents is not None:
encoder_hidden_states, hidden_states, condition_latents = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[
:, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
],
hidden_states[:, -condition_latents.shape[1] :],
)
else:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if condition_latents is not None:
condition_latents = attn.to_out[0](condition_latents)
condition_latents = attn.to_out[1](condition_latents)
return (
(hidden_states, encoder_hidden_states, condition_latents)
if condition_latents is not None
else (hidden_states, encoder_hidden_states)
)
elif condition_latents is not None:
# if there are condition_latents, we need to separate the hidden_states and the condition_latents
hidden_states, condition_latents = (
hidden_states[:, : -condition_latents.shape[1]],
hidden_states[:, -condition_latents.shape[1] :],
)
return hidden_states, condition_latents
else:
return hidden_states
def block_forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
condition_latents: torch.FloatTensor,
temb: torch.FloatTensor,
cond_temb: torch.FloatTensor,
cond_rotary_emb=None,
image_rotary_emb=None,
model_config: Optional[Dict[str, Any]] = {},
):
use_cond = condition_latents is not None
with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, emb=temb
)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
self.norm1_context(encoder_hidden_states, emb=temb)
)
if use_cond:
(
norm_condition_latents,
cond_gate_msa,
cond_shift_mlp,
cond_scale_mlp,
cond_gate_mlp,
) = self.norm1(condition_latents, emb=cond_temb)
# Attention.
result = attn_forward(
self.attn,
model_config=model_config,
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
condition_latents=norm_condition_latents if use_cond else None,
image_rotary_emb=image_rotary_emb,
cond_rotary_emb=cond_rotary_emb if use_cond else None,
)
attn_output, context_attn_output = result[:2]
cond_attn_output = result[2] if use_cond else None
# Process attention outputs for the `hidden_states`.
# 1. hidden_states
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
# 2. encoder_hidden_states
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
# 3. condition_latents
if use_cond:
cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
condition_latents = condition_latents + cond_attn_output
if model_config.get("add_cond_attn", False):
hidden_states += cond_attn_output
# LayerNorm + MLP.
# 1. hidden_states
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
# 2. encoder_hidden_states
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = (
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
)
# 3. condition_latents
if use_cond:
norm_condition_latents = self.norm2(condition_latents)
norm_condition_latents = (
norm_condition_latents * (1 + cond_scale_mlp[:, None])
+ cond_shift_mlp[:, None]
)
# Feed-forward.
with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
# 1. hidden_states
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
# 2. encoder_hidden_states
context_ff_output = self.ff_context(norm_encoder_hidden_states)
context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
# 3. condition_latents
if use_cond:
cond_ff_output = self.ff(norm_condition_latents)
cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
# Process feed-forward outputs.
hidden_states = hidden_states + ff_output
encoder_hidden_states = encoder_hidden_states + context_ff_output
if use_cond:
condition_latents = condition_latents + cond_ff_output
# Clip to avoid overflow.
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
def single_block_forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
condition_latents: torch.FloatTensor = None,
cond_temb: torch.FloatTensor = None,
cond_rotary_emb=None,
model_config: Optional[Dict[str, Any]] = {},
):
using_cond = condition_latents is not None
residual = hidden_states
with enable_lora(
(
self.norm.linear,
self.proj_mlp,
),
model_config.get("latent_lora", False),
):
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
if using_cond:
residual_cond = condition_latents
norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
attn_output = attn_forward(
self.attn,
model_config=model_config,
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**(
{
"condition_latents": norm_condition_latents,
"cond_rotary_emb": cond_rotary_emb if using_cond else None,
}
if using_cond
else {}
),
)
if using_cond:
attn_output, cond_attn_output = attn_output
with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if using_cond:
condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
cond_gate = cond_gate.unsqueeze(1)
condition_latents = cond_gate * self.proj_out(condition_latents)
condition_latents = residual_cond + condition_latents
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states if not using_cond else (hidden_states, condition_latents)