mcm / app.py
yhzhai's picture
format
e61b61b
import spaces
import os
import random
from datetime import datetime
from typing import Optional
import gradio as gr
import numpy as np
import torch
from diffusers import (
AnimateDiffPipeline,
DiffusionPipeline,
LCMScheduler,
MotionAdapter,
)
from diffusers.utils import export_to_video
from peft import PeftModel
device = "cuda"
mcm_id = "yhzhai/mcm"
basedir = os.getcwd()
savedir = os.path.join(
basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")
)
MAX_SEED = np.iinfo(np.int32).max
def get_modelscope_pipeline(
mcm_variant: Optional[str] = "WebVid",
):
model_id = "ali-vilab/text-to-video-ms-1.7b"
# if torch.cuda.is_available():
# pipe = DiffusionPipeline.from_pretrained(
# model_id, torch_dtype=torch.float16, variant="fp16"
# )
# else:
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
scheduler = LCMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
timestep_scaling=4.0,
)
pipe.scheduler = scheduler
pipe.enable_vae_slicing()
if mcm_variant == "WebVid":
subfolder = "modelscopet2v-webvid"
elif mcm_variant == "LAION-aes":
subfolder = "modelscopet2v-laion"
elif mcm_variant == "Anime":
subfolder = "modelscopet2v-anime"
elif mcm_variant == "Realistic":
subfolder = "modelscopet2v-real"
elif mcm_variant == "3D Cartoon":
subfolder = "modelscopet2v-3d-cartoon"
else:
subfolder = "modelscopet2v-laion"
lora = PeftModel.from_pretrained(
pipe.unet,
model_id=mcm_id,
subfolder=subfolder,
adapter_name="lora",
torch_device="cpu",
)
lora.merge_and_unload()
pipe.unet = lora
pipe = pipe.to(device)
return pipe
def get_animatediff_pipeline(
real_variant: Optional[str] = "realvision",
motion_module_path: str = "guoyww/animatediff-motion-adapter-v1-5-2",
mcm_variant: Optional[str] = "WebVid",
):
if real_variant is None:
model_id = "runwayml/stable-diffusion-v1-5"
elif real_variant == "epicrealism":
model_id = "emilianJR/epiCRealism"
elif real_variant == "realvision":
model_id = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
else:
raise ValueError(f"Unknown real_variant {real_variant}")
# if torch.cuda.is_available():
# adapter = MotionAdapter.from_pretrained(
# motion_module_path, torch_dtype=torch.float16
# )
# pipe = AnimateDiffPipeline.from_pretrained(
# model_id,
# motion_adapter=adapter,
# torch_dtype=torch.float16,
# )
# else:
adapter = MotionAdapter.from_pretrained(motion_module_path)
pipe = AnimateDiffPipeline.from_pretrained(
model_id,
motion_adapter=adapter, torch_dtype=torch.float16
)
scheduler = LCMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
timestep_scaling=4.0,
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
)
pipe.scheduler = scheduler
pipe.enable_vae_slicing()
if mcm_variant == "WebVid":
subfolder = "animatediff-webvid"
elif mcm_variant == "LAION-aes":
subfolder = "animatediff-laion"
else:
subfolder = "animatediff-laion"
lora = PeftModel.from_pretrained(
pipe.unet,
model_id=mcm_id,
subfolder=subfolder,
adapter_name="lora",
torch_device="cpu",
)
lora.merge_and_unload()
pipe.unet = lora
pipe = pipe.to(device)
return pipe
pipe_dict = {
"ModelScope T2V": {
"WebVid": None,
"LAION-aes": None,
"Anime": None,
"Realistic": None,
"3D Cartoon": None,
},
"AnimateDiff (SD1.5)": {"WebVid": None, "LAION-aes": None},
"AnimateDiff (RealisticVision)": {"WebVid": None, "LAION-aes": None},
"AnimateDiff (epiCRealism)": {"WebVid": None, "LAION-aes": None},
}
cache_pipeline = {
"base_model": None,
"variant": None,
"pipeline": None,
}
# def init_pipelines():
# for base_model in variants.keys():
# for variant in variants[base_model]:
# if pipe_dict[base_model][variant] is None:
# if base_model == "ModelScope T2V":
# pipe_dict[base_model][variant] = get_modelscope_pipeline(mcm_variant=variant)
# elif base_model == "AnimateDiff (SD1.5)":
# pipe_dict[base_model][variant] = get_animatediff_pipeline(
# real_variant=None,
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
# mcm_variant=variant,
# )
# elif base_model == "AnimateDiff (RealisticVision)":
# pipe_dict[base_model][variant] = get_animatediff_pipeline(
# real_variant="realvision",
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
# mcm_variant=variant,
# )
# elif base_model == "AnimateDiff (epiCRealism)":
# pipe_dict[base_model][variant] = get_animatediff_pipeline(
# real_variant="epicrealism",
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
# mcm_variant=variant,
# )
# else:
# raise ValueError(f"Unknown base_model {base_model}")
@spaces.GPU(duration=60)
def infer(
base_model,
variant,
prompt,
num_inference_steps=4,
height=256,
width=256,
seed=0,
randomize_seed=True,
progress = gr.Progress(track_tqdm=True),
):
# if pipe_dict[base_model][variant] is None:
# if base_model == "ModelScope T2V":
# pipe_dict[base_model][variant] = get_modelscope_pipeline(mcm_variant=variant)
# elif base_model == "AnimateDiff (SD1.5)":
# pipe_dict[base_model][variant] = get_animatediff_pipeline(
# real_variant=None,
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
# mcm_variant=variant,
# )
# elif base_model == "AnimateDiff (RealisticVision)":
# pipe_dict[base_model][variant] = get_animatediff_pipeline(
# real_variant="realvision",
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
# mcm_variant=variant,
# )
# elif base_model == "AnimateDiff (epiCRealism)":
# pipe_dict[base_model][variant] = get_animatediff_pipeline(
# real_variant="epicrealism",
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
# mcm_variant=variant,
# )
# else:
# raise ValueError(f"Unknown base_model {base_model}")
if (
cache_pipeline["base_model"] == base_model
and cache_pipeline["variant"] == variant
):
pass
else:
if base_model == "ModelScope T2V":
pipeline = get_modelscope_pipeline(mcm_variant=variant)
elif base_model == "AnimateDiff (SD1.5)":
pipeline = get_animatediff_pipeline(
real_variant=None,
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
mcm_variant=variant,
)
elif base_model == "AnimateDiff (RealisticVision)":
pipeline = get_animatediff_pipeline(
real_variant="realvision",
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
mcm_variant=variant,
)
elif base_model == "AnimateDiff (epiCRealism)":
pipeline = get_animatediff_pipeline(
real_variant="epicrealism",
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
mcm_variant=variant,
)
else:
raise ValueError(f"Unknown base_model {base_model}")
cache_pipeline["base_model"] = base_model
cache_pipeline["variant"] = variant
cache_pipeline["pipeline"] = pipeline
# pipe_dict[base_model][variant] = pipe_dict[base_model][variant].to(device)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cpu").manual_seed(seed)
output = cache_pipeline["pipeline"](
prompt=prompt,
num_frames=16,
guidance_scale=1.0,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=generator,
).frames
if not isinstance(output, list):
output = [output[i] for i in range(output.shape[0])]
os.makedirs(savedir, exist_ok=True)
save_path = os.path.join(
savedir, f"sample_{base_model}_{variant}_{seed}.mp4".replace(" ", "_")
)
export_to_video(
output[0],
save_path,
fps=7,
)
print(f"Saved to {save_path}")
# pipe_dict[base_model][variant] = pipe_dict[base_model][variant].to("cpu")
return save_path, seed
examples = [
[
"ModelScope T2V",
"LAION-aes",
"Aerial uhd 4k view. mid-air flight over fresh and clean mountain river at sunny summer morning. Green trees and sun rays on horizon. Direct on sun.",
4,
256,
256,
],
[
"ModelScope T2V",
"Anime",
"Timelapse misty mountain landscape",
4,
256,
256,
],
[
"ModelScope T2V",
"WebVid",
"Back of woman in shorts going near pure creek in beautiful mountains.",
4,
256,
256,
],
[
"ModelScope T2V",
"3D Cartoon",
"A rotating pandoro (a traditional italian sweet yeast bread, most popular around christmas and new year) being eaten in time-lapse.",
4,
256,
256,
],
[
"ModelScope T2V",
"Realistic",
"Slow motion avocado with a stone falls and breaks into 2 parts with splashes",
4,
256,
256,
],
[
"AnimateDiff (epiCRealism)",
"LAION-aes",
"Slow motion of delicious salmon sachimi set with green vegetables leaves served on wood plate. make homemade japanese food at home.-dan",
8,
512,
512,
],
[
"AnimateDiff (epiCRealism)",
"WebVid",
"Blooming meadow panorama zoom-out shot heavenly clouds and upcoming thunderstorm in mountain range harz, germany.",
8,
512,
512,
],
[
"AnimateDiff (epiCRealism)",
"LAION-aes",
"A young woman in a yellow sweater uses vr glasses, sitting on the shore of a pond on a background of dark waves. a strong wind develops her hair, the sun's rays are reflected from the water.",
8,
512,
512,
],
[
"AnimateDiff (epiCRealism)",
"LAION-aes",
"Female running at sunset. healthy fitness concept",
8,
512,
512,
],
]
css = """
#col-container {
margin: 0 auto;
}
"""
variants = {
"ModelScope T2V": ["WebVid", "LAION-aes", "Anime", "Realistic", "3D Cartoon"],
"AnimateDiff (SD1.5)": ["WebVid", "LAION-aes"],
"AnimateDiff (RealisticVision)": ["WebVid", "LAION-aes"],
"AnimateDiff (epiCRealism)": ["WebVid", "LAION-aes"],
}
def update_variant(rs):
return gr.update(choices=variants[rs], value=None)
# init_pipelines()
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div style="text-align: center; margin-bottom: 20px;">
<h1 align="center">
<a href="https://yhzhai.github.io/mcm/"><b>Motion Consistency Model: Accelerating Video Diffusion with Disentangled Motion-Appearance Distillation</b></a>
</h1>
<h4>Our motion consistency model not only accelerates text2video diffusion model sampling process, but also can benefit from an additional high-quality image dataset to improve the frame quality of generated videos.</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href='https://yhzhai.github.io/mcm/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp;
<a href='https://arxiv.org/abs/2406.06890'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a> &nbsp;
<a href='https://huggingface.co/yhzhai/mcm'><img src='https://img.shields.io/badge/HF-checkpoint-yellow'></a>
</div>
</div>
"""
)
gr.Markdown(
f"""
<p align="center">Currently running on {device}.</p>
<p align="center">Model loading takes extra time.</p>
"""
)
# <p align="center">ModelScope T2V works the best for resolution 256x256, and AnimateDiff works the best for 512x512.</p>
with gr.Row():
base_model = gr.Dropdown(
label="Base model",
choices=[
"ModelScope T2V",
"AnimateDiff (SD1.5)",
"AnimateDiff (RealisticVision)",
"AnimateDiff (epiCRealism)",
],
value="ModelScope T2V",
interactive=True,
)
variant_dropdown = gr.Dropdown(
variants["ModelScope T2V"],
label="MCM Variant",
interactive=True,
value=None,
)
base_model.change(
update_variant, inputs=[base_model], outputs=[variant_dropdown]
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
with gr.Row():
with gr.Column():
with gr.Accordion("Advanced Settings", open=True):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=12,
step=1,
value=4,
)
with gr.Group():
with gr.Row():
text_hint = gr.Textbox(
"Hint: ModelScope T2V works the best for resolution 256x256, and AnimateDiff works the best for resolution 512x512.",
interactive=False,
label="Hint",
container=False,
)
with gr.Row():
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
step=64,
value=512,
interactive=True,
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024,
step=64,
value=512,
interactive=True,
)
with gr.Column(show_progress=True):
# result = gr.Video(label="Result", show_label=False, interactive=False, height=512, width=512, autoplay=True)
result = gr.Video(
label="Result",
show_label=False,
interactive=False,
autoplay=True,
# height=512,
# width=512,
)
gr.Examples(
examples=examples,
inputs=[base_model, variant_dropdown, prompt, num_inference_steps, height, width],
fn=infer,
outputs=[result, seed],
)
run_button.click(
fn=infer,
inputs=[
base_model,
variant_dropdown,
prompt,
num_inference_steps,
height,
width,
seed,
randomize_seed,
],
outputs=[result, seed],
)
demo.queue().launch()