|
import spaces |
|
import os |
|
import random |
|
from datetime import datetime |
|
from typing import Optional |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from diffusers import ( |
|
AnimateDiffPipeline, |
|
DiffusionPipeline, |
|
LCMScheduler, |
|
MotionAdapter, |
|
) |
|
from diffusers.utils import export_to_video |
|
from peft import PeftModel |
|
|
|
device = "cuda" |
|
mcm_id = "yhzhai/mcm" |
|
basedir = os.getcwd() |
|
savedir = os.path.join( |
|
basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S") |
|
) |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
def get_modelscope_pipeline( |
|
mcm_variant: Optional[str] = "WebVid", |
|
): |
|
model_id = "ali-vilab/text-to-video-ms-1.7b" |
|
|
|
|
|
|
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
model_id |
|
) |
|
scheduler = LCMScheduler.from_pretrained( |
|
model_id, |
|
subfolder="scheduler", |
|
timestep_scaling=4.0, |
|
) |
|
pipe.scheduler = scheduler |
|
pipe.enable_vae_slicing() |
|
|
|
if mcm_variant == "WebVid": |
|
subfolder = "modelscopet2v-webvid" |
|
elif mcm_variant == "LAION-aes": |
|
subfolder = "modelscopet2v-laion" |
|
elif mcm_variant == "Anime": |
|
subfolder = "modelscopet2v-anime" |
|
elif mcm_variant == "Realistic": |
|
subfolder = "modelscopet2v-real" |
|
elif mcm_variant == "3D Cartoon": |
|
subfolder = "modelscopet2v-3d-cartoon" |
|
else: |
|
subfolder = "modelscopet2v-laion" |
|
|
|
lora = PeftModel.from_pretrained( |
|
pipe.unet, |
|
model_id=mcm_id, |
|
subfolder=subfolder, |
|
adapter_name="lora", |
|
torch_device="cpu", |
|
) |
|
lora.merge_and_unload() |
|
pipe.unet = lora |
|
|
|
pipe = pipe.to(device) |
|
|
|
return pipe |
|
|
|
|
|
def get_animatediff_pipeline( |
|
real_variant: Optional[str] = "realvision", |
|
motion_module_path: str = "guoyww/animatediff-motion-adapter-v1-5-2", |
|
mcm_variant: Optional[str] = "WebVid", |
|
): |
|
if real_variant is None: |
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
elif real_variant == "epicrealism": |
|
model_id = "emilianJR/epiCRealism" |
|
elif real_variant == "realvision": |
|
model_id = "SG161222/Realistic_Vision_V6.0_B1_noVAE" |
|
else: |
|
raise ValueError(f"Unknown real_variant {real_variant}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
adapter = MotionAdapter.from_pretrained( |
|
motion_module_path |
|
) |
|
pipe = AnimateDiffPipeline.from_pretrained( |
|
model_id, |
|
motion_adapter=adapter, |
|
) |
|
scheduler = LCMScheduler.from_pretrained( |
|
model_id, |
|
subfolder="scheduler", |
|
timestep_scaling=4.0, |
|
clip_sample=False, |
|
timestep_spacing="linspace", |
|
beta_schedule="linear", |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
steps_offset=1, |
|
) |
|
pipe.scheduler = scheduler |
|
pipe.enable_vae_slicing() |
|
|
|
if mcm_variant == "WebVid": |
|
subfolder = "animatediff-webvid" |
|
elif mcm_variant == "LAION-aes": |
|
subfolder = "animatediff-laion" |
|
else: |
|
subfolder = "animatediff-laion" |
|
|
|
lora = PeftModel.from_pretrained( |
|
pipe.unet, |
|
model_id=mcm_id, |
|
subfolder=subfolder, |
|
adapter_name="lora", |
|
torch_device="cpu", |
|
) |
|
lora.merge_and_unload() |
|
pipe.unet = lora |
|
|
|
pipe = pipe.to(device) |
|
return pipe |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cache_pipeline = { |
|
"base_model": None, |
|
"variant": None, |
|
"pipeline": None, |
|
} |
|
|
|
|
|
@spaces.GPU |
|
def infer( |
|
base_model, variant, prompt, seed=0, randomize_seed=True, num_inference_steps=6 |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
cache_pipeline["base_model"] == base_model |
|
and cache_pipeline["variant"] == variant |
|
): |
|
pass |
|
else: |
|
if base_model == "ModelScope T2V": |
|
pipeline = get_modelscope_pipeline(mcm_variant=variant) |
|
elif base_model == "AnimateDiff (SD1.5)": |
|
pipeline = get_animatediff_pipeline( |
|
real_variant=None, |
|
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2", |
|
mcm_variant=variant, |
|
) |
|
elif base_model == "AnimateDiff (RealisticVision)": |
|
pipeline = get_animatediff_pipeline( |
|
real_variant="realvision", |
|
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2", |
|
mcm_variant=variant, |
|
) |
|
elif base_model == "AnimateDiff (epiCRealism)": |
|
pipeline = get_animatediff_pipeline( |
|
real_variant="epicrealism", |
|
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2", |
|
mcm_variant=variant, |
|
) |
|
else: |
|
raise ValueError(f"Unknown base_model {base_model}") |
|
|
|
cache_pipeline["base_model"] = base_model |
|
cache_pipeline["variant"] = variant |
|
cache_pipeline["pipeline"] = pipeline |
|
|
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
|
|
generator = torch.Generator("cpu").manual_seed(seed) |
|
|
|
output = cache_pipeline["pipeline"]( |
|
prompt=prompt, |
|
num_frames=16, |
|
guidance_scale=1.0, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
).frames |
|
if not isinstance(output, list): |
|
output = [output[i] for i in range(output.shape[0])] |
|
|
|
os.makedirs(savedir, exist_ok=True) |
|
save_path = os.path.join( |
|
savedir, f"sample_{base_model}_{variant}_{seed}.mp4".replace(" ", "_") |
|
) |
|
export_to_video( |
|
output[0], |
|
save_path, |
|
fps=7, |
|
) |
|
print(f"Saved to {save_path}") |
|
return save_path |
|
|
|
|
|
examples = [ |
|
[ |
|
"ModelScope T2V", |
|
"LAION-aes", |
|
"Aerial uhd 4k view. mid-air flight over fresh and clean mountain river at sunny summer morning. Green trees and sun rays on horizon. Direct on sun.", |
|
], |
|
["ModelScope T2V", "Anime", "Timelapse misty mountain landscape"], |
|
[ |
|
"ModelScope T2V", |
|
"WebVid", |
|
"Back of woman in shorts going near pure creek in beautiful mountains.", |
|
], |
|
[ |
|
"ModelScope T2V", |
|
"3D Cartoon", |
|
"A rotating pandoro (a traditional italian sweet yeast bread, most popular around christmas and new year) being eaten in time-lapse.", |
|
], |
|
[ |
|
"ModelScope T2V", |
|
"Realistic", |
|
"Slow motion avocado with a stone falls and breaks into 2 parts with splashes", |
|
], |
|
[ |
|
"AnimateDiff (SD1.5)", |
|
"LAION-aes", |
|
"Slow motion of delicious salmon sachimi set with green vegetables leaves served on wood plate. make homemade japanese food at home.-dan", |
|
], |
|
[ |
|
"AnimateDiff (SD1.5)", |
|
"WebVid", |
|
"Blooming meadow panorama zoom-out shot heavenly clouds and upcoming thunderstorm in mountain range harz, germany.", |
|
], |
|
[ |
|
"AnimateDiff (RealisticVision)", |
|
"LAION-aes", |
|
"A young woman in a yellow sweater uses vr glasses, sitting on the shore of a pond on a background of dark waves. a strong wind develops her hair, the sun's rays are reflected from the water.", |
|
], |
|
[ |
|
"AnimateDiff (RealisticVision)", |
|
"LAION-aes", |
|
"Female running at sunset. healthy fitness concept", |
|
], |
|
] |
|
|
|
css = """ |
|
#col-container { |
|
margin: 0 auto; |
|
} |
|
""" |
|
|
|
variants = { |
|
"ModelScope T2V": ["WebVid", "LAION-aes", "Anime", "Realistic", "3D Cartoon"], |
|
"AnimateDiff (SD1.5)": ["WebVid", "LAION-aes"], |
|
"AnimateDiff (RealisticVision)": ["WebVid", "LAION-aes"], |
|
"AnimateDiff (epiCRealism)": ["WebVid", "LAION-aes"], |
|
} |
|
|
|
|
|
def update_variant(rs): |
|
return gr.update(choices=variants[rs], value=None) |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
|
with gr.Column(elem_id="col-container"): |
|
gr.HTML( |
|
""" |
|
<div style="text-align: center; margin-bottom: 20px;"> |
|
<h1 align="center"> |
|
<a href="https://yhzhai.github.io/mcm/"><b>Motion Consistency Model: Accelerating Video Diffusion with Disentangled Motion-Appearance Distillation</b></a> |
|
</h1> |
|
<h4>Our motion consistency model not only accelerates text2video diffusion model sampling process, but also can benefit from an additional high-quality image dataset to improve the frame quality of generated videos.</h4> |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
|
<a href='https://yhzhai.github.io/mcm/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> |
|
<a href='https://arxiv.org/abs/2406.06890'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a> |
|
<a href='https://huggingface.co/yhzhai/mcm'><img src='https://img.shields.io/badge/HF-checkpoint-yellow'></a> |
|
</div> |
|
</div> |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
f""" |
|
<p align="center"> Currently running on {device}.</p> |
|
""" |
|
) |
|
with gr.Row(): |
|
base_model = gr.Dropdown( |
|
label="Base model", |
|
choices=[ |
|
"ModelScope T2V", |
|
"AnimateDiff (SD1.5)", |
|
"AnimateDiff (RealisticVision)", |
|
"AnimateDiff (epiCRealism)", |
|
], |
|
value="ModelScope T2V", |
|
interactive=True, |
|
) |
|
variant_dropdown = gr.Dropdown( |
|
variants["ModelScope T2V"], |
|
label="MCM Variant", |
|
interactive=True, |
|
value=None, |
|
) |
|
base_model.change( |
|
update_variant, inputs=[base_model], outputs=[variant_dropdown] |
|
) |
|
|
|
with gr.Row(): |
|
prompt = gr.Text( |
|
label="Prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your prompt", |
|
container=False, |
|
) |
|
|
|
run_button = gr.Button("Run", scale=0) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Accordion("Advanced Settings", open=True): |
|
|
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0, |
|
) |
|
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
|
|
with gr.Row(): |
|
num_inference_steps = gr.Slider( |
|
label="Number of inference steps", |
|
minimum=1, |
|
maximum=16, |
|
step=1, |
|
value=4, |
|
) |
|
|
|
with gr.Column(): |
|
|
|
result = gr.Video( |
|
label="Result", show_label=False, interactive=False, autoplay=True, height=512, width=512, |
|
) |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[base_model, variant_dropdown, prompt], |
|
cache_examples=True, |
|
fn=infer, |
|
outputs=[result], |
|
) |
|
|
|
run_button.click( |
|
fn=infer, |
|
inputs=[ |
|
base_model, |
|
variant_dropdown, |
|
prompt, |
|
seed, |
|
randomize_seed, |
|
num_inference_steps, |
|
], |
|
outputs=[result], |
|
) |
|
|
|
demo.queue().launch() |
|
|