Spaces:
Running
on
Zero
Running
on
Zero
from typing import List | |
from PIL import Image | |
import numpy as np | |
import math | |
import random | |
import cv2 | |
from typing import List | |
import torch | |
import einops | |
from pytorch_lightning import seed_everything | |
from transparent_background import Remover | |
from dataset.opencv_transforms.functional import to_tensor, center_crop | |
from vtdm.model import create_model | |
from vtdm.util import tensor2vid | |
remover = Remover(jit=False) | |
def cv2_to_pil(cv_image: np.ndarray) -> Image.Image: | |
return Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)) | |
def pil_to_cv2(pil_image: Image.Image) -> np.ndarray: | |
cv_image = np.array(pil_image) | |
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR) | |
return cv_image | |
def prepare_white_image(input_image: Image.Image) -> Image.Image: | |
# remove bg | |
output = remover.process(input_image, type='rgba') | |
# expand image | |
width, height = output.size | |
max_side = max(width, height) | |
white_image = Image.new('RGBA', (max_side, max_side), (0, 0, 0, 0)) | |
x_offset = (max_side - width) // 2 | |
y_offset = (max_side - height) // 2 | |
white_image.paste(output, (x_offset, y_offset)) | |
return white_image | |
class MultiViewGenerator: | |
def __init__(self, checkpoint_path, config_path="inference.yaml"): | |
self.models = {} | |
denoising_model = create_model(config_path).cpu() | |
denoising_model.init_from_ckpt(checkpoint_path) | |
denoising_model = denoising_model.cuda().half() | |
self.models["denoising_model"] = denoising_model | |
def denoising(self, frames, args): | |
with torch.no_grad(): | |
C, T, H, W = frames.shape | |
batch = {"video": frames.unsqueeze(0)} | |
batch["elevation"] = ( | |
torch.Tensor([args["elevation"]]).to(torch.int64).to(frames.device) | |
) | |
batch["fps_id"] = torch.Tensor([7]).to(torch.int64).to(frames.device) | |
batch["motion_bucket_id"] = ( | |
torch.Tensor([127]).to(torch.int64).to(frames.device) | |
) | |
batch = self.models["denoising_model"].add_custom_cond(batch, infer=True) | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
c, uc = self.models[ | |
"denoising_model" | |
].conditioner.get_unconditional_conditioning( | |
batch, | |
force_uc_zero_embeddings=["cond_frames", "cond_frames_without_noise"], | |
) | |
additional_model_inputs = { | |
"image_only_indicator": torch.zeros(2, T).to( | |
self.models["denoising_model"].device | |
), | |
"num_video_frames": batch["num_video_frames"], | |
} | |
def denoiser(input, sigma, c): | |
return self.models["denoising_model"].denoiser( | |
self.models["denoising_model"].model, | |
input, | |
sigma, | |
c, | |
**additional_model_inputs | |
) | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
randn = torch.randn( | |
[T, 4, H // 8, W // 8], device=self.models["denoising_model"].device | |
) | |
samples = self.models["denoising_model"].sampler(denoiser, randn, cond=c, uc=uc) | |
samples = self.models["denoising_model"].decode_first_stage(samples.half()) | |
samples = einops.rearrange(samples, "(b t) c h w -> b c t h w", t=T) | |
return tensor2vid(samples) | |
def video_pipeline(self, frames, args) -> List[Image.Image]: | |
num_iter = args["num_iter"] | |
out_list = [] | |
for _ in range(num_iter): | |
with torch.no_grad(): | |
results = self.denoising(frames, args) | |
if len(out_list) == 0: | |
out_list = out_list + results | |
else: | |
out_list = out_list + results[1:] | |
img = out_list[-1] | |
img = to_tensor(img) | |
img = (img - 0.5) * 2.0 | |
frames[:, 0] = img | |
result = [] | |
for i, frame in enumerate(out_list): | |
input_image = cv2_to_pil(frame) | |
output_image = remover.process(input_image, type='rgba') | |
result.append(output_image) | |
return result | |
def process(self, white_image: Image.Image, args) -> List[Image.Image]: | |
img = pil_to_cv2(white_image) | |
frame_list = [img] * args["clip_size"] | |
h, w = frame_list[0].shape[0:2] | |
rate = max( | |
args["input_resolution"][0] * 1.0 / h, args["input_resolution"][1] * 1.0 / w | |
) | |
frame_list = [ | |
cv2.resize(f, [math.ceil(w * rate), math.ceil(h * rate)]) for f in frame_list | |
] | |
frame_list = [ | |
center_crop(f, [args["input_resolution"][0], args["input_resolution"][1]]) | |
for f in frame_list | |
] | |
frame_list = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frame_list] | |
frame_list = [to_tensor(f) for f in frame_list] | |
frame_list = [(f - 0.5) * 2.0 for f in frame_list] | |
frames = torch.stack(frame_list, 1) | |
frames = frames.cuda() | |
self.models["denoising_model"].num_samples = args["clip_size"] | |
self.models["denoising_model"].image_size = args["input_resolution"] | |
return self.video_pipeline(frames, args) | |
def infer(self, white_image: Image.Image) -> List[Image.Image]: | |
seed = random.randint(0, 65535) | |
seed_everything(seed) | |
params = { | |
"clip_size": 25, | |
"input_resolution": [512, 512], | |
"num_iter": 1, | |
"aes": 6.0, | |
"mv": [0.0, 0.0, 0.0, 10.0], | |
"elevation": 0, | |
} | |
return self.process(white_image, params) | |