Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- README.md +63 -11
- diffusion.py +145 -0
- utils.py +349 -0
README.md
CHANGED
@@ -1,11 +1,63 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffused Heads
|
2 |
+
|
3 |
+
Official repository for Diffused Heads: Diffusion Models Beat GANs on Talking-Face Generation.
|
4 |
+
|
5 |
+
### [Project](https://mstypulkowski.github.io/diffusedheads/) | [Paper](https://arxiv.org/abs/2301.03396) | [Demo](https://youtu.be/DSipIDj-5q0)
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<img src='./intro.gif' width=400>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
## Setup
|
12 |
+
Python 3.x environment with [ffmpeg](https://www.ffmpeg.org/) is needed. The rest of the requirements can be installed using:
|
13 |
+
```
|
14 |
+
pip install -r requirements.txt
|
15 |
+
```
|
16 |
+
|
17 |
+
## Sampling
|
18 |
+
Due to LRW license agreement, we are only able to provide a checkpoint of our model trained on CREMA.
|
19 |
+
|
20 |
+
The entire test set generated by our method can be downloaded from [here](https://drive.google.com/file/d/1zWSqtV7O4WGkgh6WB55b8Mdg2lXXUudH/view?usp=drive_link).
|
21 |
+
|
22 |
+
|
23 |
+
1. Download and unpack [checkpoints](https://drive.google.com/file/d/1U90egQvzERHclTYPCjZadrEMyF7TAPa-/view?usp=drive_link) (our model and pretrained audio encoder).
|
24 |
+
|
25 |
+
2. Download and unpack preprocessed CREMA [video](https://drive.google.com/file/d/1rM0FZLGiy-bJcxpv4CTlbUf0FuROubdk/view?usp=drive_link) and [audio](https://drive.google.com/file/d/1uS7Vi8EwarJFGQhsYHDMSkQmaNuiJIVW/view?usp=drive_link) files.
|
26 |
+
|
27 |
+
3. Specify paths and options in `config_crema.yaml` (check comments in the file).
|
28 |
+
|
29 |
+
4. Run the script
|
30 |
+
```
|
31 |
+
python sample.py
|
32 |
+
```
|
33 |
+
|
34 |
+
|
35 |
+
## Using your own data
|
36 |
+
### Audio
|
37 |
+
You can use audio recordings of your choosing freely. The only requirements are 16 kHz audio rate and a single audio channel. Please note our model is able to generate videos up to 9 seconds long depending on the audio.
|
38 |
+
|
39 |
+
### Identity frame
|
40 |
+
It is highly recommended to use a frame from the provided CREMA videos. This instance of the model was trained on clips with green background only. If you want to use your identity frame anyway, please follow this [repo](https://github.com/DinoMan/face-processor) for face alignment. Additionally, you may want to try segmenting the person and replacing background to green.
|
41 |
+
|
42 |
+
## Training
|
43 |
+
A training script will be uploaded in the future (ETA December 2023).
|
44 |
+
|
45 |
+
## Citation
|
46 |
+
```
|
47 |
+
@article{stypulkowski2023diffused,
|
48 |
+
title={Diffused heads: Diffusion models beat gans on talking-face generation},
|
49 |
+
author={Stypu{\l}kowski, Micha{\l} and Vougioukas, Konstantinos and He, Sen and Zi{\k{e}}ba, Maciej and Petridis, Stavros and Pantic, Maja},
|
50 |
+
journal={arXiv preprint arXiv:2301.03396},
|
51 |
+
year={2023}
|
52 |
+
}
|
53 |
+
```
|
54 |
+
|
55 |
+
## License
|
56 |
+
This work is licensed under a
|
57 |
+
[Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa].
|
58 |
+
|
59 |
+
[![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa]
|
60 |
+
|
61 |
+
[cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/
|
62 |
+
[cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png
|
63 |
+
[cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg
|
diffusion.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from tqdm import trange
|
5 |
+
from torchvision.transforms import Compose
|
6 |
+
|
7 |
+
|
8 |
+
class Diffusion(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self, nn_backbone, device, n_timesteps=1000, in_channels=3, image_size=128, out_channels=6, motion_transforms=None):
|
11 |
+
super(Diffusion, self).__init__()
|
12 |
+
|
13 |
+
self.nn_backbone = nn_backbone
|
14 |
+
self.n_timesteps = n_timesteps
|
15 |
+
self.in_channels = in_channels
|
16 |
+
self.out_channels = out_channels
|
17 |
+
self.x_shape = (image_size, image_size)
|
18 |
+
self.device = device
|
19 |
+
|
20 |
+
self.motion_transforms = motion_transforms if motion_transforms else Compose([])
|
21 |
+
|
22 |
+
self.timesteps = torch.arange(n_timesteps)
|
23 |
+
self.beta = self.get_beta_schedule()
|
24 |
+
self.set_params()
|
25 |
+
self.device = device
|
26 |
+
|
27 |
+
def sample(self, x_cond, audio_emb, n_audio_motion_embs=2, n_motion_frames=2, motion_channels=3):
|
28 |
+
with torch.no_grad():
|
29 |
+
n_frames = audio_emb.shape[1]
|
30 |
+
|
31 |
+
xT = torch.randn(x_cond.shape[0], n_frames, self.in_channels, self.x_shape[0], self.x_shape[1]).to(x_cond.device)
|
32 |
+
|
33 |
+
audio_ids = [0] * n_audio_motion_embs
|
34 |
+
for i in range(n_audio_motion_embs + 1):
|
35 |
+
audio_ids += [i]
|
36 |
+
|
37 |
+
motion_frames = [self.motion_transforms(x_cond) for _ in range(n_motion_frames)]
|
38 |
+
motion_frames = torch.cat(motion_frames, dim=1)
|
39 |
+
|
40 |
+
samples = []
|
41 |
+
for i in trange(n_frames, desc=f'Sampling'):
|
42 |
+
sample_frame = self.sample_loop(xT[:, i].to(x_cond.device), x_cond, motion_frames, audio_emb[:, audio_ids])
|
43 |
+
samples.append(sample_frame.unsqueeze(1))
|
44 |
+
motion_frames = torch.cat([motion_frames[:, motion_channels:, :], self.motion_transforms(sample_frame)], dim=1)
|
45 |
+
audio_ids = audio_ids[1:] + [min(i + n_audio_motion_embs + 1, n_frames - 1)]
|
46 |
+
return torch.cat(samples, dim=1)
|
47 |
+
|
48 |
+
def sample_loop(self, xT, x_cond, motion_frames, audio_emb):
|
49 |
+
xt = xT
|
50 |
+
for i, t in reversed(list(enumerate(self.timesteps))):
|
51 |
+
timesteps = torch.tensor([t] * xT.shape[0]).to(xT.device)
|
52 |
+
timesteps_ids = torch.tensor([i] * xT.shape[0]).to(xT.device)
|
53 |
+
nn_out = self.nn_backbone(xt, timesteps, x_cond, motion_frames=motion_frames, audio_emb=audio_emb)
|
54 |
+
mean, logvar = self.get_p_params(xt, timesteps_ids, nn_out)
|
55 |
+
noise = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt)
|
56 |
+
xt = mean + noise * torch.exp(logvar / 2)
|
57 |
+
|
58 |
+
return xt
|
59 |
+
|
60 |
+
def get_p_params(self, xt, timesteps, nn_out):
|
61 |
+
if self.in_channels == self.out_channels:
|
62 |
+
eps_pred = nn_out
|
63 |
+
p_logvar = self.expand(torch.log(self.beta[timesteps]))
|
64 |
+
else:
|
65 |
+
eps_pred, nu = nn_out.chunk(2, 1)
|
66 |
+
nu = (nu + 1) / 2
|
67 |
+
p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps])
|
68 |
+
|
69 |
+
p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred)
|
70 |
+
return p_mean, p_logvar
|
71 |
+
|
72 |
+
def get_q_params(self, xt, timesteps, eps_pred=None, x0=None):
|
73 |
+
if x0 is None:
|
74 |
+
# predict x0 from xt and eps_pred
|
75 |
+
coef1_x0 = self.expand(self.coef1_x0[timesteps])
|
76 |
+
coef2_x0 = self.expand(self.coef2_x0[timesteps])
|
77 |
+
x0 = coef1_x0 * xt - coef2_x0 * eps_pred
|
78 |
+
x0 = x0.clamp(-1, 1)
|
79 |
+
|
80 |
+
# q(x_{t-1} | x_t, x_0)
|
81 |
+
coef1_q = self.expand(self.coef1_q[timesteps])
|
82 |
+
coef2_q = self.expand(self.coef2_q[timesteps])
|
83 |
+
q_mean = coef1_q * x0 + coef2_q * xt
|
84 |
+
|
85 |
+
q_logvar = self.expand(self.log_beta_tilde_clipped[timesteps])
|
86 |
+
|
87 |
+
return q_mean, q_logvar
|
88 |
+
|
89 |
+
def get_beta_schedule(self, max_beta=0.999):
|
90 |
+
alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
|
91 |
+
betas = []
|
92 |
+
for i in range(self.n_timesteps):
|
93 |
+
t1 = i / self.n_timesteps
|
94 |
+
t2 = (i + 1) / self.n_timesteps
|
95 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
96 |
+
return torch.tensor(betas).float()
|
97 |
+
|
98 |
+
def set_params(self):
|
99 |
+
self.alpha = 1 - self.beta
|
100 |
+
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
|
101 |
+
self.alpha_bar_prev = torch.cat([torch.ones(1,), self.alpha_bar[:-1]])
|
102 |
+
|
103 |
+
self.beta_tilde = self.beta * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar)
|
104 |
+
self.log_beta_tilde_clipped = torch.log(torch.cat([self.beta_tilde[1, None], self.beta_tilde[1:]]))
|
105 |
+
|
106 |
+
# to caluclate x0 from eps_pred
|
107 |
+
self.coef1_x0 = torch.sqrt(1.0 / self.alpha_bar)
|
108 |
+
self.coef2_x0 = torch.sqrt(1.0 / self.alpha_bar - 1)
|
109 |
+
|
110 |
+
# for q(x_{t-1} | x_t, x_0)
|
111 |
+
self.coef1_q = self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar)
|
112 |
+
self.coef2_q = (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar)
|
113 |
+
|
114 |
+
def space(self, n_timesteps_new):
|
115 |
+
# change parameters for spaced timesteps during sampling
|
116 |
+
self.timesteps = self.space_timesteps(self.n_timesteps, n_timesteps_new)
|
117 |
+
self.n_timesteps = n_timesteps_new
|
118 |
+
|
119 |
+
self.beta = self.get_spaced_beta()
|
120 |
+
self.set_params()
|
121 |
+
|
122 |
+
def space_timesteps(self, n_timesteps, target_timesteps):
|
123 |
+
all_steps = []
|
124 |
+
frac_stride = (n_timesteps - 1) / (target_timesteps - 1)
|
125 |
+
cur_idx = 0.0
|
126 |
+
taken_steps = []
|
127 |
+
for _ in range(target_timesteps):
|
128 |
+
taken_steps.append(round(cur_idx))
|
129 |
+
cur_idx += frac_stride
|
130 |
+
all_steps += taken_steps
|
131 |
+
return all_steps
|
132 |
+
|
133 |
+
def get_spaced_beta(self):
|
134 |
+
last_alpha_cumprod = 1.0
|
135 |
+
new_beta = []
|
136 |
+
for i, alpha_cumprod in enumerate(self.alpha_bar):
|
137 |
+
if i in self.timesteps:
|
138 |
+
new_beta.append(1 - alpha_cumprod / last_alpha_cumprod)
|
139 |
+
last_alpha_cumprod = alpha_cumprod
|
140 |
+
return torch.tensor(new_beta)
|
141 |
+
|
142 |
+
def expand(self, arr, dim=4):
|
143 |
+
while arr.dim() < dim:
|
144 |
+
arr = arr[:, None]
|
145 |
+
return arr.to(self.device)
|
utils.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import scipy.io.wavfile as wav
|
4 |
+
import ffmpeg
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import decord
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torchvision.transforms import Compose, GaussianBlur, Grayscale, Resize
|
14 |
+
import torchaudio
|
15 |
+
|
16 |
+
decord.bridge.set_bridge('torch')
|
17 |
+
torchaudio.set_audio_backend("sox_io")
|
18 |
+
|
19 |
+
|
20 |
+
class AudioEncoder(nn.Module):
|
21 |
+
"""
|
22 |
+
A PyTorch Module to encode audio data into a fixed-size vector
|
23 |
+
(also known as an "embedding"). This can be useful for various machine
|
24 |
+
learning tasks such as classification, similarity matching, etc.
|
25 |
+
"""
|
26 |
+
def __init__(self, path):
|
27 |
+
"""
|
28 |
+
Initialize the AudioEncoder object.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
path (str): The file path where the pre-trained model is stored.
|
32 |
+
"""
|
33 |
+
super().__init__()
|
34 |
+
self.model = torch.jit.load(path)
|
35 |
+
self.register_buffer('hidden', torch.zeros(2, 1, 256))
|
36 |
+
|
37 |
+
def forward(self, audio):
|
38 |
+
"""
|
39 |
+
The forward method is where the actual encoding happens. Given an
|
40 |
+
audio sample, this function returns its corresponding embedding.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
audio (Tensor): A PyTorch tensor containing the audio data.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Tensor: The embedding of the given audio.
|
47 |
+
"""
|
48 |
+
self.reset()
|
49 |
+
x = create_windowed_sequence(audio, 3200, cutting_stride=640, pad_samples=3200-640, cut_dim=1)
|
50 |
+
embs = []
|
51 |
+
for i in range(x.shape[1]):
|
52 |
+
emb, _, self.hidden = self.model(x[:, i], torch.LongTensor([3200]), init_state=self.hidden)
|
53 |
+
embs.append(emb)
|
54 |
+
return torch.vstack(embs)
|
55 |
+
|
56 |
+
def reset(self):
|
57 |
+
"""
|
58 |
+
Resets the hidden states in the model. Call this function
|
59 |
+
before processing a new audio sample to ensure that there is
|
60 |
+
no state carried over from the previous sample.
|
61 |
+
"""
|
62 |
+
self.hidden = torch.zeros(2, 1, 256).to(self.hidden.device)
|
63 |
+
|
64 |
+
|
65 |
+
def get_audio_emb(audio_path, checkpoint, device):
|
66 |
+
"""
|
67 |
+
This function takes the path of an audio file, loads it into a
|
68 |
+
PyTorch tensor, and returns its embedding.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
audio_path (str): The file path of the audio to be loaded.
|
72 |
+
checkpoint (str): The file path of the pre-trained model.
|
73 |
+
device (str): The computing device ('cpu' or 'cuda').
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Tensor, Tensor: The original audio as a tensor and its corresponding embedding.
|
77 |
+
"""
|
78 |
+
audio, audio_rate = torchaudio.load(audio_path, channels_first=False)
|
79 |
+
assert audio_rate == 16000, 'Only 16 kHZ audio is supported.'
|
80 |
+
audio = audio[None, None, :, 0].to(device)
|
81 |
+
|
82 |
+
audio_encoder = AudioEncoder(checkpoint).to(device)
|
83 |
+
|
84 |
+
emb = audio_encoder(audio)
|
85 |
+
return audio, emb
|
86 |
+
|
87 |
+
|
88 |
+
def get_id_frame(path, random=False, resize=128):
|
89 |
+
"""
|
90 |
+
Retrieves a frame from either a video or image file. This frame can
|
91 |
+
serve as an identifier or reference for the video or image.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
path (str): File path to the video or image.
|
95 |
+
random (bool): Whether to randomly select a frame from the video.
|
96 |
+
resize (int): The dimensions to which the frame should be resized.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Tensor: The image frame as a tensor.
|
100 |
+
"""
|
101 |
+
if path.endswith('.mp4'):
|
102 |
+
vr = decord.VideoReader(path)
|
103 |
+
if random:
|
104 |
+
idx = [np.random.randint(len(vr))]
|
105 |
+
else:
|
106 |
+
idx = [0]
|
107 |
+
frame = vr.get_batch(idx).permute(0, 3, 1, 2)
|
108 |
+
else:
|
109 |
+
frame = load_image_to_torch(path).unsqueeze(0)
|
110 |
+
|
111 |
+
frame = (frame / 255) * 2 - 1
|
112 |
+
frame = Resize((resize, resize), antialias=True)(frame).float()
|
113 |
+
return frame
|
114 |
+
|
115 |
+
|
116 |
+
def get_motion_transforms(args):
|
117 |
+
"""
|
118 |
+
Applies a series of transformations like Gaussian blur and grayscale
|
119 |
+
conversion based on the provided arguments. This is commonly used for
|
120 |
+
data augmentation or preprocessing.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
args (Namespace): Arguments containing options for motion transformations.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
Compose: A composed function of transforms.
|
127 |
+
"""
|
128 |
+
motion_transforms = []
|
129 |
+
if args.motion_blur:
|
130 |
+
motion_transforms.append(GaussianBlur(5, sigma=2.0))
|
131 |
+
if args.grayscale_motion:
|
132 |
+
motion_transforms.append(Grayscale(1))
|
133 |
+
return Compose(motion_transforms)
|
134 |
+
|
135 |
+
|
136 |
+
def save_audio(path, audio, audio_rate=16000):
|
137 |
+
"""
|
138 |
+
Saves the audio data as a WAV file.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
path (str): The file path where the audio will be saved.
|
142 |
+
audio (Tensor or np.array): The audio data.
|
143 |
+
audio_rate (int): The sampling rate of the audio, defaults to 16000Hz.
|
144 |
+
"""
|
145 |
+
if torch.is_tensor(audio):
|
146 |
+
aud = audio.squeeze().detach().cpu().numpy()
|
147 |
+
else:
|
148 |
+
aud = audio.copy() # Make a copy so that we don't alter the object
|
149 |
+
|
150 |
+
aud = ((2 ** 15) * aud).astype(np.int16)
|
151 |
+
wav.write(path, audio_rate, aud)
|
152 |
+
|
153 |
+
|
154 |
+
def save_video(path, video, fps=25, scale=2, audio=None, audio_rate=16000, overlay_pts=None, ffmpeg_experimental=False):
|
155 |
+
"""
|
156 |
+
Saves the video data as an MP4 file. Optionally includes audio and overlay points.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
path (str): The file path where the video will be saved.
|
160 |
+
video (Tensor or np.array): The video data.
|
161 |
+
fps (int): Frames per second of the video.
|
162 |
+
scale (int): Scaling factor for the video dimensions.
|
163 |
+
audio (Tensor or np.array, optional): Audio data.
|
164 |
+
audio_rate (int, optional): The sampling rate for the audio.
|
165 |
+
overlay_pts (list of points, optional): Points to overlay on the video frames.
|
166 |
+
ffmpeg_experimental (bool): Whether to use experimental ffmpeg options.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
bool: Success status.
|
170 |
+
"""
|
171 |
+
if not os.path.exists(os.path.dirname(path)):
|
172 |
+
os.makedirs(os.path.dirname(path))
|
173 |
+
success = True
|
174 |
+
out_size = (scale * video.shape[-1], scale * video.shape[-2])
|
175 |
+
video_path = get_temp_path(os.path.split(path)[0], ext=".mp4")
|
176 |
+
if torch.is_tensor(video):
|
177 |
+
vid = video.squeeze().detach().cpu().numpy()
|
178 |
+
else:
|
179 |
+
vid = video.copy() # Make a copy so that we don't alter the object
|
180 |
+
|
181 |
+
if np.min(vid) < 0:
|
182 |
+
vid = 127 * vid + 127
|
183 |
+
elif np.max(vid) <= 1:
|
184 |
+
vid = 255 * vid
|
185 |
+
|
186 |
+
is_color = True
|
187 |
+
if vid.ndim == 3:
|
188 |
+
is_color = False
|
189 |
+
|
190 |
+
writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"mp4v"), float(fps), out_size, isColor=is_color)
|
191 |
+
for i, frame in enumerate(vid):
|
192 |
+
if is_color:
|
193 |
+
frame = cv2.cvtColor(np.rollaxis(frame, 0, 3), cv2.COLOR_RGB2BGR)
|
194 |
+
|
195 |
+
if scale != 1:
|
196 |
+
frame = cv2.resize(frame, out_size)
|
197 |
+
|
198 |
+
write_frame = frame.astype('uint8')
|
199 |
+
|
200 |
+
if overlay_pts is not None:
|
201 |
+
for pt in overlay_pts[i]:
|
202 |
+
cv2.circle(write_frame, (int(scale * pt[0]), int(scale * pt[1])), 2, (0, 0, 0), -1)
|
203 |
+
|
204 |
+
writer.write(write_frame)
|
205 |
+
writer.release()
|
206 |
+
|
207 |
+
inputs = [ffmpeg.input(video_path)['v']]
|
208 |
+
|
209 |
+
if audio is not None: # Save the audio file
|
210 |
+
audio_path = swp_extension(video_path, ".wav")
|
211 |
+
save_audio(audio_path, audio, audio_rate)
|
212 |
+
inputs += [ffmpeg.input(audio_path)['a']]
|
213 |
+
|
214 |
+
try:
|
215 |
+
if ffmpeg_experimental:
|
216 |
+
out = ffmpeg.output(*inputs, path, strict='-2', loglevel="panic", vcodec='h264').overwrite_output()
|
217 |
+
else:
|
218 |
+
out = ffmpeg.output(*inputs, path, loglevel="panic", vcodec='h264').overwrite_output()
|
219 |
+
out.run(quiet=True)
|
220 |
+
except:
|
221 |
+
success = False
|
222 |
+
|
223 |
+
if audio is not None and os.path.isfile(audio_path):
|
224 |
+
os.remove(audio_path)
|
225 |
+
if os.path.isfile(video_path):
|
226 |
+
os.remove(video_path)
|
227 |
+
|
228 |
+
return success
|
229 |
+
|
230 |
+
|
231 |
+
def load_image_to_torch(dir):
|
232 |
+
"""
|
233 |
+
Load an image from disk and convert it to a PyTorch tensor.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
dir (str): The directory path to the image file.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
torch.Tensor: A tensor representation of the image.
|
240 |
+
"""
|
241 |
+
img = Image.open(dir).convert('RGB')
|
242 |
+
img = np.array(img)
|
243 |
+
return torch.from_numpy(img).permute(2, 0, 1)
|
244 |
+
|
245 |
+
|
246 |
+
def get_temp_path(tmp_dir, mode="", ext=""):
|
247 |
+
"""
|
248 |
+
Generate a temporary file path for storing data.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
tmp_dir (str): The directory where the temporary file will be created.
|
252 |
+
mode (str, optional): A string to append to the file name.
|
253 |
+
ext (str, optional): The file extension.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
str: The full path to the temporary file.
|
257 |
+
"""
|
258 |
+
file_path = next(tempfile._get_candidate_names()) + mode + ext
|
259 |
+
if not os.path.exists(tmp_dir):
|
260 |
+
os.makedirs(tmp_dir)
|
261 |
+
file_path = os.path.join(tmp_dir, file_path)
|
262 |
+
return file_path
|
263 |
+
|
264 |
+
|
265 |
+
def swp_extension(file, ext):
|
266 |
+
"""
|
267 |
+
Swap the extension of a given file name.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
file (str): The original file name.
|
271 |
+
ext (str): The new extension.
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
str: The file name with the new extension.
|
275 |
+
"""
|
276 |
+
return os.path.splitext(file)[0] + ext
|
277 |
+
|
278 |
+
|
279 |
+
def pad_both_ends(tensor, left, right, dim=0):
|
280 |
+
"""
|
281 |
+
Pad a tensor on both ends along a specific dimension.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
tensor (torch.Tensor): The tensor to be padded.
|
285 |
+
left (int): The padding size for the left side.
|
286 |
+
right (int): The padding size for the right side.
|
287 |
+
dim (int, optional): The dimension along which to pad.
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
torch.Tensor: The padded tensor.
|
291 |
+
"""
|
292 |
+
no_dims = len(tensor.size())
|
293 |
+
if dim == -1:
|
294 |
+
dim = no_dims - 1
|
295 |
+
|
296 |
+
padding = [0] * 2 * no_dims
|
297 |
+
padding[2 * (no_dims - dim - 1)] = left
|
298 |
+
padding[2 * (no_dims - dim - 1) + 1] = right
|
299 |
+
return F.pad(tensor, padding, "constant", 0)
|
300 |
+
|
301 |
+
|
302 |
+
def cut_n_stack(seq, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0):
|
303 |
+
"""
|
304 |
+
Divide a sequence tensor into smaller snips and stack them.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
seq (torch.Tensor): The original sequence tensor.
|
308 |
+
snip_length (int): The length of each snip.
|
309 |
+
cut_dim (int, optional): The dimension along which to cut.
|
310 |
+
cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length.
|
311 |
+
pad_samples (int, optional): Number of samples to pad at both ends.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
torch.Tensor: A tensor containing the stacked snips.
|
315 |
+
"""
|
316 |
+
if cutting_stride is None:
|
317 |
+
cutting_stride = snip_length
|
318 |
+
|
319 |
+
pad_left = pad_samples // 2
|
320 |
+
pad_right = pad_samples - pad_samples // 2
|
321 |
+
|
322 |
+
seq = pad_both_ends(seq, pad_left, pad_right, dim=cut_dim)
|
323 |
+
|
324 |
+
stacked = seq.narrow(cut_dim, 0, snip_length).unsqueeze(0)
|
325 |
+
iterations = (seq.size()[cut_dim] - snip_length) // cutting_stride + 1
|
326 |
+
for i in range(1, iterations):
|
327 |
+
stacked = torch.cat((stacked, seq.narrow(cut_dim, i * cutting_stride, snip_length).unsqueeze(0)))
|
328 |
+
return stacked
|
329 |
+
|
330 |
+
|
331 |
+
def create_windowed_sequence(seqs, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0):
|
332 |
+
"""
|
333 |
+
Create a windowed sequence from a list of sequences.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
seqs (list of torch.Tensor): List of sequence tensors.
|
337 |
+
snip_length (int): The length of each snip.
|
338 |
+
cut_dim (int, optional): The dimension along which to cut.
|
339 |
+
cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length.
|
340 |
+
pad_samples (int, optional): Number of samples to pad at both ends.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
torch.Tensor: A tensor containing the windowed sequences.
|
344 |
+
"""
|
345 |
+
windowed_seqs = []
|
346 |
+
for seq in seqs:
|
347 |
+
windowed_seqs.append(cut_n_stack(seq, snip_length, cut_dim, cutting_stride, pad_samples).unsqueeze(0))
|
348 |
+
|
349 |
+
return torch.cat(windowed_seqs)
|