|
""" |
|
input: json file with video, audio, motion paths |
|
output: igraph object with nodes containing video, audio, motion, position, velocity, axis_angle, previous, next, frame, fps |
|
|
|
preprocess: |
|
1. assume you have a video for one speaker in folder, listed in |
|
-- video_a.mp4 |
|
-- video_b.mp4 |
|
run process_video.py to extract frames and audio |
|
""" |
|
|
|
import os |
|
import smplx |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
import librosa |
|
import igraph |
|
import json |
|
import utils.rotation_conversions as rc |
|
from moviepy.editor import VideoClip, AudioFileClip, VideoFileClip |
|
from tqdm import tqdm |
|
import imageio |
|
import tempfile |
|
import argparse |
|
|
|
|
|
def get_motion_reps_tensor(motion_tensor, smplx_model, pose_fps=30, device='cuda'): |
|
bs, n, _ = motion_tensor.shape |
|
motion_tensor = motion_tensor.float().to(device) |
|
motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165) |
|
|
|
output = smplx_model( |
|
betas=torch.zeros(bs * n, 300, device=device), |
|
transl=torch.zeros(bs * n, 3, device=device), |
|
expression=torch.zeros(bs * n, 100, device=device), |
|
jaw_pose=torch.zeros(bs * n, 3, device=device), |
|
global_orient=torch.zeros(bs * n, 3, device=device), |
|
body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3], |
|
left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3], |
|
right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3], |
|
return_joints=True, |
|
leye_pose=torch.zeros(bs * n, 3, device=device), |
|
reye_pose=torch.zeros(bs * n, 3, device=device), |
|
) |
|
|
|
joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :] |
|
dt = 1 / pose_fps |
|
init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt |
|
middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt) |
|
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt |
|
vel = torch.cat([init_vel, middle_vel, final_vel], dim=1) |
|
|
|
position = joints |
|
rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3)) |
|
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6) |
|
|
|
init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt |
|
middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt) |
|
final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt |
|
angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3) |
|
|
|
rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15) |
|
|
|
return { |
|
"position": position, |
|
"velocity": vel, |
|
"rotation": rot6d, |
|
"axis_angle": motion_tensor, |
|
"angular_velocity": angular_velocity, |
|
"rep15d": rep15d, |
|
} |
|
|
|
|
|
|
|
def get_motion_reps(motion, smplx_model, pose_fps=30): |
|
gt_motion_tensor = motion["poses"] |
|
n = gt_motion_tensor.shape[0] |
|
bs = 1 |
|
gt_motion_tensor = torch.from_numpy(gt_motion_tensor).float().to(device).unsqueeze(0) |
|
gt_motion_tensor_reshaped = gt_motion_tensor.reshape(bs * n, 165) |
|
output = smplx_model( |
|
betas=torch.zeros(bs * n, 300).to(device), |
|
transl=torch.zeros(bs * n, 3).to(device), |
|
expression=torch.zeros(bs * n, 100).to(device), |
|
jaw_pose=torch.zeros(bs * n, 3).to(device), |
|
global_orient=torch.zeros(bs * n, 3).to(device), |
|
body_pose=gt_motion_tensor_reshaped[:, 3:21 * 3 + 3], |
|
left_hand_pose=gt_motion_tensor_reshaped[:, 25 * 3:40 * 3], |
|
right_hand_pose=gt_motion_tensor_reshaped[:, 40 * 3:55 * 3], |
|
return_joints=True, |
|
leye_pose=torch.zeros(bs * n, 3).to(device), |
|
reye_pose=torch.zeros(bs * n, 3).to(device), |
|
) |
|
joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :] |
|
dt = 1 / pose_fps |
|
init_vel = (joints[1:2] - joints[0:1]) / dt |
|
middle_vel = (joints[2:] - joints[:-2]) / (2 * dt) |
|
final_vel = (joints[-1:] - joints[-2:-1]) / dt |
|
vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0) |
|
position = joints |
|
rot_matrices = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3))[0] |
|
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy() |
|
|
|
init_vel = (motion["poses"][1:2] - motion["poses"][0:1]) / dt |
|
middle_vel = (motion["poses"][2:] - motion["poses"][:-2]) / (2 * dt) |
|
final_vel = (motion["poses"][-1:] - motion["poses"][-2:-1]) / dt |
|
angular_velocity = np.concatenate([init_vel, middle_vel, final_vel], axis=0).reshape(n, 55, 3) |
|
|
|
rep15d = np.concatenate([ |
|
position, |
|
vel, |
|
rot6d, |
|
angular_velocity], |
|
axis=2 |
|
).reshape(n, 55*15) |
|
return { |
|
"position": position, |
|
"velocity": vel, |
|
"rotation": rot6d, |
|
"axis_angle": motion["poses"], |
|
"angular_velocity": angular_velocity, |
|
"rep15d": rep15d, |
|
"trans": motion["trans"] |
|
} |
|
|
|
def create_graph(json_path, smplx_model): |
|
fps = 30 |
|
data_meta = json.load(open(json_path, "r")) |
|
graph = igraph.Graph(directed=True) |
|
global_i = 0 |
|
for data_item in data_meta: |
|
video_path = os.path.join(data_item['video_path'], data_item['video_id'] + ".mp4") |
|
|
|
motion_path = os.path.join(data_item['motion_path'], data_item['video_id'] + ".npz") |
|
video_id = data_item.get("video_id", "") |
|
motion = np.load(motion_path, allow_pickle=True) |
|
motion_reps = get_motion_reps(motion, smplx_model) |
|
position = motion_reps['position'] |
|
velocity = motion_reps['velocity'] |
|
trans = motion_reps['trans'] |
|
axis_angle = motion_reps['axis_angle'] |
|
|
|
|
|
all_frames = [] |
|
reader = imageio.get_reader(video_path) |
|
all_frames = [] |
|
for frame in reader: |
|
all_frames.append(frame) |
|
video_frames = np.array(all_frames) |
|
min_frames = min(len(video_frames), position.shape[0]) |
|
position = position[:min_frames] |
|
velocity = velocity[:min_frames] |
|
video_frames = video_frames[:min_frames] |
|
|
|
for i in tqdm(range(min_frames)): |
|
if i == 0: |
|
previous = -1 |
|
next_node = global_i + 1 |
|
elif i == min_frames - 1: |
|
previous = global_i - 1 |
|
next_node = -1 |
|
else: |
|
previous = global_i - 1 |
|
next_node = global_i + 1 |
|
graph.add_vertex( |
|
idx=global_i, |
|
name=video_id, |
|
motion=motion_reps, |
|
position=position[i], |
|
velocity=velocity[i], |
|
axis_angle=axis_angle[i], |
|
trans=trans[i], |
|
|
|
video=video_frames[i], |
|
previous=previous, |
|
next=next_node, |
|
frame=i, |
|
fps=fps, |
|
) |
|
global_i += 1 |
|
return graph |
|
|
|
def create_edges(graph): |
|
adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4] |
|
|
|
for i, node in enumerate(graph.vs): |
|
current_position = node['position'] |
|
current_velocity = node['velocity'] |
|
current_trans = node['trans'] |
|
|
|
avg_position = np.zeros(current_position.shape[0]) |
|
avg_velocity = np.zeros(current_position.shape[0]) |
|
avg_trans = 0 |
|
count = 0 |
|
for node_offset in adaptive_length: |
|
idx = i + node_offset |
|
if idx < 0 or idx >= len(graph.vs): |
|
continue |
|
if node_offset < 0: |
|
if graph.vs[idx]['next'] == -1:continue |
|
else: |
|
if graph.vs[idx]['previous'] == -1:continue |
|
|
|
other_node = graph.vs[idx] |
|
other_position = other_node['position'] |
|
other_velocity = other_node['velocity'] |
|
other_trans = other_node['trans'] |
|
|
|
avg_position += np.linalg.norm(current_position - other_position, axis=1) |
|
avg_velocity += np.linalg.norm(current_velocity - other_velocity, axis=1) |
|
avg_trans += np.linalg.norm(current_trans - other_trans, axis=0) |
|
count += 1 |
|
|
|
if count == 0: |
|
continue |
|
threshold_position = avg_position / count |
|
threshold_velocity = avg_velocity / count |
|
threshold_trans = avg_trans / count |
|
|
|
for j, other_node in enumerate(graph.vs): |
|
if i == j: |
|
continue |
|
if j == node['previous'] or j == node['next']: |
|
graph.add_edge(i, j, is_continue=1) |
|
continue |
|
other_position = other_node['position'] |
|
other_velocity = other_node['velocity'] |
|
other_trans = other_node['trans'] |
|
position_similarity = np.linalg.norm(current_position - other_position, axis=1) |
|
velocity_similarity = np.linalg.norm(current_velocity - other_velocity, axis=1) |
|
trans_similarity = np.linalg.norm(current_trans - other_trans, axis=0) |
|
if trans_similarity < threshold_trans: |
|
if np.sum(position_similarity < threshold_position) >= 45 and np.sum(velocity_similarity < threshold_velocity) >= 45: |
|
graph.add_edge(i, j, is_continue=0) |
|
|
|
print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}") |
|
in_degrees = graph.indegree() |
|
out_degrees = graph.outdegree() |
|
avg_in_degree = sum(in_degrees) / len(in_degrees) |
|
avg_out_degree = sum(out_degrees) / len(out_degrees) |
|
print(f"Average In-degree: {avg_in_degree}") |
|
print(f"Average Out-degree: {avg_out_degree}") |
|
print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}") |
|
print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}") |
|
|
|
return graph |
|
|
|
def random_walk(graph, walk_length, start_node=None): |
|
if start_node is None: |
|
start_node = np.random.choice(graph.vs) |
|
walk = [start_node] |
|
is_continue = [1] |
|
for _ in range(walk_length): |
|
current_node = walk[-1] |
|
neighbor_indices = graph.neighbors(current_node.index, mode='OUT') |
|
if not neighbor_indices: |
|
break |
|
next_idx = np.random.choice(neighbor_indices) |
|
edge_id = graph.get_eid(current_node.index, next_idx) |
|
is_cont = graph.es[edge_id]['is_continue'] |
|
walk.append(graph.vs[next_idx]) |
|
is_continue.append(is_cont) |
|
return walk, is_continue |
|
|
|
import subprocess |
|
def path_visualization(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False): |
|
all_frames = [node['video'] for node in path] |
|
average_dis_continue = 1 - sum(is_continue) / len(is_continue) |
|
if verbose_continue: |
|
print("average_dis_continue:", average_dis_continue) |
|
|
|
fps = graph.vs[0]['fps'] |
|
duration = len(all_frames) / fps |
|
|
|
def make_frame(t): |
|
idx = min(int(t * fps), len(all_frames) - 1) |
|
return all_frames[idx] |
|
|
|
video_only_path = 'video_only.mp4' |
|
video_clip = VideoClip(make_frame, duration=duration) |
|
video_clip.write_videofile( |
|
video_only_path, |
|
codec='libx264', |
|
fps=fps, |
|
audio=False |
|
) |
|
|
|
|
|
if audio_path is not None: |
|
audio_clip = AudioFileClip(audio_path) |
|
video_duration = video_clip.duration |
|
audio_duration = audio_clip.duration |
|
|
|
if audio_duration > video_duration: |
|
|
|
trimmed_audio_path = 'trimmed_audio.aac' |
|
audio_clip = audio_clip.subclip(0, video_duration) |
|
audio_clip.write_audiofile(trimmed_audio_path) |
|
audio_input = trimmed_audio_path |
|
else: |
|
audio_input = audio_path |
|
|
|
|
|
ffmpeg_command = [ |
|
'ffmpeg', '-y', |
|
'-i', video_only_path, |
|
'-i', audio_input, |
|
'-c:v', 'copy', |
|
'-c:a', 'aac', |
|
'-strict', 'experimental', |
|
save_path |
|
] |
|
subprocess.check_call(ffmpeg_command) |
|
|
|
|
|
os.remove(video_only_path) |
|
if audio_input != audio_path: |
|
os.remove(audio_input) |
|
|
|
if return_motion: |
|
all_motion = [node['axis_angle'] for node in path] |
|
all_motion = np.stack(all_motion, 0) |
|
return all_motion |
|
|
|
|
|
|
|
def generate_transition_video(frame_start_path, frame_end_path, output_video_path): |
|
import subprocess |
|
import os |
|
|
|
|
|
model_path = "./frame-interpolation-pytorch/film_net_fp32.pt" |
|
inference_script = "./frame-interpolation-pytorch/inference.py" |
|
|
|
|
|
command = [ |
|
"python", |
|
inference_script, |
|
model_path, |
|
frame_start_path, |
|
frame_end_path, |
|
"--save_path", output_video_path, |
|
"--gpu", |
|
"--frames", "3", |
|
"--fps", "30" |
|
] |
|
|
|
|
|
try: |
|
subprocess.run(command, check=True) |
|
print(f"Generated transition video saved at {output_video_path}") |
|
except subprocess.CalledProcessError as e: |
|
print(f"Error occurred while generating transition video: {e}") |
|
|
|
|
|
def path_visualization_v2(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False): |
|
''' |
|
this is for hugging face demo for fast interpolation. our paper use a diffusion based interpolation method |
|
''' |
|
all_frames = [node['video'] for node in path] |
|
average_dis_continue = 1 - sum(is_continue) / len(is_continue) |
|
if verbose_continue: |
|
print("average_dis_continue:", average_dis_continue) |
|
duration = len(all_frames) / graph.vs[0]['fps'] |
|
|
|
|
|
discontinuity_indices = [] |
|
for i, cont in enumerate(is_continue): |
|
if cont == 0: |
|
discontinuity_indices.append(i) |
|
|
|
|
|
blend_positions = [] |
|
processed_frames = set() |
|
for i in discontinuity_indices: |
|
|
|
start_idx = i - 2 |
|
end_idx = i + 2 |
|
|
|
if start_idx < 0 or end_idx >= len(all_frames): |
|
continue |
|
|
|
overlap = any(idx in processed_frames for idx in range(i - 1, i + 2)) |
|
if overlap: |
|
continue |
|
|
|
processed_frames.update(range(i - 1, i + 2)) |
|
blend_positions.append(i) |
|
|
|
|
|
temp_dir = tempfile.mkdtemp(prefix='blending_frames_') |
|
for i in tqdm(blend_positions): |
|
start_frame_idx = i - 2 |
|
end_frame_idx = i + 2 |
|
frame_start = all_frames[start_frame_idx] |
|
frame_end = all_frames[end_frame_idx] |
|
frame_start_path = os.path.join(temp_dir, f'frame_{start_frame_idx}.png') |
|
frame_end_path = os.path.join(temp_dir, f'frame_{end_frame_idx}.png') |
|
|
|
imageio.imwrite(frame_start_path, frame_start) |
|
imageio.imwrite(frame_end_path, frame_end) |
|
|
|
|
|
generated_video_path = os.path.join(temp_dir, f'generated_{start_frame_idx}_{end_frame_idx}.mp4') |
|
generate_transition_video(frame_start_path, frame_end_path, generated_video_path) |
|
|
|
|
|
reader = imageio.get_reader(generated_video_path) |
|
generated_frames = [frame for frame in reader] |
|
reader.close() |
|
|
|
|
|
total_generated_frames = len(generated_frames) |
|
if total_generated_frames < 5: |
|
print(f"Generated video has insufficient frames ({total_generated_frames}). Skipping blending at position {i}.") |
|
continue |
|
middle_start = 1 |
|
middle_frames = generated_frames[middle_start:middle_start+3] |
|
for idx, frame_idx in enumerate(range(i - 1, i + 2)): |
|
all_frames[frame_idx] = middle_frames[idx] |
|
|
|
|
|
def make_frame(t): |
|
idx = min(int(t * graph.vs[0]['fps']), len(all_frames) - 1) |
|
return all_frames[idx] |
|
|
|
video_clip = VideoClip(make_frame, duration=duration) |
|
if audio_path is not None: |
|
audio_clip = AudioFileClip(audio_path) |
|
video_clip = video_clip.set_audio(audio_clip) |
|
video_clip.write_videofile(save_path, codec='libx264', fps=graph.vs[0]['fps'], audio_codec='aac') |
|
|
|
if return_motion: |
|
all_motion = [node['axis_angle'] for node in path] |
|
all_motion = np.stack(all_motion, 0) |
|
return all_motion |
|
|
|
|
|
def graph_pruning(graph): |
|
ascc = graph.clusters(mode="STRONG") |
|
lascc = ascc.giant() |
|
print(f"before nodes: {len(graph.vs)}, edges: {len(graph.es)}") |
|
print(f"after nodes: {len(lascc.vs)}, edges: {len(lascc.es)}") |
|
in_degrees = lascc.indegree() |
|
out_degrees = lascc.outdegree() |
|
avg_in_degree = sum(in_degrees) / len(in_degrees) |
|
avg_out_degree = sum(out_degrees) / len(out_degrees) |
|
print(f"Average In-degree: {avg_in_degree}") |
|
print(f"Average Out-degree: {avg_out_degree}") |
|
print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}") |
|
print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}") |
|
return lascc |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--json_save_path", type=str, default="") |
|
parser.add_argument("--graph_save_path", type=str, default="") |
|
args = parser.parse_args() |
|
json_path = args.json_save_path |
|
graph_path = args.graph_save_path |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
smplx_model = smplx.create( |
|
"./emage/smplx_models/", |
|
model_type='smplx', |
|
gender='NEUTRAL_2020', |
|
use_face_contour=False, |
|
num_betas=300, |
|
num_expression_coeffs=100, |
|
ext='npz', |
|
use_pca=False, |
|
).to(device).eval() |
|
|
|
|
|
|
|
graph = create_graph(json_path, smplx_model) |
|
graph = create_edges(graph) |
|
|
|
|
|
|
|
|
|
walk, is_continue = random_walk(graph, 100) |
|
motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True) |
|
|
|
save_graph = graph.write_pickle(fname=graph_path) |
|
graph = graph_pruning(graph) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|