import argparse, os, sys, glob, yaml, math, random import datetime, time import numpy as np from omegaconf import OmegaConf from collections import OrderedDict from tqdm import trange, tqdm from einops import repeat from einops import rearrange, repeat from functools import partial import torch from pytorch_lightning import seed_everything from .funcs import load_model_checkpoint, load_image_batch, get_filelist, save_videos from .funcs import batch_ddim_sampling from .utils import instantiate_from_config def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=20230211, help="seed for seed_everything") parser.add_argument("--mode", default="base", type=str, help="which kind of inference mode: {'base', 'i2v'}") parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") parser.add_argument("--config", type=str, help="config (yaml) path") parser.add_argument("--savefps", type=str, default=10, help="video fps to generate") parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt", ) parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM", ) parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)", ) parser.add_argument("--bs", type=int, default=1, help="batch size for inference") parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") parser.add_argument("--frames", type=int, default=-1, help="frames num to inference") parser.add_argument("--fps", type=int, default=24) parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance") parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, help="temporal consistency guidance") ## for conditional i2v only # parser.add_argument("--cond_input", type=str, default=None, help="data dir of conditional input") return parser class VideoCrafterPipeline(): def __init__(self, arg_list, device, rank: int = 0, gpu_num: int = 1): """ Initialize the pipeline of videocrafter. It is always on one GPU. Args: arg_list: The parameters needed for the model. device: rank: gpu_num: """ parser = get_parser() self.args = parser.parse_args(args=arg_list) self.gpu_no, self.gpu_num = rank, gpu_num _dict = {'model': {'target': 'lvdm.models.ddpm3d.LatentDiffusion', 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'timesteps': 1000, 'first_stage_key': 'video', 'cond_stage_key': 'caption', 'cond_stage_trainable': False, 'conditioning_key': 'crossattn', 'image_size': [40, 64], 'channels': 4, 'scale_by_std': False, 'scale_factor': 0.18215, 'use_ema': False, 'uncond_type': 'empty_seq', 'use_scale': True, 'scale_b': 0.7, 'unet_config': {'target': 'lvdm.modules.networks.openaimodel3d.UNetModel', 'params': {'in_channels': 4, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'transformer_depth': 1, 'context_dim': 1024, 'use_linear': True, 'use_checkpoint': True, 'temporal_conv': True, 'temporal_attention': True, 'temporal_selfatt_only': True, 'use_relative_position': False, 'use_causal_attention': False, 'temporal_length': 16, 'addition_attention': True, 'fps_cond': True}}, 'first_stage_config': {'target': 'lvdm.models.autoencoder.AutoencoderKL', 'params': {'embed_dim': 4, 'monitor': 'val/rec_loss', 'ddconfig': {'double_z': True, 'z_channels': 4, 'resolution': 512, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}, 'lossconfig': {'target': 'torch.nn.Identity'}}}, 'cond_stage_config': {'target': 'lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder', 'params': {'freeze': True, 'layer': 'penultimate'}}}}} config = OmegaConf.create(_dict) #config = OmegaConf.load(self.args.config) # data_config = config.pop("data", OmegaConf.create()) model_config = config.pop("model", OmegaConf.create()) model = instantiate_from_config(model_config) model = model.cuda(self.gpu_no) print("About to load model") assert os.path.exists(self.args.ckpt_path), f"Error: checkpoint [{self.args.ckpt_path}] Not Found!" self.model = load_model_checkpoint(model, self.args.ckpt_path) self.model.eval() def run_inference(self, prompt, video_length, height, width, **kwargs): """ https://github.com/AILab-CVC/VideoCrafter Generate video from the provided text prompt. Args: prompt: The provided text prompt. video_length: The length (num of frames) of the generated video. height: The height of the video frame. width: The width of the video frame. **kwargs: Returns: The generated video represented as tensor with shape (1, 1, channels, height, width, num of frames) """ ## step 1: model config ## ----------------------------------------------------------------- ## sample shape assert (self.args.height % 16 == 0) and ( self.args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" ## latent noise shape h, w = height // 8, width // 8 frames = video_length channels = self.model.channels ## step 2: load data ## ----------------------------------------------------------------- prompt_list = [prompt] num_samples = len(prompt_list) # filename_list = [f"{id + 1:04d}" for id in range(num_samples)] gpu_num = self.gpu_num gpu_no = self.gpu_no samples_split = num_samples // gpu_num residual_tail = num_samples % gpu_num print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.') indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1))) if gpu_no == 0 and residual_tail != 0: indices = indices + list(range(num_samples - residual_tail, num_samples)) prompt_list_rank = [prompt_list[i] for i in indices] # # conditional input # if self.args.mode == "i2v": # ## each video or frames dir per prompt # cond_inputs = get_filelist(self.args.cond_input, ext='[mpj][pn][4gj]') # '[mpj][pn][4gj]' # assert len( # cond_inputs) == num_samples, f"Error: conditional input ({len(cond_inputs)}) NOT match prompt ({num_samples})!" # filename_list = [f"{os.path.split(cond_inputs[id])[-1][:-4]}" for id in range(num_samples)] # cond_inputs_rank = [cond_inputs[i] for i in indices] # filename_list_rank = [filename_list[i] for i in indices] ## step 3: run over samples ## ----------------------------------------------------------------- # start = time.time() n_rounds = len(prompt_list_rank) // self.args.bs n_rounds = n_rounds + 1 if len(prompt_list_rank) % self.args.bs != 0 else n_rounds for idx in range(0, n_rounds): print(f'[rank:{gpu_no}] batch-{idx + 1} ({self.args.bs})x{self.args.n_samples} ...') idx_s = idx * self.args.bs idx_e = min(idx_s + self.args.bs, len(prompt_list_rank)) batch_size = idx_e - idx_s # filenames = filename_list_rank[idx_s:idx_e] noise_shape = [batch_size, channels, frames, h, w] fps = torch.tensor([self.args.fps] * batch_size).to(self.model.device).long() prompts = prompt_list_rank[idx_s:idx_e] if isinstance(prompts, str): prompts = [prompts] # prompts = batch_size * [""] text_emb = self.model.get_learned_conditioning(prompts) if self.args.mode == 'base': cond = {"c_crossattn": [text_emb], "fps": fps} # elif self.args.mode == 'i2v': # # cond_images = torch.zeros(noise_shape[0],3,224,224).to(model.device) # cond_images = load_image_batch(cond_inputs_rank[idx_s:idx_e], (self.args.height, self.args.width)) # cond_images = cond_images.to(self.model.device) # img_emb = self.model.get_image_embeds(cond_images) # imtext_cond = torch.cat([text_emb, img_emb], dim=1) # cond = {"c_crossattn": [imtext_cond], "fps": fps} else: raise NotImplementedError ## inference batch_samples = batch_ddim_sampling(self.model, cond, noise_shape, self.args.n_samples, self.args.ddim_steps, self.args.ddim_eta, self.args.unconditional_guidance_scale, **kwargs) return batch_samples