Spaces:
Running
on
Zero
Running
on
Zero
# Modified from Meta DiT | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# DiT: https://github.com/facebookresearch/DiT/tree/main | |
# GLIDE: https://github.com/openai/glide-text2im | |
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py | |
# -------------------------------------------------------- | |
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint | |
def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool): | |
if use_kernel: | |
try: | |
from apex.normalization import FusedLayerNorm | |
return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps) | |
except ImportError: | |
raise RuntimeError("FusedLayerNorm not available. Please install apex.") | |
else: | |
return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine) | |
def modulate(norm_func, x, shift, scale, use_kernel=False): | |
# Suppose x is (N, T, D), shift is (N, D), scale is (N, D) | |
dtype = x.dtype | |
x = norm_func(x.to(torch.float32)).to(dtype) | |
if use_kernel: | |
try: | |
from videosys.kernels.fused_modulate import fused_modulate | |
x = fused_modulate(x, scale, shift) | |
except ImportError: | |
raise RuntimeError("FusedModulate kernel not available. Please install triton.") | |
else: | |
x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1) | |
x = x.to(dtype) | |
return x | |
class FinalLayer(nn.Module): | |
""" | |
The final layer of DiT. | |
""" | |
def __init__(self, hidden_size, patch_size, out_channels): | |
super().__init__() | |
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True) | |
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) | |
def forward(self, x, c): | |
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) | |
x = modulate(self.norm_final, x, shift, scale) | |
x = self.linear(x) | |
return x | |
class LlamaRMSNorm(nn.Module): | |
def __init__(self, hidden_size, eps=1e-6): | |
""" | |
LlamaRMSNorm is equivalent to T5LayerNorm | |
""" | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.variance_epsilon = eps | |
def forward(self, hidden_states): | |
input_dtype = hidden_states.dtype | |
hidden_states = hidden_states.to(torch.float32) | |
variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
return self.weight * hidden_states.to(input_dtype) | |