saksham209 commited on
Commit
397bbeb
·
1 Parent(s): 6eab4f7

Upload 6 files

Browse files
binarizer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from scipy.misc import face
4
+ import torch
5
+ from tqdm import trange
6
+ import pickle
7
+ from copy import deepcopy
8
+
9
+ from data_util.face3d_helper import Face3DHelper
10
+ from utils.commons.indexed_datasets import IndexedDataset, IndexedDatasetBuilder
11
+
12
+
13
+ def load_video_npy(fn):
14
+ assert fn.endswith(".npy")
15
+ ret_dict = np.load(fn,allow_pickle=True).item()
16
+ video_dict = {
17
+ 'coeff': ret_dict['coeff'], # [T, h]
18
+ 'lm68': ret_dict['lm68'], # [T, 68, 2]
19
+ 'lm5': ret_dict['lm5'], # [T, 5, 2]
20
+ }
21
+ return video_dict
22
+
23
+ def cal_lm3d_in_video_dict(video_dict, face3d_helper):
24
+ coeff = torch.from_numpy(video_dict['coeff']).float()
25
+ identity = coeff[:, 0:80]
26
+ exp = coeff[:, 80:144]
27
+ idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d(identity, exp).cpu().numpy()
28
+ video_dict['idexp_lm3d'] = idexp_lm3d
29
+
30
+ def load_audio_npy(fn):
31
+ assert fn.endswith(".npy")
32
+ ret_dict = np.load(fn,allow_pickle=True).item()
33
+ audio_dict = {
34
+ "mel": ret_dict['mel'], # [T, 80]
35
+ "f0": ret_dict['f0'], # [T,1]
36
+ }
37
+ return audio_dict
38
+
39
+
40
+ if __name__ == '__main__':
41
+ face3d_helper = Face3DHelper(use_gpu=False)
42
+
43
+ import glob,tqdm
44
+ prefixs = ['val', 'train']
45
+ binarized_ds_path = "data/binary/lrs3"
46
+ os.makedirs(binarized_ds_path, exist_ok=True)
47
+ for prefix in prefixs:
48
+ databuilder = IndexedDatasetBuilder(os.path.join(binarized_ds_path, prefix), gzip=False)
49
+ raw_base_dir = '/home/yezhenhui/datasets/raw/lrs3_raw'
50
+ spk_ids = sorted([dir_name.split("/")[-1] for dir_name in glob.glob(raw_base_dir + "/*")])
51
+ spk_id2spk_idx = {spk_id : i for i,spk_id in enumerate(spk_ids) }
52
+ np.save(os.path.join(binarized_ds_path, "spk_id2spk_idx.npy"), spk_id2spk_idx, allow_pickle=True)
53
+ mp4_names = glob.glob(raw_base_dir + "/*/*.mp4")
54
+ cnt = 0
55
+ for i, mp4_name in tqdm.tqdm(enumerate(mp4_names), total=len(mp4_names)):
56
+ if prefix == 'train':
57
+ if i % 100 == 0:
58
+ continue
59
+ else:
60
+ if i % 100 != 0:
61
+ continue
62
+ lst = mp4_name.split("/")
63
+ spk_id = lst[-2]
64
+ clip_id = lst[-1][:-4]
65
+ audio_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_audio.npy")
66
+ hubert_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_hubert.npy")
67
+ video_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_coeff_pt.npy")
68
+ if (not os.path.exists(audio_npy_name)) or (not os.path.exists(video_npy_name)):
69
+ print(f"Skip item for not found.")
70
+ continue
71
+ if (not os.path.exists(hubert_npy_name)):
72
+ print(f"Skip item for hubert_npy not found.")
73
+ continue
74
+ audio_dict = load_audio_npy(audio_npy_name)
75
+ hubert = np.load(hubert_npy_name)
76
+ video_dict = load_video_npy(video_npy_name)
77
+ cal_lm3d_in_video_dict(video_dict, face3d_helper)
78
+ mel = audio_dict['mel']
79
+ if mel.shape[0] < 64: # the video is shorter than 0.6s
80
+ print(f"Skip item for too short.")
81
+ continue
82
+ audio_dict.update(video_dict)
83
+ audio_dict['spk_id'] = spk_id
84
+ audio_dict['spk_idx'] = spk_id2spk_idx[spk_id]
85
+ audio_dict['item_id'] = spk_id + "_" + clip_id
86
+
87
+ audio_dict['hubert'] = hubert # [T_x, hid=1024]
88
+ databuilder.add_item(audio_dict)
89
+ cnt += 1
90
+ databuilder.finalize()
91
+ print(f"{prefix} set has {cnt} samples!")
process_audio_hubert.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Processor, HubertModel
2
+ import soundfile as sf
3
+ import numpy as np
4
+ import torch
5
+
6
+ print("Loading the Wav2Vec2 Processor...")
7
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
8
+ print("Loading the HuBERT Model...")
9
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
10
+
11
+
12
+ def get_hubert_from_16k_wav(wav_16k_name):
13
+ speech_16k, _ = sf.read(wav_16k_name)
14
+ hubert = get_hubert_from_16k_speech(speech_16k)
15
+ return hubert
16
+
17
+ @torch.no_grad()
18
+ def get_hubert_from_16k_speech(speech, device="cuda:0"):
19
+ global hubert_model
20
+ hubert_model = hubert_model.to(device)
21
+ if speech.ndim ==2:
22
+ speech = speech[:, 0] # [T, 2] ==> [T,]
23
+ input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
24
+ input_values_all = input_values_all.to(device)
25
+ # For long audio sequence, due to the memory limitation, we cannot process them in one run
26
+ # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
27
+ # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
28
+ # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
29
+ # We have the equation to calculate out time step: T = floor((t-k)/s)
30
+ # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
31
+ # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
32
+ kernel = 400
33
+ stride = 320
34
+ clip_length = stride * 1000
35
+ num_iter = input_values_all.shape[1] // clip_length
36
+ expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
37
+ res_lst = []
38
+ for i in range(num_iter):
39
+ if i == 0:
40
+ start_idx = 0
41
+ end_idx = clip_length - stride + kernel
42
+ else:
43
+ start_idx = clip_length * i
44
+ end_idx = start_idx + (clip_length - stride + kernel)
45
+ input_values = input_values_all[:, start_idx: end_idx]
46
+ hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
47
+ res_lst.append(hidden_states[0])
48
+ if num_iter > 0:
49
+ input_values = input_values_all[:, clip_length * num_iter:]
50
+ else:
51
+ input_values = input_values_all
52
+ # if input_values.shape[1] != 0:
53
+ if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
54
+ hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
55
+ res_lst.append(hidden_states[0])
56
+ ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
57
+ # assert ret.shape[0] == expected_T
58
+ assert abs(ret.shape[0] - expected_T) <= 1
59
+ if ret.shape[0] < expected_T:
60
+ ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))
61
+ else:
62
+ ret = ret[:expected_T]
63
+ return ret
64
+
65
+
66
+ if __name__ == '__main__':
67
+ ### Process Single Long Audio for NeRF dataset
68
+ # person_id = 'May'
69
+ # wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
70
+ # hubert_npy_name = f"data/processed/videos/{person_id}/hubert.npy"
71
+ # speech_16k, _ = sf.read(wav_16k_name)
72
+ # hubert_hidden = get_hubert_from_16k_speech(speech_16k)
73
+ # np.save(hubert_npy_name, hubert_hidden.detach().numpy())
74
+
75
+ ### Process short audio clips for LRS3 dataset
76
+ import glob, os, tqdm
77
+ lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/'
78
+ wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav'))
79
+ for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)):
80
+ spk_id = wav_16k_name.split("/")[-2]
81
+ clip_id = wav_16k_name.split("/")[-1][:-4]
82
+ out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy')
83
+ if os.path.exists(out_name):
84
+ continue
85
+ speech_16k, _ = sf.read(wav_16k_name)
86
+ hubert_hidden = get_hubert_from_16k_speech(speech_16k)
87
+ np.save(out_name, hubert_hidden.detach().numpy())
process_audio_mel_f0.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import glob
4
+ import os
5
+ import tqdm
6
+ import librosa
7
+ import parselmouth
8
+ from utils.commons.pitch_utils import f0_to_coarse
9
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm
10
+
11
+
12
+ def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
13
+ '''compute right padding (final frame) or both sides padding (first and final frames)
14
+ '''
15
+ assert pad_sides in (1, 2)
16
+ # return int(fsize // 2)
17
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
18
+ if pad_sides == 1:
19
+ return 0, pad
20
+ else:
21
+ return pad // 2, pad // 2 + pad % 2
22
+
23
+ def extract_mel_from_fname(wav_path,
24
+ fft_size=512,
25
+ hop_size=320,
26
+ win_length=512,
27
+ window="hann",
28
+ num_mels=80,
29
+ fmin=80,
30
+ fmax=7600,
31
+ eps=1e-6,
32
+ sample_rate=16000,
33
+ min_level_db=-100):
34
+ if isinstance(wav_path, str):
35
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
36
+ else:
37
+ wav = wav_path
38
+
39
+ # get amplitude spectrogram
40
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
41
+ win_length=win_length, window=window, center=False)
42
+ spc = np.abs(x_stft) # (n_bins, T)
43
+
44
+ # get mel basis
45
+ fmin = 0 if fmin == -1 else fmin
46
+ fmax = sample_rate / 2 if fmax == -1 else fmax
47
+ mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
48
+ mel = mel_basis @ spc
49
+
50
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
51
+ mel = mel.T
52
+
53
+ l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
54
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
55
+
56
+ return wav.T, mel
57
+
58
+ def extract_f0_from_wav_and_mel(wav, mel,
59
+ hop_size=320,
60
+ audio_sample_rate=16000,
61
+ ):
62
+ time_step = hop_size / audio_sample_rate * 1000
63
+ f0_min = 80
64
+ f0_max = 750
65
+ f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac(
66
+ time_step=time_step / 1000, voicing_threshold=0.6,
67
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
68
+
69
+ delta_l = len(mel) - len(f0)
70
+ assert np.abs(delta_l) <= 8
71
+ if delta_l > 0:
72
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
73
+ f0 = f0[:len(mel)]
74
+ pitch_coarse = f0_to_coarse(f0)
75
+ return f0, pitch_coarse
76
+
77
+ def extract_mel_f0_from_fname(fname, out_name=None):
78
+ assert fname.endswith(".wav")
79
+ if out_name is None:
80
+ out_name = fname[:-4] + '_audio.npy'
81
+
82
+ wav, mel = extract_mel_from_fname(fname)
83
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
84
+ out_dict = {
85
+ "mel": mel, # [T, 80]
86
+ "f0": f0,
87
+ }
88
+ np.save(out_name, out_dict)
89
+ return True
90
+
91
+ if __name__ == '__main__':
92
+ import os, glob
93
+ lrs3_dir = "/home/yezhenhui/datasets/raw/lrs3_raw"
94
+ wav_name_pattern = os.path.join(lrs3_dir, "*/*.wav")
95
+ wav_names = glob.glob(wav_name_pattern)
96
+ wav_names = sorted(wav_names)
97
+ for _ in multiprocess_run_tqdm(extract_mel_f0_from_fname, args=wav_names, num_workers=32,desc='extracting Mel and f0'):
98
+ pass
process_video_3dmm.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import cv2
3
+ import numpy as np
4
+ from time import time
5
+ from scipy.io import savemat
6
+ import argparse
7
+ from tqdm import tqdm, trange
8
+ import torch
9
+ import face_alignment
10
+ import deep_3drecon
11
+ from moviepy.editor import VideoFileClip
12
+ import copy
13
+
14
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
15
+
16
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, network_size=4, device='cuda')
17
+ face_reconstructor = deep_3drecon.Reconstructor()
18
+
19
+ # landmark detection in Deep3DRecon
20
+ def lm68_2_lm5(in_lm):
21
+ # in_lm: shape=[68,2]
22
+ lm_idx = np.array([31,37,40,43,46,49,55]) - 1
23
+ # 将上述特殊角点的数据取出,得到5个新的角点数据,拼接起来。
24
+ lm = np.stack([in_lm[lm_idx[0],:],np.mean(in_lm[lm_idx[[1,2]],:],0),np.mean(in_lm[lm_idx[[3,4]],:],0),in_lm[lm_idx[5],:],in_lm[lm_idx[6],:]], axis = 0)
25
+ # 将第一个角点放在了第三个位置
26
+ lm = lm[[1,2,0,3,4],:2]
27
+ return lm
28
+
29
+ def process_video(fname, out_name=None):
30
+ assert fname.endswith(".mp4")
31
+ if out_name is None:
32
+ out_name = fname[:-4] + '.npy'
33
+ tmp_name = out_name[:-4] + '.doi'
34
+ # if os.path.exists(tmp_name):
35
+ # print("tmp exist, skip")
36
+ # return
37
+ # if os.path.exists(out_name):
38
+ # print("out exisit, skip")
39
+ # return
40
+ os.system(f"touch {tmp_name}")
41
+ cap = cv2.VideoCapture(fname)
42
+ lm68_lst = []
43
+ lm5_lst = []
44
+ frame_rgb_lst = []
45
+ cnt = 0
46
+ while cap.isOpened():
47
+ ret, frame_bgr = cap.read()
48
+ if frame_bgr is None:
49
+ break
50
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
51
+ try:
52
+ lm68 = fa.get_landmarks(frame_rgb)[0] # 识别图片中的人脸,获得角点, shape=[68,2]
53
+ except:
54
+ print(f"Skip Item: Caught errors when fa.get_landmarks, maybe No face detected in some frames in {fname}!")
55
+ # print(f"Caught error at {cnt}")
56
+ cnt +=1
57
+ return None
58
+ # continue
59
+ lm5 = lm68_2_lm5(lm68)
60
+ lm68_lst.append(lm68)
61
+ lm5_lst.append(lm5)
62
+ frame_rgb_lst.append(frame_rgb)
63
+ cnt += 1
64
+ video_rgb = np.stack(frame_rgb_lst) # [t, 224,224, 3]
65
+ lm68_arr = np.stack(lm68_lst).reshape([cnt, 68, 2])
66
+ lm5_arr = np.stack(lm5_lst).reshape([cnt, 5, 2])
67
+ num_frames = cnt
68
+ batch_size = 32
69
+ iter_times = num_frames // batch_size
70
+ last_bs = num_frames % batch_size
71
+ coeff_lst = []
72
+ for i_iter in range(iter_times):
73
+ start_idx = i_iter * batch_size
74
+ batched_images = video_rgb[start_idx: start_idx + batch_size]
75
+ batched_lm5 = lm5_arr[start_idx: start_idx + batch_size]
76
+ coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True)
77
+ coeff_lst.append(coeff)
78
+ if last_bs != 0:
79
+ batched_images = video_rgb[-last_bs:]
80
+ batched_lm5 = lm5_arr[-last_bs:]
81
+ coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True)
82
+ coeff_lst.append(coeff)
83
+ coeff_arr = np.concatenate(coeff_lst,axis=0)
84
+ result_dict = {
85
+ 'coeff': coeff_arr.reshape([cnt, -1]),
86
+ 'lm68': lm68_arr.reshape([cnt, 68, 2]),
87
+ 'lm5': lm5_arr.reshape([cnt, 5, 2]),
88
+ }
89
+ np.save(out_name, result_dict)
90
+ os.system(f"rm {tmp_name}")
91
+
92
+
93
+ def split_wav(mp4_name):
94
+ wav_name = mp4_name[:-4] + '.wav'
95
+ if os.path.exists(wav_name):
96
+ return
97
+ video = VideoFileClip(mp4_name,verbose=False)
98
+ dur = video.duration
99
+ audio = video.audio
100
+ assert audio is not None
101
+ audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
102
+
103
+ if __name__ == '__main__':
104
+ ### Process Single Long video for NeRF dataset
105
+ # video_id = 'May'
106
+ # video_fname = f"data/raw/videos/{video_id}.mp4"
107
+ # out_fname = f"data/processed/videos/{video_id}/coeff.npy"
108
+ # process_video(video_fname, out_fname)
109
+
110
+ ### Process short video clips for LRS3 dataset
111
+ from argparse import ArgumentParser
112
+ parser = ArgumentParser()
113
+ parser.add_argument('--lrs3_path', type=int, default='/home/yezhenhui/datasets/raw/lrs3_raw', help='')
114
+ parser.add_argument('--process_id', type=int, default=0, help='')
115
+ parser.add_argument('--total_process', type=int, default=1, help='')
116
+ args = parser.parse_args()
117
+
118
+ import os, glob
119
+ lrs3_dir = parser.lrs3_path
120
+ mp4_name_pattern = os.path.join(lrs3_dir, "*/*.mp4")
121
+ mp4_names = glob.glob(mp4_name_pattern)
122
+ mp4_names = sorted(mp4_names)
123
+ if args.total_process > 1:
124
+ assert args.process_id <= args.total_process-1
125
+ num_samples_per_process = len(mp4_names) // args.total_process
126
+ if args.process_id == args.total_process-1:
127
+ mp4_names = mp4_names[args.process_id * num_samples_per_process : ]
128
+ else:
129
+ mp4_names = mp4_names[args.process_id * num_samples_per_process : (args.process_id+1) * num_samples_per_process]
130
+ for mp4_name in tqdm(mp4_names, desc='extracting 3DMM...'):
131
+ split_wav(mp4_name)
132
+ process_video(mp4_name,out_name=mp4_name.replace(".mp4", "_coeff_pt.npy"))
133
+
process_video_3dmm_th1kh.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import cv2
3
+ import numpy as np
4
+ from time import time
5
+ from scipy.io import savemat
6
+ import argparse
7
+ from tqdm import tqdm, trange
8
+ import torch
9
+ import face_alignment
10
+ import deep_3drecon
11
+ from moviepy.editor import VideoFileClip
12
+ import copy
13
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run
14
+ from utils.commons.meters import Timer
15
+ from decord import VideoReader
16
+ from decord import cpu, gpu
17
+ from utils.commons.face_alignment_utils import mediapipe_lm478_to_face_alignment_lm68
18
+ import mediapipe
19
+
20
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
21
+
22
+ # fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, network_size=4, device='cuda')
23
+ mp_face_mesh = mediapipe.solutions.face_mesh
24
+ face_reconstructor = deep_3drecon.Reconstructor()
25
+
26
+
27
+ def chunk(iterable, chunk_size):
28
+ final_ret = []
29
+ cnt = 0
30
+ ret = []
31
+ for record in iterable:
32
+ if cnt == 0:
33
+ ret = []
34
+ ret.append(record)
35
+ cnt += 1
36
+ if len(ret) == chunk_size:
37
+ final_ret.append(ret)
38
+ ret = []
39
+ if len(final_ret[-1]) != chunk_size:
40
+ final_ret.append(ret)
41
+ return final_ret
42
+
43
+ # landmark detection in Deep3DRecon
44
+ def lm68_2_lm5(in_lm):
45
+ assert in_lm.ndim == 2
46
+ # in_lm: shape=[68,2]
47
+ lm_idx = np.array([31,37,40,43,46,49,55]) - 1
48
+ # 将上述特殊角点的数据取出,得到5个新的角点数据,拼接起来。
49
+ lm = np.stack([in_lm[lm_idx[0],:],np.mean(in_lm[lm_idx[[1,2]],:],0),np.mean(in_lm[lm_idx[[3,4]],:],0),in_lm[lm_idx[5],:],in_lm[lm_idx[6],:]], axis = 0)
50
+ # 将第一个角点放在了第三个位置
51
+ lm = lm[[1,2,0,3,4],:2]
52
+ return lm
53
+
54
+ def extract_frames_job(fname):
55
+ out_name=fname.replace(".mp4", "_coeff_pt.npy").replace("datasets/raw/cropped_clips", "datasets/processed/coeff")
56
+ if os.path.exists(out_name):
57
+ return None
58
+ video_reader = VideoReader(fname, ctx=cpu(0))
59
+ frame_rgb_lst = video_reader.get_batch(list(range(0,len(video_reader)))).asnumpy()
60
+ return frame_rgb_lst
61
+
62
+ def extract_lms_mediapipe_job(frames):
63
+ if frames is None:
64
+ return None
65
+ with mp_face_mesh.FaceMesh(
66
+ static_image_mode=False,
67
+ max_num_faces=1,
68
+ refine_landmarks=True,
69
+ min_detection_confidence=0.5) as face_mesh:
70
+ ldms_normed = []
71
+ frame_i = 0
72
+ frame_ids = []
73
+ for i in range(len(frames)):
74
+ # Convert the BGR image to RGB before processing.
75
+ ret = face_mesh.process(frames[i])
76
+ # Print and draw face mesh landmarks on the image.
77
+ if not ret.multi_face_landmarks:
78
+ print(f"Skip Item: Caught errors when mediapipe get face_mesh, maybe No face detected in some frames!")
79
+ return None
80
+ else:
81
+ myFaceLandmarks = []
82
+ lms = ret.multi_face_landmarks[0]
83
+ for lm in lms.landmark:
84
+ myFaceLandmarks.append([lm.x, lm.y, lm.z])
85
+ ldms_normed.append(myFaceLandmarks)
86
+ frame_ids.append(frame_i)
87
+ frame_i += 1
88
+ bs, H, W, _ = frames.shape
89
+ ldms478 = np.array(ldms_normed)
90
+ lm68 = mediapipe_lm478_to_face_alignment_lm68(ldms478, H, W, return_2d=True)
91
+ lm5_lst = [lm68_2_lm5(lm68[i]) for i in range(lm68.shape[0])]
92
+ lm5 = np.stack(lm5_lst)
93
+ return ldms478, lm68, lm5
94
+
95
+ def process_video_batch(fname_lst, out_name_lst=None):
96
+ frames_lst = []
97
+ with Timer("load_frames", True):
98
+ for (i, res) in multiprocess_run_tqdm(extract_frames_job, fname_lst, num_workers=2, desc="decord is loading frames in the batch videos..."):
99
+ frames_lst.append(res)
100
+
101
+ lm478s_lst = []
102
+ lm68s_lst = []
103
+ lm5s_lst = []
104
+ with Timer("mediapipe_faceAlign", True):
105
+ for (i, res) in multiprocess_run_tqdm(extract_lms_mediapipe_job, frames_lst, num_workers=2, desc="mediapipe is predicting face mesh in batch videos..."):
106
+ if res is None:
107
+ res = (None, None, None)
108
+ lm478s, lm68s, lm5s = res
109
+ lm478s_lst.append(lm478s)
110
+ lm68s_lst.append(lm68s)
111
+ lm5s_lst.append(lm5s)
112
+
113
+ processed_cnt_in_this_batch = 0
114
+ with Timer("deep_3drecon_pytorch", True):
115
+ for i, fname in tqdm(enumerate(fname_lst), total=len(fname_lst), desc="extracting 3DMM in the batch videos..."):
116
+ video_rgb = frames_lst[i] # [t, 224,224, 3]
117
+ lm478_arr = lm478s_lst[i]
118
+ lm68_arr = lm68s_lst[i]
119
+ lm5_arr = lm5s_lst[i]
120
+ if lm5_arr is None:
121
+ continue
122
+ num_frames = len(video_rgb)
123
+ batch_size = 32
124
+ iter_times = num_frames // batch_size
125
+ last_bs = num_frames % batch_size
126
+
127
+ coeff_lst = []
128
+ for i_iter in range(iter_times):
129
+ start_idx = i_iter * batch_size
130
+ batched_images = video_rgb[start_idx: start_idx + batch_size]
131
+ batched_lm5 = lm5_arr[start_idx: start_idx + batch_size]
132
+ coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True)
133
+ coeff_lst.append(coeff)
134
+ if last_bs != 0:
135
+ batched_images = video_rgb[-last_bs:]
136
+ batched_lm5 = lm5_arr[-last_bs:]
137
+ coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True)
138
+ coeff_lst.append(coeff)
139
+ coeff_arr = np.concatenate(coeff_lst,axis=0)
140
+ result_dict = {
141
+ 'coeff': coeff_arr.reshape([num_frames, -1]).astype(np.float32),
142
+ 'lm478': lm478_arr.reshape([num_frames, 478, 3]).astype(np.float32),
143
+ 'lm68': lm68_arr.reshape([num_frames, 68, 2]).astype(np.int16),
144
+ 'lm5': lm5_arr.reshape([num_frames, 5, 2]).astype(np.int16),
145
+ }
146
+ np.save(out_name_lst[i], result_dict)
147
+ processed_cnt_in_this_batch +=1
148
+
149
+ print(f"In this batch {processed_cnt_in_this_batch} files are processed")
150
+
151
+
152
+
153
+ def split_wav(mp4_name):
154
+ wav_name = mp4_name[:-4] + '.wav'
155
+ if os.path.exists(wav_name):
156
+ return
157
+ video = VideoFileClip(mp4_name,verbose=False)
158
+ dur = video.duration
159
+ audio = video.audio
160
+ assert audio is not None
161
+ audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
162
+
163
+ if __name__ == '__main__':
164
+ ### Process Single Long video for NeRF dataset
165
+ # video_id = 'May'
166
+ # video_fname = f"data/raw/videos/{video_id}.mp4"
167
+ # out_fname = f"data/processed/videos/{video_id}/coeff.npy"
168
+ # process_video(video_fname, out_fname)
169
+
170
+ ### Process short video clips for LRS3 dataset
171
+ import random
172
+
173
+ from argparse import ArgumentParser
174
+ parser = ArgumentParser()
175
+ parser.add_argument('--lrs3_path', type=str, default='/home/yezhenhui/projects/TalkingHead-1KH/datasets/raw/cropped_clips', help='')
176
+ parser.add_argument('--process_id', type=int, default=0, help='')
177
+ parser.add_argument('--total_process', type=int, default=1, help='')
178
+ args = parser.parse_args()
179
+
180
+ import os, glob
181
+ lrs3_dir = args.lrs3_path
182
+ out_dir = lrs3_dir.replace("raw/cropped_clips", "processed/coeff")
183
+ os.makedirs(out_dir, exist_ok=True)
184
+ # mp4_name_pattern = os.path.join(lrs3_dir, "*.mp4")
185
+ # mp4_names = glob.glob(mp4_name_pattern)
186
+ with open('/home/yezhenhui/projects/LDMAvatar/clean.txt', 'r') as f:
187
+ txt = f.read()
188
+ mp4_names = txt.split("\n")
189
+ mp4_names = sorted(mp4_names)
190
+ if args.total_process > 1:
191
+ assert args.process_id <= args.total_process-1
192
+ num_samples_per_process = len(mp4_names) // args.total_process
193
+ if args.process_id == args.total_process-1:
194
+ mp4_names = mp4_names[args.process_id * num_samples_per_process : ]
195
+ else:
196
+ mp4_names = mp4_names[args.process_id * num_samples_per_process : (args.process_id+1) * num_samples_per_process]
197
+ random.seed(111)
198
+ random.shuffle(mp4_names)
199
+ batched_mp4_names_lst = chunk(mp4_names, chunk_size=8)
200
+ for batch_mp4_names in tqdm(batched_mp4_names_lst, desc='[ROOT]: extracting face mesh and 3DMM in batches...'):
201
+ try:
202
+ for mp4_name in batch_mp4_names:
203
+ split_wav(mp4_name)
204
+ out_names = [mp4_name.replace(".mp4", "_coeff_pt.npy").replace("datasets/raw/cropped_clips", "datasets/processed/coeff") for mp4_name in batch_mp4_names]
205
+ process_video_batch(batch_mp4_names, out_names)
206
+ # process_video(mp4_name,out_name=mp4_name.replace(".mp4", "_coeff_pt.npy").replace("datasets/raw/cropped_clips", "datasets/processed/coeff"))
207
+ except Exception as e:
208
+ print(e)
209
+ continue
process_video_3dmm_vox2.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import numpy as np
3
+ from tqdm import tqdm, trange
4
+ import deep_3drecon
5
+ from moviepy.editor import VideoFileClip
6
+ from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run
7
+ from utils.commons.meters import Timer
8
+ from decord import VideoReader
9
+ from decord import cpu, gpu
10
+ from utils.commons.face_alignment_utils import mediapipe_lm478_to_face_alignment_lm68
11
+ import mediapipe
12
+ import cv2
13
+
14
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
15
+
16
+ # fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, network_size=4, device='cuda')
17
+ mp_face_mesh = mediapipe.solutions.face_mesh
18
+ face_reconstructor = deep_3drecon.Reconstructor()
19
+
20
+
21
+ def chunk(iterable, chunk_size):
22
+ final_ret = []
23
+ cnt = 0
24
+ ret = []
25
+ for record in iterable:
26
+ if cnt == 0:
27
+ ret = []
28
+ ret.append(record)
29
+ cnt += 1
30
+ if len(ret) == chunk_size:
31
+ final_ret.append(ret)
32
+ ret = []
33
+ if len(final_ret[-1]) != chunk_size:
34
+ final_ret.append(ret)
35
+ return final_ret
36
+
37
+ # landmark detection in Deep3DRecon
38
+ def lm68_2_lm5(in_lm):
39
+ assert in_lm.ndim == 2
40
+ # in_lm: shape=[68,2]
41
+ lm_idx = np.array([31,37,40,43,46,49,55]) - 1
42
+ # 将上述特殊角点的数据取出,得到5个新的角点数据,拼接起来。
43
+ lm = np.stack([in_lm[lm_idx[0],:],np.mean(in_lm[lm_idx[[1,2]],:],0),np.mean(in_lm[lm_idx[[3,4]],:],0),in_lm[lm_idx[5],:],in_lm[lm_idx[6],:]], axis = 0)
44
+ # 将第一个角点放在了第三个位置
45
+ lm = lm[[1,2,0,3,4],:2]
46
+ return lm
47
+
48
+ def extract_frames_job(fname):
49
+ try:
50
+ out_name=fname.replace(".mp4", "_coeff_pt.npy").replace("/dev/", "/coeff/")
51
+ if os.path.exists(out_name):
52
+ return None
53
+ cap = cv2.VideoCapture(fname)
54
+ frames = []
55
+ while cap.isOpened():
56
+ ret, frame_bgr = cap.read()
57
+ if frame_bgr is None:
58
+ break
59
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
60
+ frames.append(frame_rgb)
61
+ return np.stack(frames)
62
+ # out_name=fname.replace(".mp4", "_coeff_pt.npy").replace("/dev/", "/coeff/")
63
+ # if os.path.exists(out_name):
64
+ # return None
65
+ # video_reader = VideoReader(fname, ctx=cpu(0))
66
+ # frame_rgb_lst = video_reader.get_batch(list(range(0,len(video_reader)))).asnumpy()
67
+ # return frame_rgb_lst
68
+ except Exception as e:
69
+ print(e)
70
+ return None
71
+
72
+ def extract_lms_mediapipe_job(frames):
73
+ try:
74
+ if frames is None:
75
+ return None
76
+ with mp_face_mesh.FaceMesh(
77
+ static_image_mode=False,
78
+ max_num_faces=1,
79
+ refine_landmarks=True,
80
+ min_detection_confidence=0.5) as face_mesh:
81
+ ldms_normed = []
82
+ frame_i = 0
83
+ frame_ids = []
84
+ for i in range(len(frames)):
85
+ # Convert the BGR image to RGB before processing.
86
+ ret = face_mesh.process(frames[i])
87
+ # Print and draw face mesh landmarks on the image.
88
+ if not ret.multi_face_landmarks:
89
+ print(f"Skip Item: Caught errors when mediapipe get face_mesh, maybe No face detected in some frames!")
90
+ return None
91
+ else:
92
+ myFaceLandmarks = []
93
+ lms = ret.multi_face_landmarks[0]
94
+ for lm in lms.landmark:
95
+ myFaceLandmarks.append([lm.x, lm.y, lm.z])
96
+ ldms_normed.append(myFaceLandmarks)
97
+ frame_ids.append(frame_i)
98
+ frame_i += 1
99
+ bs, H, W, _ = frames.shape
100
+ ldms478 = np.array(ldms_normed)
101
+ lm68 = mediapipe_lm478_to_face_alignment_lm68(ldms478, H, W, return_2d=True)
102
+ lm5_lst = [lm68_2_lm5(lm68[i]) for i in range(lm68.shape[0])]
103
+ lm5 = np.stack(lm5_lst)
104
+ return ldms478, lm68, lm5
105
+ except Exception as e:
106
+ print(e)
107
+ return None
108
+
109
+ def process_video_batch(fname_lst, out_name_lst=None):
110
+ frames_lst = []
111
+ with Timer("load_frames", True):
112
+ for fname in tqdm(fname_lst, desc="decord is loading frames in the batch videos..."):
113
+ res = extract_frames_job(fname)
114
+ frames_lst.append(res)
115
+ # for (i, res) in multiprocess_run_tqdm(extract_frames_job, fname_lst, num_workers=1, desc="decord is loading frames in the batch videos..."):
116
+ # frames_lst.append(res)
117
+
118
+ lm478s_lst = []
119
+ lm68s_lst = []
120
+ lm5s_lst = []
121
+ with Timer("mediapipe_faceAlign", True):
122
+ # for (i, res) in multiprocess_run_tqdm(extract_lms_mediapipe_job, frames_lst, num_workers=2, desc="mediapipe is predicting face mesh in batch videos..."):
123
+ for i, frames in tqdm(enumerate(frames_lst),total=len(fname_lst), desc="mediapipe is predicting face mesh in batch videos..."):
124
+ res = extract_lms_mediapipe_job(frames)
125
+ if res is None:
126
+ res = (None, None, None)
127
+ lm478s, lm68s, lm5s = res
128
+ lm478s_lst.append(lm478s)
129
+ lm68s_lst.append(lm68s)
130
+ lm5s_lst.append(lm5s)
131
+
132
+ processed_cnt_in_this_batch = 0
133
+ with Timer("deep_3drecon_pytorch", True):
134
+ for i, fname in tqdm(enumerate(fname_lst), total=len(fname_lst), desc="extracting 3DMM in the batch videos..."):
135
+ video_rgb = frames_lst[i] # [t, 224,224, 3]
136
+ lm478_arr = lm478s_lst[i]
137
+ lm68_arr = lm68s_lst[i]
138
+ lm5_arr = lm5s_lst[i]
139
+ if lm5_arr is None:
140
+ continue
141
+ num_frames = len(video_rgb)
142
+ batch_size = 32
143
+ iter_times = num_frames // batch_size
144
+ last_bs = num_frames % batch_size
145
+
146
+ coeff_lst = []
147
+ for i_iter in range(iter_times):
148
+ start_idx = i_iter * batch_size
149
+ batched_images = video_rgb[start_idx: start_idx + batch_size]
150
+ batched_lm5 = lm5_arr[start_idx: start_idx + batch_size]
151
+ coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True)
152
+ coeff_lst.append(coeff)
153
+ if last_bs != 0:
154
+ batched_images = video_rgb[-last_bs:]
155
+ batched_lm5 = lm5_arr[-last_bs:]
156
+ coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True)
157
+ coeff_lst.append(coeff)
158
+ coeff_arr = np.concatenate(coeff_lst,axis=0)
159
+ result_dict = {
160
+ 'coeff': coeff_arr.reshape([num_frames, -1]).astype(np.float32),
161
+ 'lm478': lm478_arr.reshape([num_frames, 478, 3]).astype(np.float32),
162
+ 'lm68': lm68_arr.reshape([num_frames, 68, 2]).astype(np.int16),
163
+ 'lm5': lm5_arr.reshape([num_frames, 5, 2]).astype(np.int16),
164
+ }
165
+ os.makedirs(os.path.dirname(out_name_lst[i]),exist_ok=True)
166
+ np.save(out_name_lst[i], result_dict)
167
+ processed_cnt_in_this_batch +=1
168
+
169
+ print(f"In this batch {processed_cnt_in_this_batch} files are processed")
170
+
171
+
172
+
173
+ def split_wav(mp4_name):
174
+ try:
175
+ wav_name = mp4_name[:-4] + '.wav'
176
+ if os.path.exists(wav_name):
177
+ return
178
+ video = VideoFileClip(mp4_name,verbose=False)
179
+ dur = video.duration
180
+ audio = video.audio
181
+ assert audio is not None
182
+ audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
183
+ except Exception as e:
184
+ print(e)
185
+ return None
186
+
187
+ if __name__ == '__main__':
188
+ ### Process Single Long video for NeRF dataset
189
+ # video_id = 'May'
190
+ # video_fname = f"data/raw/videos/{video_id}.mp4"
191
+ # out_fname = f"data/processed/videos/{video_id}/coeff.npy"
192
+ # process_video(video_fname, out_fname)
193
+
194
+ ### Process short video clips for LRS3 dataset
195
+ import random
196
+
197
+ from argparse import ArgumentParser
198
+ parser = ArgumentParser()
199
+ parser.add_argument('--lrs3_path', type=str, default='/mnt/sda/yezhenhui/datasets/voxceleb2', help='')
200
+ parser.add_argument('--process_id', type=int, default=0, help='')
201
+ parser.add_argument('--total_process', type=int, default=1, help='')
202
+ args = parser.parse_args()
203
+
204
+ import os, glob
205
+ lrs3_dir = args.lrs3_path
206
+ mp4_name_pattern = os.path.join(lrs3_dir, "dev/id*/*/*.mp4")
207
+ mp4_names = glob.glob(mp4_name_pattern)
208
+
209
+ if args.total_process > 1:
210
+ assert args.process_id <= args.total_process-1
211
+ num_samples_per_process = len(mp4_names) // args.total_process
212
+ if args.process_id == args.total_process-1:
213
+ mp4_names = mp4_names[args.process_id * num_samples_per_process : ]
214
+ else:
215
+ mp4_names = mp4_names[args.process_id * num_samples_per_process : (args.process_id+1) * num_samples_per_process]
216
+ random.seed(111)
217
+ random.shuffle(mp4_names)
218
+ batched_mp4_names_lst = chunk(mp4_names, chunk_size=1)
219
+ for batch_mp4_names in tqdm(batched_mp4_names_lst, desc='[ROOT]: extracting face mesh and 3DMM in batches...'):
220
+ try:
221
+ for mp4_name in batch_mp4_names:
222
+ split_wav(mp4_name)
223
+ out_names = [mp4_name.replace(".mp4", "_coeff_pt.npy").replace("/dev/", "/coeff/") for mp4_name in batch_mp4_names]
224
+ process_video_batch(batch_mp4_names, out_names)
225
+ except Exception as e:
226
+ print(e)
227
+ continue