zxl
first commit
07c6a04
raw
history blame
2.82 kB
# Modified from Meta DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.utils.checkpoint
def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
if use_kernel:
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps)
except ImportError:
raise RuntimeError("FusedLayerNorm not available. Please install apex.")
else:
return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
def modulate(norm_func, x, shift, scale, use_kernel=False):
# Suppose x is (N, T, D), shift is (N, D), scale is (N, D)
dtype = x.dtype
x = norm_func(x.to(torch.float32)).to(dtype)
if use_kernel:
try:
from videosys.kernels.fused_modulate import fused_modulate
x = fused_modulate(x, scale, shift)
except ImportError:
raise RuntimeError("FusedModulate kernel not available. Please install triton.")
else:
x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1)
x = x.to(dtype)
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final, x, shift, scale)
x = self.linear(x)
return x
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)