wxDai commited on
Commit
eb339cb
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +12 -0
  2. LICENSE +25 -0
  3. README.md +11 -0
  4. app.py +236 -0
  5. configs/mld_t2m.yaml +104 -0
  6. configs/modules/denoiser.yaml +28 -0
  7. configs/modules/motion_vae.yaml +18 -0
  8. configs/modules/noise_optimizer.yaml +15 -0
  9. configs/modules/scheduler_ddim.yaml +14 -0
  10. configs/modules/scheduler_lcm.yaml +19 -0
  11. configs/modules/text_encoder.yaml +5 -0
  12. configs/modules/traj_encoder.yaml +17 -0
  13. configs/motionlcm_control_s.yaml +113 -0
  14. configs/motionlcm_control_t.yaml +111 -0
  15. configs/motionlcm_t2m.yaml +109 -0
  16. configs/motionlcm_t2m_clt.yaml +69 -0
  17. configs/vae.yaml +103 -0
  18. configs_v1/modules/denoiser.yaml +28 -0
  19. configs_v1/modules/motion_vae.yaml +18 -0
  20. configs_v1/modules/scheduler_lcm.yaml +11 -0
  21. configs_v1/modules/text_encoder.yaml +5 -0
  22. configs_v1/modules/traj_encoder.yaml +17 -0
  23. configs_v1/motionlcm_control_t.yaml +114 -0
  24. configs_v1/motionlcm_t2m.yaml +109 -0
  25. demo.py +196 -0
  26. fit.py +136 -0
  27. mld/__init__.py +0 -0
  28. mld/config.py +52 -0
  29. mld/data/__init__.py +0 -0
  30. mld/data/base.py +58 -0
  31. mld/data/data.py +73 -0
  32. mld/data/get_data.py +79 -0
  33. mld/data/humanml/__init__.py +0 -0
  34. mld/data/humanml/common/quaternion.py +29 -0
  35. mld/data/humanml/dataset.py +348 -0
  36. mld/data/humanml/scripts/motion_process.py +51 -0
  37. mld/data/humanml/utils/__init__.py +0 -0
  38. mld/data/humanml/utils/paramUtil.py +62 -0
  39. mld/data/humanml/utils/plot_script.py +98 -0
  40. mld/data/humanml/utils/word_vectorizer.py +82 -0
  41. mld/data/utils.py +52 -0
  42. mld/launch/__init__.py +0 -0
  43. mld/launch/blender.py +23 -0
  44. mld/models/__init__.py +0 -0
  45. mld/models/architectures/__init__.py +0 -0
  46. mld/models/architectures/dno.py +79 -0
  47. mld/models/architectures/mld_clip.py +72 -0
  48. mld/models/architectures/mld_denoiser.py +200 -0
  49. mld/models/architectures/mld_traj_encoder.py +64 -0
  50. mld/models/architectures/mld_vae.py +136 -0
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/*.pyc
2
+ .idea/
3
+ __pycache__/
4
+
5
+ deps/
6
+ datasets/
7
+ experiments_t2m/
8
+ experiments_t2m_test/
9
+ experiments_control/
10
+ experiments_control_test/
11
+ experiments_recons/
12
+ experiments_recons_test/
LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright Tsinghua University and Shanghai AI Laboratory. All Rights Reserved.
2
+
3
+ License for Non-commercial Scientific Research Purposes.
4
+
5
+ For more information see <https://github.com/Dai-Wenxun/MotionLCM>.
6
+ If you use this software, please cite the corresponding publications
7
+ listed on the above website.
8
+
9
+ Permission to use, copy, modify, and distribute this software and its
10
+ documentation for educational, research, and non-profit purposes only.
11
+ Any modification based on this work must be open-source and prohibited
12
+ for commercial, pornographic, military, or surveillance use.
13
+
14
+ The authors grant you a non-exclusive, worldwide, non-transferable,
15
+ non-sublicensable, revocable, royalty-free, and limited license under
16
+ our copyright interests to reproduce, distribute, and create derivative
17
+ works of the text, videos, and codes solely for your non-commercial
18
+ research purposes.
19
+
20
+ You must retain, in the source form of any derivative works that you
21
+ distribute, all copyright, patent, trademark, and attribution notices
22
+ from the source form of this work.
23
+
24
+ For commercial uses of this software, please send email to all people
25
+ in the author list.
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MotionLCM
3
+ emoji: 🏎️💨
4
+ colorFrom: yellow
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ python_version: 3.10.12
11
+ ---
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import datetime
5
+ import os.path as osp
6
+ from functools import partial
7
+
8
+ import tqdm
9
+ from omegaconf import OmegaConf
10
+
11
+ import torch
12
+ import gradio as gr
13
+
14
+ from mld.config import get_module_config
15
+ from mld.data.get_data import get_dataset
16
+ from mld.models.modeltype.mld import MLD
17
+ from mld.utils.utils import set_seed
18
+ from mld.data.humanml.utils.plot_script import plot_3d_motion
19
+
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+
22
+ WEBSITE = """
23
+ <div class="embed_hidden">
24
+ <h1 style='text-align: center'> MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model </h1>
25
+ <h2 style='text-align: center'>
26
+ <a href="https://github.com/Dai-Wenxun/" target="_blank"><nobr>Wenxun Dai</nobr><sup>1</sup></a> &emsp;
27
+ <a href="https://lhchen.top/" target="_blank"><nobr>Ling-Hao Chen</nobr></a><sup>1</sup> &emsp;
28
+ <a href="https://wangjingbo1219.github.io/" target="_blank"><nobr>Jingbo Wang</nobr></a><sup>2</sup> &emsp;
29
+ <a href="https://moonsliu.github.io/" target="_blank"><nobr>Jinpeng Liu</nobr></a><sup>1</sup> &emsp;
30
+ <a href="https://daibo.info/" target="_blank"><nobr>Bo Dai</nobr></a><sup>2</sup> &emsp;
31
+ <a href="https://andytang15.github.io/" target="_blank"><nobr>Yansong Tang</nobr></a><sup>1</sup>
32
+ </h2>
33
+ <h2 style='text-align: center'>
34
+ <nobr><sup>1</sup>Tsinghua University</nobr> &emsp;
35
+ <nobr><sup>2</sup>Shanghai AI Laboratory</nobr>
36
+ </h2>
37
+ </div>
38
+ """
39
+
40
+ WEBSITE_bottom = """
41
+ <div class="embed_hidden">
42
+ <p>
43
+ Space adapted from <a href="https://huggingface.co/spaces/Mathux/TMR" target="_blank">TMR</a>
44
+ and <a href="https://huggingface.co/spaces/MeYourHint/MoMask" target="_blank">MoMask</a>.
45
+ </p>
46
+ </div>
47
+ """
48
+
49
+ EXAMPLES = [
50
+ "a person does a jump",
51
+ "a person waves both arms in the air.",
52
+ "The person takes 4 steps backwards.",
53
+ "this person bends forward as if to bow.",
54
+ "The person was pushed but did not fall.",
55
+ "a man walks forward in a snake like pattern.",
56
+ "a man paces back and forth along the same line.",
57
+ "with arms out to the sides a person walks forward",
58
+ "A man bends down and picks something up with his right hand.",
59
+ "The man walked forward, spun right on one foot and walked back to his original position.",
60
+ "a person slightly bent over with right hand pressing against the air walks forward slowly"
61
+ ]
62
+
63
+ if not os.path.exists("./experiments_t2m/"):
64
+ os.system("bash prepare/download_pretrained_models.sh")
65
+ if not os.path.exists('./deps/glove/'):
66
+ os.system("bash prepare/download_glove.sh")
67
+ if not os.path.exists('./deps/sentence-t5-large/'):
68
+ os.system("bash prepare/prepare_t5.sh")
69
+ if not os.path.exists('./deps/t2m/'):
70
+ os.system("bash prepare/download_t2m_evaluators.sh")
71
+ if not os.path.exists('./datasets/humanml3d/'):
72
+ os.system("bash prepare/prepare_tiny_humanml3d.sh")
73
+
74
+ DEFAULT_TEXT = "cheerfully walking forward with each step."
75
+ MAX_VIDEOS = 8
76
+ NUM_ROWS = 2
77
+ NUM_COLS = MAX_VIDEOS // NUM_ROWS
78
+ EXAMPLES_PER_PAGE = 12
79
+ T2M_CFG = "./configs/mld_t2m.yaml"
80
+ step_map = {1: 10, 2: 25, 4: 50}
81
+
82
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
83
+ print("device: ", device)
84
+
85
+ cfg = OmegaConf.load(T2M_CFG)
86
+ cfg_root = os.path.dirname(T2M_CFG)
87
+ cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root)
88
+ cfg = OmegaConf.merge(cfg, cfg_model)
89
+ set_seed(cfg.SEED_VALUE)
90
+
91
+ name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
92
+ cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
93
+ vis_dir = osp.join(cfg.output_dir, 'samples')
94
+ os.makedirs(cfg.output_dir, exist_ok=False)
95
+ os.makedirs(vis_dir, exist_ok=False)
96
+
97
+ state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
98
+ print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
99
+
100
+ is_lcm = False
101
+ lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG
102
+ if lcm_key in state_dict:
103
+ is_lcm = True
104
+ time_cond_proj_dim = state_dict[lcm_key].shape[1]
105
+ cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
106
+ print(f'Is LCM: {is_lcm}')
107
+
108
+ dataset = get_dataset(cfg)
109
+ model = MLD(cfg, dataset)
110
+ model.to(device)
111
+ model.eval()
112
+ model.requires_grad_(False)
113
+ model.load_state_dict(state_dict)
114
+
115
+ FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE")
116
+
117
+
118
+ @torch.no_grad()
119
+ def generate(text_, motion_len_):
120
+ batch = {"text": [text_] * MAX_VIDEOS, "length": [motion_len_] * MAX_VIDEOS}
121
+
122
+ s = time.time()
123
+ joints = model(batch)[0]
124
+ runtime_infer = round(time.time() - s, 3)
125
+
126
+ s = time.time()
127
+ path = []
128
+ for i in tqdm.tqdm(range(len(joints))):
129
+ uid = random.randrange(999999999)
130
+ video_path = osp.join(vis_dir, f"sample_{uid}.mp4")
131
+ plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=FPS)
132
+ path.append(video_path)
133
+ runtime_draw = round(time.time() - s, 3)
134
+
135
+ runtime_info = f'Inference {len(joints)} motions, Runtime (Inference): {runtime_infer}s, ' \
136
+ f'Runtime (Draw Skeleton): {runtime_draw}s, device: {device} '
137
+
138
+ return path, runtime_info
139
+
140
+
141
+ def generate_component(generate_function, text_, motion_len_, num_inference_steps_, guidance_scale_):
142
+ if text_ == "" or text_ is None:
143
+ return [None] * MAX_VIDEOS + ["Please modify the text prompt."]
144
+
145
+ model.cfg.model.scheduler.num_inference_steps = step_map[num_inference_steps_]
146
+ model.guidance_scale = guidance_scale_
147
+ motion_len_ = max(36, min(int(float(motion_len_) * FPS), 196))
148
+ paths, info = generate_function(text_, motion_len_)
149
+ paths = paths + [None] * (MAX_VIDEOS - len(paths))
150
+ return paths + [info]
151
+
152
+
153
+ theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray")
154
+ generate_and_show = partial(generate_component, generate)
155
+
156
+ with gr.Blocks(theme=theme) as demo:
157
+ gr.HTML(WEBSITE)
158
+ videos = []
159
+
160
+ with gr.Row():
161
+ with gr.Column(scale=3):
162
+ text = gr.Textbox(
163
+ show_label=True,
164
+ label="Text prompt",
165
+ value=DEFAULT_TEXT,
166
+ )
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=2):
170
+ motion_len = gr.Slider(
171
+ minimum=1.8,
172
+ maximum=9.8,
173
+ step=0.2,
174
+ value=5.0,
175
+ label="Motion length",
176
+ info="Motion duration in seconds: [1.8s, 9.8s] (FPS = 20)."
177
+ )
178
+
179
+ with gr.Column(scale=1):
180
+ num_inference_steps = gr.Radio(
181
+ [1, 2, 4],
182
+ label="Inference steps",
183
+ value=4,
184
+ info="Number of inference steps.",
185
+ )
186
+
187
+ cfg = gr.Slider(
188
+ minimum=1,
189
+ maximum=15,
190
+ step=0.5,
191
+ value=7.5,
192
+ label="CFG",
193
+ info="Classifier-free diffusion guidance.",
194
+ )
195
+
196
+ gen_btn = gr.Button("Generate", variant="primary")
197
+ clear = gr.Button("Clear", variant="secondary")
198
+
199
+ results = gr.Textbox(show_label=True,
200
+ label='Inference info (runtime and device)',
201
+ info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.',
202
+ interactive=False)
203
+
204
+ with gr.Column(scale=2):
205
+ examples = gr.Examples(
206
+ examples=EXAMPLES,
207
+ inputs=[text],
208
+ examples_per_page=EXAMPLES_PER_PAGE)
209
+
210
+ for i in range(NUM_ROWS):
211
+ with gr.Row():
212
+ for j in range(NUM_COLS):
213
+ video = gr.Video(autoplay=True, loop=True)
214
+ videos.append(video)
215
+
216
+ # gr.HTML(WEBSITE_bottom)
217
+
218
+ gen_btn.click(
219
+ fn=generate_and_show,
220
+ inputs=[text, motion_len, num_inference_steps, cfg],
221
+ outputs=videos + [results],
222
+ )
223
+ text.submit(
224
+ fn=generate_and_show,
225
+ inputs=[text, motion_len, num_inference_steps, cfg],
226
+ outputs=videos + [results],
227
+ )
228
+
229
+
230
+ def clear_videos():
231
+ return [None] * MAX_VIDEOS + [DEFAULT_TEXT] + [None]
232
+
233
+
234
+ clear.click(fn=clear_videos, outputs=videos + [text] + [results])
235
+
236
+ demo.launch()
configs/mld_t2m.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_t2m'
2
+ TEST_FOLDER: './experiments_t2m_test'
3
+
4
+ NAME: 'mld_humanml'
5
+
6
+ SEED_VALUE: 1234
7
+
8
+ TRAIN:
9
+ BATCH_SIZE: 64
10
+ SPLIT: 'train'
11
+ NUM_WORKERS: 8
12
+ PERSISTENT_WORKERS: true
13
+
14
+ PRETRAINED: 'experiments_recons/vae_humanml/vae_humanml.ckpt'
15
+
16
+ validation_steps: -1
17
+ validation_epochs: 50
18
+ checkpointing_steps: -1
19
+ checkpointing_epochs: 50
20
+ max_train_steps: -1
21
+ max_train_epochs: 3000
22
+ learning_rate: 1e-4
23
+ lr_scheduler: "cosine"
24
+ lr_warmup_steps: 1000
25
+ adam_beta1: 0.9
26
+ adam_beta2: 0.999
27
+ adam_weight_decay: 0.0
28
+ adam_epsilon: 1e-08
29
+ max_grad_norm: 1.0
30
+ model_ema: false
31
+ model_ema_steps: 32
32
+ model_ema_decay: 0.999
33
+
34
+ VAL:
35
+ BATCH_SIZE: 32
36
+ SPLIT: 'test'
37
+ NUM_WORKERS: 12
38
+ PERSISTENT_WORKERS: true
39
+
40
+ TEST:
41
+ BATCH_SIZE: 32
42
+ SPLIT: 'test'
43
+ NUM_WORKERS: 12
44
+ PERSISTENT_WORKERS: true
45
+
46
+ CHECKPOINTS: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
47
+
48
+ # Testing Args
49
+ REPLICATION_TIMES: 20
50
+ MM_NUM_SAMPLES: 100
51
+ MM_NUM_REPEATS: 30
52
+ MM_NUM_TIMES: 10
53
+ DIVERSITY_TIMES: 300
54
+ DO_MM_TEST: true
55
+
56
+ DATASET:
57
+ NAME: 'humanml3d'
58
+ SMPL_PATH: './deps/smpl'
59
+ WORD_VERTILIZER_PATH: './deps/glove/'
60
+ HUMANML3D:
61
+ FRAME_RATE: 20.0
62
+ UNIT_LEN: 4
63
+ ROOT: './datasets/humanml3d'
64
+ CONTROL_ARGS:
65
+ CONTROL: false
66
+ TEMPORAL: false
67
+ TRAIN_JOINTS: [0]
68
+ TEST_JOINTS: [0]
69
+ TRAIN_DENSITY: 'random'
70
+ TEST_DENSITY: 100
71
+ MEAN_STD_PATH: './datasets/humanml_spatial_norm'
72
+ SAMPLER:
73
+ MAX_LEN: 200
74
+ MIN_LEN: 40
75
+ MAX_TEXT_LEN: 20
76
+ PADDING_TO_MAX: false
77
+ WINDOW_SIZE: null
78
+
79
+ METRIC:
80
+ DIST_SYNC_ON_STEP: true
81
+ TYPE: ['TM2TMetrics']
82
+
83
+ model:
84
+ target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_ddim', 'noise_optimizer']
85
+ latent_dim: [16, 32]
86
+ guidance_scale: 7.5
87
+ guidance_uncondp: 0.1
88
+
89
+ t2m_textencoder:
90
+ dim_word: 300
91
+ dim_pos_ohot: 15
92
+ dim_text_hidden: 512
93
+ dim_coemb_hidden: 512
94
+
95
+ t2m_motionencoder:
96
+ dim_move_hidden: 512
97
+ dim_move_latent: 512
98
+ dim_motion_hidden: 1024
99
+ dim_motion_latent: 512
100
+
101
+ bert_path: './deps/distilbert-base-uncased'
102
+ clip_path: './deps/clip-vit-large-patch14'
103
+ t5_path: './deps/sentence-t5-large'
104
+ t2m_path: './deps/t2m/'
configs/modules/denoiser.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ denoiser:
2
+ target: mld.models.architectures.mld_denoiser.MldDenoiser
3
+ params:
4
+ latent_dim: ${model.latent_dim}
5
+ hidden_dim: 256
6
+ text_dim: 768
7
+ time_dim: 768
8
+ ff_size: 1024
9
+ num_layers: 9
10
+ num_heads: 4
11
+ dropout: 0.1
12
+ normalize_before: false
13
+ norm_eps: 1e-5
14
+ activation: 'gelu'
15
+ norm_post: true
16
+ activation_post: null
17
+ flip_sin_to_cos: true
18
+ freq_shift: 0
19
+ time_act_fn: 'silu'
20
+ time_post_act_fn: null
21
+ position_embedding: 'learned'
22
+ arch: 'trans_enc'
23
+ add_mem_pos: true
24
+ force_pre_post_proj: true
25
+ text_act_fn: null
26
+ zero_init_cond: true
27
+ controlnet_embed_dim: 256
28
+ controlnet_act_fn: 'silu'
configs/modules/motion_vae.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_vae:
2
+ target: mld.models.architectures.mld_vae.MldVae
3
+ params:
4
+ nfeats: ${DATASET.NFEATS}
5
+ latent_dim: ${model.latent_dim}
6
+ hidden_dim: 256
7
+ force_pre_post_proj: true
8
+ ff_size: 1024
9
+ num_layers: 9
10
+ num_heads: 4
11
+ dropout: 0.1
12
+ arch: 'encoder_decoder'
13
+ normalize_before: false
14
+ norm_eps: 1e-5
15
+ activation: 'gelu'
16
+ norm_post: true
17
+ activation_post: null
18
+ position_embedding: 'learned'
configs/modules/noise_optimizer.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ noise_optimizer:
2
+ target: mld.models.architectures.dno.DNO
3
+ params:
4
+ optimize: false
5
+ max_train_steps: 400
6
+ learning_rate: 0.1
7
+ lr_scheduler: 'cosine'
8
+ lr_warmup_steps: 50
9
+ clip_grad: true
10
+ loss_hint_type: 'l2'
11
+ loss_diff_penalty: 0.000
12
+ loss_correlate_penalty: 100
13
+ visualize_samples: 0
14
+ visualize_ske_steps: []
15
+ output_dir: ${output_dir}
configs/modules/scheduler_ddim.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scheduler:
2
+ target: diffusers.DDIMScheduler
3
+ num_inference_steps: 50
4
+ eta: 0.0
5
+ params:
6
+ num_train_timesteps: 1000
7
+ beta_start: 0.00085
8
+ beta_end: 0.012
9
+ beta_schedule: 'scaled_linear'
10
+ prediction_type: 'epsilon'
11
+ clip_sample: false
12
+ # below are for ddim
13
+ set_alpha_to_one: false
14
+ steps_offset: 1
configs/modules/scheduler_lcm.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scheduler:
2
+ target: mld.models.schedulers.scheduling_lcm.LCMScheduler
3
+ num_inference_steps: 1
4
+ cfg_step_map:
5
+ 1: 8.0
6
+ 2: 12.5
7
+ 4: 13.5
8
+ params:
9
+ num_train_timesteps: 1000
10
+ beta_start: 0.00085
11
+ beta_end: 0.012
12
+ beta_schedule: 'scaled_linear'
13
+ clip_sample: false
14
+ set_alpha_to_one: false
15
+ original_inference_steps: 10
16
+ timesteps_step_map:
17
+ 1: [799]
18
+ 2: [699, 299]
19
+ 4: [699, 399, 299, 299]
configs/modules/text_encoder.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ text_encoder:
2
+ target: mld.models.architectures.mld_clip.MldTextEncoder
3
+ params:
4
+ last_hidden_state: false
5
+ modelpath: ${model.t5_path}
configs/modules/traj_encoder.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ traj_encoder:
2
+ target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
3
+ params:
4
+ nfeats: ${DATASET.NJOINTS}
5
+ latent_dim: ${model.latent_dim}
6
+ hidden_dim: 256
7
+ force_post_proj: true
8
+ ff_size: 1024
9
+ num_layers: 9
10
+ num_heads: 4
11
+ dropout: 0.1
12
+ normalize_before: false
13
+ norm_eps: 1e-5
14
+ activation: 'gelu'
15
+ norm_post: true
16
+ activation_post: null
17
+ position_embedding: 'learned'
configs/motionlcm_control_s.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_control/spatial'
2
+ TEST_FOLDER: './experiments_control_test/spatial'
3
+
4
+ NAME: 'motionlcm_humanml'
5
+
6
+ SEED_VALUE: 1234
7
+
8
+ TRAIN:
9
+ DATASET: 'humanml3d'
10
+ BATCH_SIZE: 128
11
+ SPLIT: 'train'
12
+ NUM_WORKERS: 8
13
+ PERSISTENT_WORKERS: true
14
+
15
+ PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
16
+
17
+ validation_steps: -1
18
+ validation_epochs: 50
19
+ checkpointing_steps: -1
20
+ checkpointing_epochs: 50
21
+ max_train_steps: -1
22
+ max_train_epochs: 1000
23
+ learning_rate: 1e-4
24
+ learning_rate_spatial: 1e-4
25
+ lr_scheduler: "cosine"
26
+ lr_warmup_steps: 1000
27
+ adam_beta1: 0.9
28
+ adam_beta2: 0.999
29
+ adam_weight_decay: 0.0
30
+ adam_epsilon: 1e-08
31
+ max_grad_norm: 1.0
32
+
33
+ VAL:
34
+ DATASET: 'humanml3d'
35
+ BATCH_SIZE: 32
36
+ SPLIT: 'test'
37
+ NUM_WORKERS: 12
38
+ PERSISTENT_WORKERS: true
39
+
40
+ TEST:
41
+ DATASET: 'humanml3d'
42
+ BATCH_SIZE: 32
43
+ SPLIT: 'test'
44
+ NUM_WORKERS: 12
45
+ PERSISTENT_WORKERS: true
46
+
47
+ CHECKPOINTS: 'experiments_control/spatial/motionlcm_humanml/motionlcm_humanml_s_pelvis.ckpt'
48
+ # CHECKPOINTS: 'experiments_control/spatial/motionlcm_humanml/motionlcm_humanml_s_all.ckpt'
49
+
50
+ # Testing Args
51
+ REPLICATION_TIMES: 1
52
+ DIVERSITY_TIMES: 300
53
+ DO_MM_TEST: false
54
+ MAX_NUM_SAMPLES: 1024
55
+
56
+ DATASET:
57
+ NAME: 'humanml3d'
58
+ SMPL_PATH: './deps/smpl'
59
+ WORD_VERTILIZER_PATH: './deps/glove/'
60
+ HUMANML3D:
61
+ FRAME_RATE: 20.0
62
+ UNIT_LEN: 4
63
+ ROOT: './datasets/humanml3d'
64
+ CONTROL_ARGS:
65
+ CONTROL: true
66
+ TEMPORAL: false
67
+ TRAIN_JOINTS: [0]
68
+ TEST_JOINTS: [0]
69
+ TRAIN_DENSITY: 'random'
70
+ TEST_DENSITY: 100
71
+ MEAN_STD_PATH: './datasets/humanml_spatial_norm'
72
+ SAMPLER:
73
+ MAX_LEN: 200
74
+ MIN_LEN: 40
75
+ MAX_TEXT_LEN: 20
76
+ PADDING_TO_MAX: false
77
+ WINDOW_SIZE: null
78
+
79
+ METRIC:
80
+ DIST_SYNC_ON_STEP: true
81
+ TYPE: ['TM2TMetrics', 'ControlMetrics']
82
+
83
+ model:
84
+ target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder', 'noise_optimizer']
85
+ latent_dim: [16, 32]
86
+ guidance_scale: 'dynamic'
87
+
88
+ # ControlNet Args
89
+ is_controlnet: true
90
+ vaeloss: true
91
+ vaeloss_type: 'mask'
92
+ cond_ratio: 1.0
93
+ control_loss_func: 'l1_smooth'
94
+ use_3d: true
95
+ lcm_w_min_nax: [5, 15]
96
+ lcm_num_ddim_timesteps: 10
97
+
98
+ t2m_textencoder:
99
+ dim_word: 300
100
+ dim_pos_ohot: 15
101
+ dim_text_hidden: 512
102
+ dim_coemb_hidden: 512
103
+
104
+ t2m_motionencoder:
105
+ dim_move_hidden: 512
106
+ dim_move_latent: 512
107
+ dim_motion_hidden: 1024
108
+ dim_motion_latent: 512
109
+
110
+ bert_path: './deps/distilbert-base-uncased'
111
+ clip_path: './deps/clip-vit-large-patch14'
112
+ t5_path: './deps/sentence-t5-large'
113
+ t2m_path: './deps/t2m/'
configs/motionlcm_control_t.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_control/temporal'
2
+ TEST_FOLDER: './experiments_control_test/temporal'
3
+
4
+ NAME: 'motionlcm_humanml'
5
+
6
+ SEED_VALUE: 1234
7
+
8
+ TRAIN:
9
+ DATASET: 'humanml3d'
10
+ BATCH_SIZE: 128
11
+ SPLIT: 'train'
12
+ NUM_WORKERS: 8
13
+ PERSISTENT_WORKERS: true
14
+
15
+ PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
16
+
17
+ validation_steps: -1
18
+ validation_epochs: 50
19
+ checkpointing_steps: -1
20
+ checkpointing_epochs: 50
21
+ max_train_steps: -1
22
+ max_train_epochs: 1000
23
+ learning_rate: 1e-4
24
+ learning_rate_spatial: 1e-4
25
+ lr_scheduler: "cosine"
26
+ lr_warmup_steps: 1000
27
+ adam_beta1: 0.9
28
+ adam_beta2: 0.999
29
+ adam_weight_decay: 0.0
30
+ adam_epsilon: 1e-08
31
+ max_grad_norm: 1.0
32
+
33
+ VAL:
34
+ DATASET: 'humanml3d'
35
+ BATCH_SIZE: 32
36
+ SPLIT: 'test'
37
+ NUM_WORKERS: 12
38
+ PERSISTENT_WORKERS: true
39
+
40
+ TEST:
41
+ DATASET: 'humanml3d'
42
+ BATCH_SIZE: 32
43
+ SPLIT: 'test'
44
+ NUM_WORKERS: 12
45
+ PERSISTENT_WORKERS: true
46
+
47
+ CHECKPOINTS: 'experiments_control/temporal/motionlcm_humanml/motionlcm_humanml_t.ckpt'
48
+
49
+ # Testing Args
50
+ REPLICATION_TIMES: 20
51
+ DIVERSITY_TIMES: 300
52
+ DO_MM_TEST: false
53
+
54
+ DATASET:
55
+ NAME: 'humanml3d'
56
+ SMPL_PATH: './deps/smpl'
57
+ WORD_VERTILIZER_PATH: './deps/glove/'
58
+ HUMANML3D:
59
+ FRAME_RATE: 20.0
60
+ UNIT_LEN: 4
61
+ ROOT: './datasets/humanml3d'
62
+ CONTROL_ARGS:
63
+ CONTROL: true
64
+ TEMPORAL: true
65
+ TRAIN_JOINTS: [0, 10, 11, 15, 20, 21]
66
+ TEST_JOINTS: [0, 10, 11, 15, 20, 21]
67
+ TRAIN_DENSITY: [25, 25]
68
+ TEST_DENSITY: 25
69
+ MEAN_STD_PATH: './datasets/humanml_spatial_norm'
70
+ SAMPLER:
71
+ MAX_LEN: 200
72
+ MIN_LEN: 40
73
+ MAX_TEXT_LEN: 20
74
+ PADDING_TO_MAX: false
75
+ WINDOW_SIZE: null
76
+
77
+ METRIC:
78
+ DIST_SYNC_ON_STEP: true
79
+ TYPE: ['TM2TMetrics', 'ControlMetrics']
80
+
81
+ model:
82
+ target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder', 'noise_optimizer']
83
+ latent_dim: [16, 32]
84
+ guidance_scale: 'dynamic'
85
+
86
+ # ControlNet Args
87
+ is_controlnet: true
88
+ vaeloss: true
89
+ vaeloss_type: 'sum'
90
+ cond_ratio: 1.0
91
+ control_loss_func: 'l2'
92
+ use_3d: false
93
+ lcm_w_min_nax: [5, 15]
94
+ lcm_num_ddim_timesteps: 10
95
+
96
+ t2m_textencoder:
97
+ dim_word: 300
98
+ dim_pos_ohot: 15
99
+ dim_text_hidden: 512
100
+ dim_coemb_hidden: 512
101
+
102
+ t2m_motionencoder:
103
+ dim_move_hidden: 512
104
+ dim_move_latent: 512
105
+ dim_motion_hidden: 1024
106
+ dim_motion_latent: 512
107
+
108
+ bert_path: './deps/distilbert-base-uncased'
109
+ clip_path: './deps/clip-vit-large-patch14'
110
+ t5_path: './deps/sentence-t5-large'
111
+ t2m_path: './deps/t2m/'
configs/motionlcm_t2m.yaml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_t2m'
2
+ TEST_FOLDER: './experiments_t2m_test'
3
+
4
+ NAME: 'motionlcm_humanml'
5
+
6
+ SEED_VALUE: 1234
7
+
8
+ TRAIN:
9
+ BATCH_SIZE: 128
10
+ SPLIT: 'train'
11
+ NUM_WORKERS: 8
12
+ PERSISTENT_WORKERS: true
13
+
14
+ PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
15
+
16
+ validation_steps: -1
17
+ validation_epochs: 50
18
+ checkpointing_steps: -1
19
+ checkpointing_epochs: 50
20
+ max_train_steps: -1
21
+ max_train_epochs: 1000
22
+ learning_rate: 2e-4
23
+ lr_scheduler: "cosine"
24
+ lr_warmup_steps: 1000
25
+ adam_beta1: 0.9
26
+ adam_beta2: 0.999
27
+ adam_weight_decay: 0.0
28
+ adam_epsilon: 1e-08
29
+ max_grad_norm: 1.0
30
+
31
+ # Latent Consistency Distillation Specific Arguments
32
+ w_min: 5.0
33
+ w_max: 15.0
34
+ num_ddim_timesteps: 10
35
+ loss_type: 'huber'
36
+ huber_c: 0.5
37
+ unet_time_cond_proj_dim: 256
38
+ ema_decay: 0.95
39
+
40
+ VAL:
41
+ BATCH_SIZE: 32
42
+ SPLIT: 'test'
43
+ NUM_WORKERS: 12
44
+ PERSISTENT_WORKERS: true
45
+
46
+ TEST:
47
+ BATCH_SIZE: 32
48
+ SPLIT: 'test'
49
+ NUM_WORKERS: 12
50
+ PERSISTENT_WORKERS: true
51
+
52
+ CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
53
+
54
+ # Testing Args
55
+ REPLICATION_TIMES: 20
56
+ MM_NUM_SAMPLES: 100
57
+ MM_NUM_REPEATS: 30
58
+ MM_NUM_TIMES: 10
59
+ DIVERSITY_TIMES: 300
60
+ DO_MM_TEST: true
61
+
62
+ DATASET:
63
+ NAME: 'humanml3d'
64
+ SMPL_PATH: './deps/smpl'
65
+ WORD_VERTILIZER_PATH: './deps/glove/'
66
+ HUMANML3D:
67
+ FRAME_RATE: 20.0
68
+ UNIT_LEN: 4
69
+ ROOT: './datasets/humanml3d'
70
+ CONTROL_ARGS:
71
+ CONTROL: false
72
+ TEMPORAL: false
73
+ TRAIN_JOINTS: [0]
74
+ TEST_JOINTS: [0]
75
+ TRAIN_DENSITY: 'random'
76
+ TEST_DENSITY: 100
77
+ MEAN_STD_PATH: './datasets/humanml_spatial_norm'
78
+ SAMPLER:
79
+ MAX_LEN: 200
80
+ MIN_LEN: 40
81
+ MAX_TEXT_LEN: 20
82
+ PADDING_TO_MAX: false
83
+ WINDOW_SIZE: null
84
+
85
+ METRIC:
86
+ DIST_SYNC_ON_STEP: true
87
+ TYPE: ['TM2TMetrics']
88
+
89
+ model:
90
+ target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'noise_optimizer']
91
+ latent_dim: [16, 32]
92
+ guidance_scale: 'dynamic'
93
+
94
+ t2m_textencoder:
95
+ dim_word: 300
96
+ dim_pos_ohot: 15
97
+ dim_text_hidden: 512
98
+ dim_coemb_hidden: 512
99
+
100
+ t2m_motionencoder:
101
+ dim_move_hidden: 512
102
+ dim_move_latent: 512
103
+ dim_motion_hidden: 1024
104
+ dim_motion_latent: 512
105
+
106
+ bert_path: './deps/distilbert-base-uncased'
107
+ clip_path: './deps/clip-vit-large-patch14'
108
+ t5_path: './deps/sentence-t5-large'
109
+ t2m_path: './deps/t2m/'
configs/motionlcm_t2m_clt.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_t2m'
2
+ TEST_FOLDER: './experiments_t2m_test'
3
+
4
+ NAME: 'motionlcm_humanml'
5
+
6
+ SEED_VALUE: 1234
7
+
8
+ TEST:
9
+ BATCH_SIZE: 1
10
+ SPLIT: 'test'
11
+ NUM_WORKERS: 12
12
+ PERSISTENT_WORKERS: true
13
+
14
+ CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
15
+
16
+ # Testing Args
17
+ REPLICATION_TIMES: 1
18
+ DIVERSITY_TIMES: 300
19
+ DO_MM_TEST: false
20
+ MAX_NUM_SAMPLES: 1024
21
+
22
+ DATASET:
23
+ NAME: 'humanml3d'
24
+ SMPL_PATH: './deps/smpl'
25
+ WORD_VERTILIZER_PATH: './deps/glove/'
26
+ HUMANML3D:
27
+ FRAME_RATE: 20.0
28
+ UNIT_LEN: 4
29
+ ROOT: './datasets/humanml3d'
30
+ CONTROL_ARGS:
31
+ CONTROL: true
32
+ TEMPORAL: false
33
+ TRAIN_JOINTS: [0]
34
+ TEST_JOINTS: [0]
35
+ TRAIN_DENSITY: 'random'
36
+ TEST_DENSITY: 100
37
+ MEAN_STD_PATH: './datasets/humanml_spatial_norm'
38
+ SAMPLER:
39
+ MAX_LEN: 200
40
+ MIN_LEN: 40
41
+ MAX_TEXT_LEN: 20
42
+ PADDING_TO_MAX: false
43
+ WINDOW_SIZE: null
44
+
45
+ METRIC:
46
+ DIST_SYNC_ON_STEP: true
47
+ TYPE: ['TM2TMetrics', 'ControlMetrics']
48
+
49
+ model:
50
+ target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'noise_optimizer']
51
+ latent_dim: [16, 32]
52
+ guidance_scale: 'dynamic'
53
+
54
+ t2m_textencoder:
55
+ dim_word: 300
56
+ dim_pos_ohot: 15
57
+ dim_text_hidden: 512
58
+ dim_coemb_hidden: 512
59
+
60
+ t2m_motionencoder:
61
+ dim_move_hidden: 512
62
+ dim_move_latent: 512
63
+ dim_motion_hidden: 1024
64
+ dim_motion_latent: 512
65
+
66
+ bert_path: './deps/distilbert-base-uncased'
67
+ clip_path: './deps/clip-vit-large-patch14'
68
+ t5_path: './deps/sentence-t5-large'
69
+ t2m_path: './deps/t2m/'
configs/vae.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_recons'
2
+ TEST_FOLDER: './experiments_recons_test'
3
+
4
+ NAME: 'vae_humanml'
5
+
6
+ SEED_VALUE: 1234
7
+
8
+ TRAIN:
9
+ BATCH_SIZE: 128
10
+ SPLIT: 'train'
11
+ NUM_WORKERS: 8
12
+ PERSISTENT_WORKERS: true
13
+ PRETRAINED: ''
14
+
15
+ validation_steps: -1
16
+ validation_epochs: 100
17
+ checkpointing_steps: -1
18
+ checkpointing_epochs: 100
19
+ max_train_steps: -1
20
+ max_train_epochs: 6000
21
+ learning_rate: 2e-4
22
+ lr_scheduler: "cosine"
23
+ lr_warmup_steps: 1000
24
+ adam_beta1: 0.9
25
+ adam_beta2: 0.999
26
+ adam_weight_decay: 0.0
27
+ adam_epsilon: 1e-08
28
+ max_grad_norm: 1.0
29
+
30
+ VAL:
31
+ BATCH_SIZE: 32
32
+ SPLIT: 'test'
33
+ NUM_WORKERS: 12
34
+ PERSISTENT_WORKERS: true
35
+
36
+ TEST:
37
+ BATCH_SIZE: 32
38
+ SPLIT: 'test'
39
+ NUM_WORKERS: 12
40
+ PERSISTENT_WORKERS: true
41
+
42
+ CHECKPOINTS: 'experiments_recons/vae_humanml/vae_humanml.ckpt'
43
+
44
+ # Testing Args
45
+ REPLICATION_TIMES: 20
46
+ DIVERSITY_TIMES: 300
47
+ DO_MM_TEST: false
48
+
49
+ DATASET:
50
+ NAME: 'humanml3d'
51
+ SMPL_PATH: './deps/smpl'
52
+ WORD_VERTILIZER_PATH: './deps/glove/'
53
+ HUMANML3D:
54
+ FRAME_RATE: 20.0
55
+ UNIT_LEN: 4
56
+ ROOT: './datasets/humanml3d'
57
+ CONTROL_ARGS:
58
+ CONTROL: false
59
+ TEMPORAL: false
60
+ TRAIN_JOINTS: [0]
61
+ TEST_JOINTS: [0]
62
+ TRAIN_DENSITY: 'random'
63
+ TEST_DESITY: 100
64
+ MEAN_STD_PATH: './datasets/humanml_spatial_norm'
65
+ SAMPLER:
66
+ MAX_LEN: 200
67
+ MIN_LEN: 40
68
+ MAX_TEXT_LEN: 20
69
+ PADDING_TO_MAX: true
70
+ WINDOW_SIZE: 64
71
+
72
+ METRIC:
73
+ DIST_SYNC_ON_STEP: true
74
+ TYPE: ['TM2TMetrics', "PosMetrics"]
75
+
76
+ model:
77
+ target: ['motion_vae']
78
+ latent_dim: [16, 32]
79
+
80
+ # VAE Args
81
+ rec_feats_ratio: 1.0
82
+ rec_joints_ratio: 1.0
83
+ rec_velocity_ratio: 0.0
84
+ kl_ratio: 1e-4
85
+
86
+ rec_feats_loss: 'l1_smooth'
87
+ rec_joints_loss: 'l1_smooth'
88
+ rec_velocity_loss: 'l1_smooth'
89
+ mask_loss: true
90
+
91
+ t2m_textencoder:
92
+ dim_word: 300
93
+ dim_pos_ohot: 15
94
+ dim_text_hidden: 512
95
+ dim_coemb_hidden: 512
96
+
97
+ t2m_motionencoder:
98
+ dim_move_hidden: 512
99
+ dim_move_latent: 512
100
+ dim_motion_hidden: 1024
101
+ dim_motion_latent: 512
102
+
103
+ t2m_path: './deps/t2m/'
configs_v1/modules/denoiser.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ denoiser:
2
+ target: mld.models.architectures.mld_denoiser.MldDenoiser
3
+ params:
4
+ latent_dim: ${model.latent_dim}
5
+ hidden_dim: null
6
+ text_dim: 768
7
+ time_dim: 768
8
+ ff_size: 1024
9
+ num_layers: 9
10
+ num_heads: 4
11
+ dropout: 0.1
12
+ normalize_before: false
13
+ norm_eps: 1e-5
14
+ activation: 'gelu'
15
+ norm_post: true
16
+ activation_post: null
17
+ flip_sin_to_cos: true
18
+ freq_shift: 0
19
+ time_act_fn: 'silu'
20
+ time_post_act_fn: null
21
+ position_embedding: 'learned'
22
+ arch: 'trans_enc'
23
+ add_mem_pos: true
24
+ force_pre_post_proj: false
25
+ text_act_fn: 'relu'
26
+ zero_init_cond: true
27
+ controlnet_embed_dim: 256
28
+ controlnet_act_fn: null
configs_v1/modules/motion_vae.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_vae:
2
+ target: mld.models.architectures.mld_vae.MldVae
3
+ params:
4
+ nfeats: ${DATASET.NFEATS}
5
+ latent_dim: ${model.latent_dim}
6
+ hidden_dim: null
7
+ force_pre_post_proj: false
8
+ ff_size: 1024
9
+ num_layers: 9
10
+ num_heads: 4
11
+ dropout: 0.1
12
+ arch: 'encoder_decoder'
13
+ normalize_before: false
14
+ norm_eps: 1e-5
15
+ activation: 'gelu'
16
+ norm_post: true
17
+ activation_post: null
18
+ position_embedding: 'learned'
configs_v1/modules/scheduler_lcm.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scheduler:
2
+ target: diffusers.LCMScheduler
3
+ num_inference_steps: 1
4
+ params:
5
+ num_train_timesteps: 1000
6
+ beta_start: 0.00085
7
+ beta_end: 0.012
8
+ beta_schedule: 'scaled_linear'
9
+ clip_sample: false
10
+ set_alpha_to_one: false
11
+ original_inference_steps: 50
configs_v1/modules/text_encoder.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ text_encoder:
2
+ target: mld.models.architectures.mld_clip.MldTextEncoder
3
+ params:
4
+ last_hidden_state: false
5
+ modelpath: ${model.t5_path}
configs_v1/modules/traj_encoder.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ traj_encoder:
2
+ target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
3
+ params:
4
+ nfeats: ${DATASET.NJOINTS}
5
+ latent_dim: ${model.latent_dim}
6
+ hidden_dim: null
7
+ force_post_proj: false
8
+ ff_size: 1024
9
+ num_layers: 9
10
+ num_heads: 4
11
+ dropout: 0.1
12
+ normalize_before: false
13
+ norm_eps: 1e-5
14
+ activation: 'gelu'
15
+ norm_post: true
16
+ activation_post: null
17
+ position_embedding: 'learned'
configs_v1/motionlcm_control_t.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_control/temporal'
2
+ TEST_FOLDER: './experiments_control_test/temporal'
3
+
4
+ NAME: 'motionlcm_humanml'
5
+
6
+ SEED_VALUE: 1234
7
+
8
+ TRAIN:
9
+ DATASET: 'humanml3d'
10
+ BATCH_SIZE: 128
11
+ SPLIT: 'train'
12
+ NUM_WORKERS: 8
13
+ PERSISTENT_WORKERS: true
14
+
15
+ PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml_v1.ckpt'
16
+
17
+ validation_steps: -1
18
+ validation_epochs: 50
19
+ checkpointing_steps: -1
20
+ checkpointing_epochs: 50
21
+ max_train_steps: -1
22
+ max_train_epochs: 1000
23
+ learning_rate: 1e-4
24
+ learning_rate_spatial: 1e-4
25
+ lr_scheduler: "cosine"
26
+ lr_warmup_steps: 1000
27
+ adam_beta1: 0.9
28
+ adam_beta2: 0.999
29
+ adam_weight_decay: 0.0
30
+ adam_epsilon: 1e-08
31
+ max_grad_norm: 1.0
32
+
33
+ VAL:
34
+ DATASET: 'humanml3d'
35
+ BATCH_SIZE: 32
36
+ SPLIT: 'test'
37
+ NUM_WORKERS: 12
38
+ PERSISTENT_WORKERS: true
39
+
40
+ TEST:
41
+ DATASET: 'humanml3d'
42
+ BATCH_SIZE: 32
43
+ SPLIT: 'test'
44
+ NUM_WORKERS: 12
45
+ PERSISTENT_WORKERS: true
46
+
47
+ CHECKPOINTS: 'experiments_control/temporal/motionlcm_humanml/motionlcm_humanml_t_v1.ckpt'
48
+
49
+ # Testing Args
50
+ REPLICATION_TIMES: 20
51
+ MM_NUM_SAMPLES: 100
52
+ MM_NUM_REPEATS: 30
53
+ MM_NUM_TIMES: 10
54
+ DIVERSITY_TIMES: 300
55
+ DO_MM_TEST: false
56
+
57
+ DATASET:
58
+ NAME: 'humanml3d'
59
+ SMPL_PATH: './deps/smpl'
60
+ WORD_VERTILIZER_PATH: './deps/glove/'
61
+ HUMANML3D:
62
+ FRAME_RATE: 20.0
63
+ UNIT_LEN: 4
64
+ ROOT: './datasets/humanml3d'
65
+ CONTROL_ARGS:
66
+ CONTROL: true
67
+ TEMPORAL: true
68
+ TRAIN_JOINTS: [0, 10, 11, 15, 20, 21]
69
+ TEST_JOINTS: [0, 10, 11, 15, 20, 21]
70
+ TRAIN_DENSITY: [25, 25]
71
+ TEST_DENSITY: 25
72
+ MEAN_STD_PATH: './datasets/humanml_spatial_norm'
73
+ SAMPLER:
74
+ MAX_LEN: 200
75
+ MIN_LEN: 40
76
+ MAX_TEXT_LEN: 20
77
+ PADDING_TO_MAX: false
78
+ WINDOW_SIZE: null
79
+
80
+ METRIC:
81
+ DIST_SYNC_ON_STEP: true
82
+ TYPE: ['TM2TMetrics', 'ControlMetrics']
83
+
84
+ model:
85
+ target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm', 'traj_encoder']
86
+ latent_dim: [1, 256]
87
+ guidance_scale: 7.5
88
+
89
+ # ControlNet Args
90
+ is_controlnet: true
91
+ vaeloss: true
92
+ vaeloss_type: 'sum'
93
+ cond_ratio: 1.0
94
+ control_loss_func: 'l2'
95
+ use_3d: false
96
+ lcm_w_min_nax: null
97
+ lcm_num_ddim_timesteps: null
98
+
99
+ t2m_textencoder:
100
+ dim_word: 300
101
+ dim_pos_ohot: 15
102
+ dim_text_hidden: 512
103
+ dim_coemb_hidden: 512
104
+
105
+ t2m_motionencoder:
106
+ dim_move_hidden: 512
107
+ dim_move_latent: 512
108
+ dim_motion_hidden: 1024
109
+ dim_motion_latent: 512
110
+
111
+ bert_path: './deps/distilbert-base-uncased'
112
+ clip_path: './deps/clip-vit-large-patch14'
113
+ t5_path: './deps/sentence-t5-large'
114
+ t2m_path: './deps/t2m/'
configs_v1/motionlcm_t2m.yaml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_t2m'
2
+ TEST_FOLDER: './experiments_t2m_test'
3
+
4
+ NAME: 'motionlcm_humanml'
5
+
6
+ SEED_VALUE: 1234
7
+
8
+ TRAIN:
9
+ BATCH_SIZE: 256
10
+ SPLIT: 'train'
11
+ NUM_WORKERS: 8
12
+ PERSISTENT_WORKERS: true
13
+
14
+ PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml_v1.ckpt'
15
+
16
+ validation_steps: -1
17
+ validation_epochs: 50
18
+ checkpointing_steps: -1
19
+ checkpointing_epochs: 50
20
+ max_train_steps: -1
21
+ max_train_epochs: 1000
22
+ learning_rate: 2e-4
23
+ lr_scheduler: "cosine"
24
+ lr_warmup_steps: 1000
25
+ adam_beta1: 0.9
26
+ adam_beta2: 0.999
27
+ adam_weight_decay: 0.0
28
+ adam_epsilon: 1e-08
29
+ max_grad_norm: 1.0
30
+
31
+ # Latent Consistency Distillation Specific Arguments
32
+ w_min: 5.0
33
+ w_max: 15.0
34
+ num_ddim_timesteps: 50
35
+ loss_type: 'huber'
36
+ huber_c: 0.001
37
+ unet_time_cond_proj_dim: 256
38
+ ema_decay: 0.95
39
+
40
+ VAL:
41
+ BATCH_SIZE: 32
42
+ SPLIT: 'test'
43
+ NUM_WORKERS: 12
44
+ PERSISTENT_WORKERS: true
45
+
46
+ TEST:
47
+ BATCH_SIZE: 32
48
+ SPLIT: 'test'
49
+ NUM_WORKERS: 12
50
+ PERSISTENT_WORKERS: true
51
+
52
+ CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml_v1.ckpt'
53
+
54
+ # Testing Args
55
+ REPLICATION_TIMES: 20
56
+ MM_NUM_SAMPLES: 100
57
+ MM_NUM_REPEATS: 30
58
+ MM_NUM_TIMES: 10
59
+ DIVERSITY_TIMES: 300
60
+ DO_MM_TEST: true
61
+
62
+ DATASET:
63
+ NAME: 'humanml3d'
64
+ SMPL_PATH: './deps/smpl'
65
+ WORD_VERTILIZER_PATH: './deps/glove/'
66
+ HUMANML3D:
67
+ FRAME_RATE: 20.0
68
+ UNIT_LEN: 4
69
+ ROOT: './datasets/humanml3d'
70
+ CONTROL_ARGS:
71
+ CONTROL: false
72
+ TEMPORAL: false
73
+ TRAIN_JOINTS: [0]
74
+ TEST_JOINTS: [0]
75
+ TRAIN_DENSITY: 'random'
76
+ TEST_DENSITY: 100
77
+ MEAN_STD_PATH: './datasets/humanml_spatial_norm'
78
+ SAMPLER:
79
+ MAX_LEN: 200
80
+ MIN_LEN: 40
81
+ MAX_TEXT_LEN: 20
82
+ PADDING_TO_MAX: false
83
+ WINDOW_SIZE: null
84
+
85
+ METRIC:
86
+ DIST_SYNC_ON_STEP: true
87
+ TYPE: ['TM2TMetrics']
88
+
89
+ model:
90
+ target: ['motion_vae', 'text_encoder', 'denoiser', 'scheduler_lcm']
91
+ latent_dim: [1, 256]
92
+ guidance_scale: 7.5
93
+
94
+ t2m_textencoder:
95
+ dim_word: 300
96
+ dim_pos_ohot: 15
97
+ dim_text_hidden: 512
98
+ dim_coemb_hidden: 512
99
+
100
+ t2m_motionencoder:
101
+ dim_move_hidden: 512
102
+ dim_move_latent: 512
103
+ dim_motion_hidden: 1024
104
+ dim_motion_latent: 512
105
+
106
+ bert_path: './deps/distilbert-base-uncased'
107
+ clip_path: './deps/clip-vit-large-patch14'
108
+ t5_path: './deps/sentence-t5-large'
109
+ t2m_path: './deps/t2m/'
demo.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import sys
4
+ import datetime
5
+ import logging
6
+ import os.path as osp
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+ import torch
11
+
12
+ from mld.config import parse_args
13
+ from mld.data.get_data import get_dataset
14
+ from mld.models.modeltype.mld import MLD
15
+ from mld.models.modeltype.vae import VAE
16
+ from mld.utils.utils import set_seed, move_batch_to_device
17
+ from mld.data.humanml.utils.plot_script import plot_3d_motion
18
+ from mld.utils.temos_utils import remove_padding
19
+
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+
22
+
23
+ def load_example_hint_input(text_path: str) -> tuple:
24
+ with open(text_path, "r") as f:
25
+ lines = f.readlines()
26
+
27
+ n_frames, control_type_ids, control_hint_ids = [], [], []
28
+ for line in lines:
29
+ s = line.strip()
30
+ n_frame, control_type_id, control_hint_id = s.split(' ')
31
+ n_frames.append(int(n_frame))
32
+ control_type_ids.append(int(control_type_id))
33
+ control_hint_ids.append(int(control_hint_id))
34
+
35
+ return n_frames, control_type_ids, control_hint_ids
36
+
37
+
38
+ def load_example_input(text_path: str) -> tuple:
39
+ with open(text_path, "r") as f:
40
+ lines = f.readlines()
41
+
42
+ texts, lens = [], []
43
+ for line in lines:
44
+ s = line.strip()
45
+ s_l = s.split(" ")[0]
46
+ s_t = s[(len(s_l) + 1):]
47
+ lens.append(int(s_l))
48
+ texts.append(s_t)
49
+ return texts, lens
50
+
51
+
52
+ def main():
53
+ cfg = parse_args()
54
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
55
+ set_seed(cfg.SEED_VALUE)
56
+
57
+ name_time_str = osp.join(cfg.NAME, "demo_" + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
58
+ cfg.output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
59
+ vis_dir = osp.join(cfg.output_dir, 'samples')
60
+ os.makedirs(cfg.output_dir, exist_ok=False)
61
+ os.makedirs(vis_dir, exist_ok=False)
62
+
63
+ steam_handler = logging.StreamHandler(sys.stdout)
64
+ file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log'))
65
+ logging.basicConfig(level=logging.INFO,
66
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
67
+ datefmt="%m/%d/%Y %H:%M:%S",
68
+ handlers=[steam_handler, file_handler])
69
+ logger = logging.getLogger(__name__)
70
+
71
+ OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml'))
72
+
73
+ state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
74
+ logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
75
+
76
+ # Step 1: Check if the checkpoint is VAE-based.
77
+ is_vae = False
78
+ vae_key = 'vae.skel_embedding.weight'
79
+ if vae_key in state_dict:
80
+ is_vae = True
81
+ logger.info(f'Is VAE: {is_vae}')
82
+
83
+ # Step 2: Check if the checkpoint is MLD-based.
84
+ is_mld = False
85
+ mld_key = 'denoiser.time_embedding.linear_1.weight'
86
+ if mld_key in state_dict:
87
+ is_mld = True
88
+ logger.info(f'Is MLD: {is_mld}')
89
+
90
+ # Step 3: Check if the checkpoint is LCM-based.
91
+ is_lcm = False
92
+ lcm_key = 'denoiser.time_embedding.cond_proj.weight' # unique key for CFG
93
+ if lcm_key in state_dict:
94
+ is_lcm = True
95
+ time_cond_proj_dim = state_dict[lcm_key].shape[1]
96
+ cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
97
+ logger.info(f'Is LCM: {is_lcm}')
98
+
99
+ # Step 4: Check if the checkpoint is Controlnet-based.
100
+ cn_key = "controlnet.controlnet_cond_embedding.0.weight"
101
+ is_controlnet = True if cn_key in state_dict else False
102
+ cfg.model.is_controlnet = is_controlnet
103
+ logger.info(f'Is Controlnet: {is_controlnet}')
104
+
105
+ if is_mld or is_lcm or is_controlnet:
106
+ target_model_class = MLD
107
+ else:
108
+ target_model_class = VAE
109
+
110
+ if cfg.optimize:
111
+ assert cfg.model.get('noise_optimizer') is not None
112
+ cfg.model.noise_optimizer.params.optimize = True
113
+ logger.info('Optimization enabled. Set the batch size to 1.')
114
+ logger.info(f'Original batch size: {cfg.TEST.BATCH_SIZE}')
115
+ cfg.TEST.BATCH_SIZE = 1
116
+
117
+ dataset = get_dataset(cfg)
118
+ model = target_model_class(cfg, dataset)
119
+ model.to(device)
120
+ model.eval()
121
+ model.requires_grad_(False)
122
+ logger.info(model.load_state_dict(state_dict))
123
+
124
+ FPS = eval(f"cfg.DATASET.{cfg.DATASET.NAME.upper()}.FRAME_RATE")
125
+
126
+ if cfg.example is not None and not is_controlnet:
127
+ text, length = load_example_input(cfg.example)
128
+ for t, l in zip(text, length):
129
+ logger.info(f"{l}: {t}")
130
+
131
+ batch = {"length": length, "text": text}
132
+
133
+ for rep_i in range(cfg.replication):
134
+ with torch.no_grad():
135
+ joints = model(batch)[0]
136
+
137
+ num_samples = len(joints)
138
+ for i in range(num_samples):
139
+ res = dict()
140
+ pkl_path = osp.join(vis_dir, f"sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
141
+ res['joints'] = joints[i].detach().cpu().numpy()
142
+ res['text'] = text[i]
143
+ res['length'] = length[i]
144
+ res['hint'] = None
145
+ with open(pkl_path, 'wb') as f:
146
+ pickle.dump(res, f)
147
+ logger.info(f"Motions are generated here:\n{pkl_path}")
148
+
149
+ if not cfg.no_plot:
150
+ plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), text[i], fps=FPS)
151
+
152
+ else:
153
+ test_dataloader = dataset.test_dataloader()
154
+ for rep_i in range(cfg.replication):
155
+ for batch_id, batch in enumerate(test_dataloader):
156
+ batch = move_batch_to_device(batch, device)
157
+ with torch.no_grad():
158
+ joints, joints_ref = model(batch)
159
+
160
+ num_samples = len(joints)
161
+ text = batch['text']
162
+ length = batch['length']
163
+ if 'hint' in batch:
164
+ hint, hint_mask = batch['hint'], batch['hint_mask']
165
+ hint = dataset.denorm_spatial(hint) * hint_mask
166
+ hint = remove_padding(hint, lengths=length)
167
+ else:
168
+ hint = None
169
+
170
+ for i in range(num_samples):
171
+ res = dict()
172
+ pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
173
+ res['joints'] = joints[i].detach().cpu().numpy()
174
+ res['text'] = text[i]
175
+ res['length'] = length[i]
176
+ res['hint'] = hint[i].detach().cpu().numpy() if hint is not None else None
177
+ with open(pkl_path, 'wb') as f:
178
+ pickle.dump(res, f)
179
+ logger.info(f"Motions are generated here:\n{pkl_path}")
180
+
181
+ if not cfg.no_plot:
182
+ plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(),
183
+ text[i], fps=FPS, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
184
+
185
+ if rep_i == 0:
186
+ res['joints'] = joints_ref[i].detach().cpu().numpy()
187
+ with open(pkl_path.replace('.pkl', '_ref.pkl'), 'wb') as f:
188
+ pickle.dump(res, f)
189
+ logger.info(f"Motions are generated here:\n{pkl_path.replace('.pkl', '_ref.pkl')}")
190
+ if not cfg.no_plot:
191
+ plot_3d_motion(pkl_path.replace('.pkl', '_ref.mp4'), joints_ref[i].detach().cpu().numpy(),
192
+ text[i], fps=FPS, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
fit.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # borrow from optimization https://github.com/wangsen1312/joints2smpl
2
+ import os
3
+ import argparse
4
+ import pickle
5
+
6
+ import h5py
7
+ import natsort
8
+ import smplx
9
+
10
+ import torch
11
+
12
+ from mld.transforms.joints2rots import config
13
+ from mld.transforms.joints2rots.smplify import SMPLify3D
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--pkl", type=str, default=None, help="pkl motion file")
17
+ parser.add_argument("--dir", type=str, default=None, help="pkl motion folder")
18
+ parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters")
19
+ parser.add_argument("--cuda", type=bool, default=True, help="enables cuda")
20
+ parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids")
21
+ parser.add_argument("--num_joints", type=int, default=22, help="joint number")
22
+ parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence")
23
+ parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not")
24
+ opt = parser.parse_args()
25
+ print(opt)
26
+
27
+ if opt.pkl:
28
+ paths = [opt.pkl]
29
+ elif opt.dir:
30
+ paths = []
31
+ file_list = natsort.natsorted(os.listdir(opt.dir))
32
+ for item in file_list:
33
+ if item.endswith('.pkl') and not item.endswith("_mesh.pkl"):
34
+ paths.append(os.path.join(opt.dir, item))
35
+ else:
36
+ raise ValueError(f'{opt.pkl} and {opt.dir} are both None!')
37
+
38
+ for path in paths:
39
+ # load joints
40
+ if os.path.exists(path.replace('.pkl', '_mesh.pkl')):
41
+ print(f"{path} is rendered! skip!")
42
+ continue
43
+
44
+ with open(path, 'rb') as f:
45
+ data = pickle.load(f)
46
+
47
+ joints = data['joints']
48
+ # load predefined something
49
+ device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu")
50
+ print(config.SMPL_MODEL_DIR)
51
+ smplxmodel = smplx.create(
52
+ config.SMPL_MODEL_DIR,
53
+ model_type="smpl",
54
+ gender="neutral",
55
+ ext="pkl",
56
+ batch_size=joints.shape[0],
57
+ ).to(device)
58
+
59
+ # load the mean pose as original
60
+ smpl_mean_file = config.SMPL_MEAN_FILE
61
+
62
+ file = h5py.File(smpl_mean_file, "r")
63
+ init_mean_pose = (
64
+ torch.from_numpy(file["pose"][:])
65
+ .unsqueeze(0).repeat(joints.shape[0], 1)
66
+ .float()
67
+ .to(device)
68
+ )
69
+ init_mean_shape = (
70
+ torch.from_numpy(file["shape"][:])
71
+ .unsqueeze(0).repeat(joints.shape[0], 1)
72
+ .float()
73
+ .to(device)
74
+ )
75
+ cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)
76
+
77
+ # initialize SMPLify
78
+ smplify = SMPLify3D(
79
+ smplxmodel=smplxmodel,
80
+ batch_size=joints.shape[0],
81
+ joints_category=opt.joint_category,
82
+ num_iters=opt.num_smplify_iters,
83
+ device=device,
84
+ )
85
+ print("initialize SMPLify3D done!")
86
+
87
+ print("Start SMPLify!")
88
+ keypoints_3d = torch.Tensor(joints).to(device).float()
89
+
90
+ if opt.joint_category == "AMASS":
91
+ confidence_input = torch.ones(opt.num_joints)
92
+ # make sure the foot and ankle
93
+ if opt.fix_foot:
94
+ confidence_input[7] = 1.5
95
+ confidence_input[8] = 1.5
96
+ confidence_input[10] = 1.5
97
+ confidence_input[11] = 1.5
98
+ else:
99
+ print("Such category not settle down!")
100
+
101
+ # ----- from initial to fitting -------
102
+ (
103
+ new_opt_vertices,
104
+ new_opt_joints,
105
+ new_opt_pose,
106
+ new_opt_betas,
107
+ new_opt_cam_t,
108
+ new_opt_joint_loss,
109
+ ) = smplify(
110
+ init_mean_pose.detach(),
111
+ init_mean_shape.detach(),
112
+ cam_trans_zero.detach(),
113
+ keypoints_3d,
114
+ conf_3d=confidence_input.to(device)
115
+ )
116
+
117
+ # fix shape
118
+ betas = torch.zeros_like(new_opt_betas)
119
+ root = keypoints_3d[:, 0, :]
120
+
121
+ output = smplxmodel(
122
+ betas=betas,
123
+ global_orient=new_opt_pose[:, :3],
124
+ body_pose=new_opt_pose[:, 3:],
125
+ transl=root,
126
+ return_verts=True
127
+ )
128
+ vertices = output.vertices.detach().cpu().numpy()
129
+ floor_height = vertices[..., 1].min()
130
+ vertices[..., 1] -= floor_height
131
+ data['vertices'] = vertices
132
+
133
+ save_file = path.replace('.pkl', '_mesh.pkl')
134
+ with open(save_file, 'wb') as f:
135
+ pickle.dump(data, f)
136
+ print(f'vertices saved in {save_file}')
mld/__init__.py ADDED
File without changes
mld/config.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ from typing import Type, TypeVar
4
+ from argparse import ArgumentParser
5
+
6
+ from omegaconf import OmegaConf, DictConfig
7
+
8
+
9
+ def get_module_config(cfg_model: DictConfig, paths: list[str], cfg_root: str) -> DictConfig:
10
+ files = [os.path.join(cfg_root, 'modules', p+'.yaml') for p in paths]
11
+ for file in files:
12
+ assert os.path.exists(file), f'{file} is not exists.'
13
+ with open(file, 'r') as f:
14
+ cfg_model.merge_with(OmegaConf.load(f))
15
+ return cfg_model
16
+
17
+
18
+ def get_obj_from_str(string: str, reload: bool = False) -> Type:
19
+ module, cls = string.rsplit(".", 1)
20
+ if reload:
21
+ module_imp = importlib.import_module(module)
22
+ importlib.reload(module_imp)
23
+ return getattr(importlib.import_module(module, package=None), cls)
24
+
25
+
26
+ def instantiate_from_config(config: DictConfig) -> TypeVar:
27
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
28
+
29
+
30
+ def parse_args() -> DictConfig:
31
+ parser = ArgumentParser()
32
+ parser.add_argument("--cfg", type=str, required=True, help="The main config file")
33
+ parser.add_argument('--example', type=str, required=False, help="The input texts and lengths with txt format")
34
+ parser.add_argument('--example_hint', type=str, required=False, help="The input hint ids and lengths with txt format")
35
+ parser.add_argument('--no-plot', action="store_true", required=False, help="Whether to plot the skeleton-based motion")
36
+ parser.add_argument('--replication', type=int, default=1, help="The number of replications of sampling")
37
+ parser.add_argument('--vis', type=str, default="tb", choices=['tb', 'swanlab'], help="The visualization backends: tensorboard or swanlab")
38
+ parser.add_argument('--optimize', action='store_true', help="Enable optimization for motion control")
39
+ args = parser.parse_args()
40
+
41
+ cfg = OmegaConf.load(args.cfg)
42
+ cfg_root = os.path.dirname(args.cfg)
43
+ cfg_model = get_module_config(cfg.model, cfg.model.target, cfg_root)
44
+ cfg = OmegaConf.merge(cfg, cfg_model)
45
+
46
+ cfg.example = args.example
47
+ cfg.example_hint = args.example_hint
48
+ cfg.no_plot = args.no_plot
49
+ cfg.replication = args.replication
50
+ cfg.vis = args.vis
51
+ cfg.optimize = args.optimize
52
+ return cfg
mld/data/__init__.py ADDED
File without changes
mld/data/base.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from os.path import join as pjoin
3
+ from typing import Any, Callable
4
+
5
+ from torch.utils.data import DataLoader
6
+
7
+
8
+ class BaseDataModule:
9
+ def __init__(self, collate_fn: Callable) -> None:
10
+ super(BaseDataModule, self).__init__()
11
+ self.collate_fn = collate_fn
12
+ self.is_mm = False
13
+
14
+ def get_sample_set(self, overrides: dict) -> Any:
15
+ sample_params = copy.deepcopy(self.hparams)
16
+ sample_params.update(overrides)
17
+ split_file = pjoin(
18
+ eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"),
19
+ self.cfg.TEST.SPLIT + ".txt"
20
+ )
21
+ return self.Dataset(split_file=split_file, **sample_params)
22
+
23
+ def __getattr__(self, item: str) -> Any:
24
+ if item.endswith("_dataset") and not item.startswith("_"):
25
+ subset = item[:-len("_dataset")].upper()
26
+ item_c = "_" + item
27
+ if item_c not in self.__dict__:
28
+ split_file = pjoin(
29
+ eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"),
30
+ eval(f"self.cfg.{subset}.SPLIT") + ".txt"
31
+ )
32
+ self.__dict__[item_c] = self.Dataset(split_file=split_file, **self.hparams)
33
+ return getattr(self, item_c)
34
+ classname = self.__class__.__name__
35
+ raise AttributeError(f"'{classname}' object has no attribute '{item}'")
36
+
37
+ def get_dataloader_options(self, stage: str) -> dict:
38
+ stage_args = eval(f"self.cfg.{stage.upper()}")
39
+ dataloader_options = {
40
+ "batch_size": stage_args.BATCH_SIZE,
41
+ "num_workers": stage_args.NUM_WORKERS,
42
+ "collate_fn": self.collate_fn,
43
+ "persistent_workers": stage_args.PERSISTENT_WORKERS,
44
+ }
45
+ return dataloader_options
46
+
47
+ def train_dataloader(self) -> DataLoader:
48
+ dataloader_options = self.get_dataloader_options('TRAIN')
49
+ return DataLoader(self.train_dataset, shuffle=True, **dataloader_options)
50
+
51
+ def val_dataloader(self) -> DataLoader:
52
+ dataloader_options = self.get_dataloader_options('VAL')
53
+ return DataLoader(self.val_dataset, shuffle=False, **dataloader_options)
54
+
55
+ def test_dataloader(self) -> DataLoader:
56
+ dataloader_options = self.get_dataloader_options('TEST')
57
+ dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
58
+ return DataLoader(self.test_dataset, shuffle=False, **dataloader_options)
mld/data/data.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Callable, Optional
3
+
4
+ import numpy as np
5
+ from omegaconf import DictConfig
6
+
7
+ import torch
8
+
9
+ from .base import BaseDataModule
10
+ from .humanml.dataset import Text2MotionDataset, MotionDataset
11
+ from .humanml.scripts.motion_process import recover_from_ric
12
+
13
+
14
+ # (nfeats, njoints)
15
+ dataset_map = {'humanml3d': (263, 22), 'kit': (251, 21)}
16
+
17
+
18
+ class DataModule(BaseDataModule):
19
+
20
+ def __init__(self,
21
+ name: str,
22
+ cfg: DictConfig,
23
+ motion_only: bool,
24
+ collate_fn: Optional[Callable] = None,
25
+ **kwargs) -> None:
26
+ super().__init__(collate_fn=collate_fn)
27
+ self.cfg = cfg
28
+ self.name = name
29
+ self.nfeats, self.njoints = dataset_map[name]
30
+ self.hparams = copy.deepcopy({**kwargs, 'njoints': self.njoints})
31
+ self.Dataset = MotionDataset if motion_only else Text2MotionDataset
32
+ sample_overrides = {"tiny": True, "progress_bar": False}
33
+ self._sample_set = self.get_sample_set(overrides=sample_overrides)
34
+
35
+ def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
36
+ raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
37
+ raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
38
+ hint = hint * raw_std + raw_mean
39
+ return hint
40
+
41
+ def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
42
+ raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
43
+ raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
44
+ hint = (hint - raw_mean) / raw_std
45
+ return hint
46
+
47
+ def feats2joints(self, features: torch.Tensor) -> torch.Tensor:
48
+ mean = torch.tensor(self.hparams['mean']).to(features)
49
+ std = torch.tensor(self.hparams['std']).to(features)
50
+ features = features * std + mean
51
+ return recover_from_ric(features, self.njoints)
52
+
53
+ def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor:
54
+ # renorm to t2m norms for using t2m evaluators
55
+ ori_mean = torch.tensor(self.hparams['mean']).to(features)
56
+ ori_std = torch.tensor(self.hparams['std']).to(features)
57
+ eval_mean = torch.tensor(self.hparams['mean_eval']).to(features)
58
+ eval_std = torch.tensor(self.hparams['std_eval']).to(features)
59
+ features = features * ori_std + ori_mean
60
+ features = (features - eval_mean) / eval_std
61
+ return features
62
+
63
+ def mm_mode(self, mm_on: bool = True) -> None:
64
+ if mm_on:
65
+ self.is_mm = True
66
+ self.name_list = self.test_dataset.name_list
67
+ self.mm_list = np.random.choice(self.name_list,
68
+ self.cfg.TEST.MM_NUM_SAMPLES,
69
+ replace=False)
70
+ self.test_dataset.name_list = self.mm_list
71
+ else:
72
+ self.is_mm = False
73
+ self.test_dataset.name_list = self.name_list
mld/data/get_data.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from os.path import join as pjoin
3
+
4
+ import numpy as np
5
+
6
+ from omegaconf import DictConfig
7
+
8
+ from .data import DataModule
9
+ from .base import BaseDataModule
10
+ from .utils import mld_collate, mld_collate_motion_only
11
+ from .humanml.utils.word_vectorizer import WordVectorizer
12
+
13
+
14
+ def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]:
15
+ name = "t2m" if dataset_name == "humanml3d" else dataset_name
16
+ assert name in ["t2m", "kit"]
17
+ if phase in ["val"]:
18
+ if name == 't2m':
19
+ data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta")
20
+ elif name == 'kit':
21
+ data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta")
22
+ else:
23
+ raise ValueError("Only support t2m and kit")
24
+ mean = np.load(pjoin(data_root, "mean.npy"))
25
+ std = np.load(pjoin(data_root, "std.npy"))
26
+ else:
27
+ data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
28
+ mean = np.load(pjoin(data_root, "Mean.npy"))
29
+ std = np.load(pjoin(data_root, "Std.npy"))
30
+
31
+ return mean, std
32
+
33
+
34
+ def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]:
35
+ if dataset_name.lower() in ["humanml3d", "kit"]:
36
+ return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab")
37
+ else:
38
+ raise ValueError("Only support WordVectorizer for HumanML3D and KIT")
39
+
40
+
41
+ dataset_module_map = {"humanml3d": DataModule, "kit": DataModule}
42
+ motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"}
43
+
44
+
45
+ def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule:
46
+ dataset_name = cfg.DATASET.NAME
47
+ if dataset_name.lower() in ["humanml3d", "kit"]:
48
+ data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
49
+ mean, std = get_mean_std('train', cfg, dataset_name)
50
+ mean_eval, std_eval = get_mean_std("val", cfg, dataset_name)
51
+ wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name)
52
+ collate_fn = mld_collate_motion_only if motion_only else mld_collate
53
+ dataset = dataset_module_map[dataset_name.lower()](
54
+ name=dataset_name.lower(),
55
+ cfg=cfg,
56
+ motion_only=motion_only,
57
+ collate_fn=collate_fn,
58
+ mean=mean,
59
+ std=std,
60
+ mean_eval=mean_eval,
61
+ std_eval=std_eval,
62
+ w_vectorizer=wordVectorizer,
63
+ text_dir=pjoin(data_root, "texts"),
64
+ motion_dir=pjoin(data_root, motion_subdir[dataset_name]),
65
+ max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN,
66
+ min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN,
67
+ max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN,
68
+ unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"),
69
+ fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"),
70
+ padding_to_max=cfg.DATASET.PADDING_TO_MAX,
71
+ window_size=cfg.DATASET.WINDOW_SIZE,
72
+ control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS"))
73
+
74
+ cfg.DATASET.NFEATS = dataset.nfeats
75
+ cfg.DATASET.NJOINTS = dataset.njoints
76
+ return dataset
77
+
78
+ elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]:
79
+ raise NotImplementedError
mld/data/humanml/__init__.py ADDED
File without changes
mld/data/humanml/common/quaternion.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def qinv(q: torch.Tensor) -> torch.Tensor:
5
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
6
+ mask = torch.ones_like(q)
7
+ mask[..., 1:] = -mask[..., 1:]
8
+ return q * mask
9
+
10
+
11
+ def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
12
+ """
13
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
14
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
15
+ where * denotes any number of dimensions.
16
+ Returns a tensor of shape (*, 3).
17
+ """
18
+ assert q.shape[-1] == 4
19
+ assert v.shape[-1] == 3
20
+ assert q.shape[:-1] == v.shape[:-1]
21
+
22
+ original_shape = list(v.shape)
23
+ q = q.contiguous().view(-1, 4)
24
+ v = v.contiguous().view(-1, 3)
25
+
26
+ qvec = q[:, 1:]
27
+ uv = torch.cross(qvec, v, dim=1)
28
+ uuv = torch.cross(qvec, uv, dim=1)
29
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
mld/data/humanml/dataset.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import logging
4
+ import codecs as cs
5
+ from os.path import join as pjoin
6
+
7
+ import numpy as np
8
+ from rich.progress import track
9
+
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+
13
+ from .scripts.motion_process import recover_from_ric
14
+ from .utils.word_vectorizer import WordVectorizer
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class MotionDataset(Dataset):
20
+ def __init__(self, mean: np.ndarray, std: np.ndarray,
21
+ split_file: str, motion_dir: str, window_size: int,
22
+ tiny: bool = False, progress_bar: bool = True, **kwargs) -> None:
23
+ self.data = []
24
+ self.lengths = []
25
+ id_list = []
26
+ with cs.open(split_file, "r") as f:
27
+ for line in f.readlines():
28
+ id_list.append(line.strip())
29
+
30
+ maxdata = 10 if tiny else 1e10
31
+ if progress_bar:
32
+ enumerator = enumerate(
33
+ track(
34
+ id_list,
35
+ f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}",
36
+ ))
37
+ else:
38
+ enumerator = enumerate(id_list)
39
+
40
+ count = 0
41
+ for i, name in enumerator:
42
+ if count > maxdata:
43
+ break
44
+ try:
45
+ motion = np.load(pjoin(motion_dir, name + '.npy'))
46
+ if motion.shape[0] < window_size:
47
+ continue
48
+ self.lengths.append(motion.shape[0] - window_size)
49
+ self.data.append(motion)
50
+ except Exception as e:
51
+ print(e)
52
+ pass
53
+
54
+ self.cumsum = np.cumsum([0] + self.lengths)
55
+ if not tiny:
56
+ logger.info("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
57
+
58
+ self.mean = mean
59
+ self.std = std
60
+ self.window_size = window_size
61
+
62
+ def __len__(self) -> int:
63
+ return self.cumsum[-1]
64
+
65
+ def __getitem__(self, item: int) -> tuple:
66
+ if item != 0:
67
+ motion_id = np.searchsorted(self.cumsum, item) - 1
68
+ idx = item - self.cumsum[motion_id] - 1
69
+ else:
70
+ motion_id = 0
71
+ idx = 0
72
+ motion = self.data[motion_id][idx:idx + self.window_size]
73
+ "Z Normalization"
74
+ motion = (motion - self.mean) / self.std
75
+ return motion, self.window_size
76
+
77
+
78
+ class Text2MotionDataset(Dataset):
79
+
80
+ def __init__(
81
+ self,
82
+ mean: np.ndarray,
83
+ std: np.ndarray,
84
+ split_file: str,
85
+ w_vectorizer: WordVectorizer,
86
+ max_motion_length: int,
87
+ min_motion_length: int,
88
+ max_text_len: int,
89
+ unit_length: int,
90
+ motion_dir: str,
91
+ text_dir: str,
92
+ fps: int,
93
+ padding_to_max: bool,
94
+ njoints: int,
95
+ tiny: bool = False,
96
+ progress_bar: bool = True,
97
+ **kwargs,
98
+ ) -> None:
99
+ self.w_vectorizer = w_vectorizer
100
+ self.max_motion_length = max_motion_length
101
+ self.min_motion_length = min_motion_length
102
+ self.max_text_len = max_text_len
103
+ self.unit_length = unit_length
104
+ self.padding_to_max = padding_to_max
105
+ self.njoints = njoints
106
+
107
+ data_dict = {}
108
+ id_list = []
109
+ with cs.open(split_file, "r") as f:
110
+ for line in f.readlines():
111
+ id_list.append(line.strip())
112
+ self.id_list = id_list
113
+
114
+ maxdata = 10 if tiny else 1e10
115
+ if progress_bar:
116
+ enumerator = enumerate(
117
+ track(
118
+ id_list,
119
+ f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}",
120
+ ))
121
+ else:
122
+ enumerator = enumerate(id_list)
123
+ count = 0
124
+ bad_count = 0
125
+ new_name_list = []
126
+ length_list = []
127
+ for i, name in enumerator:
128
+ if count > maxdata:
129
+ break
130
+ try:
131
+ motion = np.load(pjoin(motion_dir, name + ".npy"))
132
+ if len(motion) < self.min_motion_length or len(motion) >= self.max_motion_length:
133
+ bad_count += 1
134
+ continue
135
+ text_data = []
136
+ flag = False
137
+ with cs.open(pjoin(text_dir, name + ".txt")) as f:
138
+ for line in f.readlines():
139
+ text_dict = {}
140
+ line_split = line.strip().split("#")
141
+ caption = line_split[0]
142
+ tokens = line_split[1].split(" ")
143
+ f_tag = float(line_split[2])
144
+ to_tag = float(line_split[3])
145
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
146
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
147
+
148
+ text_dict["caption"] = caption
149
+ text_dict["tokens"] = tokens
150
+ if f_tag == 0.0 and to_tag == 0.0:
151
+ flag = True
152
+ text_data.append(text_dict)
153
+ else:
154
+ try:
155
+ n_motion = motion[int(f_tag * fps): int(to_tag * fps)]
156
+ if (len(n_motion)) < self.min_motion_length or \
157
+ len(n_motion) >= self.max_motion_length:
158
+ continue
159
+ new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
160
+ while new_name in data_dict:
161
+ new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
162
+ data_dict[new_name] = {
163
+ "motion": n_motion,
164
+ "length": len(n_motion),
165
+ "text": [text_dict],
166
+ }
167
+ new_name_list.append(new_name)
168
+ length_list.append(len(n_motion))
169
+ except ValueError:
170
+ print(line_split)
171
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
172
+
173
+ if flag:
174
+ data_dict[name] = {
175
+ "motion": motion,
176
+ "length": len(motion),
177
+ "text": text_data,
178
+ }
179
+ new_name_list.append(name)
180
+ length_list.append(len(motion))
181
+ count += 1
182
+ except Exception as e:
183
+ print(e)
184
+ pass
185
+
186
+ name_list, length_list = zip(
187
+ *sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
188
+
189
+ if not tiny:
190
+ logger.info(f"Reading {len(self.id_list)} motions from {split_file}.")
191
+ logger.info(f"Total {len(name_list)} motions are used.")
192
+ logger.info(f"{bad_count} motion sequences not within the length range of "
193
+ f"[{self.min_motion_length}, {self.max_motion_length}) are filtered out.")
194
+
195
+ self.mean = mean
196
+ self.std = std
197
+
198
+ control_args = kwargs['control_args']
199
+ self.control_mode = None
200
+ if os.path.exists(control_args.MEAN_STD_PATH):
201
+ self.raw_mean = np.load(pjoin(control_args.MEAN_STD_PATH, 'Mean_raw.npy'))
202
+ self.raw_std = np.load(pjoin(control_args.MEAN_STD_PATH, 'Std_raw.npy'))
203
+ else:
204
+ self.raw_mean = self.raw_std = None
205
+ if not tiny and control_args.CONTROL:
206
+ self.t_ctrl = control_args.TEMPORAL
207
+ self.training_control_joints = np.array(control_args.TRAIN_JOINTS)
208
+ self.testing_control_joints = np.array(control_args.TEST_JOINTS)
209
+ self.training_density = control_args.TRAIN_DENSITY
210
+ self.testing_density = control_args.TEST_DENSITY
211
+
212
+ self.control_mode = 'val' if ('test' in split_file or 'val' in split_file) else 'train'
213
+ if self.control_mode == 'train':
214
+ logger.info(f'Training Control Joints: {self.training_control_joints}')
215
+ logger.info(f'Training Control Density: {self.training_density}')
216
+ else:
217
+ logger.info(f'Testing Control Joints: {self.testing_control_joints}')
218
+ logger.info(f'Testing Control Density: {self.testing_density}')
219
+ logger.info(f"Temporal Control: {self.t_ctrl}")
220
+
221
+ self.data_dict = data_dict
222
+ self.name_list = name_list
223
+
224
+ def __len__(self) -> int:
225
+ return len(self.name_list)
226
+
227
+ def random_mask(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
228
+ choose_joint = self.testing_control_joints
229
+
230
+ length = joints.shape[0]
231
+ density = self.testing_density
232
+ if density in [1, 2, 5]:
233
+ choose_seq_num = density
234
+ else:
235
+ choose_seq_num = int(length * density / 100)
236
+
237
+ if self.t_ctrl:
238
+ choose_seq = np.arange(0, choose_seq_num)
239
+ else:
240
+ choose_seq = np.random.choice(length, choose_seq_num, replace=False)
241
+ choose_seq.sort()
242
+
243
+ mask_seq = np.zeros((length, self.njoints, 3))
244
+ for cj in choose_joint:
245
+ mask_seq[choose_seq, cj] = 1.0
246
+
247
+ joints = (joints - self.raw_mean) / self.raw_std
248
+ joints = joints * mask_seq
249
+ return joints, mask_seq
250
+
251
+ def random_mask_train(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
252
+ if self.t_ctrl:
253
+ choose_joint = self.training_control_joints
254
+ else:
255
+ num_joints = len(self.training_control_joints)
256
+ num_joints_control = 1
257
+ choose_joint = np.random.choice(num_joints, num_joints_control, replace=False)
258
+ choose_joint = self.training_control_joints[choose_joint]
259
+
260
+ length = joints.shape[0]
261
+
262
+ if self.training_density == 'random':
263
+ choose_seq_num = np.random.choice(length - 1, 1) + 1
264
+ else:
265
+ choose_seq_num = int(length * random.uniform(self.training_density[0], self.training_density[1]) / 100)
266
+
267
+ if self.t_ctrl:
268
+ choose_seq = np.arange(0, choose_seq_num)
269
+ else:
270
+ choose_seq = np.random.choice(length, choose_seq_num, replace=False)
271
+ choose_seq.sort()
272
+
273
+ mask_seq = np.zeros((length, self.njoints, 3))
274
+ for cj in choose_joint:
275
+ mask_seq[choose_seq, cj] = 1
276
+
277
+ joints = (joints - self.raw_mean) / self.raw_std
278
+ joints = joints * mask_seq
279
+ return joints, mask_seq
280
+
281
+ def __getitem__(self, idx: int) -> tuple:
282
+ data = self.data_dict[self.name_list[idx]]
283
+ motion, m_length, text_list = data["motion"], data["length"], data["text"]
284
+ # Randomly select a caption
285
+ text_data = random.choice(text_list)
286
+ caption, tokens = text_data["caption"], text_data["tokens"]
287
+
288
+ if len(tokens) < self.max_text_len:
289
+ # pad with "unk"
290
+ tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
291
+ sent_len = len(tokens)
292
+ tokens = tokens + ["unk/OTHER"] * (self.max_text_len + 2 - sent_len)
293
+ else:
294
+ # crop
295
+ tokens = tokens[:self.max_text_len]
296
+ tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
297
+ sent_len = len(tokens)
298
+ pos_one_hots = []
299
+ word_embeddings = []
300
+ for token in tokens:
301
+ word_emb, pos_oh = self.w_vectorizer[token]
302
+ pos_one_hots.append(pos_oh[None, :])
303
+ word_embeddings.append(word_emb[None, :])
304
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
305
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
306
+
307
+ # Crop the motions in to times of 4, and introduce small variations
308
+ if self.unit_length < 10:
309
+ coin2 = np.random.choice(["single", "single", "double"])
310
+ else:
311
+ coin2 = "single"
312
+
313
+ if coin2 == "double":
314
+ m_length = (m_length // self.unit_length - 1) * self.unit_length
315
+ elif coin2 == "single":
316
+ m_length = (m_length // self.unit_length) * self.unit_length
317
+ idx = random.randint(0, len(motion) - m_length)
318
+ motion = motion[idx:idx + m_length]
319
+
320
+ hint, hint_mask = None, None
321
+ if self.control_mode is not None:
322
+ joints = recover_from_ric(torch.from_numpy(motion).float(), self.njoints)
323
+ joints = joints.numpy()
324
+ if self.control_mode == 'train':
325
+ hint, hint_mask = self.random_mask_train(joints)
326
+ else:
327
+ hint, hint_mask = self.random_mask(joints)
328
+
329
+ if self.padding_to_max:
330
+ padding = np.zeros((self.max_motion_length - m_length, *hint.shape[1:]))
331
+ hint = np.concatenate([hint, padding], axis=0)
332
+ hint_mask = np.concatenate([hint_mask, padding], axis=0)
333
+
334
+ "Z Normalization"
335
+ motion = (motion - self.mean) / self.std
336
+
337
+ if self.padding_to_max:
338
+ padding = np.zeros((self.max_motion_length - m_length, motion.shape[1]))
339
+ motion = np.concatenate([motion, padding], axis=0)
340
+
341
+ return (word_embeddings,
342
+ pos_one_hots,
343
+ caption,
344
+ sent_len,
345
+ motion,
346
+ m_length,
347
+ "_".join(tokens),
348
+ (hint, hint_mask))
mld/data/humanml/scripts/motion_process.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..common.quaternion import qinv, qrot
4
+
5
+
6
+ # Recover global angle and positions for rotation dataset
7
+ # root_rot_velocity (B, seq_len, 1)
8
+ # root_linear_velocity (B, seq_len, 2)
9
+ # root_y (B, seq_len, 1)
10
+ # ric_data (B, seq_len, (joint_num - 1)*3)
11
+ # rot_data (B, seq_len, (joint_num - 1)*6)
12
+ # local_velocity (B, seq_len, joint_num*3)
13
+ # foot contact (B, seq_len, 4)
14
+ def recover_root_rot_pos(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
15
+ rot_vel = data[..., 0]
16
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
17
+ '''Get Y-axis rotation from rotation velocity'''
18
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
19
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
20
+
21
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
22
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
23
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
24
+
25
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
26
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
27
+ '''Add Y-axis rotation to root position'''
28
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
29
+
30
+ r_pos = torch.cumsum(r_pos, dim=-2)
31
+
32
+ r_pos[..., 1] = data[..., 3]
33
+ return r_rot_quat, r_pos
34
+
35
+
36
+ def recover_from_ric(data: torch.Tensor, joints_num: int) -> torch.Tensor:
37
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
38
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
39
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
40
+
41
+ '''Add Y-axis rotation to local joints'''
42
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
43
+
44
+ '''Add root XZ to joints'''
45
+ positions[..., 0] += r_pos[..., 0:1]
46
+ positions[..., 2] += r_pos[..., 2:3]
47
+
48
+ '''Concat root and joints'''
49
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
50
+
51
+ return positions
mld/data/humanml/utils/__init__.py ADDED
File without changes
mld/data/humanml/utils/paramUtil.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Define a kinematic tree for the skeletal structure
4
+ kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5
+
6
+ kit_raw_offsets = np.array(
7
+ [
8
+ [0, 0, 0],
9
+ [0, 1, 0],
10
+ [0, 1, 0],
11
+ [0, 1, 0],
12
+ [0, 1, 0],
13
+ [1, 0, 0],
14
+ [0, -1, 0],
15
+ [0, -1, 0],
16
+ [-1, 0, 0],
17
+ [0, -1, 0],
18
+ [0, -1, 0],
19
+ [1, 0, 0],
20
+ [0, -1, 0],
21
+ [0, -1, 0],
22
+ [0, 0, 1],
23
+ [0, 0, 1],
24
+ [-1, 0, 0],
25
+ [0, -1, 0],
26
+ [0, -1, 0],
27
+ [0, 0, 1],
28
+ [0, 0, 1]
29
+ ]
30
+ )
31
+
32
+ t2m_raw_offsets = np.array([[0, 0, 0],
33
+ [1, 0, 0],
34
+ [-1, 0, 0],
35
+ [0, 1, 0],
36
+ [0, -1, 0],
37
+ [0, -1, 0],
38
+ [0, 1, 0],
39
+ [0, -1, 0],
40
+ [0, -1, 0],
41
+ [0, 1, 0],
42
+ [0, 0, 1],
43
+ [0, 0, 1],
44
+ [0, 1, 0],
45
+ [1, 0, 0],
46
+ [-1, 0, 0],
47
+ [0, 0, 1],
48
+ [0, -1, 0],
49
+ [0, -1, 0],
50
+ [0, -1, 0],
51
+ [0, -1, 0],
52
+ [0, -1, 0],
53
+ [0, -1, 0]])
54
+
55
+ t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21],
56
+ [9, 13, 16, 18, 20]]
57
+ t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
58
+ t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
59
+
60
+ kit_tgt_skel_id = '03950'
61
+
62
+ t2m_tgt_skel_id = '000021'
mld/data/humanml/utils/plot_script.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from textwrap import wrap
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+ import matplotlib.pyplot as plt
7
+ import mpl_toolkits.mplot3d.axes3d as p3
8
+ from matplotlib.animation import FuncAnimation
9
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
10
+
11
+ import mld.data.humanml.utils.paramUtil as paramUtil
12
+
13
+ skeleton = paramUtil.t2m_kinematic_chain
14
+
15
+
16
+ def plot_3d_motion(save_path: str, joints: np.ndarray, title: str,
17
+ figsize: tuple[int, int] = (3, 3),
18
+ fps: int = 120, radius: int = 3, kinematic_tree: list = skeleton,
19
+ hint: Optional[np.ndarray] = None) -> None:
20
+
21
+ title = '\n'.join(wrap(title, 20))
22
+
23
+ def init():
24
+ ax.set_xlim3d([-radius / 2, radius / 2])
25
+ ax.set_ylim3d([0, radius])
26
+ ax.set_zlim3d([-radius / 3., radius * 2 / 3.])
27
+ fig.suptitle(title, fontsize=10)
28
+ ax.grid(b=False)
29
+
30
+ def plot_xzPlane(minx, maxx, miny, minz, maxz):
31
+ # Plot a plane XZ
32
+ verts = [
33
+ [minx, miny, minz],
34
+ [minx, miny, maxz],
35
+ [maxx, miny, maxz],
36
+ [maxx, miny, minz]
37
+ ]
38
+ xz_plane = Poly3DCollection([verts])
39
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
40
+ ax.add_collection3d(xz_plane)
41
+
42
+ # (seq_len, joints_num, 3)
43
+ data = joints.copy().reshape(len(joints), -1, 3)
44
+
45
+ data *= 1.3 # scale for visualization
46
+ if hint is not None:
47
+ mask = hint.sum(-1) != 0
48
+ hint = hint[mask]
49
+ hint *= 1.3
50
+
51
+ fig = plt.figure(figsize=figsize)
52
+ plt.tight_layout()
53
+ ax = p3.Axes3D(fig)
54
+ init()
55
+ MINS = data.min(axis=0).min(axis=0)
56
+ MAXS = data.max(axis=0).max(axis=0)
57
+ colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00",
58
+ "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00",
59
+ "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ]
60
+
61
+ frame_number = data.shape[0]
62
+
63
+ height_offset = MINS[1]
64
+ data[:, :, 1] -= height_offset
65
+ if hint is not None:
66
+ hint[..., 1] -= height_offset
67
+ trajec = data[:, 0, [0, 2]]
68
+
69
+ data[..., 0] -= data[:, 0:1, 0]
70
+ data[..., 2] -= data[:, 0:1, 2]
71
+
72
+ def update(index):
73
+ ax.lines = []
74
+ ax.collections = []
75
+ ax.view_init(elev=120, azim=-90)
76
+ ax.dist = 7.5
77
+ plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
78
+ MAXS[2] - trajec[index, 1])
79
+
80
+ if hint is not None:
81
+ ax.scatter(hint[..., 0] - trajec[index, 0], hint[..., 1], hint[..., 2] - trajec[index, 1], color="#80B79A")
82
+
83
+ for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
84
+ if i < 5:
85
+ linewidth = 4.0
86
+ else:
87
+ linewidth = 2.0
88
+ ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
89
+ color=color)
90
+
91
+ plt.axis('off')
92
+ ax.set_xticklabels([])
93
+ ax.set_yticklabels([])
94
+ ax.set_zticklabels([])
95
+
96
+ ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
97
+ ani.save(save_path, fps=fps)
98
+ plt.close()
mld/data/humanml/utils/word_vectorizer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from os.path import join as pjoin
3
+
4
+ import numpy as np
5
+
6
+
7
+ POS_enumerator = {
8
+ 'VERB': 0,
9
+ 'NOUN': 1,
10
+ 'DET': 2,
11
+ 'ADP': 3,
12
+ 'NUM': 4,
13
+ 'AUX': 5,
14
+ 'PRON': 6,
15
+ 'ADJ': 7,
16
+ 'ADV': 8,
17
+ 'Loc_VIP': 9,
18
+ 'Body_VIP': 10,
19
+ 'Obj_VIP': 11,
20
+ 'Act_VIP': 12,
21
+ 'Desc_VIP': 13,
22
+ 'OTHER': 14
23
+ }
24
+
25
+ Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
26
+ 'up', 'down', 'straight', 'curve')
27
+
28
+ Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
29
+
30
+ Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
31
+
32
+ Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
33
+ 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
34
+ 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
35
+
36
+ Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
37
+ 'angrily', 'sadly')
38
+
39
+ VIP_dict = {
40
+ 'Loc_VIP': Loc_list,
41
+ 'Body_VIP': Body_list,
42
+ 'Obj_VIP': Obj_List,
43
+ 'Act_VIP': Act_list,
44
+ 'Desc_VIP': Desc_list,
45
+ }
46
+
47
+
48
+ class WordVectorizer(object):
49
+ def __init__(self, meta_root: str, prefix: str) -> None:
50
+ vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix))
51
+ words = pickle.load(open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb'))
52
+ word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb'))
53
+ self.word2vec = {w: vectors[word2idx[w]] for w in words}
54
+
55
+ def _get_pos_ohot(self, pos: str) -> np.ndarray:
56
+ pos_vec = np.zeros(len(POS_enumerator))
57
+ if pos in POS_enumerator:
58
+ pos_vec[POS_enumerator[pos]] = 1
59
+ else:
60
+ pos_vec[POS_enumerator['OTHER']] = 1
61
+ return pos_vec
62
+
63
+ def __len__(self) -> int:
64
+ return len(self.word2vec)
65
+
66
+ def __getitem__(self, item: str) -> tuple:
67
+ word, pos = item.split('/')
68
+ if word in self.word2vec:
69
+ word_vec = self.word2vec[word]
70
+ vip_pos = None
71
+ for key, values in VIP_dict.items():
72
+ if word in values:
73
+ vip_pos = key
74
+ break
75
+ if vip_pos is not None:
76
+ pos_vec = self._get_pos_ohot(vip_pos)
77
+ else:
78
+ pos_vec = self._get_pos_ohot(pos)
79
+ else:
80
+ word_vec = self.word2vec['unk']
81
+ pos_vec = self._get_pos_ohot('OTHER')
82
+ return word_vec, pos_vec
mld/data/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mld.utils.temos_utils import lengths_to_mask
4
+
5
+
6
+ def collate_tensors(batch: list) -> torch.Tensor:
7
+ dims = batch[0].dim()
8
+ max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
9
+ size = (len(batch), ) + tuple(max_size)
10
+ canvas = batch[0].new_zeros(size=size)
11
+ for i, b in enumerate(batch):
12
+ sub_tensor = canvas[i]
13
+ for d in range(dims):
14
+ sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
15
+ sub_tensor.add_(b)
16
+ return canvas
17
+
18
+
19
+ def mld_collate(batch: list) -> dict:
20
+ notnone_batches = [b for b in batch if b is not None]
21
+ notnone_batches.sort(key=lambda x: x[3], reverse=True)
22
+ adapted_batch = {
23
+ "motion":
24
+ collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]),
25
+ "text": [b[2] for b in notnone_batches],
26
+ "length": [b[5] for b in notnone_batches],
27
+ "word_embs":
28
+ collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]),
29
+ "pos_ohot":
30
+ collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]),
31
+ "text_len":
32
+ collate_tensors([torch.tensor(b[3]) for b in notnone_batches]),
33
+ "tokens": [b[6] for b in notnone_batches]
34
+ }
35
+
36
+ mask = lengths_to_mask(adapted_batch['length'], adapted_batch['motion'].device, adapted_batch['motion'].shape[1])
37
+ adapted_batch['mask'] = mask
38
+
39
+ # collate trajectory
40
+ if notnone_batches[0][-1][0] is not None:
41
+ adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1][0]).float() for b in notnone_batches])
42
+ adapted_batch['hint_mask'] = collate_tensors([torch.tensor(b[-1][1]).float() for b in notnone_batches])
43
+
44
+ return adapted_batch
45
+
46
+
47
+ def mld_collate_motion_only(batch: list) -> dict:
48
+ batch = {
49
+ "motion": collate_tensors([torch.tensor(b[0]).float() for b in batch]),
50
+ "length": [b[1] for b in batch]
51
+ }
52
+ return batch
mld/launch/__init__.py ADDED
File without changes
mld/launch/blender.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fix blender path
2
+ import os
3
+ import sys
4
+ from argparse import ArgumentParser
5
+
6
+ sys.path.insert(0, os.path.expanduser("~/.local/lib/python3.9/site-packages"))
7
+
8
+
9
+ # Monkey patch argparse such that
10
+ # blender / python parsing works
11
+ def parse_args(self, args=None, namespace=None):
12
+ if args is not None:
13
+ return self.parse_args_bak(args=args, namespace=namespace)
14
+ try:
15
+ idx = sys.argv.index("--")
16
+ args = sys.argv[idx + 1:] # the list after '--'
17
+ except ValueError as e: # '--' not in the list:
18
+ args = []
19
+ return self.parse_args_bak(args=args, namespace=namespace)
20
+
21
+
22
+ setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args)
23
+ setattr(ArgumentParser, 'parse_args', parse_args)
mld/models/__init__.py ADDED
File without changes
mld/models/architectures/__init__.py ADDED
File without changes
mld/models/architectures/dno.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.tensorboard import SummaryWriter
6
+
7
+
8
+ class DNO(object):
9
+ def __init__(
10
+ self,
11
+ optimize: bool,
12
+ max_train_steps: int,
13
+ learning_rate: float,
14
+ lr_scheduler: str,
15
+ lr_warmup_steps: int,
16
+ clip_grad: bool,
17
+ loss_hint_type: str,
18
+ loss_diff_penalty: float,
19
+ loss_correlate_penalty: float,
20
+ visualize_samples: int,
21
+ visualize_ske_steps: list[int],
22
+ output_dir: str
23
+ ) -> None:
24
+
25
+ self.optimize = optimize
26
+ self.max_train_steps = max_train_steps
27
+ self.learning_rate = learning_rate
28
+ self.lr_scheduler = lr_scheduler
29
+ self.lr_warmup_steps = lr_warmup_steps
30
+ self.clip_grad = clip_grad
31
+ self.loss_hint_type = loss_hint_type
32
+ self.loss_diff_penalty = loss_diff_penalty
33
+ self.loss_correlate_penalty = loss_correlate_penalty
34
+
35
+ if loss_hint_type == 'l1':
36
+ self.loss_hint_func = F.l1_loss
37
+ elif loss_hint_type == 'l1_smooth':
38
+ self.loss_hint_func = F.smooth_l1_loss
39
+ elif loss_hint_type == 'l2':
40
+ self.loss_hint_func = F.mse_loss
41
+ else:
42
+ raise ValueError(f'Invalid loss type: {loss_hint_type}')
43
+
44
+ self.visualize_samples = float('inf') if visualize_samples == 'inf' else visualize_samples
45
+ assert self.visualize_samples >= 0
46
+ self.visualize_samples_done = 0
47
+ self.visualize_ske_steps = visualize_ske_steps
48
+ if len(visualize_ske_steps) > 0:
49
+ self.vis_dir = os.path.join(output_dir, 'vis_optimize')
50
+ os.makedirs(self.vis_dir)
51
+
52
+ self.writer = None
53
+ self.output_dir = output_dir
54
+ if self.visualize_samples > 0:
55
+ self.writer = SummaryWriter(output_dir)
56
+
57
+ @property
58
+ def do_visualize(self):
59
+ return self.visualize_samples_done < self.visualize_samples
60
+
61
+ @staticmethod
62
+ def noise_regularize_1d(noise: torch.Tensor, stop_at: int = 2, dim: int = 1) -> torch.Tensor:
63
+ size = noise.shape[dim]
64
+ if size & (size - 1) != 0:
65
+ new_size = 2 ** (size - 1).bit_length()
66
+ pad = new_size - size
67
+ pad_shape = list(noise.shape)
68
+ pad_shape[dim] = pad
69
+ pad_noise = torch.randn(*pad_shape, device=noise.device)
70
+ noise = torch.cat([noise, pad_noise], dim=dim)
71
+ size = noise.shape[dim]
72
+
73
+ loss = torch.zeros(noise.shape[0], device=noise.device)
74
+ while size > stop_at:
75
+ rolled_noise = torch.roll(noise, shifts=1, dims=dim)
76
+ loss += (noise * rolled_noise).mean(dim=tuple(range(1, noise.ndim))).pow(2)
77
+ noise = noise.view(*noise.shape[:dim], size // 2, 2, *noise.shape[dim + 1:]).mean(dim=dim + 1)
78
+ size //= 2
79
+ return loss
mld/models/architectures/mld_clip.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+
8
+ class MldTextEncoder(nn.Module):
9
+
10
+ def __init__(self, modelpath: str, last_hidden_state: bool = False) -> None:
11
+ super().__init__()
12
+
13
+ if 't5' in modelpath:
14
+ self.text_model = SentenceTransformer(modelpath)
15
+ self.tokenizer = self.text_model.tokenizer
16
+ else:
17
+ self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
18
+ self.text_model = AutoModel.from_pretrained(modelpath)
19
+
20
+ self.max_length = self.tokenizer.model_max_length
21
+ if "clip" in modelpath:
22
+ self.text_encoded_dim = self.text_model.config.text_config.hidden_size
23
+ if last_hidden_state:
24
+ self.name = "clip_hidden"
25
+ else:
26
+ self.name = "clip"
27
+ elif "bert" in modelpath:
28
+ self.name = "bert"
29
+ self.text_encoded_dim = self.text_model.config.hidden_size
30
+ elif 't5' in modelpath:
31
+ self.name = 't5'
32
+ else:
33
+ raise ValueError(f"Model {modelpath} not supported")
34
+
35
+ def forward(self, texts: list[str]) -> torch.Tensor:
36
+ # get prompt text embeddings
37
+ if self.name in ["clip", "clip_hidden"]:
38
+ text_inputs = self.tokenizer(
39
+ texts,
40
+ padding="max_length",
41
+ truncation=True,
42
+ max_length=self.max_length,
43
+ return_tensors="pt",
44
+ )
45
+ text_input_ids = text_inputs.input_ids
46
+ # split into max length Clip can handle
47
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
48
+ text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
49
+ elif self.name == "bert":
50
+ text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
51
+
52
+ if self.name == "clip":
53
+ # (batch_Size, text_encoded_dim)
54
+ text_embeddings = self.text_model.get_text_features(
55
+ text_input_ids.to(self.text_model.device))
56
+ # (batch_Size, 1, text_encoded_dim)
57
+ text_embeddings = text_embeddings.unsqueeze(1)
58
+ elif self.name == "clip_hidden":
59
+ # (batch_Size, seq_length , text_encoded_dim)
60
+ text_embeddings = self.text_model.text_model(
61
+ text_input_ids.to(self.text_model.device)).last_hidden_state
62
+ elif self.name == "bert":
63
+ # (batch_Size, seq_length , text_encoded_dim)
64
+ text_embeddings = self.text_model(
65
+ **text_inputs.to(self.text_model.device)).last_hidden_state
66
+ elif self.name == 't5':
67
+ text_embeddings = self.text_model.encode(texts, show_progress_bar=False, convert_to_tensor=True, batch_size=len(texts))
68
+ text_embeddings = text_embeddings.unsqueeze(1)
69
+ else:
70
+ raise NotImplementedError(f"Model {self.name} not implemented")
71
+
72
+ return text_embeddings
mld/models/architectures/mld_denoiser.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from mld.models.operator.embeddings import TimestepEmbedding, Timesteps
7
+ from mld.models.operator.attention import (SkipTransformerEncoder,
8
+ SkipTransformerDecoder,
9
+ TransformerDecoder,
10
+ TransformerDecoderLayer,
11
+ TransformerEncoder,
12
+ TransformerEncoderLayer)
13
+ from mld.models.operator.moe import MoeTransformerEncoderLayer, MoeTransformerDecoderLayer
14
+ from mld.models.operator.utils import get_clones, get_activation_fn, zero_module
15
+ from mld.models.operator.position_encoding import build_position_encoding
16
+
17
+
18
+ def load_balancing_loss_func(router_logits: tuple, num_experts: int = 4, topk: int = 2):
19
+ router_logits = torch.cat(router_logits, dim=0)
20
+ routing_weights = torch.nn.functional.softmax(router_logits, dim=-1)
21
+ _, selected_experts = torch.topk(routing_weights, topk, dim=-1)
22
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
23
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
24
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
25
+ overall_loss = num_experts * torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
26
+ return overall_loss
27
+
28
+
29
+ class MldDenoiser(nn.Module):
30
+
31
+ def __init__(self,
32
+ latent_dim: list = [1, 256],
33
+ hidden_dim: Optional[int] = None,
34
+ text_dim: int = 768,
35
+ time_dim: int = 768,
36
+ ff_size: int = 1024,
37
+ num_layers: int = 9,
38
+ num_heads: int = 4,
39
+ dropout: float = 0.1,
40
+ normalize_before: bool = False,
41
+ norm_eps: float = 1e-5,
42
+ activation: str = "gelu",
43
+ norm_post: bool = True,
44
+ activation_post: Optional[str] = None,
45
+ flip_sin_to_cos: bool = True,
46
+ freq_shift: float = 0,
47
+ time_act_fn: str = 'silu',
48
+ time_post_act_fn: Optional[str] = None,
49
+ position_embedding: str = "learned",
50
+ arch: str = "trans_enc",
51
+ add_mem_pos: bool = True,
52
+ force_pre_post_proj: bool = False,
53
+ text_act_fn: str = 'relu',
54
+ time_cond_proj_dim: Optional[int] = None,
55
+ zero_init_cond: bool = True,
56
+ is_controlnet: bool = False,
57
+ controlnet_embed_dim: Optional[int] = None,
58
+ controlnet_act_fn: str = 'silu',
59
+ moe: bool = False,
60
+ moe_num_experts: int = 4,
61
+ moe_topk: int = 2,
62
+ moe_loss_weight: float = 1e-2,
63
+ moe_jitter_noise: Optional[float] = None
64
+ ) -> None:
65
+ super(MldDenoiser, self).__init__()
66
+
67
+ self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
68
+ add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
69
+ self.latent_pre = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity()
70
+ self.latent_post = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity()
71
+
72
+ self.arch = arch
73
+ self.time_cond_proj_dim = time_cond_proj_dim
74
+
75
+ self.moe_num_experts = moe_num_experts
76
+ self.moe_topk = moe_topk
77
+ self.moe_loss_weight = moe_loss_weight
78
+
79
+ self.time_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift)
80
+ self.time_embedding = TimestepEmbedding(time_dim, self.latent_dim, time_act_fn, post_act_fn=time_post_act_fn,
81
+ cond_proj_dim=time_cond_proj_dim, zero_init_cond=zero_init_cond)
82
+ self.emb_proj = nn.Sequential(get_activation_fn(text_act_fn), nn.Linear(text_dim, self.latent_dim))
83
+
84
+ self.query_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding)
85
+ if self.arch == "trans_enc":
86
+ if moe:
87
+ encoder_layer = MoeTransformerEncoderLayer(
88
+ self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size,
89
+ dropout, activation, normalize_before, norm_eps, moe_jitter_noise)
90
+ else:
91
+ encoder_layer = TransformerEncoderLayer(
92
+ self.latent_dim, num_heads, ff_size, dropout,
93
+ activation, normalize_before, norm_eps)
94
+
95
+ encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None
96
+ self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post,
97
+ is_controlnet=is_controlnet, is_moe=moe)
98
+
99
+ elif self.arch == 'trans_dec':
100
+ if add_mem_pos:
101
+ self.mem_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding)
102
+ else:
103
+ self.mem_pos = None
104
+ if moe:
105
+ decoder_layer = MoeTransformerDecoderLayer(
106
+ self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size,
107
+ dropout, activation, normalize_before, norm_eps, moe_jitter_noise)
108
+ else:
109
+ decoder_layer = TransformerDecoderLayer(
110
+ self.latent_dim, num_heads, ff_size, dropout,
111
+ activation, normalize_before, norm_eps)
112
+
113
+ decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None
114
+ self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post,
115
+ is_controlnet=is_controlnet, is_moe=moe)
116
+ else:
117
+ raise ValueError(f"Not supported architecture: {self.arch}!")
118
+
119
+ self.is_controlnet = is_controlnet
120
+ if self.is_controlnet:
121
+ embed_dim = controlnet_embed_dim if controlnet_embed_dim is not None else self.latent_dim
122
+ modules = [
123
+ nn.Linear(latent_dim[-1], embed_dim),
124
+ get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None,
125
+ nn.Linear(embed_dim, embed_dim),
126
+ get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None,
127
+ zero_module(nn.Linear(embed_dim, latent_dim[-1]))
128
+ ]
129
+ self.controlnet_cond_embedding = nn.Sequential(*[m for m in modules if m is not None])
130
+
131
+ self.controlnet_down_mid_blocks = nn.ModuleList([
132
+ zero_module(nn.Linear(self.latent_dim, self.latent_dim)) for _ in range(num_layers)])
133
+
134
+ def forward(self,
135
+ sample: torch.Tensor,
136
+ timestep: torch.Tensor,
137
+ encoder_hidden_states: torch.Tensor,
138
+ timestep_cond: Optional[torch.Tensor] = None,
139
+ controlnet_cond: Optional[torch.Tensor] = None,
140
+ controlnet_residuals: Optional[list[torch.Tensor]] = None
141
+ ) -> tuple:
142
+
143
+ # 0. check if controlnet
144
+ if self.is_controlnet:
145
+ sample = sample + self.controlnet_cond_embedding(controlnet_cond)
146
+
147
+ # 1. dimension matching (pre)
148
+ sample = sample.permute(1, 0, 2)
149
+ sample = self.latent_pre(sample)
150
+
151
+ # 2. time_embedding
152
+ timesteps = timestep.expand(sample.shape[1]).clone()
153
+ time_emb = self.time_proj(timesteps)
154
+ time_emb = time_emb.to(dtype=sample.dtype)
155
+ # [1, bs, latent_dim] <= [bs, latent_dim]
156
+ time_emb = self.time_embedding(time_emb, timestep_cond).unsqueeze(0)
157
+
158
+ # 3. condition + time embedding
159
+ # text_emb [seq_len, batch_size, text_dim] <= [batch_size, seq_len, text_dim]
160
+ encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2)
161
+ # text embedding projection
162
+ text_emb_latent = self.emb_proj(encoder_hidden_states)
163
+ emb_latent = torch.cat((time_emb, text_emb_latent), 0)
164
+
165
+ # 4. transformer
166
+ if self.arch == "trans_enc":
167
+ xseq = torch.cat((sample, emb_latent), axis=0)
168
+ xseq = self.query_pos(xseq)
169
+ tokens, intermediates, router_logits = self.encoder(xseq, controlnet_residuals=controlnet_residuals)
170
+ elif self.arch == 'trans_dec':
171
+ sample = self.query_pos(sample)
172
+ if self.mem_pos:
173
+ emb_latent = self.mem_pos(emb_latent)
174
+ tokens, intermediates, router_logits = self.decoder(sample, emb_latent,
175
+ controlnet_residuals=controlnet_residuals)
176
+ else:
177
+ raise TypeError(f"{self.arch} is not supported")
178
+
179
+ router_loss = None
180
+ if router_logits is not None:
181
+ router_loss = load_balancing_loss_func(router_logits, self.moe_num_experts, self.moe_topk)
182
+ router_loss = self.moe_loss_weight * router_loss
183
+
184
+ if self.is_controlnet:
185
+ control_res_samples = []
186
+ for res, block in zip(intermediates, self.controlnet_down_mid_blocks):
187
+ r = block(res)
188
+ control_res_samples.append(r)
189
+ return control_res_samples, router_loss
190
+ elif self.arch == "trans_enc":
191
+ sample = tokens[:sample.shape[0]]
192
+ elif self.arch == 'trans_dec':
193
+ sample = tokens
194
+ else:
195
+ raise TypeError(f"{self.arch} is not supported")
196
+
197
+ # 5. dimension matching (post)
198
+ sample = self.latent_post(sample)
199
+ sample = sample.permute(1, 0, 2)
200
+ return sample, router_loss
mld/models/architectures/mld_traj_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from mld.models.operator.attention import SkipTransformerEncoder, TransformerEncoderLayer
7
+ from mld.models.operator.position_encoding import build_position_encoding
8
+
9
+
10
+ class MldTrajEncoder(nn.Module):
11
+
12
+ def __init__(self,
13
+ nfeats: int,
14
+ latent_dim: list = [1, 256],
15
+ hidden_dim: Optional[int] = None,
16
+ force_post_proj: bool = False,
17
+ ff_size: int = 1024,
18
+ num_layers: int = 9,
19
+ num_heads: int = 4,
20
+ dropout: float = 0.1,
21
+ normalize_before: bool = False,
22
+ norm_eps: float = 1e-5,
23
+ activation: str = "gelu",
24
+ norm_post: bool = True,
25
+ activation_post: Optional[str] = None,
26
+ position_embedding: str = "learned") -> None:
27
+ super(MldTrajEncoder, self).__init__()
28
+
29
+ self.latent_size = latent_dim[0]
30
+ self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
31
+ add_post_proj = force_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
32
+ self.latent_proj = nn.Linear(self.latent_dim, latent_dim[-1]) if add_post_proj else nn.Identity()
33
+
34
+ self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim)
35
+
36
+ self.query_pos_encoder = build_position_encoding(
37
+ self.latent_dim, position_embedding=position_embedding)
38
+
39
+ encoder_layer = TransformerEncoderLayer(
40
+ self.latent_dim,
41
+ num_heads,
42
+ ff_size,
43
+ dropout,
44
+ activation,
45
+ normalize_before,
46
+ norm_eps
47
+ )
48
+ encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
49
+ self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post)
50
+ self.global_motion_token = nn.Parameter(torch.randn(self.latent_size, self.latent_dim))
51
+
52
+ def forward(self, features: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
53
+ bs, nframes, nfeats = features.shape
54
+ x = self.skel_embedding(features)
55
+ x = x.permute(1, 0, 2)
56
+ dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
57
+ dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
58
+ aug_mask = torch.cat((dist_masks, mask), 1)
59
+ xseq = torch.cat((dist, x), 0)
60
+ xseq = self.query_pos_encoder(xseq)
61
+ global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]]
62
+ global_token = self.latent_proj(global_token)
63
+ global_token = global_token.permute(1, 0, 2)
64
+ return global_token
mld/models/architectures/mld_vae.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributions.distribution import Distribution
6
+
7
+ from mld.models.operator.attention import (
8
+ SkipTransformerEncoder,
9
+ SkipTransformerDecoder,
10
+ TransformerDecoder,
11
+ TransformerDecoderLayer,
12
+ TransformerEncoder,
13
+ TransformerEncoderLayer
14
+ )
15
+ from mld.models.operator.position_encoding import build_position_encoding
16
+
17
+
18
+ class MldVae(nn.Module):
19
+
20
+ def __init__(self,
21
+ nfeats: int,
22
+ latent_dim: list = [1, 256],
23
+ hidden_dim: Optional[int] = None,
24
+ force_pre_post_proj: bool = False,
25
+ ff_size: int = 1024,
26
+ num_layers: int = 9,
27
+ num_heads: int = 4,
28
+ dropout: float = 0.1,
29
+ arch: str = "encoder_decoder",
30
+ normalize_before: bool = False,
31
+ norm_eps: float = 1e-5,
32
+ activation: str = "gelu",
33
+ norm_post: bool = True,
34
+ activation_post: Optional[str] = None,
35
+ position_embedding: str = "learned") -> None:
36
+ super(MldVae, self).__init__()
37
+
38
+ self.latent_size = latent_dim[0]
39
+ self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim
40
+ add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1])
41
+ self.latent_pre = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity()
42
+ self.latent_post = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity()
43
+
44
+ self.arch = arch
45
+
46
+ self.query_pos_encoder = build_position_encoding(
47
+ self.latent_dim, position_embedding=position_embedding)
48
+
49
+ encoder_layer = TransformerEncoderLayer(
50
+ self.latent_dim,
51
+ num_heads,
52
+ ff_size,
53
+ dropout,
54
+ activation,
55
+ normalize_before,
56
+ norm_eps
57
+ )
58
+ encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
59
+ self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post)
60
+
61
+ if self.arch == "all_encoder":
62
+ decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
63
+ self.decoder = SkipTransformerEncoder(encoder_layer, num_layers, decoder_norm, activation_post)
64
+ elif self.arch == 'encoder_decoder':
65
+ self.query_pos_decoder = build_position_encoding(
66
+ self.latent_dim, position_embedding=position_embedding)
67
+
68
+ decoder_layer = TransformerDecoderLayer(
69
+ self.latent_dim,
70
+ num_heads,
71
+ ff_size,
72
+ dropout,
73
+ activation,
74
+ normalize_before,
75
+ norm_eps
76
+ )
77
+ decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None
78
+ self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post)
79
+ else:
80
+ raise ValueError(f"Not support architecture: {self.arch}!")
81
+
82
+ self.global_motion_token = nn.Parameter(torch.randn(self.latent_size * 2, self.latent_dim))
83
+ self.skel_embedding = nn.Linear(nfeats, self.latent_dim)
84
+ self.final_layer = nn.Linear(self.latent_dim, nfeats)
85
+
86
+ def forward(self, features: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, Distribution]:
87
+ z, dist = self.encode(features, mask)
88
+ feats_rst = self.decode(z, mask)
89
+ return feats_rst, z, dist
90
+
91
+ def encode(self, features: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, Distribution]:
92
+ bs, nframes, nfeats = features.shape
93
+ x = self.skel_embedding(features)
94
+ x = x.permute(1, 0, 2)
95
+ dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
96
+ dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
97
+ aug_mask = torch.cat((dist_masks, mask), 1)
98
+ xseq = torch.cat((dist, x), 0)
99
+
100
+ xseq = self.query_pos_encoder(xseq)
101
+ dist = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]]
102
+ dist = self.latent_pre(dist)
103
+
104
+ mu = dist[0:self.latent_size, ...]
105
+ logvar = dist[self.latent_size:, ...]
106
+
107
+ std = logvar.exp().pow(0.5)
108
+ dist = torch.distributions.Normal(mu, std)
109
+ latent = dist.rsample()
110
+ # [latent_dim[0], batch_size, latent_dim] -> [batch_size, latent_dim[0], latent_dim[1]]
111
+ latent = latent.permute(1, 0, 2)
112
+ return latent, dist
113
+
114
+ def decode(self, z: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
115
+ # [batch_size, latent_dim[0], latent_dim[1]] -> [latent_dim[0], batch_size, latent_dim[1]]
116
+ z = self.latent_post(z)
117
+ z = z.permute(1, 0, 2)
118
+ bs, nframes = mask.shape
119
+ queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
120
+
121
+ if self.arch == "all_encoder":
122
+ xseq = torch.cat((z, queries), axis=0)
123
+ z_mask = torch.ones((bs, self.latent_size), dtype=torch.bool, device=z.device)
124
+ aug_mask = torch.cat((z_mask, mask), axis=1)
125
+ xseq = self.query_pos_decoder(xseq)
126
+ output = self.decoder(xseq, src_key_padding_mask=~aug_mask)[0][z.shape[0]:]
127
+ elif self.arch == "encoder_decoder":
128
+ queries = self.query_pos_decoder(queries)
129
+ output = self.decoder(tgt=queries, memory=z, tgt_key_padding_mask=~mask)[0]
130
+ else:
131
+ raise ValueError(f"Not support architecture: {self.arch}!")
132
+
133
+ output = self.final_layer(output)
134
+ output[~mask.T] = 0
135
+ feats = output.permute(1, 0, 2)
136
+ return feats