Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
import argparse | |
import torchvision | |
from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, | |
EulerDiscreteScheduler, DPMSolverMultistepScheduler, | |
HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, | |
DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) | |
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler | |
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder | |
from omegaconf import OmegaConf | |
from transformers import T5EncoderModel, T5Tokenizer | |
import os, sys | |
sys.path.append(os.path.split(sys.path[0])[0]) | |
from sample.pipeline_latte import LattePipeline | |
from models import get_models | |
import imageio | |
from torchvision.utils import save_image | |
import spaces | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml") | |
args = parser.parse_args() | |
args = OmegaConf.load(args.config) | |
torch.set_grad_enabled(False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
transformer_model = get_models(args).to(device, dtype=torch.float16) | |
if args.enable_vae_temporal_decoder: | |
vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) | |
else: | |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device) | |
tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") | |
text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) | |
# set eval mode | |
transformer_model.eval() | |
vae.eval() | |
text_encoder.eval() | |
def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step): | |
torch.manual_seed(seed) | |
if sample_method == 'DDIM': | |
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type, | |
clip_sample=False) | |
elif sample_method == 'EulerDiscrete': | |
scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'DDPM': | |
scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type, | |
clip_sample=False) | |
elif sample_method == 'DPMSolverMultistep': | |
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'DPMSolverSinglestep': | |
scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'PNDM': | |
scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'HeunDiscrete': | |
scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'EulerAncestralDiscrete': | |
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'DEISMultistep': | |
scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
elif sample_method == 'KDPM2AncestralDiscrete': | |
scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
subfolder="scheduler", | |
beta_start=args.beta_start, | |
beta_end=args.beta_end, | |
beta_schedule=args.beta_schedule, | |
variance_type=args.variance_type) | |
videogen_pipeline = LattePipeline(vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
transformer=transformer_model).to(device) | |
# videogen_pipeline.enable_xformers_memory_efficient_attention() | |
videos = videogen_pipeline(text_input, | |
video_length=video_length, | |
height=height, | |
width=width, | |
num_inference_steps=diffusion_step, | |
guidance_scale=scfg_scale, | |
enable_temporal_attentions=args.enable_temporal_attentions, | |
num_images_per_prompt=1, | |
mask_feature=True, | |
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder | |
).video | |
save_path = args.save_img_path + 'temp' + '.mp4' | |
# torchvision.io.write_video(save_path, videos[0], fps=8) | |
imageio.mimwrite(save_path, videos[0], fps=8, quality=7) | |
return save_path | |
if not os.path.exists(args.save_img_path): | |
os.makedirs(args.save_img_path) | |
intro = """ | |
<div style="display: flex;align-items: center;justify-content: center"> | |
<h1 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte: Latent Diffusion Transformer for Video Generation</h1> | |
</div> | |
""" | |
with gr.Blocks() as demo: | |
# gr.HTML(intro) | |
# with gr.Accordion("README", open=False): | |
# gr.HTML( | |
# """ | |
# <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block"> | |
# <a href="https://maxin-cn.github.io/latte_project/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2401.03048" target="_blank">paper</a> | |
# </p> | |
# We will continue update Latte. | |
# """ | |
# ) | |
gr.Markdown("<font color=red size=10><center>Latte: Latent Diffusion Transformer for Video Generation</center></font>") | |
gr.Markdown( | |
"""<div style="display: flex;align-items: center;justify-content: center"> | |
<h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!</h2></div> | |
""" | |
) | |
gr.Markdown( | |
"""<div style="display: flex;align-items: center;justify-content: center"> | |
[<a href="https://arxiv.org/abs/2401.03048">Arxiv Report</a>] | [<a href="https://maxin-cn.github.io/latte_project/">Project Page</a>] | [<a href="https://github.com/Vchitect/Latte">Github</a>]</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(visible=True) as input_raws: | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
# text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False) | |
text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt") | |
# with gr.Row(): | |
# with gr.Column(scale=0.5): | |
# image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False) | |
# with gr.Column(scale=0.5): | |
# preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False) | |
with gr.Row(): | |
with gr.Column(scale=0.5): | |
sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM") | |
# with gr.Row(): | |
# with gr.Column(scale=1.0): | |
# video_length = gr.Slider( | |
# minimum=1, | |
# maximum=24, | |
# value=1, | |
# step=1, | |
# interactive=True, | |
# label="Video Length (1 for T2I and 16 for T2V)", | |
# ) | |
with gr.Column(scale=0.5): | |
video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16) | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
scfg_scale = gr.Slider( | |
minimum=1, | |
maximum=50, | |
value=7.5, | |
step=0.1, | |
interactive=True, | |
label="Guidence Scale", | |
) | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
seed = gr.Slider( | |
minimum=1, | |
maximum=2147483647, | |
value=100, | |
step=1, | |
interactive=True, | |
label="Seed", | |
) | |
with gr.Row(): | |
with gr.Column(scale=0.5): | |
height = gr.Slider( | |
minimum=256, | |
maximum=768, | |
value=512, | |
step=16, | |
interactive=False, | |
label="Height", | |
) | |
# with gr.Row(): | |
with gr.Column(scale=0.5): | |
width = gr.Slider( | |
minimum=256, | |
maximum=768, | |
value=512, | |
step=16, | |
interactive=False, | |
label="Width", | |
) | |
with gr.Row(): | |
with gr.Column(scale=1.0): | |
diffusion_step = gr.Slider( | |
minimum=20, | |
maximum=250, | |
value=50, | |
step=1, | |
interactive=True, | |
label="Sampling Step", | |
) | |
with gr.Column(scale=0.6, visible=True) as video_upload: | |
# with gr.Column(visible=True) as video_upload: | |
output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360) | |
# with gr.Column(elem_id="image", scale=0.5) as img_part: | |
# with gr.Tab("Video", elem_id='video_tab'): | |
# with gr.Tab("Image", elem_id='image_tab'): | |
# up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360) | |
# upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
# clear = gr.Button("Restart") | |
with gr.Row(): | |
with gr.Column(scale=1.0, min_width=0): | |
run = gr.Button("💭Run") | |
# with gr.Column(scale=0.5, min_width=0): | |
# clear = gr.Button("🔄Clear️") | |
EXAMPLES = [ | |
["3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest.", "DDIM", 7.5, 100, 512, 512, 16, 50], | |
["A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table, expression is one of pure joy and happiness, with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker, the grandmother wears a light blue blouse adorned with floral patterns, several happy friends and family sitting at the table can be seen celebrating, out of focus. The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood.", "DDIM", 7.5, 100, 512, 512, 16, 50], | |
["A wizard wearing a pointed hat and a blue robe with white stars casting a spell that shoots lightning from his hand and holding an old tome in his other hand.", "DDIM", 7.5, 100, 512, 512, 16, 50], | |
["A young man at his 20s is sitting on a piece of cloud in the sky, reading a book.", "DDIM", 7.5, 100, 512, 512, 16, 50], | |
["Cinematic trailer for a group of samoyed puppies learning to become chefs.", "DDIM", 7.5, 100, 512, 512, 16, 50], | |
["Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", "DDIM", 7.5, 100, 512, 512, 16, 50], | |
["A cyborg koala dj in front of aturntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-f, fantasy, intricate, neon light, soft light smooth, sharp focus, illustration.", "DDIM", 7.5, 100, 512, 512, 16, 50], | |
] | |
examples = gr.Examples( | |
examples = EXAMPLES, | |
fn = gen_video, | |
inputs=[text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], | |
outputs=[output], | |
# cache_examples=True, | |
cache_examples="lazy", | |
) | |
run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output]) | |
demo.launch(debug=False, share=True) | |