Spaces:
Runtime error
Runtime error
import torch | |
import av | |
import pims | |
import numpy as np | |
from typing import Optional, Tuple | |
from torchvision import transforms | |
from torch.utils.data import Dataset | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
class VideoReader(Dataset): | |
def __init__(self, path, transform=None): | |
self.video = pims.PyAVVideoReader(path) | |
self.rate = self.video.frame_rate | |
self.transform = transform | |
def frame_rate(self): | |
return self.rate | |
def __len__(self): | |
return len(self.video) | |
def __getitem__(self, idx): | |
frame = self.video[idx] | |
frame = Image.fromarray(np.asarray(frame)) | |
if self.transform is not None: | |
frame = self.transform(frame) | |
return frame | |
class VideoWriter: | |
def __init__(self, path, frame_rate, bit_rate=1000000): | |
self.container = av.open(path, mode="w") | |
self.stream = self.container.add_stream("h264", rate=f"{frame_rate:.4f}") | |
self.stream.pix_fmt = "yuv420p" | |
self.stream.bit_rate = bit_rate | |
def write(self, frames): | |
# frames: [T, C, H, W] | |
self.stream.width = frames.size(3) | |
self.stream.height = frames.size(2) | |
if frames.size(1) == 1: | |
frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB | |
frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy() | |
for t in range(frames.shape[0]): | |
frame = frames[t] | |
frame = av.VideoFrame.from_ndarray(frame, format="rgb24") | |
self.container.mux(self.stream.encode(frame)) | |
def close(self): | |
self.container.mux(self.stream.encode()) | |
self.container.close() | |
def auto_downsample_ratio(h, w): | |
""" | |
Automatically find a downsample ratio so that the largest side of the resolution be 512px. | |
""" | |
return min(512 / max(h, w), 1) | |
def convert_video( | |
model, | |
input_source: str, | |
input_resize: Optional[Tuple[int, int]] = None, | |
downsample_ratio: Optional[float] = None, | |
output_composition: Optional[str] = None, | |
output_alpha: Optional[str] = None, | |
output_foreground: Optional[str] = None, | |
output_video_mbps: Optional[float] = None, | |
seq_chunk: int = 1, | |
num_workers: int = 0, | |
progress: bool = True, | |
device: Optional[str] = None, | |
dtype: Optional[torch.dtype] = None, | |
): | |
""" | |
Args: | |
input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg. | |
input_resize: If provided, the input are first resized to (w, h). | |
downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one. | |
output_type: Options: ["video", "png_sequence"]. | |
output_composition: | |
The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'. | |
If output_type == 'video', the composition has green screen background. | |
If output_type == 'png_sequence'. the composition is RGBA png images. | |
output_alpha: The alpha output from the model. | |
output_foreground: The foreground output from the model. | |
seq_chunk: Number of frames to process at once. Increase it for better parallelism. | |
num_workers: PyTorch's DataLoader workers. Only use >0 for image input. | |
progress: Show progress bar. | |
device: Only need to manually provide if model is a TorchScript freezed model. | |
dtype: Only need to manually provide if model is a TorchScript freezed model. | |
""" | |
assert downsample_ratio is None or ( | |
downsample_ratio > 0 and downsample_ratio <= 1 | |
), "Downsample ratio must be between 0 (exclusive) and 1 (inclusive)." | |
assert any( | |
[output_composition, output_alpha, output_foreground] | |
), "Must provide at least one output." | |
assert seq_chunk >= 1, "Sequence chunk must be >= 1" | |
assert num_workers >= 0, "Number of workers must be >= 0" | |
# Initialize transform | |
if input_resize is not None: | |
transform = transforms.Compose( | |
[transforms.Resize(input_resize[::-1]), transforms.ToTensor()] | |
) | |
else: | |
transform = transforms.ToTensor() | |
# Initialize reader | |
source = VideoReader(input_source, transform) | |
reader = DataLoader( | |
source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers | |
) | |
# Initialize writers | |
frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30 | |
output_video_mbps = 1 if output_video_mbps is None else output_video_mbps | |
if output_composition is not None: | |
writer_com = VideoWriter( | |
path=output_composition, | |
frame_rate=frame_rate, | |
bit_rate=int(output_video_mbps * 1000000), | |
) | |
if output_alpha is not None: | |
writer_pha = VideoWriter( | |
path=output_alpha, | |
frame_rate=frame_rate, | |
bit_rate=int(output_video_mbps * 1000000), | |
) | |
if output_foreground is not None: | |
writer_fgr = VideoWriter( | |
path=output_foreground, | |
frame_rate=frame_rate, | |
bit_rate=int(output_video_mbps * 1000000), | |
) | |
# Inference | |
model = model.eval() | |
if device is None or dtype is None: | |
param = next(model.parameters()) | |
dtype = param.dtype | |
device = param.device | |
if output_composition is not None: | |
bgr = ( | |
torch.tensor([0, 0, 0], device=device, dtype=dtype) | |
.div(255) | |
.view(1, 1, 3, 1, 1) | |
) | |
try: | |
with torch.no_grad(): | |
bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True) | |
rec = [None] * 4 | |
for src in reader: | |
if downsample_ratio is None: | |
downsample_ratio = auto_downsample_ratio(*src.shape[2:]) | |
src = src.to(device, dtype, non_blocking=True).unsqueeze( | |
0 | |
) # [B, T, C, H, W] | |
fgr, pha, *rec = model(src, *rec, downsample_ratio) | |
if output_foreground is not None: | |
writer_fgr.write(fgr[0]) | |
if output_alpha is not None: | |
writer_pha.write(pha[0]) | |
if output_composition is not None: | |
com = fgr * pha + bgr * (1 - pha) | |
writer_com.write(com[0]) | |
bar.update(src.size(1)) | |
finally: | |
# Clean up | |
if output_composition is not None: | |
writer_com.close() | |
if output_alpha is not None: | |
writer_pha.close() | |
if output_foreground is not None: | |
writer_fgr.close() | |