|
import os |
|
import time |
|
import random |
|
import datetime |
|
import os.path as osp |
|
from functools import partial |
|
|
|
import tqdm |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
import gradio as gr |
|
|
|
from mld.config import get_module_config |
|
from mld.data.get_data import get_dataset |
|
from mld.models.modeltype.mld import MLD |
|
from mld.utils.utils import set_seed |
|
from mld.data.humanml.utils.plot_script import plot_3d_motion |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
WEBSITE = """ |
|
<div class="embed_hidden"> |
|
<h1 style='text-align: center'> MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model </h1> |
|
<h2 style='text-align: center'> |
|
<a href="https://github.com/Dai-Wenxun/" target="_blank"><nobr>Wenxun Dai</nobr><sup>1</sup></a>   |
|
<a href="https://lhchen.top/" target="_blank"><nobr>Ling-Hao Chen</nobr></a><sup>1</sup>   |
|
<a href="https://wangjingbo1219.github.io/" target="_blank"><nobr>Jingbo Wang</nobr></a><sup>2</sup>   |
|
<a href="https://moonsliu.github.io/" target="_blank"><nobr>Jinpeng Liu</nobr></a><sup>1</sup>   |
|
<a href="https://daibo.info/" target="_blank"><nobr>Bo Dai</nobr></a><sup>2</sup>   |
|
<a href="https://andytang15.github.io/" target="_blank"><nobr>Yansong Tang</nobr></a><sup>1</sup> |
|
</h2> |
|
<h2 style='text-align: center'> |
|
<nobr><sup>1</sup>Tsinghua University</nobr>   |
|
<nobr><sup>2</sup>Shanghai AI Laboratory</nobr> |
|
</h2> |
|
</div> |
|
""" |
|
|
|
WEBSITE_bottom = """ |
|
<div class="embed_hidden"> |
|
<p> |
|
Space adapted from <a href="https://huggingface.co/spaces/Mathux/TMR" target="_blank">TMR</a> |
|
and <a href="https://huggingface.co/spaces/MeYourHint/MoMask" target="_blank">MoMask</a>. |
|
</p> |
|
</div> |
|
""" |
|
|
|
EXAMPLES = [ |
|
"a person does a jump", |
|
"a person waves both arms in the air.", |
|
"The person takes 4 steps backwards.", |
|
"this person bends forward as if to bow.", |
|
"The person was pushed but did not fall.", |
|
"a man walks forward in a snake like pattern.", |
|
"a man paces back and forth along the same line.", |
|
"with arms out to the sides a person walks forward", |
|
"A man bends down and picks something up with his right hand.", |
|
"The man walked forward, spun right on one foot and walked back to his original position.", |
|
"a person slightly bent over with right hand pressing against the air walks forward slowly" |
|
] |
|
|
|
if not os.path.exists("./experiments_t2m/"): |
|
os.system("bash prepare/download_pretrained_models.sh") |
|
if not os.path.exists('./deps/glove/'): |
|
os.system("bash prepare/download_glove.sh") |
|
if not os.path.exists('./deps/sentence-t5-large/'): |
|
os.system("bash prepare/prepare_t5.sh") |
|
if not os.path.exists('./deps/t2m/'): |
|
os.system("bash prepare/download_t2m_evaluators.sh") |
|
if not os.path.exists('./datasets/humanml3d/'): |
|
os.system("bash prepare/prepare_tiny_humanml3d.sh") |
|
|
|
DEFAULT_TEXT = "cheerfully walking forward with each step." |
|
MAX_VIDEOS = 8 |
|
NUM_ROWS = 2 |
|
NUM_COLS = MAX_VIDEOS // NUM_ROWS |
|
EXAMPLES_PER_PAGE = 12 |
|
T2M_CFG = "./configs/mld_t2m.yaml" |
|
step_map = {1: 10, 2: 25, 4: 50} |
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
print("device: ", device) |
|
|
|
cfg = OmegaConf.load(T2M_CFG) |
|
cfg_root = os.path.dirname(T2M_CFG) |
|
cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root) |
|
cfg = OmegaConf.merge(cfg, cfg_model) |
|
set_seed(cfg.SEED_VALUE) |
|
|
|
name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) |
|
cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str) |
|
vis_dir = osp.join(cfg.output_dir, 'samples') |
|
os.makedirs(cfg.output_dir, exist_ok=False) |
|
os.makedirs(vis_dir, exist_ok=False) |
|
|
|
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"] |
|
print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) |
|
|
|
is_lcm = False |
|
lcm_key = 'denoiser.time_embedding.cond_proj.weight' |
|
if lcm_key in state_dict: |
|
is_lcm = True |
|
time_cond_proj_dim = state_dict[lcm_key].shape[1] |
|
cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim |
|
print(f'Is LCM: {is_lcm}') |
|
|
|
dataset = get_dataset(cfg) |
|
model = MLD(cfg, dataset) |
|
model.to(device) |
|
model.eval() |
|
model.requires_grad_(False) |
|
model.load_state_dict(state_dict) |
|
|
|
FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE") |
|
|
|
|
|
@torch.no_grad() |
|
def generate(text_, motion_len_): |
|
batch = {"text": [text_] * MAX_VIDEOS, "length": [motion_len_] * MAX_VIDEOS} |
|
|
|
s = time.time() |
|
joints = model(batch)[0] |
|
runtime_infer = round(time.time() - s, 3) |
|
|
|
s = time.time() |
|
path = [] |
|
for i in tqdm.tqdm(range(len(joints))): |
|
uid = random.randrange(999999999) |
|
video_path = osp.join(vis_dir, f"sample_{uid}.mp4") |
|
plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=FPS) |
|
path.append(video_path) |
|
runtime_draw = round(time.time() - s, 3) |
|
|
|
runtime_info = f'Inference {len(joints)} motions, Runtime (Inference): {runtime_infer}s, ' \ |
|
f'Runtime (Draw Skeleton): {runtime_draw}s, device: {device} ' |
|
|
|
return path, runtime_info |
|
|
|
|
|
def generate_component(generate_function, text_, motion_len_, num_inference_steps_, guidance_scale_): |
|
if text_ == "" or text_ is None: |
|
return [None] * MAX_VIDEOS + ["Please modify the text prompt."] |
|
|
|
model.cfg.model.scheduler.num_inference_steps = step_map[num_inference_steps_] |
|
model.guidance_scale = guidance_scale_ |
|
motion_len_ = max(36, min(int(float(motion_len_) * FPS), 196)) |
|
paths, info = generate_function(text_, motion_len_) |
|
paths = paths + [None] * (MAX_VIDEOS - len(paths)) |
|
return paths + [info] |
|
|
|
|
|
theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray") |
|
generate_and_show = partial(generate_component, generate) |
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
gr.HTML(WEBSITE) |
|
videos = [] |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
text = gr.Textbox( |
|
show_label=True, |
|
label="Text prompt", |
|
value=DEFAULT_TEXT, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
motion_len = gr.Slider( |
|
minimum=1.8, |
|
maximum=9.8, |
|
step=0.2, |
|
value=5.0, |
|
label="Motion length", |
|
info="Motion duration in seconds: [1.8s, 9.8s] (FPS = 20)." |
|
) |
|
|
|
with gr.Column(scale=1): |
|
num_inference_steps = gr.Radio( |
|
[1, 2, 4], |
|
label="Inference steps", |
|
value=4, |
|
info="Number of inference steps.", |
|
) |
|
|
|
cfg = gr.Slider( |
|
minimum=1, |
|
maximum=15, |
|
step=0.5, |
|
value=7.5, |
|
label="CFG", |
|
info="Classifier-free diffusion guidance.", |
|
) |
|
|
|
gen_btn = gr.Button("Generate", variant="primary") |
|
clear = gr.Button("Clear", variant="secondary") |
|
|
|
results = gr.Textbox(show_label=True, |
|
label='Inference info (runtime and device)', |
|
info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.', |
|
interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
examples = gr.Examples( |
|
examples=EXAMPLES, |
|
inputs=[text], |
|
examples_per_page=EXAMPLES_PER_PAGE) |
|
|
|
for i in range(NUM_ROWS): |
|
with gr.Row(): |
|
for j in range(NUM_COLS): |
|
video = gr.Video(autoplay=True, loop=True) |
|
videos.append(video) |
|
|
|
|
|
|
|
gen_btn.click( |
|
fn=generate_and_show, |
|
inputs=[text, motion_len, num_inference_steps, cfg], |
|
outputs=videos + [results], |
|
) |
|
text.submit( |
|
fn=generate_and_show, |
|
inputs=[text, motion_len, num_inference_steps, cfg], |
|
outputs=videos + [results], |
|
) |
|
|
|
|
|
def clear_videos(): |
|
return [None] * MAX_VIDEOS + [DEFAULT_TEXT] + [None] |
|
|
|
|
|
clear.click(fn=clear_videos, outputs=videos + [text] + [results]) |
|
|
|
demo.launch() |
|
|