|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
from dataclasses import dataclass |
|
from typing import Callable, List, Optional, Union |
|
import numpy as np |
|
from diffusers.utils import deprecate, logging, BaseOutput |
|
from einops import rearrange, repeat |
|
from torch.nn.functional import grid_sample |
|
import torchvision.transforms as T |
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
|
|
|
@dataclass |
|
class TextToVideoPipelineOutput(BaseOutput): |
|
videos: Union[torch.Tensor, np.ndarray] |
|
code: Union[torch.Tensor, np.ndarray] |
|
|
|
|
|
|
|
def coords_grid(batch, ht, wd, device): |
|
|
|
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) |
|
coords = torch.stack(coords[::-1], dim=0).float() |
|
return coords[None].repeat(batch, 1, 1, 1) |
|
|
|
|
|
|
|
class TextToVideoPipeline(StableDiffusionPipeline): |
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: KarrasDiffusionSchedulers, |
|
safety_checker: StableDiffusionSafetyChecker, |
|
feature_extractor: CLIPFeatureExtractor, |
|
requires_safety_checker: bool = True, |
|
): |
|
|
|
super().__init__(vae,text_encoder,tokenizer,unet,scheduler,safety_checker,feature_extractor,requires_safety_checker) |
|
|
|
|
|
def DDPM_forward(self, x0, t0, tMax, generator, device, shape, text_embeddings): |
|
rand_device = "cpu" if device.type == "mps" else device |
|
|
|
if x0 is None: |
|
return torch.randn(shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype).to(device) |
|
else: |
|
eps = torch.randn_like(x0, dtype=text_embeddings.dtype).to(device) |
|
alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax]) |
|
xt = torch.sqrt(alpha_vec) * x0 + \ |
|
torch.sqrt(1-alpha_vec) * eps |
|
return xt |
|
|
|
|
|
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): |
|
shape = (batch_size, num_channels_latents, video_length, height // |
|
self.vae_scale_factor, width // self.vae_scale_factor) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if latents is None: |
|
rand_device = "cpu" if device.type == "mps" else device |
|
|
|
if isinstance(generator, list): |
|
shape = (1,) + shape[1:] |
|
latents = [ |
|
torch.randn( |
|
shape, generator=generator[i], device=rand_device, dtype=dtype) |
|
for i in range(batch_size) |
|
] |
|
latents = torch.cat(latents, dim=0).to(device) |
|
else: |
|
latents = torch.randn( |
|
shape, generator=generator, device=rand_device, dtype=dtype).to(device) |
|
else: |
|
latents = latents.to(device) |
|
|
|
|
|
latents = latents * self.scheduler.init_noise_sigma |
|
return latents |
|
|
|
|
|
|
|
def warp_latents(self, latents, reference_flow): |
|
_, _, H, W = reference_flow.size() |
|
b, c, f, h, w = latents.size() |
|
coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype) |
|
coords_t0 = coords0 + reference_flow |
|
coords_t0[:, 0] /= W |
|
coords_t0[:, 1] /= H |
|
coords_t0 = coords_t0 * 2.0 - 1.0 |
|
coords_t0 = T.Resize((h, w))(coords_t0) |
|
coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c') |
|
latents_0 = latents[:, :, 0] |
|
latents_0 = latents_0.repeat(f, 1, 1, 1) |
|
warped = grid_sample(latents_0, coords_t0, |
|
mode='nearest', padding_mode='reflection') |
|
warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f) |
|
return warped |
|
|
|
def warp_latents_independently(self, latents, reference_flow): |
|
_, _, H, W = reference_flow.size() |
|
b, c, f, h, w = latents.size() |
|
assert b == 1 |
|
coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype) |
|
coords_t0 = coords0 + reference_flow |
|
|
|
coords_t0[:, 0] /= W |
|
coords_t0[:, 1] /= H |
|
coords_t0 = coords_t0 * 2.0 - 1.0 |
|
|
|
coords_t0 = T.Resize((h, w))(coords_t0) |
|
|
|
coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c') |
|
|
|
latents_0 = rearrange(latents[0], 'c f h w -> f c h w') |
|
|
|
warped = grid_sample(latents_0, coords_t0, |
|
mode='nearest', padding_mode='reflection') |
|
warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f) |
|
return warped |
|
|
|
def DDIM_backward(self, num_inference_steps, timesteps, skip_t, t0, t1, do_classifier_free_guidance, null_embs, text_embeddings, latents_local, latents_dtype, guidance_scale, guidance_stop_step, callback, callback_steps, extra_step_kwargs, num_warmup_steps): |
|
entered = False |
|
|
|
f = latents_local.shape[2] |
|
latents_local = rearrange(latents_local,"b c f w h -> (b f) c w h") |
|
|
|
latents = latents_local.detach().clone() |
|
x_t0_1 = None |
|
x_t1_1 = None |
|
|
|
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if t > skip_t: |
|
continue |
|
else: |
|
if not entered: |
|
print( |
|
f"Continue DDIM with i = {i}, t = {t}, latent = {latents.shape}, device = {latents.device}, type = {latents.dtype}") |
|
entered = True |
|
|
|
latents = latents.detach() |
|
|
|
latent_model_input = torch.cat( |
|
[latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input( |
|
latent_model_input, t) |
|
|
|
|
|
with torch.no_grad(): |
|
if null_embs is not None: |
|
text_embeddings[0] = null_embs[i][0] |
|
te = torch.cat([repeat(text_embeddings[0,:,:], "c k -> f c k",f=f),repeat(text_embeddings[1,:,:], "c k -> f c k",f=f)]) |
|
noise_pred = self.unet( |
|
latent_model_input, t, encoder_hidden_states=te).sample.to(dtype=latents_dtype) |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk( |
|
2) |
|
noise_pred = noise_pred_uncond + guidance_scale * \ |
|
(noise_pred_text - noise_pred_uncond) |
|
|
|
if i >= guidance_stop_step * len(timesteps): |
|
alpha = 0 |
|
|
|
latents = self.scheduler.step( |
|
noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
|
|
if i < len(timesteps)-1 and timesteps[i+1] == t0: |
|
x_t0_1 = latents.detach().clone() |
|
print(f"latent t0 found at i = {i}, t = {t}") |
|
elif i < len(timesteps)-1 and timesteps[i+1] == t1: |
|
x_t1_1 = latents.detach().clone() |
|
print(f"latent t1 found at i={i}, t = {t}") |
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
|
|
latents = rearrange(latents,"(b f) c w h -> b c f w h",f = f) |
|
|
|
|
|
|
|
res = {"x0": latents.detach().clone()} |
|
if x_t0_1 is not None: |
|
x_t0_1 = rearrange(x_t0_1,"(b f) c w h -> b c f w h",f = f) |
|
res["x_t0_1"] = x_t0_1.detach().clone() |
|
if x_t1_1 is not None: |
|
x_t1_1 = rearrange(x_t1_1,"(b f) c w h -> b c f w h",f = f) |
|
res["x_t1_1"] = x_t1_1.detach().clone() |
|
return res |
|
|
|
def decode_latents(self, latents): |
|
video_length = latents.shape[2] |
|
latents = 1 / 0.18215 * latents |
|
latents = rearrange(latents, "b c f h w -> (b f) c h w") |
|
video = self.vae.decode(latents).sample |
|
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) |
|
video = (video / 2 + 0.5).clamp(0, 1) |
|
|
|
return video |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
video_length: Optional[int], |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
guidance_stop_step: float = 0.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_videos_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, |
|
List[torch.Generator]]] = None, |
|
xT: Optional[torch.FloatTensor] = None, |
|
null_embs: Optional[torch.FloatTensor] = None, |
|
|
|
motion_field_strength_x: float = 12, |
|
motion_field_strength_y: float = 12, |
|
output_type: Optional[str] = "tensor", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[ |
|
int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: Optional[int] = 1, |
|
use_motion_field: bool = True, |
|
smooth_bg: bool = True, |
|
smooth_bg_strength: float = 0.4, |
|
**kwargs, |
|
): |
|
print(motion_field_strength_x,motion_field_strength_y) |
|
print(f" Use: Motion field = {use_motion_field}") |
|
print(f" Use: Background smoothing = {smooth_bg}") |
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
|
|
|
self.check_inputs(prompt, height, width, callback_steps) |
|
|
|
|
|
batch_size = 1 if isinstance(prompt, str) else len(prompt) |
|
device = self._execution_device |
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
text_embeddings = self._encode_prompt( |
|
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt |
|
) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
|
|
|
|
num_channels_latents = self.unet.in_channels |
|
|
|
xT = self.prepare_latents( |
|
batch_size * num_videos_per_prompt, |
|
num_channels_latents, |
|
video_length, |
|
height, |
|
width, |
|
text_embeddings.dtype, |
|
device, |
|
generator, |
|
xT, |
|
) |
|
dtype = xT.dtype |
|
|
|
|
|
if use_motion_field: |
|
xT = xT[:, :, :1] |
|
else: |
|
if xT.shape[2] < video_length: |
|
xT_missing = self.prepare_latents( |
|
batch_size * num_videos_per_prompt, |
|
num_channels_latents, |
|
video_length-xT.shape[2], |
|
height, |
|
width, |
|
text_embeddings.dtype, |
|
device, |
|
generator, |
|
None, |
|
) |
|
xT = torch.cat([xT, xT_missing], dim=2) |
|
|
|
|
|
xInit = xT.clone() |
|
t0 = kwargs["t0"] |
|
t1 = kwargs["t1"] |
|
x_t1_1 = None |
|
|
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
num_warmup_steps = len(timesteps) - \ |
|
num_inference_steps * self.scheduler.order |
|
|
|
|
|
|
|
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance, |
|
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=xT, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) |
|
|
|
x0 = ddim_res["x0"].detach() |
|
|
|
if "x_t0_1" in ddim_res: |
|
x_t0_1 = ddim_res["x_t0_1"].detach() |
|
if "x_t1_1" in ddim_res: |
|
x_t1_1 = ddim_res["x_t1_1"].detach() |
|
del ddim_res |
|
del xT |
|
|
|
if use_motion_field: |
|
del x0 |
|
shape = (batch_size, num_channels_latents, 1, height // |
|
self.vae_scale_factor, width // self.vae_scale_factor) |
|
|
|
|
|
x_t0_k = x_t0_1[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1) |
|
|
|
|
|
reference_flow = torch.zeros( |
|
(video_length-1, 2, 512, 512), device=x_t0_1.device, dtype=x_t0_1.dtype) |
|
for fr_idx in range(video_length-1): |
|
|
|
reference_flow[fr_idx, 0, :, :] = motion_field_strength_x*(fr_idx+1) |
|
reference_flow[fr_idx, 1, :, :] = motion_field_strength_y*(fr_idx+1) |
|
|
|
for idx, latent in enumerate(x_t0_k): |
|
x_t0_k[idx] = self.warp_latents_independently( |
|
latent[None], reference_flow) |
|
|
|
|
|
if t1 > t0: |
|
x_t1_k = self.DDPM_forward( |
|
x0=x_t0_k, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator) |
|
else: |
|
x_t1_k = x_t0_k |
|
|
|
if x_t1_1 is None: |
|
raise Exception |
|
|
|
x_t1 = torch.cat([x_t1_1, x_t1_k], dim=2).clone().detach() |
|
|
|
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, |
|
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) |
|
|
|
x0 = ddim_res["x0"].detach() |
|
del ddim_res |
|
else: |
|
x_t1 = x_t1_1.clone() |
|
x_t1_1 = x_t1_1[:,:,:1,:,:].clone() |
|
x_t1_k = x_t1_1[:,:,1:,:,:].clone() |
|
x_t0_k = x_t0_1[:, :, 1:, :, :].clone() |
|
x_t0_1 = x_t0_1[:,:,:1,:,:].clone() |
|
|
|
|
|
if smooth_bg: |
|
h, w = x0.shape[3], x0.shape[4] |
|
M_FG = torch.zeros((batch_size, video_length, h, w), |
|
device=x0.device).to(x0.dtype) |
|
for batch_idx, x0_b in enumerate(x0): |
|
z0_b = self.decode_latents(x0_b[None]).detach() |
|
z0_b = rearrange(z0_b[0], "c f h w -> f h w c") |
|
for frame_idx, z0_f in enumerate(z0_b): |
|
z0_f = torch.round( |
|
z0_f * 255).cpu().numpy().astype(np.uint8) |
|
|
|
m_f = torch.tensor(self.sod_model.process_data( |
|
z0_f), device=x0.device).to(x0.dtype) |
|
mask = T.Resize( |
|
size=(h, w), interpolation=T.InterpolationMode.NEAREST)(m_f[None]) |
|
kernel = torch.ones(5, 5, device=x0.device, dtype=x0.dtype) |
|
mask = dilation(mask[None].to(x0.device), kernel)[0] |
|
M_FG[batch_idx, frame_idx, :, :] = mask |
|
|
|
|
|
x_t1_1_fg_masked = x_t1_1 * \ |
|
(1 - repeat(M_FG[:, 0, :, :], |
|
"b w h -> b c 1 w h", c=x_t1_1.shape[1])) |
|
|
|
|
|
x_t1_1_fg_masked_moved = [] |
|
for batch_idx, x_t1_1_fg_masked_b in enumerate(x_t1_1_fg_masked): |
|
x_t1_fg_masked_b = x_t1_1_fg_masked_b.clone() |
|
|
|
x_t1_fg_masked_b = x_t1_fg_masked_b.repeat( |
|
1, video_length-1, 1, 1) |
|
if use_motion_field: |
|
x_t1_fg_masked_b = x_t1_fg_masked_b[None] |
|
x_t1_fg_masked_b = self.warp_latents_independently( |
|
x_t1_fg_masked_b, reference_flow) |
|
else: |
|
x_t1_fg_masked_b = x_t1_fg_masked_b[None] |
|
|
|
x_t1_fg_masked_b = torch.cat( |
|
[x_t1_1_fg_masked_b[None], x_t1_fg_masked_b], dim=2) |
|
x_t1_1_fg_masked_moved.append(x_t1_fg_masked_b) |
|
|
|
x_t1_1_fg_masked_moved = torch.cat(x_t1_1_fg_masked_moved, dim=0) |
|
|
|
M_FG_1 = M_FG[:, :1, :, :] |
|
|
|
M_FG_warped = [] |
|
for batch_idx, m_fg_1_b in enumerate(M_FG_1): |
|
m_fg_1_b = m_fg_1_b[None, None] |
|
m_fg_b = m_fg_1_b.repeat(1, 1, video_length-1, 1, 1) |
|
if use_motion_field: |
|
m_fg_b = self.warp_latents_independently( |
|
m_fg_b.clone(), reference_flow) |
|
M_FG_warped.append( |
|
torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1)) |
|
|
|
M_FG_warped = torch.cat(M_FG_warped, dim=0) |
|
|
|
channels = x0.shape[1] |
|
|
|
M_BG = (1-M_FG) * (1 - M_FG_warped) |
|
M_BG = repeat(M_BG, "b f h w -> b c f h w", c=channels) |
|
a_convex = smooth_bg_strength |
|
|
|
x_t1_blending = (1-M_BG) * x_t1 + M_BG * (a_convex * |
|
x_t1 + (1-a_convex) * x_t1_1_fg_masked_moved) |
|
|
|
''' |
|
x_t1_blending = self.DDPM_forward( |
|
x0=x_t1_blending, t0=t1, tMax=961, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator) |
|
t1 = 961 |
|
''' |
|
latents = x_t1_blending |
|
|
|
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, |
|
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=latents, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) |
|
x0 = ddim_res["x0"].detach() |
|
del ddim_res |
|
|
|
|
|
|
|
video_list = [] |
|
for latent in x0: |
|
tmp = latent[None] |
|
print("Frame spit shape", tmp.shape) |
|
frames = [] |
|
for fr_split in range(tmp.shape[2]): |
|
print("frame decoding") |
|
frames.append(self.decode_latents( |
|
tmp[:, :, fr_split, None]).detach()) |
|
|
|
video_list.append(torch.cat(frames, dim=2).cpu().float().numpy()) |
|
|
|
|
|
videos = [] |
|
if output_type == "tensor": |
|
for video in video_list: |
|
videos.append(torch.from_numpy(video)) |
|
if output_type == 'numpy': |
|
for video in video_list: |
|
videos.append(rearrange(video, 'b c f h w -> (b f) h w c')) |
|
if not return_dict: |
|
return video |
|
|
|
return TextToVideoPipelineOutput(videos=videos, code=torch.split(xInit.detach().cpu(), 1, dim=0)) |