|
from typing import Optional
|
|
|
|
from einops import rearrange
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .activation_layers import get_activation_layer
|
|
from .attenion import attention
|
|
from .norm_layers import get_norm_layer
|
|
from .embed_layers import TimestepEmbedder, TextProjection
|
|
from .attenion import attention
|
|
from .mlp_layers import MLP
|
|
from .modulate_layers import modulate, apply_gate
|
|
|
|
|
|
class IndividualTokenRefinerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
heads_num,
|
|
mlp_width_ratio: str = 4.0,
|
|
mlp_drop_rate: float = 0.0,
|
|
act_type: str = "silu",
|
|
qk_norm: bool = False,
|
|
qk_norm_type: str = "layer",
|
|
qkv_bias: bool = True,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.heads_num = heads_num
|
|
head_dim = hidden_size // heads_num
|
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
|
|
|
self.norm1 = nn.LayerNorm(
|
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
|
)
|
|
self.self_attn_qkv = nn.Linear(
|
|
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
|
)
|
|
qk_norm_layer = get_norm_layer(qk_norm_type)
|
|
self.self_attn_q_norm = (
|
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
|
if qk_norm
|
|
else nn.Identity()
|
|
)
|
|
self.self_attn_k_norm = (
|
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
|
if qk_norm
|
|
else nn.Identity()
|
|
)
|
|
self.self_attn_proj = nn.Linear(
|
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
|
)
|
|
|
|
self.norm2 = nn.LayerNorm(
|
|
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
|
)
|
|
act_layer = get_activation_layer(act_type)
|
|
self.mlp = MLP(
|
|
in_channels=hidden_size,
|
|
hidden_channels=mlp_hidden_dim,
|
|
act_layer=act_layer,
|
|
drop=mlp_drop_rate,
|
|
**factory_kwargs,
|
|
)
|
|
|
|
self.adaLN_modulation = nn.Sequential(
|
|
act_layer(),
|
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
|
)
|
|
|
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
c: torch.Tensor,
|
|
attn_mask: torch.Tensor = None,
|
|
):
|
|
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
|
|
norm_x = self.norm1(x)
|
|
qkv = self.self_attn_qkv(norm_x)
|
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
|
|
|
q = self.self_attn_q_norm(q).to(v)
|
|
k = self.self_attn_k_norm(k).to(v)
|
|
|
|
|
|
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
|
|
|
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
|
|
|
|
|
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
|
|
|
return x
|
|
|
|
|
|
class IndividualTokenRefiner(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
heads_num,
|
|
depth,
|
|
mlp_width_ratio: float = 4.0,
|
|
mlp_drop_rate: float = 0.0,
|
|
act_type: str = "silu",
|
|
qk_norm: bool = False,
|
|
qk_norm_type: str = "layer",
|
|
qkv_bias: bool = True,
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
IndividualTokenRefinerBlock(
|
|
hidden_size=hidden_size,
|
|
heads_num=heads_num,
|
|
mlp_width_ratio=mlp_width_ratio,
|
|
mlp_drop_rate=mlp_drop_rate,
|
|
act_type=act_type,
|
|
qk_norm=qk_norm,
|
|
qk_norm_type=qk_norm_type,
|
|
qkv_bias=qkv_bias,
|
|
**factory_kwargs,
|
|
)
|
|
for _ in range(depth)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
c: torch.LongTensor,
|
|
mask: Optional[torch.Tensor] = None,
|
|
):
|
|
self_attn_mask = None
|
|
if mask is not None:
|
|
batch_size = mask.shape[0]
|
|
seq_len = mask.shape[1]
|
|
mask = mask.to(x.device)
|
|
|
|
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
|
1, 1, seq_len, 1
|
|
)
|
|
|
|
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
|
|
|
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
|
|
|
self_attn_mask[:, :, :, 0] = True
|
|
|
|
for block in self.blocks:
|
|
x = block(x, c, self_attn_mask)
|
|
return x
|
|
|
|
|
|
class SingleTokenRefiner(nn.Module):
|
|
"""
|
|
A single token refiner block for llm text embedding refine.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
hidden_size,
|
|
heads_num,
|
|
depth,
|
|
mlp_width_ratio: float = 4.0,
|
|
mlp_drop_rate: float = 0.0,
|
|
act_type: str = "silu",
|
|
qk_norm: bool = False,
|
|
qk_norm_type: str = "layer",
|
|
qkv_bias: bool = True,
|
|
attn_mode: str = "torch",
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.attn_mode = attn_mode
|
|
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
|
|
|
self.input_embedder = nn.Linear(
|
|
in_channels, hidden_size, bias=True, **factory_kwargs
|
|
)
|
|
|
|
act_layer = get_activation_layer(act_type)
|
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
|
|
|
self.c_embedder = TextProjection(
|
|
in_channels, hidden_size, act_layer, **factory_kwargs
|
|
)
|
|
|
|
self.individual_token_refiner = IndividualTokenRefiner(
|
|
hidden_size=hidden_size,
|
|
heads_num=heads_num,
|
|
depth=depth,
|
|
mlp_width_ratio=mlp_width_ratio,
|
|
mlp_drop_rate=mlp_drop_rate,
|
|
act_type=act_type,
|
|
qk_norm=qk_norm,
|
|
qk_norm_type=qk_norm_type,
|
|
qkv_bias=qkv_bias,
|
|
**factory_kwargs,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
t: torch.LongTensor,
|
|
mask: Optional[torch.LongTensor] = None,
|
|
):
|
|
timestep_aware_representations = self.t_embedder(t)
|
|
|
|
if mask is None:
|
|
context_aware_representations = x.mean(dim=1)
|
|
else:
|
|
mask_float = mask.float().unsqueeze(-1)
|
|
context_aware_representations = (x * mask_float).sum(
|
|
dim=1
|
|
) / mask_float.sum(dim=1)
|
|
context_aware_representations = self.c_embedder(context_aware_representations)
|
|
c = timestep_aware_representations + context_aware_representations
|
|
|
|
x = self.input_embedder(x)
|
|
|
|
x = self.individual_token_refiner(x, c, mask)
|
|
|
|
return x
|
|
|