HYDRAS_Latte-1 / demo.py
maxin-cn's picture
Upload folder using huggingface_hub
4051d56 verified
import gradio as gr
import os
import torch
import argparse
import torchvision
from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler,
EulerDiscreteScheduler, DPMSolverMultistepScheduler,
HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,
DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler)
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from omegaconf import OmegaConf
from transformers import T5EncoderModel, T5Tokenizer
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from sample.pipeline_latte import LattePipeline
from models import get_models
import imageio
from torchvision.utils import save_image
import spaces
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml")
args = parser.parse_args()
args = OmegaConf.load(args.config)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_model = get_models(args).to(device, dtype=torch.float16)
if args.enable_vae_temporal_decoder:
vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
else:
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()
@spaces.GPU
def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step):
torch.manual_seed(seed)
if sample_method == 'DDIM':
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
clip_sample=False)
elif sample_method == 'EulerDiscrete':
scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DDPM':
scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
clip_sample=False)
elif sample_method == 'DPMSolverMultistep':
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DPMSolverSinglestep':
scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'PNDM':
scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'HeunDiscrete':
scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'EulerAncestralDiscrete':
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'DEISMultistep':
scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
elif sample_method == 'KDPM2AncestralDiscrete':
scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type)
videogen_pipeline = LattePipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer_model).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
videos = videogen_pipeline(text_input,
video_length=video_length,
height=height,
width=width,
num_inference_steps=diffusion_step,
guidance_scale=scfg_scale,
enable_temporal_attentions=args.enable_temporal_attentions,
num_images_per_prompt=1,
mask_feature=True,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder
).video
save_path = args.save_img_path + 'temp' + '.mp4'
# torchvision.io.write_video(save_path, videos[0], fps=8)
imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
return save_path
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
intro = """
<div style="display: flex;align-items: center;justify-content: center">
<h1 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte: Latent Diffusion Transformer for Video Generation</h1>
</div>
"""
with gr.Blocks() as demo:
# gr.HTML(intro)
# with gr.Accordion("README", open=False):
# gr.HTML(
# """
# <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
# <a href="https://maxin-cn.github.io/latte_project/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2401.03048" target="_blank">paper</a>
# </p>
# We will continue update Latte.
# """
# )
gr.Markdown("<font color=red size=10><center>Latte: Latent Diffusion Transformer for Video Generation</center></font>")
gr.Markdown(
"""<div style="display: flex;align-items: center;justify-content: center">
<h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!</h2></div>
"""
)
gr.Markdown(
"""<div style="display: flex;align-items: center;justify-content: center">
[<a href="https://arxiv.org/abs/2401.03048">Arxiv Report</a>] | [<a href="https://maxin-cn.github.io/latte_project/">Project Page</a>] | [<a href="https://github.com/Vchitect/Latte">Github</a>]</div>
"""
)
with gr.Row():
with gr.Column(visible=True) as input_raws:
with gr.Row():
with gr.Column(scale=1.0):
# text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False)
text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt")
# with gr.Row():
# with gr.Column(scale=0.5):
# image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False)
# with gr.Column(scale=0.5):
# preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False)
with gr.Row():
with gr.Column(scale=0.5):
sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM")
# with gr.Row():
# with gr.Column(scale=1.0):
# video_length = gr.Slider(
# minimum=1,
# maximum=24,
# value=1,
# step=1,
# interactive=True,
# label="Video Length (1 for T2I and 16 for T2V)",
# )
with gr.Column(scale=0.5):
video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16)
with gr.Row():
with gr.Column(scale=1.0):
scfg_scale = gr.Slider(
minimum=1,
maximum=50,
value=7.5,
step=0.1,
interactive=True,
label="Guidence Scale",
)
with gr.Row():
with gr.Column(scale=1.0):
seed = gr.Slider(
minimum=1,
maximum=2147483647,
value=100,
step=1,
interactive=True,
label="Seed",
)
with gr.Row():
with gr.Column(scale=0.5):
height = gr.Slider(
minimum=256,
maximum=768,
value=512,
step=16,
interactive=False,
label="Height",
)
# with gr.Row():
with gr.Column(scale=0.5):
width = gr.Slider(
minimum=256,
maximum=768,
value=512,
step=16,
interactive=False,
label="Width",
)
with gr.Row():
with gr.Column(scale=1.0):
diffusion_step = gr.Slider(
minimum=20,
maximum=250,
value=50,
step=1,
interactive=True,
label="Sampling Step",
)
with gr.Column(scale=0.6, visible=True) as video_upload:
# with gr.Column(visible=True) as video_upload:
output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360)
# with gr.Column(elem_id="image", scale=0.5) as img_part:
# with gr.Tab("Video", elem_id='video_tab'):
# with gr.Tab("Image", elem_id='image_tab'):
# up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360)
# upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
# clear = gr.Button("Restart")
with gr.Row():
with gr.Column(scale=1.0, min_width=0):
run = gr.Button("💭Run")
# with gr.Column(scale=0.5, min_width=0):
# clear = gr.Button("🔄Clear️")
EXAMPLES = [
["3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table, expression is one of pure joy and happiness, with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker, the grandmother wears a light blue blouse adorned with floral patterns, several happy friends and family sitting at the table can be seen celebrating, out of focus. The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["A wizard wearing a pointed hat and a blue robe with white stars casting a spell that shoots lightning from his hand and holding an old tome in his other hand.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["A young man at his 20s is sitting on a piece of cloud in the sky, reading a book.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["Cinematic trailer for a group of samoyed puppies learning to become chefs.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", "DDIM", 7.5, 100, 512, 512, 16, 50],
["A cyborg koala dj in front of aturntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-f, fantasy, intricate, neon light, soft light smooth, sharp focus, illustration.", "DDIM", 7.5, 100, 512, 512, 16, 50],
]
examples = gr.Examples(
examples = EXAMPLES,
fn = gen_video,
inputs=[text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step],
outputs=[output],
# cache_examples=True,
cache_examples="lazy",
)
run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output])
demo.launch(debug=False, share=True)