Blackroot's picture
Upload 4 files
e9959b7 verified
raw
history blame
15.6 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
# Changelog since original version:
# xATGLU instead of top linear in transformer block
# Added a learned residual scale to all blocks and all residuals. This allowed bfloat16 training to stabilize, prior it was just exploding.
# This architecture was my attempt at the following Simple Diffusion paper with some modifications:
# https://arxiv.org/pdf/2410.19324v1
# Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn.
class xATGLU(nn.Module):
def __init__(self, input_dim, output_dim, bias=True):
super().__init__()
# GATE path | VALUE path
self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias)
nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear')
self.alpha = nn.Parameter(torch.zeros(1))
self.half_pi = torch.pi / 2
self.inv_pi = 1 / torch.pi
def forward(self, x):
projected = self.proj(x)
gate_path, value_path = projected.chunk(2, dim=-1)
# Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768
gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi
expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha
return expanded_gate * value_path # g(x) × y
# Tensor product attention, modified. Original code from:
# https://github.com/tensorgi/T6/blob/main/model/T6_ropek.py
# https://arxiv.org/pdf/2501.06425
class CPLinear(nn.Module):
def __init__(self, in_features, n_head, head_dim, rank: int = 1, q_rank: int = 12):
super(CPLinear, self).__init__()
self.in_features = in_features
self.n_head = n_head
self.head_dim = head_dim
self.rank = rank
self.q_rank = q_rank
self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False)
self.W_A_k = nn.Linear(in_features, n_head * rank, bias=False)
self.W_A_v = nn.Linear(in_features, n_head * rank, bias=False)
nn.init.xavier_normal_(self.W_A_q.weight)
nn.init.xavier_normal_(self.W_A_k.weight)
nn.init.xavier_normal_(self.W_A_v.weight)
self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False)
self.W_B_k = nn.Linear(in_features, rank * head_dim, bias=False)
self.W_B_v = nn.Linear(in_features, rank * head_dim, bias=False)
nn.init.xavier_normal_(self.W_B_q.weight)
nn.init.xavier_normal_(self.W_B_k.weight)
nn.init.xavier_normal_(self.W_B_v.weight)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# A clarification on the naming, it's somewhat standard to call the two low rank matrices A and B, so I've followed that.
# Compute intermediate variables A for Q, K, and V
A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank)
A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.rank)
A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.rank)
# Compute intermediate variables B for Q, K, and V
B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim)
B_k = self.W_B_k(x).view(batch_size, seq_len, self.rank, self.head_dim)
B_v = self.W_B_v(x).view(batch_size, seq_len, self.rank, self.head_dim)
# Reshape A_q, A_k, A_v
A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank)
A_k = A_k.view(batch_size * seq_len, self.n_head, self.rank)
A_v = A_v.view(batch_size * seq_len, self.n_head, self.rank)
# Reshape B_k, B_v
B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim)
B_k = B_k.view(batch_size * seq_len, self.rank, self.head_dim)
B_v = B_v.view(batch_size * seq_len, self.rank, self.head_dim)
q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
k = torch.bmm(A_k, B_k).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim)
v = torch.bmm(A_v, B_v).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim)
return q, k, v
# Very possible this is not a good method for positional encoding in DiT, in fact it may be actively harmful. It does help in small datasets though.
# No positional embedding should be a serious consideration for high compute resources/large data scenarios.
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x):
seq_len = x.shape[1]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq).to(x.device)
self.cos_cached = freqs.cos().bfloat16()
self.sin_cached = freqs.sin().bfloat16()
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1 = x[..., :d]
x2 = x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3).type_as(x)
class TensorProductAttentionWithRope(nn.Module):
def __init__(self, n_head, head_dim, n_embd, kv_rank=2, q_rank=6):
super().__init__()
self.n_head = n_head
self.head_dim = head_dim
self.n_embd = n_embd
self.kv_rank = kv_rank
self.q_rank = q_rank
self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.kv_rank, self.q_rank)
# Output projection. Bias seems sensible here, each head can learn a shift.
self.o_proj = xATGLU(self.n_head * self.head_dim, self.n_embd, bias=True)
# Not a layer, just a helper
self.rotary = Rotary(self.head_dim)
def forward(self, x):
B, T, C = x.size() # batch_size, seq_length (T), embedding_dim
# Get Q, K, V through CPLinear factorization
q, k, v = self.c_qkv(x) # Each shape: (B, T, n_head, head_dim)
cos, sin = self.rotary(q)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# SDPA expects (B, n_head, T, head_dim)
q = q.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim
k = k.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim
v = v.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim
# Compute attention using scaled_dot_product_attention
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
# Back to B T C
y = y.transpose(1, 2).flatten(2)
y = self.o_proj(y)
return y
class ResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.norm1 = nn.GroupNorm(32, channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
self.learned_residual_scale = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, x):
h = self.conv1(F.silu(self.norm1(x)))
h = self.conv2(F.silu(self.norm2(h)))
return x + h * self.learned_residual_scale
class TransformerBlock(nn.Module):
def __init__(self, channels, num_heads=8):
super().__init__()
self.norm1 = nn.LayerNorm(channels)
self.norm2 = nn.LayerNorm(channels)
# Params recommended by TPA paper, seem to work fine.
self.attn = TensorProductAttentionWithRope(
n_head=num_heads,
head_dim=channels // num_heads,
n_embd=channels,
kv_rank=2,
q_rank=6
)
self.mlp = nn.Sequential(
xATGLU(channels, 2 * channels, bias=False),
nn.Linear(2 * channels, channels, bias=False) # Candidate for a bias
)
self.learned_residual_scale_attn = nn.Parameter(torch.ones(1) * 0.1)
self.learned_residual_scale_mlp = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, x):
# Input shape B C H W
b, c, h, w = x.shape
x = x.reshape(b, h * w, c) # [B, H*W, C]
# Pre-norm architecture, this was really helpful for network stability when using bf16
identity = x
x = self.norm1(x)
h_attn = self.attn(x)
#h_attn, _ = self.attn(x, x, x)
x = identity + h_attn * self.learned_residual_scale_attn
identity = x
x = self.norm2(x)
h_mlp = self.mlp(x)
x = identity + h_mlp * self.learned_residual_scale_mlp
# Reshape back to B C H W
x = x.permute(1, 2, 0).reshape(b, c, h, w)
return x
class LevelBlock(nn.Module):
def __init__(self, channels, num_blocks, block_type='res'):
super().__init__()
self.blocks = nn.ModuleList()
for _ in range(num_blocks):
if block_type == 'transformer':
self.blocks.append(TransformerBlock(channels))
else:
self.blocks.append(ResBlock(channels))
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class AsymmetricResidualUDiT(nn.Module):
def __init__(self,
in_channels=3, # Input color channels
base_channels=128, # Initial feature size, dramatically increases parameter size of network.
patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute.
num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase.
encoder_blocks=3, # Can be different number of blocks VS decoder_blocks
decoder_blocks=7, # Can be different number of blocks VS encoder_blocks
encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=)
decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=)
mid_blocks=16, # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck.
):
super().__init__()
self.learned_middle_residual_scale = nn.Parameter(torch.ones(1) * 0.1)
# Initial projection from image space
self.patch_embed = nn.Conv2d(in_channels, base_channels,
kernel_size=patch_size, stride=patch_size)
self.encoders = nn.ModuleList()
curr_channels = base_channels
for level in range(num_levels):
use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels
# Encoder blocks -- N = encoder_blocks
self.encoders.append(
LevelBlock(curr_channels, encoder_blocks, use_transformer)
)
# Each successive decoder halves the size of the feature space for each step, except for the last level.
if level < num_levels - 1:
self.encoders.append(
nn.Conv2d(curr_channels, curr_channels * 2, 1)
)
curr_channels *= 2
# Middle transformer blocks -- N = mid_blocks
self.middle = nn.ModuleList([
TransformerBlock(curr_channels) for _ in range(mid_blocks)
])
# Create decoder levels
self.decoders = nn.ModuleList()
for level in range(num_levels):
use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder)
# Decoder blocks -- N = decoder_blocks
self.decoders.append(
LevelBlock(curr_channels, decoder_blocks, use_transformer)
)
# Each successive decoder halves the size of the feature space for each step, except for the last level.
if level < num_levels - 1:
self.decoders.append(
nn.Conv2d(curr_channels, curr_channels // 2, 1)
)
curr_channels //= 2
# Final projection back to image space
self.final_proj = nn.ConvTranspose2d(base_channels, in_channels,
kernel_size=patch_size, stride=patch_size)
def downsample(self, x):
return F.avg_pool2d(x, kernel_size=2)
def upsample(self, x):
return F.interpolate(x, scale_factor=2, mode='nearest')
def forward(self, x, t=None):
# x shape B C H W
# This patchifies our input, for example given an input shape like:
# From 2, 3, 256, 256
x = self.patch_embed(x)
# Our shape is now more channels and with smaller W and H
# To 2, 128, 64, 64
# *Per resolution e.g. per num_level resolution block more or less
# f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x)
#
# Where
# 1. h = fd(x) : Encoder path processes input
# 2. D(h) : Downsample the encoded features
# 3. fm(D(h)) : Middle transformer blocks process downsampled features
# 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection)
# 5. U(...) : Upsample the processed features
# 6. ... + h : Add back original encoder features (skip connection)
# 7. fu(...) : Decoder path processes the combined features
residuals = []
curr_res = x
# Encoder path (computing h = fd(x))
h = x
for i, blocks in enumerate(self.encoders):
if isinstance(blocks, LevelBlock):
h = blocks(h)
else:
# Save residual before downsampling
residuals.append(curr_res)
# Downsample and update current residual
h = self.downsample(blocks(h))
curr_res = h
# Middle blocks (fm)
x = h
for block in self.middle:
x = block(x)
# Subtract the residual at this level (D(h))
x = x - curr_res * self.learned_middle_residual_scale
# Decoder path (fu)
for i, blocks in enumerate(self.decoders):
if isinstance(blocks, LevelBlock):
x = blocks(x)
else:
# Channel reduction
x = blocks(x)
# Upsample
x = self.upsample(x)
# Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape.
curr_res = residuals.pop()
x = x + curr_res * self.learned_middle_residual_scale
# Final projection
x = self.final_proj(x)
return x