Spaces:
Running
on
Zero
Running
on
Zero
# Adapted from Open-Sora-Plan | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan | |
# -------------------------------------------------------- | |
import glob | |
import importlib | |
import os | |
from typing import Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from diffusers import ConfigMixin, ModelMixin | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.models.modeling_utils import ModelMixin | |
from einops import rearrange | |
from torch import nn | |
def Normalize(in_channels, num_groups=32): | |
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
def tensor_to_video(x): | |
x = x.detach().cpu() | |
x = torch.clamp(x, -1, 1) | |
x = (x + 1) / 2 | |
x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> | |
x = (255 * x).astype(np.uint8) | |
return x | |
def nonlinearity(x): | |
return x * torch.sigmoid(x) | |
class DiagonalGaussianDistribution(object): | |
def __init__(self, parameters, deterministic=False): | |
self.parameters = parameters | |
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
self.deterministic = deterministic | |
self.std = torch.exp(0.5 * self.logvar) | |
self.var = torch.exp(self.logvar) | |
if self.deterministic: | |
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) | |
def sample(self): | |
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) | |
return x | |
def kl(self, other=None): | |
if self.deterministic: | |
return torch.Tensor([0.0]) | |
else: | |
if other is None: | |
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) | |
else: | |
return 0.5 * torch.sum( | |
torch.pow(self.mean - other.mean, 2) / other.var | |
+ self.var / other.var | |
- 1.0 | |
- self.logvar | |
+ other.logvar, | |
dim=[1, 2, 3], | |
) | |
def nll(self, sample, dims=[1, 2, 3]): | |
if self.deterministic: | |
return torch.Tensor([0.0]) | |
logtwopi = np.log(2.0 * np.pi) | |
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) | |
def mode(self): | |
return self.mean | |
def resolve_str_to_obj(str_val, append=True): | |
if append: | |
str_val = "videosys.models.open_sora_plan.modules." + str_val | |
if "opensora.models.ae.videobase." in str_val: | |
str_val = str_val.replace("opensora.models.ae.videobase.", "videosys.models.open_sora_plan.") | |
module_name, class_name = str_val.rsplit(".", 1) | |
module = importlib.import_module(module_name) | |
return getattr(module, class_name) | |
class VideoBaseAE_PL(ModelMixin, ConfigMixin): | |
config_name = "config.json" | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
def encode(self, x: torch.Tensor, *args, **kwargs): | |
pass | |
def decode(self, encoding: torch.Tensor, *args, **kwargs): | |
pass | |
def num_training_steps(self) -> int: | |
"""Total training steps inferred from datamodule and devices.""" | |
if self.trainer.max_steps: | |
return self.trainer.max_steps | |
limit_batches = self.trainer.limit_train_batches | |
batches = len(self.train_dataloader()) | |
batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) | |
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) | |
if self.trainer.tpu_cores: | |
num_devices = max(num_devices, self.trainer.tpu_cores) | |
effective_accum = self.trainer.accumulate_grad_batches * num_devices | |
return (batches // effective_accum) * self.trainer.max_epochs | |
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): | |
ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt")) | |
if ckpt_files: | |
# Adapt to PyTorch Lightning | |
last_ckpt_file = ckpt_files[-1] | |
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) | |
model = cls.from_config(config_file) | |
print("init from {}".format(last_ckpt_file)) | |
model.init_from_ckpt(last_ckpt_file) | |
return model | |
else: | |
print(f"Loading model from {pretrained_model_name_or_path}") | |
return super().from_pretrained(pretrained_model_name_or_path, **kwargs) | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
z_channels: int, | |
hidden_size: int, | |
hidden_size_mult: Tuple[int] = (1, 2, 4, 4), | |
attn_resolutions: Tuple[int] = (16,), | |
conv_in: str = "Conv2d", | |
conv_out: str = "CasualConv3d", | |
attention: str = "AttnBlock", | |
resnet_blocks: Tuple[str] = ( | |
"ResnetBlock2D", | |
"ResnetBlock2D", | |
"ResnetBlock2D", | |
"ResnetBlock3D", | |
), | |
spatial_downsample: Tuple[str] = ( | |
"Downsample", | |
"Downsample", | |
"Downsample", | |
"", | |
), | |
temporal_downsample: Tuple[str] = ("", "", "TimeDownsampleRes2x", ""), | |
mid_resnet: str = "ResnetBlock3D", | |
dropout: float = 0.0, | |
resolution: int = 256, | |
num_res_blocks: int = 2, | |
double_z: bool = True, | |
) -> None: | |
super().__init__() | |
assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks) | |
# ---- Config ---- | |
self.num_resolutions = len(hidden_size_mult) | |
self.resolution = resolution | |
self.num_res_blocks = num_res_blocks | |
# ---- In ---- | |
self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1) | |
# ---- Downsample ---- | |
curr_res = resolution | |
in_ch_mult = (1,) + tuple(hidden_size_mult) | |
self.in_ch_mult = in_ch_mult | |
self.down = nn.ModuleList() | |
for i_level in range(self.num_resolutions): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_in = hidden_size * in_ch_mult[i_level] | |
block_out = hidden_size * hidden_size_mult[i_level] | |
for i_block in range(self.num_res_blocks): | |
block.append( | |
resolve_str_to_obj(resnet_blocks[i_level])( | |
in_channels=block_in, | |
out_channels=block_out, | |
dropout=dropout, | |
) | |
) | |
block_in = block_out | |
if curr_res in attn_resolutions: | |
attn.append(resolve_str_to_obj(attention)(block_in)) | |
down = nn.Module() | |
down.block = block | |
down.attn = attn | |
if spatial_downsample[i_level]: | |
down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in) | |
curr_res = curr_res // 2 | |
if temporal_downsample[i_level]: | |
down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in) | |
self.down.append(down) | |
# ---- Mid ---- | |
self.mid = nn.Module() | |
self.mid.block_1 = resolve_str_to_obj(mid_resnet)( | |
in_channels=block_in, | |
out_channels=block_in, | |
dropout=dropout, | |
) | |
self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) | |
self.mid.block_2 = resolve_str_to_obj(mid_resnet)( | |
in_channels=block_in, | |
out_channels=block_in, | |
dropout=dropout, | |
) | |
# ---- Out ---- | |
self.norm_out = Normalize(block_in) | |
self.conv_out = resolve_str_to_obj(conv_out)( | |
block_in, | |
2 * z_channels if double_z else z_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
def forward(self, x): | |
hs = [self.conv_in(x)] | |
for i_level in range(self.num_resolutions): | |
for i_block in range(self.num_res_blocks): | |
h = self.down[i_level].block[i_block](hs[-1]) | |
if len(self.down[i_level].attn) > 0: | |
h = self.down[i_level].attn[i_block](h) | |
hs.append(h) | |
if hasattr(self.down[i_level], "downsample"): | |
hs.append(self.down[i_level].downsample(hs[-1])) | |
if hasattr(self.down[i_level], "time_downsample"): | |
hs_down = self.down[i_level].time_downsample(hs[-1]) | |
hs.append(hs_down) | |
h = self.mid.block_1(h) | |
h = self.mid.attn_1(h) | |
h = self.mid.block_2(h) | |
h = self.norm_out(h) | |
h = nonlinearity(h) | |
h = self.conv_out(h) | |
return h | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
z_channels: int, | |
hidden_size: int, | |
hidden_size_mult: Tuple[int] = (1, 2, 4, 4), | |
attn_resolutions: Tuple[int] = (16,), | |
conv_in: str = "Conv2d", | |
conv_out: str = "CasualConv3d", | |
attention: str = "AttnBlock", | |
resnet_blocks: Tuple[str] = ( | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
), | |
spatial_upsample: Tuple[str] = ( | |
"", | |
"SpatialUpsample2x", | |
"SpatialUpsample2x", | |
"SpatialUpsample2x", | |
), | |
temporal_upsample: Tuple[str] = ("", "", "", "TimeUpsampleRes2x"), | |
mid_resnet: str = "ResnetBlock3D", | |
dropout: float = 0.0, | |
resolution: int = 256, | |
num_res_blocks: int = 2, | |
): | |
super().__init__() | |
# ---- Config ---- | |
self.num_resolutions = len(hidden_size_mult) | |
self.resolution = resolution | |
self.num_res_blocks = num_res_blocks | |
# ---- In ---- | |
block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] | |
curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1) | |
# ---- Mid ---- | |
self.mid = nn.Module() | |
self.mid.block_1 = resolve_str_to_obj(mid_resnet)( | |
in_channels=block_in, | |
out_channels=block_in, | |
dropout=dropout, | |
) | |
self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) | |
self.mid.block_2 = resolve_str_to_obj(mid_resnet)( | |
in_channels=block_in, | |
out_channels=block_in, | |
dropout=dropout, | |
) | |
# ---- Upsample ---- | |
self.up = nn.ModuleList() | |
for i_level in reversed(range(self.num_resolutions)): | |
block = nn.ModuleList() | |
attn = nn.ModuleList() | |
block_out = hidden_size * hidden_size_mult[i_level] | |
for i_block in range(self.num_res_blocks + 1): | |
block.append( | |
resolve_str_to_obj(resnet_blocks[i_level])( | |
in_channels=block_in, | |
out_channels=block_out, | |
dropout=dropout, | |
) | |
) | |
block_in = block_out | |
if curr_res in attn_resolutions: | |
attn.append(resolve_str_to_obj(attention)(block_in)) | |
up = nn.Module() | |
up.block = block | |
up.attn = attn | |
if spatial_upsample[i_level]: | |
up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in) | |
curr_res = curr_res * 2 | |
if temporal_upsample[i_level]: | |
up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in) | |
self.up.insert(0, up) | |
# ---- Out ---- | |
self.norm_out = Normalize(block_in) | |
self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1) | |
def forward(self, z): | |
h = self.conv_in(z) | |
h = self.mid.block_1(h) | |
h = self.mid.attn_1(h) | |
h = self.mid.block_2(h) | |
for i_level in reversed(range(self.num_resolutions)): | |
for i_block in range(self.num_res_blocks + 1): | |
h = self.up[i_level].block[i_block](h) | |
if len(self.up[i_level].attn) > 0: | |
h = self.up[i_level].attn[i_block](h) | |
if hasattr(self.up[i_level], "upsample"): | |
h = self.up[i_level].upsample(h) | |
if hasattr(self.up[i_level], "time_upsample"): | |
h = self.up[i_level].time_upsample(h) | |
h = self.norm_out(h) | |
h = nonlinearity(h) | |
h = self.conv_out(h) | |
return h | |
class CausalVAEModel(VideoBaseAE_PL): | |
def __init__( | |
self, | |
lr: float = 1e-5, | |
hidden_size: int = 128, | |
z_channels: int = 4, | |
hidden_size_mult: Tuple[int] = (1, 2, 4, 4), | |
attn_resolutions: Tuple[int] = [], | |
dropout: float = 0.0, | |
resolution: int = 256, | |
double_z: bool = True, | |
embed_dim: int = 4, | |
num_res_blocks: int = 2, | |
loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator", | |
loss_params: dict = { | |
"kl_weight": 0.000001, | |
"logvar_init": 0.0, | |
"disc_start": 2001, | |
"disc_weight": 0.5, | |
}, | |
q_conv: str = "CausalConv3d", | |
encoder_conv_in: str = "CausalConv3d", | |
encoder_conv_out: str = "CausalConv3d", | |
encoder_attention: str = "AttnBlock3D", | |
encoder_resnet_blocks: Tuple[str] = ( | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
), | |
encoder_spatial_downsample: Tuple[str] = ( | |
"SpatialDownsample2x", | |
"SpatialDownsample2x", | |
"SpatialDownsample2x", | |
"", | |
), | |
encoder_temporal_downsample: Tuple[str] = ( | |
"", | |
"TimeDownsample2x", | |
"TimeDownsample2x", | |
"", | |
), | |
encoder_mid_resnet: str = "ResnetBlock3D", | |
decoder_conv_in: str = "CausalConv3d", | |
decoder_conv_out: str = "CausalConv3d", | |
decoder_attention: str = "AttnBlock3D", | |
decoder_resnet_blocks: Tuple[str] = ( | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
"ResnetBlock3D", | |
), | |
decoder_spatial_upsample: Tuple[str] = ( | |
"", | |
"SpatialUpsample2x", | |
"SpatialUpsample2x", | |
"SpatialUpsample2x", | |
), | |
decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsample2x", "TimeUpsample2x"), | |
decoder_mid_resnet: str = "ResnetBlock3D", | |
) -> None: | |
super().__init__() | |
self.tile_sample_min_size = 256 | |
self.tile_sample_min_size_t = 65 | |
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) | |
t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] | |
self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1 | |
self.tile_overlap_factor = 0.25 | |
self.use_tiling = False | |
self.learning_rate = lr | |
self.lr_g_factor = 1.0 | |
self.loss = resolve_str_to_obj(loss_type, append=False)(**loss_params) | |
self.encoder = Encoder( | |
z_channels=z_channels, | |
hidden_size=hidden_size, | |
hidden_size_mult=hidden_size_mult, | |
attn_resolutions=attn_resolutions, | |
conv_in=encoder_conv_in, | |
conv_out=encoder_conv_out, | |
attention=encoder_attention, | |
resnet_blocks=encoder_resnet_blocks, | |
spatial_downsample=encoder_spatial_downsample, | |
temporal_downsample=encoder_temporal_downsample, | |
mid_resnet=encoder_mid_resnet, | |
dropout=dropout, | |
resolution=resolution, | |
num_res_blocks=num_res_blocks, | |
double_z=double_z, | |
) | |
self.decoder = Decoder( | |
z_channels=z_channels, | |
hidden_size=hidden_size, | |
hidden_size_mult=hidden_size_mult, | |
attn_resolutions=attn_resolutions, | |
conv_in=decoder_conv_in, | |
conv_out=decoder_conv_out, | |
attention=decoder_attention, | |
resnet_blocks=decoder_resnet_blocks, | |
spatial_upsample=decoder_spatial_upsample, | |
temporal_upsample=decoder_temporal_upsample, | |
mid_resnet=decoder_mid_resnet, | |
dropout=dropout, | |
resolution=resolution, | |
num_res_blocks=num_res_blocks, | |
) | |
quant_conv_cls = resolve_str_to_obj(q_conv) | |
self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) | |
self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) | |
if hasattr(self.loss, "discriminator"): | |
self.automatic_optimization = False | |
def encode(self, x): | |
if self.use_tiling and ( | |
x.shape[-1] > self.tile_sample_min_size | |
or x.shape[-2] > self.tile_sample_min_size | |
or x.shape[-3] > self.tile_sample_min_size_t | |
): | |
return self.tiled_encode(x) | |
h = self.encoder(x) | |
moments = self.quant_conv(h) | |
posterior = DiagonalGaussianDistribution(moments) | |
return posterior | |
def decode(self, z): | |
if self.use_tiling and ( | |
z.shape[-1] > self.tile_latent_min_size | |
or z.shape[-2] > self.tile_latent_min_size | |
or z.shape[-3] > self.tile_latent_min_size_t | |
): | |
return self.tiled_decode(z) | |
z = self.post_quant_conv(z) | |
dec = self.decoder(z) | |
return dec | |
def forward(self, input, sample_posterior=True): | |
posterior = self.encode(input) | |
if sample_posterior: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
dec = self.decode(z) | |
return dec, posterior | |
def get_input(self, batch, k): | |
x = batch[k] | |
if len(x.shape) == 3: | |
x = x[..., None] | |
x = x.to(memory_format=torch.contiguous_format).float() | |
return x | |
def training_step(self, batch, batch_idx): | |
if hasattr(self.loss, "discriminator"): | |
return self._training_step_gan(batch, batch_idx=batch_idx) | |
else: | |
return self._training_step(batch, batch_idx=batch_idx) | |
def _training_step(self, batch, batch_idx): | |
inputs = self.get_input(batch, "video") | |
reconstructions, posterior = self(inputs) | |
aeloss, log_dict_ae = self.loss( | |
inputs, | |
reconstructions, | |
posterior, | |
split="train", | |
) | |
self.log( | |
"aeloss", | |
aeloss, | |
prog_bar=True, | |
logger=True, | |
on_step=True, | |
on_epoch=True, | |
) | |
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) | |
return aeloss | |
def _training_step_gan(self, batch, batch_idx): | |
inputs = self.get_input(batch, "video") | |
reconstructions, posterior = self(inputs) | |
opt1, opt2 = self.optimizers() | |
# ---- AE Loss ---- | |
aeloss, log_dict_ae = self.loss( | |
inputs, | |
reconstructions, | |
posterior, | |
0, | |
self.global_step, | |
last_layer=self.get_last_layer(), | |
split="train", | |
) | |
self.log( | |
"aeloss", | |
aeloss, | |
prog_bar=True, | |
logger=True, | |
on_step=True, | |
on_epoch=True, | |
) | |
opt1.zero_grad() | |
self.manual_backward(aeloss) | |
self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm") | |
opt1.step() | |
# ---- GAN Loss ---- | |
discloss, log_dict_disc = self.loss( | |
inputs, | |
reconstructions, | |
posterior, | |
1, | |
self.global_step, | |
last_layer=self.get_last_layer(), | |
split="train", | |
) | |
self.log( | |
"discloss", | |
discloss, | |
prog_bar=True, | |
logger=True, | |
on_step=True, | |
on_epoch=True, | |
) | |
opt2.zero_grad() | |
self.manual_backward(discloss) | |
self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm") | |
opt2.step() | |
self.log_dict( | |
{**log_dict_ae, **log_dict_disc}, | |
prog_bar=False, | |
logger=True, | |
on_step=True, | |
on_epoch=False, | |
) | |
def configure_optimizers(self): | |
from itertools import chain | |
lr = self.learning_rate | |
modules_to_train = [ | |
self.encoder.named_parameters(), | |
self.decoder.named_parameters(), | |
self.post_quant_conv.named_parameters(), | |
self.quant_conv.named_parameters(), | |
] | |
params_with_time = [] | |
params_without_time = [] | |
for name, param in chain(*modules_to_train): | |
if "time" in name: | |
params_with_time.append(param) | |
else: | |
params_without_time.append(param) | |
optimizers = [] | |
opt_ae = torch.optim.Adam( | |
[ | |
{"params": params_with_time, "lr": lr}, | |
{"params": params_without_time, "lr": lr}, | |
], | |
lr=lr, | |
betas=(0.5, 0.9), | |
) | |
optimizers.append(opt_ae) | |
if hasattr(self.loss, "discriminator"): | |
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) | |
optimizers.append(opt_disc) | |
return optimizers, [] | |
def get_last_layer(self): | |
if hasattr(self.decoder.conv_out, "conv"): | |
return self.decoder.conv_out.conv.weight | |
else: | |
return self.decoder.conv_out.weight | |
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | |
blend_extent = min(a.shape[3], b.shape[3], blend_extent) | |
for y in range(blend_extent): | |
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( | |
y / blend_extent | |
) | |
return b | |
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | |
blend_extent = min(a.shape[4], b.shape[4], blend_extent) | |
for x in range(blend_extent): | |
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( | |
x / blend_extent | |
) | |
return b | |
def tiled_encode(self, x): | |
t = x.shape[2] | |
t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)] | |
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: | |
t_chunk_start_end = [[0, t]] | |
else: | |
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] | |
if t_chunk_start_end[-1][-1] > t: | |
t_chunk_start_end[-1][-1] = t | |
elif t_chunk_start_end[-1][-1] < t: | |
last_start_end = [t_chunk_idx[-1], t] | |
t_chunk_start_end.append(last_start_end) | |
moments = [] | |
for idx, (start, end) in enumerate(t_chunk_start_end): | |
chunk_x = x[:, :, start:end] | |
if idx != 0: | |
moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:] | |
else: | |
moment = self.tiled_encode2d(chunk_x, return_moments=True) | |
moments.append(moment) | |
moments = torch.cat(moments, dim=2) | |
posterior = DiagonalGaussianDistribution(moments) | |
return posterior | |
def tiled_decode(self, x): | |
t = x.shape[2] | |
t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)] | |
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: | |
t_chunk_start_end = [[0, t]] | |
else: | |
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] | |
if t_chunk_start_end[-1][-1] > t: | |
t_chunk_start_end[-1][-1] = t | |
elif t_chunk_start_end[-1][-1] < t: | |
last_start_end = [t_chunk_idx[-1], t] | |
t_chunk_start_end.append(last_start_end) | |
dec_ = [] | |
for idx, (start, end) in enumerate(t_chunk_start_end): | |
chunk_x = x[:, :, start:end] | |
if idx != 0: | |
dec = self.tiled_decode2d(chunk_x)[:, :, 1:] | |
else: | |
dec = self.tiled_decode2d(chunk_x) | |
dec_.append(dec) | |
dec_ = torch.cat(dec_, dim=2) | |
return dec_ | |
def tiled_encode2d(self, x, return_moments=False): | |
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) | |
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) | |
row_limit = self.tile_latent_min_size - blend_extent | |
# Split the image into 512x512 tiles and encode them separately. | |
rows = [] | |
for i in range(0, x.shape[3], overlap_size): | |
row = [] | |
for j in range(0, x.shape[4], overlap_size): | |
tile = x[ | |
:, | |
:, | |
:, | |
i : i + self.tile_sample_min_size, | |
j : j + self.tile_sample_min_size, | |
] | |
tile = self.encoder(tile) | |
tile = self.quant_conv(tile) | |
row.append(tile) | |
rows.append(row) | |
result_rows = [] | |
for i, row in enumerate(rows): | |
result_row = [] | |
for j, tile in enumerate(row): | |
# blend the above tile and the left tile | |
# to the current tile and add the current tile to the result row | |
if i > 0: | |
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | |
if j > 0: | |
tile = self.blend_h(row[j - 1], tile, blend_extent) | |
result_row.append(tile[:, :, :, :row_limit, :row_limit]) | |
result_rows.append(torch.cat(result_row, dim=4)) | |
moments = torch.cat(result_rows, dim=3) | |
posterior = DiagonalGaussianDistribution(moments) | |
if return_moments: | |
return moments | |
return posterior | |
def tiled_decode2d(self, z): | |
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) | |
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) | |
row_limit = self.tile_sample_min_size - blend_extent | |
# Split z into overlapping 64x64 tiles and decode them separately. | |
# The tiles have an overlap to avoid seams between tiles. | |
rows = [] | |
for i in range(0, z.shape[3], overlap_size): | |
row = [] | |
for j in range(0, z.shape[4], overlap_size): | |
tile = z[ | |
:, | |
:, | |
:, | |
i : i + self.tile_latent_min_size, | |
j : j + self.tile_latent_min_size, | |
] | |
tile = self.post_quant_conv(tile) | |
decoded = self.decoder(tile) | |
row.append(decoded) | |
rows.append(row) | |
result_rows = [] | |
for i, row in enumerate(rows): | |
result_row = [] | |
for j, tile in enumerate(row): | |
# blend the above tile and the left tile | |
# to the current tile and add the current tile to the result row | |
if i > 0: | |
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | |
if j > 0: | |
tile = self.blend_h(row[j - 1], tile, blend_extent) | |
result_row.append(tile[:, :, :, :row_limit, :row_limit]) | |
result_rows.append(torch.cat(result_row, dim=4)) | |
dec = torch.cat(result_rows, dim=3) | |
return dec | |
def enable_tiling(self, use_tiling: bool = True): | |
self.use_tiling = use_tiling | |
def disable_tiling(self): | |
self.enable_tiling(False) | |
def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False): | |
sd = torch.load(path, map_location="cpu") | |
print("init from " + path) | |
if "state_dict" in sd: | |
sd = sd["state_dict"] | |
keys = list(sd.keys()) | |
for k in keys: | |
for ik in ignore_keys: | |
if k.startswith(ik): | |
print("Deleting key {} from state_dict.".format(k)) | |
del sd[k] | |
self.load_state_dict(sd, strict=False) | |
def validation_step(self, batch, batch_idx): | |
inputs = self.get_input(batch, "video") | |
latents = self.encode(inputs).sample() | |
video_recon = self.decode(latents) | |
for idx in range(len(video_recon)): | |
self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10]) | |
class CausalVAEModelWrapper(nn.Module): | |
def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): | |
super(CausalVAEModelWrapper, self).__init__() | |
# if os.path.exists(ckpt): | |
# self.vae = CausalVAEModel.load_from_checkpoint(ckpt) | |
self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) | |
def encode(self, x): # b c t h w | |
# x = self.vae.encode(x).sample() | |
x = self.vae.encode(x).sample().mul_(0.18215) | |
return x | |
def decode(self, x): | |
# x = self.vae.decode(x) | |
x = self.vae.decode(x / 0.18215) | |
x = rearrange(x, "b c t h w -> b t c h w").contiguous() | |
return x | |
def dtype(self): | |
return self.vae.dtype | |
# | |
# def device(self): | |
# return self.vae.device | |
videobase_ae_stride = { | |
"CausalVAEModel_4x8x8": [4, 8, 8], | |
} | |
videobase_ae_channel = { | |
"CausalVAEModel_4x8x8": 4, | |
} | |
videobase_ae = { | |
"CausalVAEModel_4x8x8": CausalVAEModelWrapper, | |
} | |
ae_stride_config = {} | |
ae_stride_config.update(videobase_ae_stride) | |
ae_channel_config = {} | |
ae_channel_config.update(videobase_ae_channel) | |
def getae_wrapper(ae): | |
"""deprecation""" | |
ae = videobase_ae.get(ae, None) | |
assert ae is not None | |
return ae | |