Spaces:
Running
on
Zero
Running
on
Zero
# Adapted from CogVideo | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# CogVideo: https://github.com/THUDM/CogVideo | |
# diffusers: https://github.com/huggingface/diffusers | |
# -------------------------------------------------------- | |
from typing import Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed_from_grid | |
class CogVideoXDownsample3D(nn.Module): | |
# Todo: Wait for paper relase. | |
r""" | |
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI | |
Args: | |
in_channels (`int`): | |
Number of channels in the input image. | |
out_channels (`int`): | |
Number of channels produced by the convolution. | |
kernel_size (`int`, defaults to `3`): | |
Size of the convolving kernel. | |
stride (`int`, defaults to `2`): | |
Stride of the convolution. | |
padding (`int`, defaults to `0`): | |
Padding added to all four sides of the input. | |
compress_time (`bool`, defaults to `False`): | |
Whether or not to compress the time dimension. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int = 3, | |
stride: int = 2, | |
padding: int = 0, | |
compress_time: bool = False, | |
): | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) | |
self.compress_time = compress_time | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if self.compress_time: | |
batch_size, channels, frames, height, width = x.shape | |
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) | |
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) | |
if x.shape[-1] % 2 == 1: | |
x_first, x_rest = x[..., 0], x[..., 1:] | |
if x_rest.shape[-1] > 0: | |
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) | |
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) | |
x = torch.cat([x_first[..., None], x_rest], dim=-1) | |
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) | |
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) | |
else: | |
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) | |
x = F.avg_pool1d(x, kernel_size=2, stride=2) | |
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) | |
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) | |
# Pad the tensor | |
pad = (0, 1, 0, 1) | |
x = F.pad(x, pad, mode="constant", value=0) | |
batch_size, channels, frames, height, width = x.shape | |
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) | |
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) | |
x = self.conv(x) | |
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) | |
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) | |
return x | |
class CogVideoXUpsample3D(nn.Module): | |
r""" | |
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. | |
Args: | |
in_channels (`int`): | |
Number of channels in the input image. | |
out_channels (`int`): | |
Number of channels produced by the convolution. | |
kernel_size (`int`, defaults to `3`): | |
Size of the convolving kernel. | |
stride (`int`, defaults to `1`): | |
Stride of the convolution. | |
padding (`int`, defaults to `1`): | |
Padding added to all four sides of the input. | |
compress_time (`bool`, defaults to `False`): | |
Whether or not to compress the time dimension. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int = 3, | |
stride: int = 1, | |
padding: int = 1, | |
compress_time: bool = False, | |
) -> None: | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) | |
self.compress_time = compress_time | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
if self.compress_time: | |
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: | |
# split first frame | |
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] | |
x_first = F.interpolate(x_first, scale_factor=2.0) | |
x_rest = F.interpolate(x_rest, scale_factor=2.0) | |
x_first = x_first[:, :, None, :, :] | |
inputs = torch.cat([x_first, x_rest], dim=2) | |
elif inputs.shape[2] > 1: | |
inputs = F.interpolate(inputs, scale_factor=2.0) | |
else: | |
inputs = inputs.squeeze(2) | |
inputs = F.interpolate(inputs, scale_factor=2.0) | |
inputs = inputs[:, :, None, :, :] | |
else: | |
# only interpolate 2D | |
b, c, t, h, w = inputs.shape | |
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) | |
inputs = F.interpolate(inputs, scale_factor=2.0) | |
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) | |
b, c, t, h, w = inputs.shape | |
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) | |
inputs = self.conv(inputs) | |
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) | |
return inputs | |
def get_3d_sincos_pos_embed( | |
embed_dim: int, | |
spatial_size: Union[int, Tuple[int, int]], | |
temporal_size: int, | |
spatial_interpolation_scale: float = 1.0, | |
temporal_interpolation_scale: float = 1.0, | |
) -> np.ndarray: | |
r""" | |
Args: | |
embed_dim (`int`): | |
spatial_size (`int` or `Tuple[int, int]`): | |
temporal_size (`int`): | |
spatial_interpolation_scale (`float`, defaults to 1.0): | |
temporal_interpolation_scale (`float`, defaults to 1.0): | |
""" | |
if embed_dim % 4 != 0: | |
raise ValueError("`embed_dim` must be divisible by 4") | |
if isinstance(spatial_size, int): | |
spatial_size = (spatial_size, spatial_size) | |
embed_dim_spatial = 3 * embed_dim // 4 | |
embed_dim_temporal = embed_dim // 4 | |
# 1. Spatial | |
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale | |
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) | |
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) | |
# 2. Temporal | |
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale | |
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) | |
# 3. Concat | |
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] | |
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] | |
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] | |
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] | |
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] | |
return pos_embed | |
class CogVideoXPatchEmbed(nn.Module): | |
def __init__( | |
self, | |
patch_size: int = 2, | |
in_channels: int = 16, | |
embed_dim: int = 1920, | |
text_embed_dim: int = 4096, | |
bias: bool = True, | |
) -> None: | |
super().__init__() | |
self.patch_size = patch_size | |
self.proj = nn.Conv2d( | |
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
) | |
self.text_proj = nn.Linear(text_embed_dim, embed_dim) | |
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | |
r""" | |
Args: | |
text_embeds (`torch.Tensor`): | |
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). | |
image_embeds (`torch.Tensor`): | |
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). | |
""" | |
text_embeds = self.text_proj(text_embeds) | |
batch, num_frames, channels, height, width = image_embeds.shape | |
image_embeds = image_embeds.reshape(-1, channels, height, width) | |
image_embeds = self.proj(image_embeds) | |
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) | |
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] | |
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] | |
embeds = torch.cat( | |
[text_embeds, image_embeds], dim=1 | |
).contiguous() # [batch, seq_length + num_frames x height x width, channels] | |
return embeds | |
class CogVideoXLayerNormZero(nn.Module): | |
def __init__( | |
self, | |
conditioning_dim: int, | |
embedding_dim: int, | |
elementwise_affine: bool = True, | |
eps: float = 1e-5, | |
bias: bool = True, | |
) -> None: | |
super().__init__() | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) | |
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) | |
def forward( | |
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) | |
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] | |
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] | |
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] | |
class AdaLayerNorm(nn.Module): | |
r""" | |
Norm layer modified to incorporate timestep embeddings. | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
num_embeddings (`int`, *optional*): The size of the embeddings dictionary. | |
output_dim (`int`, *optional*): | |
norm_elementwise_affine (`bool`, defaults to `False): | |
norm_eps (`bool`, defaults to `False`): | |
chunk_dim (`int`, defaults to `0`): | |
""" | |
def __init__( | |
self, | |
embedding_dim: int, | |
num_embeddings: Optional[int] = None, | |
output_dim: Optional[int] = None, | |
norm_elementwise_affine: bool = False, | |
norm_eps: float = 1e-5, | |
chunk_dim: int = 0, | |
): | |
super().__init__() | |
self.chunk_dim = chunk_dim | |
output_dim = output_dim or embedding_dim * 2 | |
if num_embeddings is not None: | |
self.emb = nn.Embedding(num_embeddings, embedding_dim) | |
else: | |
self.emb = None | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(embedding_dim, output_dim) | |
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) | |
def forward( | |
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None | |
) -> torch.Tensor: | |
if self.emb is not None: | |
temb = self.emb(timestep) | |
temb = self.linear(self.silu(temb)) | |
if self.chunk_dim == 1: | |
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the | |
# other if-branch. This branch is specific to CogVideoX for now. | |
shift, scale = temb.chunk(2, dim=1) | |
shift = shift[:, None, :] | |
scale = scale[:, None, :] | |
else: | |
scale, shift = temb.chunk(2, dim=0) | |
x = self.norm(x) * (1 + scale) + shift | |
return x | |