Spaces:
Running
on
Zero
Running
on
Zero
# Adapted from Open-Sora-Plan | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan | |
# -------------------------------------------------------- | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .ops import shift_dim | |
class Codebook(nn.Module): | |
def __init__(self, n_codes, embedding_dim): | |
super().__init__() | |
self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) | |
self.register_buffer("N", torch.zeros(n_codes)) | |
self.register_buffer("z_avg", self.embeddings.data.clone()) | |
self.n_codes = n_codes | |
self.embedding_dim = embedding_dim | |
self._need_init = True | |
def _tile(self, x): | |
d, ew = x.shape | |
if d < self.n_codes: | |
n_repeats = (self.n_codes + d - 1) // d | |
std = 0.01 / np.sqrt(ew) | |
x = x.repeat(n_repeats, 1) | |
x = x + torch.randn_like(x) * std | |
return x | |
def _init_embeddings(self, z): | |
# z: [b, c, t, h, w] | |
self._need_init = False | |
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) | |
y = self._tile(flat_inputs) | |
y.shape[0] | |
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] | |
if dist.is_initialized(): | |
dist.broadcast(_k_rand, 0) | |
self.embeddings.data.copy_(_k_rand) | |
self.z_avg.data.copy_(_k_rand) | |
self.N.data.copy_(torch.ones(self.n_codes)) | |
def forward(self, z): | |
# z: [b, c, t, h, w] | |
if self._need_init and self.training: | |
self._init_embeddings(z) | |
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) | |
distances = ( | |
(flat_inputs**2).sum(dim=1, keepdim=True) | |
- 2 * flat_inputs @ self.embeddings.t() | |
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) | |
) | |
encoding_indices = torch.argmin(distances, dim=1) | |
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) | |
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) | |
embeddings = F.embedding(encoding_indices, self.embeddings) | |
embeddings = shift_dim(embeddings, -1, 1) | |
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) | |
# EMA codebook update | |
if self.training: | |
n_total = encode_onehot.sum(dim=0) | |
encode_sum = flat_inputs.t() @ encode_onehot | |
if dist.is_initialized(): | |
dist.all_reduce(n_total) | |
dist.all_reduce(encode_sum) | |
self.N.data.mul_(0.99).add_(n_total, alpha=0.01) | |
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) | |
n = self.N.sum() | |
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n | |
encode_normalized = self.z_avg / weights.unsqueeze(1) | |
self.embeddings.data.copy_(encode_normalized) | |
y = self._tile(flat_inputs) | |
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] | |
if dist.is_initialized(): | |
dist.broadcast(_k_rand, 0) | |
usage = (self.N.view(self.n_codes, 1) >= 1).float() | |
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) | |
embeddings_st = (embeddings - z).detach() + z | |
avg_probs = torch.mean(encode_onehot, dim=0) | |
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | |
return dict( | |
embeddings=embeddings_st, | |
encodings=encoding_indices, | |
commitment_loss=commitment_loss, | |
perplexity=perplexity, | |
) | |
def dictionary_lookup(self, encodings): | |
embeddings = F.embedding(encodings, self.embeddings) | |
return embeddings | |