|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class xATGLU(nn.Module): |
|
def __init__(self, input_dim, output_dim, bias=True): |
|
super().__init__() |
|
|
|
self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias) |
|
nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear') |
|
|
|
self.alpha = nn.Parameter(torch.zeros(1)) |
|
self.half_pi = torch.pi / 2 |
|
self.inv_pi = 1 / torch.pi |
|
|
|
def forward(self, x): |
|
projected = self.proj(x) |
|
gate_path, value_path = projected.chunk(2, dim=-1) |
|
|
|
|
|
gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi |
|
expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha |
|
|
|
return expanded_gate * value_path |
|
|
|
|
|
|
|
|
|
|
|
class CPLinear(nn.Module): |
|
def __init__(self, in_features, n_head, head_dim, rank: int = 1, q_rank: int = 12): |
|
super(CPLinear, self).__init__() |
|
self.in_features = in_features |
|
self.n_head = n_head |
|
self.head_dim = head_dim |
|
self.rank = rank |
|
self.q_rank = q_rank |
|
|
|
self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False) |
|
self.W_A_k = nn.Linear(in_features, n_head * rank, bias=False) |
|
self.W_A_v = nn.Linear(in_features, n_head * rank, bias=False) |
|
|
|
nn.init.xavier_normal_(self.W_A_q.weight) |
|
nn.init.xavier_normal_(self.W_A_k.weight) |
|
nn.init.xavier_normal_(self.W_A_v.weight) |
|
|
|
self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False) |
|
self.W_B_k = nn.Linear(in_features, rank * head_dim, bias=False) |
|
self.W_B_v = nn.Linear(in_features, rank * head_dim, bias=False) |
|
|
|
nn.init.xavier_normal_(self.W_B_q.weight) |
|
nn.init.xavier_normal_(self.W_B_k.weight) |
|
nn.init.xavier_normal_(self.W_B_v.weight) |
|
|
|
def forward(self, x): |
|
batch_size, seq_len, _ = x.size() |
|
|
|
|
|
|
|
|
|
A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank) |
|
A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.rank) |
|
A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.rank) |
|
|
|
|
|
B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim) |
|
B_k = self.W_B_k(x).view(batch_size, seq_len, self.rank, self.head_dim) |
|
B_v = self.W_B_v(x).view(batch_size, seq_len, self.rank, self.head_dim) |
|
|
|
|
|
A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank) |
|
A_k = A_k.view(batch_size * seq_len, self.n_head, self.rank) |
|
A_v = A_v.view(batch_size * seq_len, self.n_head, self.rank) |
|
|
|
|
|
B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim) |
|
B_k = B_k.view(batch_size * seq_len, self.rank, self.head_dim) |
|
B_v = B_v.view(batch_size * seq_len, self.rank, self.head_dim) |
|
|
|
q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
|
k = torch.bmm(A_k, B_k).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
|
v = torch.bmm(A_v, B_v).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
|
|
|
return q, k, v |
|
|
|
|
|
|
|
class Rotary(torch.nn.Module): |
|
def __init__(self, dim, base=10000): |
|
super().__init__() |
|
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.seq_len_cached = None |
|
self.cos_cached = None |
|
self.sin_cached = None |
|
|
|
def forward(self, x): |
|
seq_len = x.shape[1] |
|
if seq_len != self.seq_len_cached: |
|
self.seq_len_cached = seq_len |
|
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
|
freqs = torch.outer(t, self.inv_freq).to(x.device) |
|
self.cos_cached = freqs.cos().bfloat16() |
|
self.sin_cached = freqs.sin().bfloat16() |
|
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] |
|
|
|
def apply_rotary_emb(x, cos, sin): |
|
assert x.ndim == 4 |
|
d = x.shape[3] // 2 |
|
x1 = x[..., :d] |
|
x2 = x[..., d:] |
|
y1 = x1 * cos + x2 * sin |
|
y2 = x1 * (-sin) + x2 * cos |
|
return torch.cat([y1, y2], 3).type_as(x) |
|
|
|
class TensorProductAttentionWithRope(nn.Module): |
|
def __init__(self, n_head, head_dim, n_embd, kv_rank=2, q_rank=6): |
|
super().__init__() |
|
self.n_head = n_head |
|
self.head_dim = head_dim |
|
self.n_embd = n_embd |
|
self.kv_rank = kv_rank |
|
self.q_rank = q_rank |
|
|
|
self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.kv_rank, self.q_rank) |
|
|
|
|
|
self.o_proj = xATGLU(self.n_head * self.head_dim, self.n_embd, bias=True) |
|
|
|
|
|
self.rotary = Rotary(self.head_dim) |
|
|
|
def forward(self, x): |
|
B, T, C = x.size() |
|
|
|
|
|
q, k, v = self.c_qkv(x) |
|
|
|
cos, sin = self.rotary(q) |
|
q = apply_rotary_emb(q, cos, sin) |
|
k = apply_rotary_emb(k, cos, sin) |
|
|
|
|
|
q = q.permute(0, 2, 1, 3) |
|
k = k.permute(0, 2, 1, 3) |
|
v = v.permute(0, 2, 1, 3) |
|
|
|
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=False) |
|
|
|
|
|
y = y.transpose(1, 2).flatten(2) |
|
y = self.o_proj(y) |
|
|
|
return y |
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
|
self.norm1 = nn.GroupNorm(32, channels) |
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
|
self.norm2 = nn.GroupNorm(32, channels) |
|
|
|
self.learned_residual_scale = nn.Parameter(torch.ones(1) * 0.1) |
|
|
|
def forward(self, x): |
|
h = self.conv1(F.silu(self.norm1(x))) |
|
h = self.conv2(F.silu(self.norm2(h))) |
|
return x + h * self.learned_residual_scale |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, channels, num_heads=8): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(channels) |
|
self.norm2 = nn.LayerNorm(channels) |
|
|
|
|
|
self.attn = TensorProductAttentionWithRope( |
|
n_head=num_heads, |
|
head_dim=channels // num_heads, |
|
n_embd=channels, |
|
kv_rank=2, |
|
q_rank=6 |
|
) |
|
|
|
self.mlp = nn.Sequential( |
|
xATGLU(channels, 2 * channels, bias=False), |
|
nn.Linear(2 * channels, channels, bias=False) |
|
) |
|
|
|
self.learned_residual_scale_attn = nn.Parameter(torch.ones(1) * 0.1) |
|
self.learned_residual_scale_mlp = nn.Parameter(torch.ones(1) * 0.1) |
|
|
|
def forward(self, x): |
|
|
|
b, c, h, w = x.shape |
|
|
|
x = x.reshape(b, h * w, c) |
|
|
|
|
|
identity = x |
|
x = self.norm1(x) |
|
h_attn = self.attn(x) |
|
|
|
x = identity + h_attn * self.learned_residual_scale_attn |
|
|
|
identity = x |
|
x = self.norm2(x) |
|
h_mlp = self.mlp(x) |
|
x = identity + h_mlp * self.learned_residual_scale_mlp |
|
|
|
|
|
x = x.permute(1, 2, 0).reshape(b, c, h, w) |
|
return x |
|
|
|
class LevelBlock(nn.Module): |
|
def __init__(self, channels, num_blocks, block_type='res'): |
|
super().__init__() |
|
self.blocks = nn.ModuleList() |
|
for _ in range(num_blocks): |
|
if block_type == 'transformer': |
|
self.blocks.append(TransformerBlock(channels)) |
|
else: |
|
self.blocks.append(ResBlock(channels)) |
|
|
|
def forward(self, x): |
|
for block in self.blocks: |
|
x = block(x) |
|
return x |
|
|
|
class AsymmetricResidualUDiT(nn.Module): |
|
def __init__(self, |
|
in_channels=3, |
|
base_channels=128, |
|
patch_size=2, |
|
num_levels=3, |
|
encoder_blocks=3, |
|
decoder_blocks=7, |
|
encoder_transformer_thresh=2, |
|
decoder_transformer_thresh=4, |
|
mid_blocks=16, |
|
): |
|
super().__init__() |
|
self.learned_middle_residual_scale = nn.Parameter(torch.ones(1) * 0.1) |
|
|
|
self.patch_embed = nn.Conv2d(in_channels, base_channels, |
|
kernel_size=patch_size, stride=patch_size) |
|
|
|
self.encoders = nn.ModuleList() |
|
curr_channels = base_channels |
|
|
|
for level in range(num_levels): |
|
use_transformer = level >= encoder_transformer_thresh |
|
|
|
|
|
self.encoders.append( |
|
LevelBlock(curr_channels, encoder_blocks, use_transformer) |
|
) |
|
|
|
|
|
if level < num_levels - 1: |
|
self.encoders.append( |
|
nn.Conv2d(curr_channels, curr_channels * 2, 1) |
|
) |
|
curr_channels *= 2 |
|
|
|
|
|
self.middle = nn.ModuleList([ |
|
TransformerBlock(curr_channels) for _ in range(mid_blocks) |
|
]) |
|
|
|
|
|
self.decoders = nn.ModuleList() |
|
|
|
for level in range(num_levels): |
|
use_transformer = level <= decoder_transformer_thresh |
|
|
|
|
|
self.decoders.append( |
|
LevelBlock(curr_channels, decoder_blocks, use_transformer) |
|
) |
|
|
|
|
|
if level < num_levels - 1: |
|
self.decoders.append( |
|
nn.Conv2d(curr_channels, curr_channels // 2, 1) |
|
) |
|
curr_channels //= 2 |
|
|
|
|
|
self.final_proj = nn.ConvTranspose2d(base_channels, in_channels, |
|
kernel_size=patch_size, stride=patch_size) |
|
|
|
def downsample(self, x): |
|
return F.avg_pool2d(x, kernel_size=2) |
|
|
|
def upsample(self, x): |
|
return F.interpolate(x, scale_factor=2, mode='nearest') |
|
|
|
def forward(self, x, t=None): |
|
|
|
|
|
|
|
x = self.patch_embed(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residuals = [] |
|
curr_res = x |
|
|
|
|
|
h = x |
|
for i, blocks in enumerate(self.encoders): |
|
if isinstance(blocks, LevelBlock): |
|
h = blocks(h) |
|
else: |
|
|
|
residuals.append(curr_res) |
|
|
|
h = self.downsample(blocks(h)) |
|
curr_res = h |
|
|
|
|
|
x = h |
|
for block in self.middle: |
|
x = block(x) |
|
|
|
|
|
x = x - curr_res * self.learned_middle_residual_scale |
|
|
|
|
|
for i, blocks in enumerate(self.decoders): |
|
if isinstance(blocks, LevelBlock): |
|
x = blocks(x) |
|
else: |
|
|
|
x = blocks(x) |
|
|
|
x = self.upsample(x) |
|
|
|
curr_res = residuals.pop() |
|
x = x + curr_res * self.learned_middle_residual_scale |
|
|
|
|
|
x = self.final_proj(x) |
|
|
|
return x |