# Modified from Meta DiT # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # DiT: https://github.com/facebookresearch/DiT/tree/main # GLIDE: https://github.com/openai/glide-text2im # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py # -------------------------------------------------------- import torch import torch.nn as nn import torch.utils.checkpoint def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool): if use_kernel: try: from apex.normalization import FusedLayerNorm return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps) except ImportError: raise RuntimeError("FusedLayerNorm not available. Please install apex.") else: return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine) def modulate(norm_func, x, shift, scale, use_kernel=False): # Suppose x is (N, T, D), shift is (N, D), scale is (N, D) dtype = x.dtype x = norm_func(x.to(torch.float32)).to(dtype) if use_kernel: try: from videosys.kernels.fused_modulate import fused_modulate x = fused_modulate(x, scale, shift) except ImportError: raise RuntimeError("FusedModulate kernel not available. Please install triton.") else: x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1) x = x.to(dtype) return x class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final, x, shift, scale) x = self.linear(x) return x class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype)