Spaces:
Runtime error
Runtime error
import os | |
import imageio | |
from PIL import Image | |
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from diffusers import IFSuperResolutionPipeline, VideoToVideoSDPipeline | |
from diffusers.utils.torch_utils import randn_tensor | |
class ShowOnePipeline(): | |
def __init__(self, base_path, interp_path, deepfloyd_path, sr1_path, sr2_path): | |
""" | |
Downloading the necessary models from huggingface and utilize them to load their pipelines, | |
https://github.com/showlab/Show-1 | |
""" | |
from .showone.pipelines import TextToVideoIFPipeline, TextToVideoIFInterpPipeline, \ | |
TextToVideoIFSuperResolutionPipeline | |
from .showone.pipelines.pipeline_t2v_base_pixel import tensor2vid | |
from .showone.pipelines.pipeline_t2v_sr_pixel_cond import TextToVideoIFSuperResolutionPipeline_Cond | |
self.tensor2vid = tensor2vid | |
# Base Model | |
# When using "showlab/show-1-base-0.0", it's advisable to increase the number of inference steps (e.g., 100) | |
# and opt for a larger guidance scale (e.g., 12.0) to enhance visual quality. | |
self.pipe_base = TextToVideoIFPipeline.from_pretrained( | |
base_path, | |
torch_dtype=torch.float16, | |
variant="fp16" | |
) | |
self.pipe_base.enable_model_cpu_offload() | |
# Interpolation Model | |
self.pipe_interp_1 = TextToVideoIFInterpPipeline.from_pretrained( | |
interp_path, | |
torch_dtype=torch.float16, | |
variant="fp16" | |
) | |
self.pipe_interp_1.enable_model_cpu_offload() | |
# Super-Resolution Model 1 | |
# Image super-resolution model from DeepFloyd https://huggingface.co/DeepFloyd/IF-II-L-v1.0 | |
# pretrained_model_path = "./checkpoints/DeepFloyd/IF-II-L-v1.0" | |
self.pipe_sr_1_image = IFSuperResolutionPipeline.from_pretrained( | |
deepfloyd_path, | |
text_encoder=None, | |
torch_dtype=torch.float16, | |
variant="fp16" | |
) | |
self.pipe_sr_1_image.enable_model_cpu_offload() | |
self.pipe_sr_1_cond = TextToVideoIFSuperResolutionPipeline_Cond.from_pretrained( | |
sr1_path, | |
torch_dtype=torch.float16 | |
) | |
self.pipe_sr_1_cond.enable_model_cpu_offload() | |
# Super-Resolution Model 2 | |
self.pipe_sr_2 = VideoToVideoSDPipeline.from_pretrained( | |
sr2_path, | |
torch_dtype=torch.float16 | |
) | |
self.pipe_sr_2.enable_model_cpu_offload() | |
self.pipe_sr_2.enable_vae_slicing() | |
def inference(self, prompt: str = "", | |
negative_prompt: str = "", | |
output_size: List[int] = [240, 560], | |
initial_num_frames: int = 8, | |
scaling_factor: int = 4, | |
seed: int = 42): | |
""" | |
Generates a single video based on a textual prompt. The output is a tensor representing the video. | |
The initial_num_frames is set to be 8 as shown in paper. | |
https://github.com/showlab/Show-1 | |
Args: | |
prompt (str, optional): The text prompt that guides the video generation. If not specified, the video generation will rely solely on the input image. Defaults to "". | |
negative_prompt (str, optional): The negative prompt that guided the video generation. Defaults to "". | |
output_size (list, optional): Specifies the resolution of the output video as [height, width]. Defaults to [240, 560]. | |
initial_num_frames: the number of frames generated using the base model. Defaults to 8 as proposed in the paper. | |
scaling_factor: The amount of scaling during the interpolation step. Defaults to 4 as proposed in the paper, which interpolates number of frames from 8 to 29. | |
seed (int, optional): A seed value for random number generation, ensuring reproducibility of the video generation process. Defaults to 42. | |
Returns: | |
The generated video as a tensor with shape (num_frames, channels, height, width). | |
""" | |
# Inference | |
# Text embeds | |
prompt_embeds, negative_embeds = self.pipe_base.encode_prompt(prompt) | |
# Keyframes generation (8x64x40, 2fps) | |
video_frames = self.pipe_base( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_embeds, | |
num_frames=initial_num_frames, | |
height=40, | |
width=64, | |
num_inference_steps=75, | |
guidance_scale=9.0, | |
generator=torch.manual_seed(seed), | |
output_type="pt" | |
).frames | |
# Frame interpolation (8x64x40, 2fps -> 29x64x40, 7.5fps) | |
bsz, channel, num_frames_1, height, width = video_frames.shape | |
k = scaling_factor | |
new_num_frames = (k - 1) * (num_frames_1 - 1) + num_frames_1 | |
new_video_frames = torch.zeros((bsz, channel, new_num_frames, height, width), | |
dtype=video_frames.dtype, device=video_frames.device) | |
new_video_frames[:, :, torch.arange(0, new_num_frames, k), ...] = video_frames | |
init_noise = randn_tensor((bsz, channel, k + 1, height, width), dtype=video_frames.dtype, | |
device=video_frames.device, generator=torch.manual_seed(seed)) | |
for i in range(num_frames_1 - 1): | |
batch_i = torch.zeros((bsz, channel, k + 1, height, width), dtype=video_frames.dtype, | |
device=video_frames.device) | |
batch_i[:, :, 0, ...] = video_frames[:, :, i, ...] | |
batch_i[:, :, -1, ...] = video_frames[:, :, i + 1, ...] | |
batch_i = self.pipe_interp_1( | |
pixel_values=batch_i, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_embeds, | |
num_frames=batch_i.shape[2], | |
height=40, | |
width=64, | |
num_inference_steps=75, | |
guidance_scale=4.0, | |
generator=torch.manual_seed(seed), | |
output_type="pt", | |
init_noise=init_noise, | |
cond_interpolation=True, | |
).frames | |
new_video_frames[:, :, i * k:i * k + k + 1, ...] = batch_i | |
video_frames = new_video_frames | |
# Super-resolution 1 (29x64x40 -> 29x256x160) | |
bsz, channel, num_frames_2, height, width = video_frames.shape | |
window_size, stride = 8, 7 | |
new_video_frames = torch.zeros( | |
(bsz, channel, num_frames_2, height * 4, width * 4), | |
dtype=video_frames.dtype, | |
device=video_frames.device) | |
for i in range(0, num_frames_2 - window_size + 1, stride): | |
batch_i = video_frames[:, :, i:i + window_size, ...] | |
all_frame_cond = None | |
if i == 0: | |
first_frame_cond = self.pipe_sr_1_image( | |
image=video_frames[:, :, 0, ...], | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_embeds, | |
height=height * 4, | |
width=width * 4, | |
num_inference_steps=70, | |
guidance_scale=4.0, | |
noise_level=150, | |
generator=torch.manual_seed(seed), | |
output_type="pt" | |
).images | |
first_frame_cond = first_frame_cond.unsqueeze(2) | |
else: | |
first_frame_cond = new_video_frames[:, :, i:i + 1, ...] | |
batch_i = self.pipe_sr_1_cond( | |
image=batch_i, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_embeds, | |
first_frame_cond=first_frame_cond, | |
height=height * 4, | |
width=width * 4, | |
num_inference_steps=125, | |
guidance_scale=7.0, | |
noise_level=250, | |
generator=torch.manual_seed(seed), | |
output_type="pt" | |
).frames | |
new_video_frames[:, :, i:i + window_size, ...] = batch_i | |
video_frames = new_video_frames | |
# Super-resolution 2 (29x256x160 -> 29x576x320) | |
video_frames = [Image.fromarray(frame).resize((output_size[1], output_size[0])) for frame in | |
self.tensor2vid(video_frames.clone())] | |
video_frames = self.pipe_sr_2( | |
prompt, | |
negative_prompt=negative_prompt, | |
video=video_frames, | |
strength=0.8, | |
num_inference_steps=50, | |
generator=torch.manual_seed(seed), | |
output_type="pt" | |
).frames | |
output = video_frames.squeeze() | |
return output | |