import torch import torch.nn as nn import torch.nn.functional as F # Changelog since original version: # xATGLU instead of top linear in transformer block # Added a learned residual scale to all blocks and all residuals. This allowed bfloat16 training to stabilize, prior it was just exploding. # This architecture was my attempt at the following Simple Diffusion paper with some modifications: # https://arxiv.org/pdf/2410.19324v1 # Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn. class xATGLU(nn.Module): def __init__(self, input_dim, output_dim, bias=True): super().__init__() # GATE path | VALUE path self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias) nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear') self.alpha = nn.Parameter(torch.zeros(1)) self.half_pi = torch.pi / 2 self.inv_pi = 1 / torch.pi def forward(self, x): projected = self.proj(x) gate_path, value_path = projected.chunk(2, dim=-1) # Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768 gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha return expanded_gate * value_path # g(x) × y # Tensor product attention, modified. Original code from: # https://github.com/tensorgi/T6/blob/main/model/T6_ropek.py # https://arxiv.org/pdf/2501.06425 class CPLinear(nn.Module): def __init__(self, in_features, n_head, head_dim, rank: int = 1, q_rank: int = 12): super(CPLinear, self).__init__() self.in_features = in_features self.n_head = n_head self.head_dim = head_dim self.rank = rank self.q_rank = q_rank self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False) self.W_A_k = nn.Linear(in_features, n_head * rank, bias=False) self.W_A_v = nn.Linear(in_features, n_head * rank, bias=False) nn.init.xavier_normal_(self.W_A_q.weight) nn.init.xavier_normal_(self.W_A_k.weight) nn.init.xavier_normal_(self.W_A_v.weight) self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False) self.W_B_k = nn.Linear(in_features, rank * head_dim, bias=False) self.W_B_v = nn.Linear(in_features, rank * head_dim, bias=False) nn.init.xavier_normal_(self.W_B_q.weight) nn.init.xavier_normal_(self.W_B_k.weight) nn.init.xavier_normal_(self.W_B_v.weight) def forward(self, x): batch_size, seq_len, _ = x.size() # A clarification on the naming, it's somewhat standard to call the two low rank matrices A and B, so I've followed that. # Compute intermediate variables A for Q, K, and V A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank) A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.rank) A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.rank) # Compute intermediate variables B for Q, K, and V B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim) B_k = self.W_B_k(x).view(batch_size, seq_len, self.rank, self.head_dim) B_v = self.W_B_v(x).view(batch_size, seq_len, self.rank, self.head_dim) # Reshape A_q, A_k, A_v A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank) A_k = A_k.view(batch_size * seq_len, self.n_head, self.rank) A_v = A_v.view(batch_size * seq_len, self.n_head, self.rank) # Reshape B_k, B_v B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim) B_k = B_k.view(batch_size * seq_len, self.rank, self.head_dim) B_v = B_v.view(batch_size * seq_len, self.rank, self.head_dim) q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim) k = torch.bmm(A_k, B_k).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim) v = torch.bmm(A_v, B_v).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim) return q, k, v # Very possible this is not a good method for positional encoding in DiT, in fact it may be actively harmful. It does help in small datasets though. # No positional embedding should be a serious consideration for high compute resources/large data scenarios. class Rotary(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None def forward(self, x): seq_len = x.shape[1] if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq).to(x.device) self.cos_cached = freqs.cos().bfloat16() self.sin_cached = freqs.sin().bfloat16() return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 x1 = x[..., :d] x2 = x[..., d:] y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat([y1, y2], 3).type_as(x) class TensorProductAttentionWithRope(nn.Module): def __init__(self, n_head, head_dim, n_embd, kv_rank=2, q_rank=6): super().__init__() self.n_head = n_head self.head_dim = head_dim self.n_embd = n_embd self.kv_rank = kv_rank self.q_rank = q_rank self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.kv_rank, self.q_rank) # Output projection. Bias seems sensible here, each head can learn a shift. self.o_proj = xATGLU(self.n_head * self.head_dim, self.n_embd, bias=True) # Not a layer, just a helper self.rotary = Rotary(self.head_dim) def forward(self, x): B, T, C = x.size() # batch_size, seq_length (T), embedding_dim # Get Q, K, V through CPLinear factorization q, k, v = self.c_qkv(x) # Each shape: (B, T, n_head, head_dim) cos, sin = self.rotary(q) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) # SDPA expects (B, n_head, T, head_dim) q = q.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim k = k.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim v = v.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim # Compute attention using scaled_dot_product_attention y = F.scaled_dot_product_attention(q, k, v, is_causal=False) # Back to B T C y = y.transpose(1, 2).flatten(2) y = self.o_proj(y) return y class ResBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.norm1 = nn.GroupNorm(32, channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.norm2 = nn.GroupNorm(32, channels) self.learned_residual_scale = nn.Parameter(torch.ones(1) * 0.1) def forward(self, x): h = self.conv1(F.silu(self.norm1(x))) h = self.conv2(F.silu(self.norm2(h))) return x + h * self.learned_residual_scale class TransformerBlock(nn.Module): def __init__(self, channels, num_heads=8): super().__init__() self.norm1 = nn.LayerNorm(channels) self.norm2 = nn.LayerNorm(channels) # Params recommended by TPA paper, seem to work fine. self.attn = TensorProductAttentionWithRope( n_head=num_heads, head_dim=channels // num_heads, n_embd=channels, kv_rank=2, q_rank=6 ) self.mlp = nn.Sequential( xATGLU(channels, 2 * channels, bias=False), nn.Linear(2 * channels, channels, bias=False) # Candidate for a bias ) self.learned_residual_scale_attn = nn.Parameter(torch.ones(1) * 0.1) self.learned_residual_scale_mlp = nn.Parameter(torch.ones(1) * 0.1) def forward(self, x): # Input shape B C H W b, c, h, w = x.shape x = x.reshape(b, h * w, c) # [B, H*W, C] # Pre-norm architecture, this was really helpful for network stability when using bf16 identity = x x = self.norm1(x) h_attn = self.attn(x) #h_attn, _ = self.attn(x, x, x) x = identity + h_attn * self.learned_residual_scale_attn identity = x x = self.norm2(x) h_mlp = self.mlp(x) x = identity + h_mlp * self.learned_residual_scale_mlp # Reshape back to B C H W x = x.permute(1, 2, 0).reshape(b, c, h, w) return x class LevelBlock(nn.Module): def __init__(self, channels, num_blocks, block_type='res'): super().__init__() self.blocks = nn.ModuleList() for _ in range(num_blocks): if block_type == 'transformer': self.blocks.append(TransformerBlock(channels)) else: self.blocks.append(ResBlock(channels)) def forward(self, x): for block in self.blocks: x = block(x) return x class AsymmetricResidualUDiT(nn.Module): def __init__(self, in_channels=3, # Input color channels base_channels=128, # Initial feature size, dramatically increases parameter size of network. patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute. num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase. encoder_blocks=3, # Can be different number of blocks VS decoder_blocks decoder_blocks=7, # Can be different number of blocks VS encoder_blocks encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=) decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=) mid_blocks=16, # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck. ): super().__init__() self.learned_middle_residual_scale = nn.Parameter(torch.ones(1) * 0.1) # Initial projection from image space self.patch_embed = nn.Conv2d(in_channels, base_channels, kernel_size=patch_size, stride=patch_size) self.encoders = nn.ModuleList() curr_channels = base_channels for level in range(num_levels): use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels # Encoder blocks -- N = encoder_blocks self.encoders.append( LevelBlock(curr_channels, encoder_blocks, use_transformer) ) # Each successive decoder halves the size of the feature space for each step, except for the last level. if level < num_levels - 1: self.encoders.append( nn.Conv2d(curr_channels, curr_channels * 2, 1) ) curr_channels *= 2 # Middle transformer blocks -- N = mid_blocks self.middle = nn.ModuleList([ TransformerBlock(curr_channels) for _ in range(mid_blocks) ]) # Create decoder levels self.decoders = nn.ModuleList() for level in range(num_levels): use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder) # Decoder blocks -- N = decoder_blocks self.decoders.append( LevelBlock(curr_channels, decoder_blocks, use_transformer) ) # Each successive decoder halves the size of the feature space for each step, except for the last level. if level < num_levels - 1: self.decoders.append( nn.Conv2d(curr_channels, curr_channels // 2, 1) ) curr_channels //= 2 # Final projection back to image space self.final_proj = nn.ConvTranspose2d(base_channels, in_channels, kernel_size=patch_size, stride=patch_size) def downsample(self, x): return F.avg_pool2d(x, kernel_size=2) def upsample(self, x): return F.interpolate(x, scale_factor=2, mode='nearest') def forward(self, x, t=None): # x shape B C H W # This patchifies our input, for example given an input shape like: # From 2, 3, 256, 256 x = self.patch_embed(x) # Our shape is now more channels and with smaller W and H # To 2, 128, 64, 64 # *Per resolution e.g. per num_level resolution block more or less # f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x) # # Where # 1. h = fd(x) : Encoder path processes input # 2. D(h) : Downsample the encoded features # 3. fm(D(h)) : Middle transformer blocks process downsampled features # 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection) # 5. U(...) : Upsample the processed features # 6. ... + h : Add back original encoder features (skip connection) # 7. fu(...) : Decoder path processes the combined features residuals = [] curr_res = x # Encoder path (computing h = fd(x)) h = x for i, blocks in enumerate(self.encoders): if isinstance(blocks, LevelBlock): h = blocks(h) else: # Save residual before downsampling residuals.append(curr_res) # Downsample and update current residual h = self.downsample(blocks(h)) curr_res = h # Middle blocks (fm) x = h for block in self.middle: x = block(x) # Subtract the residual at this level (D(h)) x = x - curr_res * self.learned_middle_residual_scale # Decoder path (fu) for i, blocks in enumerate(self.decoders): if isinstance(blocks, LevelBlock): x = blocks(x) else: # Channel reduction x = blocks(x) # Upsample x = self.upsample(x) # Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape. curr_res = residuals.pop() x = x + curr_res * self.learned_middle_residual_scale # Final projection x = self.final_proj(x) return x