Sof22 commited on
Commit
73eb2d5
·
1 Parent(s): d44dd3f

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +63 -11
  2. diffusion.py +145 -0
  3. utils.py +349 -0
README.md CHANGED
@@ -1,11 +1,63 @@
1
- ---
2
- title: Diffused Heads
3
- emoji: 💻
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- license: cc-by-nc-sa-4.0
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffused Heads
2
+
3
+ Official repository for Diffused Heads: Diffusion Models Beat GANs on Talking-Face Generation.
4
+
5
+ ### [Project](https://mstypulkowski.github.io/diffusedheads/) | [Paper](https://arxiv.org/abs/2301.03396) | [Demo](https://youtu.be/DSipIDj-5q0)
6
+
7
+ <p align="center">
8
+ <img src='./intro.gif' width=400>
9
+ </p>
10
+
11
+ ## Setup
12
+ Python 3.x environment with [ffmpeg](https://www.ffmpeg.org/) is needed. The rest of the requirements can be installed using:
13
+ ```
14
+ pip install -r requirements.txt
15
+ ```
16
+
17
+ ## Sampling
18
+ Due to LRW license agreement, we are only able to provide a checkpoint of our model trained on CREMA.
19
+
20
+ The entire test set generated by our method can be downloaded from [here](https://drive.google.com/file/d/1zWSqtV7O4WGkgh6WB55b8Mdg2lXXUudH/view?usp=drive_link).
21
+
22
+
23
+ 1. Download and unpack [checkpoints](https://drive.google.com/file/d/1U90egQvzERHclTYPCjZadrEMyF7TAPa-/view?usp=drive_link) (our model and pretrained audio encoder).
24
+
25
+ 2. Download and unpack preprocessed CREMA [video](https://drive.google.com/file/d/1rM0FZLGiy-bJcxpv4CTlbUf0FuROubdk/view?usp=drive_link) and [audio](https://drive.google.com/file/d/1uS7Vi8EwarJFGQhsYHDMSkQmaNuiJIVW/view?usp=drive_link) files.
26
+
27
+ 3. Specify paths and options in `config_crema.yaml` (check comments in the file).
28
+
29
+ 4. Run the script
30
+ ```
31
+ python sample.py
32
+ ```
33
+
34
+
35
+ ## Using your own data
36
+ ### Audio
37
+ You can use audio recordings of your choosing freely. The only requirements are 16 kHz audio rate and a single audio channel. Please note our model is able to generate videos up to 9 seconds long depending on the audio.
38
+
39
+ ### Identity frame
40
+ It is highly recommended to use a frame from the provided CREMA videos. This instance of the model was trained on clips with green background only. If you want to use your identity frame anyway, please follow this [repo](https://github.com/DinoMan/face-processor) for face alignment. Additionally, you may want to try segmenting the person and replacing background to green.
41
+
42
+ ## Training
43
+ A training script will be uploaded in the future (ETA December 2023).
44
+
45
+ ## Citation
46
+ ```
47
+ @article{stypulkowski2023diffused,
48
+ title={Diffused heads: Diffusion models beat gans on talking-face generation},
49
+ author={Stypu{\l}kowski, Micha{\l} and Vougioukas, Konstantinos and He, Sen and Zi{\k{e}}ba, Maciej and Petridis, Stavros and Pantic, Maja},
50
+ journal={arXiv preprint arXiv:2301.03396},
51
+ year={2023}
52
+ }
53
+ ```
54
+
55
+ ## License
56
+ This work is licensed under a
57
+ [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa].
58
+
59
+ [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa]
60
+
61
+ [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/
62
+ [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png
63
+ [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg
diffusion.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from tqdm import trange
5
+ from torchvision.transforms import Compose
6
+
7
+
8
+ class Diffusion(nn.Module):
9
+ def __init__(
10
+ self, nn_backbone, device, n_timesteps=1000, in_channels=3, image_size=128, out_channels=6, motion_transforms=None):
11
+ super(Diffusion, self).__init__()
12
+
13
+ self.nn_backbone = nn_backbone
14
+ self.n_timesteps = n_timesteps
15
+ self.in_channels = in_channels
16
+ self.out_channels = out_channels
17
+ self.x_shape = (image_size, image_size)
18
+ self.device = device
19
+
20
+ self.motion_transforms = motion_transforms if motion_transforms else Compose([])
21
+
22
+ self.timesteps = torch.arange(n_timesteps)
23
+ self.beta = self.get_beta_schedule()
24
+ self.set_params()
25
+ self.device = device
26
+
27
+ def sample(self, x_cond, audio_emb, n_audio_motion_embs=2, n_motion_frames=2, motion_channels=3):
28
+ with torch.no_grad():
29
+ n_frames = audio_emb.shape[1]
30
+
31
+ xT = torch.randn(x_cond.shape[0], n_frames, self.in_channels, self.x_shape[0], self.x_shape[1]).to(x_cond.device)
32
+
33
+ audio_ids = [0] * n_audio_motion_embs
34
+ for i in range(n_audio_motion_embs + 1):
35
+ audio_ids += [i]
36
+
37
+ motion_frames = [self.motion_transforms(x_cond) for _ in range(n_motion_frames)]
38
+ motion_frames = torch.cat(motion_frames, dim=1)
39
+
40
+ samples = []
41
+ for i in trange(n_frames, desc=f'Sampling'):
42
+ sample_frame = self.sample_loop(xT[:, i].to(x_cond.device), x_cond, motion_frames, audio_emb[:, audio_ids])
43
+ samples.append(sample_frame.unsqueeze(1))
44
+ motion_frames = torch.cat([motion_frames[:, motion_channels:, :], self.motion_transforms(sample_frame)], dim=1)
45
+ audio_ids = audio_ids[1:] + [min(i + n_audio_motion_embs + 1, n_frames - 1)]
46
+ return torch.cat(samples, dim=1)
47
+
48
+ def sample_loop(self, xT, x_cond, motion_frames, audio_emb):
49
+ xt = xT
50
+ for i, t in reversed(list(enumerate(self.timesteps))):
51
+ timesteps = torch.tensor([t] * xT.shape[0]).to(xT.device)
52
+ timesteps_ids = torch.tensor([i] * xT.shape[0]).to(xT.device)
53
+ nn_out = self.nn_backbone(xt, timesteps, x_cond, motion_frames=motion_frames, audio_emb=audio_emb)
54
+ mean, logvar = self.get_p_params(xt, timesteps_ids, nn_out)
55
+ noise = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt)
56
+ xt = mean + noise * torch.exp(logvar / 2)
57
+
58
+ return xt
59
+
60
+ def get_p_params(self, xt, timesteps, nn_out):
61
+ if self.in_channels == self.out_channels:
62
+ eps_pred = nn_out
63
+ p_logvar = self.expand(torch.log(self.beta[timesteps]))
64
+ else:
65
+ eps_pred, nu = nn_out.chunk(2, 1)
66
+ nu = (nu + 1) / 2
67
+ p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps])
68
+
69
+ p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred)
70
+ return p_mean, p_logvar
71
+
72
+ def get_q_params(self, xt, timesteps, eps_pred=None, x0=None):
73
+ if x0 is None:
74
+ # predict x0 from xt and eps_pred
75
+ coef1_x0 = self.expand(self.coef1_x0[timesteps])
76
+ coef2_x0 = self.expand(self.coef2_x0[timesteps])
77
+ x0 = coef1_x0 * xt - coef2_x0 * eps_pred
78
+ x0 = x0.clamp(-1, 1)
79
+
80
+ # q(x_{t-1} | x_t, x_0)
81
+ coef1_q = self.expand(self.coef1_q[timesteps])
82
+ coef2_q = self.expand(self.coef2_q[timesteps])
83
+ q_mean = coef1_q * x0 + coef2_q * xt
84
+
85
+ q_logvar = self.expand(self.log_beta_tilde_clipped[timesteps])
86
+
87
+ return q_mean, q_logvar
88
+
89
+ def get_beta_schedule(self, max_beta=0.999):
90
+ alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
91
+ betas = []
92
+ for i in range(self.n_timesteps):
93
+ t1 = i / self.n_timesteps
94
+ t2 = (i + 1) / self.n_timesteps
95
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
96
+ return torch.tensor(betas).float()
97
+
98
+ def set_params(self):
99
+ self.alpha = 1 - self.beta
100
+ self.alpha_bar = torch.cumprod(self.alpha, dim=0)
101
+ self.alpha_bar_prev = torch.cat([torch.ones(1,), self.alpha_bar[:-1]])
102
+
103
+ self.beta_tilde = self.beta * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar)
104
+ self.log_beta_tilde_clipped = torch.log(torch.cat([self.beta_tilde[1, None], self.beta_tilde[1:]]))
105
+
106
+ # to caluclate x0 from eps_pred
107
+ self.coef1_x0 = torch.sqrt(1.0 / self.alpha_bar)
108
+ self.coef2_x0 = torch.sqrt(1.0 / self.alpha_bar - 1)
109
+
110
+ # for q(x_{t-1} | x_t, x_0)
111
+ self.coef1_q = self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar)
112
+ self.coef2_q = (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar)
113
+
114
+ def space(self, n_timesteps_new):
115
+ # change parameters for spaced timesteps during sampling
116
+ self.timesteps = self.space_timesteps(self.n_timesteps, n_timesteps_new)
117
+ self.n_timesteps = n_timesteps_new
118
+
119
+ self.beta = self.get_spaced_beta()
120
+ self.set_params()
121
+
122
+ def space_timesteps(self, n_timesteps, target_timesteps):
123
+ all_steps = []
124
+ frac_stride = (n_timesteps - 1) / (target_timesteps - 1)
125
+ cur_idx = 0.0
126
+ taken_steps = []
127
+ for _ in range(target_timesteps):
128
+ taken_steps.append(round(cur_idx))
129
+ cur_idx += frac_stride
130
+ all_steps += taken_steps
131
+ return all_steps
132
+
133
+ def get_spaced_beta(self):
134
+ last_alpha_cumprod = 1.0
135
+ new_beta = []
136
+ for i, alpha_cumprod in enumerate(self.alpha_bar):
137
+ if i in self.timesteps:
138
+ new_beta.append(1 - alpha_cumprod / last_alpha_cumprod)
139
+ last_alpha_cumprod = alpha_cumprod
140
+ return torch.tensor(new_beta)
141
+
142
+ def expand(self, arr, dim=4):
143
+ while arr.dim() < dim:
144
+ arr = arr[:, None]
145
+ return arr.to(self.device)
utils.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import scipy.io.wavfile as wav
4
+ import ffmpeg
5
+ import cv2
6
+ from PIL import Image
7
+
8
+ import decord
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torchvision.transforms import Compose, GaussianBlur, Grayscale, Resize
14
+ import torchaudio
15
+
16
+ decord.bridge.set_bridge('torch')
17
+ torchaudio.set_audio_backend("sox_io")
18
+
19
+
20
+ class AudioEncoder(nn.Module):
21
+ """
22
+ A PyTorch Module to encode audio data into a fixed-size vector
23
+ (also known as an "embedding"). This can be useful for various machine
24
+ learning tasks such as classification, similarity matching, etc.
25
+ """
26
+ def __init__(self, path):
27
+ """
28
+ Initialize the AudioEncoder object.
29
+
30
+ Args:
31
+ path (str): The file path where the pre-trained model is stored.
32
+ """
33
+ super().__init__()
34
+ self.model = torch.jit.load(path)
35
+ self.register_buffer('hidden', torch.zeros(2, 1, 256))
36
+
37
+ def forward(self, audio):
38
+ """
39
+ The forward method is where the actual encoding happens. Given an
40
+ audio sample, this function returns its corresponding embedding.
41
+
42
+ Args:
43
+ audio (Tensor): A PyTorch tensor containing the audio data.
44
+
45
+ Returns:
46
+ Tensor: The embedding of the given audio.
47
+ """
48
+ self.reset()
49
+ x = create_windowed_sequence(audio, 3200, cutting_stride=640, pad_samples=3200-640, cut_dim=1)
50
+ embs = []
51
+ for i in range(x.shape[1]):
52
+ emb, _, self.hidden = self.model(x[:, i], torch.LongTensor([3200]), init_state=self.hidden)
53
+ embs.append(emb)
54
+ return torch.vstack(embs)
55
+
56
+ def reset(self):
57
+ """
58
+ Resets the hidden states in the model. Call this function
59
+ before processing a new audio sample to ensure that there is
60
+ no state carried over from the previous sample.
61
+ """
62
+ self.hidden = torch.zeros(2, 1, 256).to(self.hidden.device)
63
+
64
+
65
+ def get_audio_emb(audio_path, checkpoint, device):
66
+ """
67
+ This function takes the path of an audio file, loads it into a
68
+ PyTorch tensor, and returns its embedding.
69
+
70
+ Args:
71
+ audio_path (str): The file path of the audio to be loaded.
72
+ checkpoint (str): The file path of the pre-trained model.
73
+ device (str): The computing device ('cpu' or 'cuda').
74
+
75
+ Returns:
76
+ Tensor, Tensor: The original audio as a tensor and its corresponding embedding.
77
+ """
78
+ audio, audio_rate = torchaudio.load(audio_path, channels_first=False)
79
+ assert audio_rate == 16000, 'Only 16 kHZ audio is supported.'
80
+ audio = audio[None, None, :, 0].to(device)
81
+
82
+ audio_encoder = AudioEncoder(checkpoint).to(device)
83
+
84
+ emb = audio_encoder(audio)
85
+ return audio, emb
86
+
87
+
88
+ def get_id_frame(path, random=False, resize=128):
89
+ """
90
+ Retrieves a frame from either a video or image file. This frame can
91
+ serve as an identifier or reference for the video or image.
92
+
93
+ Args:
94
+ path (str): File path to the video or image.
95
+ random (bool): Whether to randomly select a frame from the video.
96
+ resize (int): The dimensions to which the frame should be resized.
97
+
98
+ Returns:
99
+ Tensor: The image frame as a tensor.
100
+ """
101
+ if path.endswith('.mp4'):
102
+ vr = decord.VideoReader(path)
103
+ if random:
104
+ idx = [np.random.randint(len(vr))]
105
+ else:
106
+ idx = [0]
107
+ frame = vr.get_batch(idx).permute(0, 3, 1, 2)
108
+ else:
109
+ frame = load_image_to_torch(path).unsqueeze(0)
110
+
111
+ frame = (frame / 255) * 2 - 1
112
+ frame = Resize((resize, resize), antialias=True)(frame).float()
113
+ return frame
114
+
115
+
116
+ def get_motion_transforms(args):
117
+ """
118
+ Applies a series of transformations like Gaussian blur and grayscale
119
+ conversion based on the provided arguments. This is commonly used for
120
+ data augmentation or preprocessing.
121
+
122
+ Args:
123
+ args (Namespace): Arguments containing options for motion transformations.
124
+
125
+ Returns:
126
+ Compose: A composed function of transforms.
127
+ """
128
+ motion_transforms = []
129
+ if args.motion_blur:
130
+ motion_transforms.append(GaussianBlur(5, sigma=2.0))
131
+ if args.grayscale_motion:
132
+ motion_transforms.append(Grayscale(1))
133
+ return Compose(motion_transforms)
134
+
135
+
136
+ def save_audio(path, audio, audio_rate=16000):
137
+ """
138
+ Saves the audio data as a WAV file.
139
+
140
+ Args:
141
+ path (str): The file path where the audio will be saved.
142
+ audio (Tensor or np.array): The audio data.
143
+ audio_rate (int): The sampling rate of the audio, defaults to 16000Hz.
144
+ """
145
+ if torch.is_tensor(audio):
146
+ aud = audio.squeeze().detach().cpu().numpy()
147
+ else:
148
+ aud = audio.copy() # Make a copy so that we don't alter the object
149
+
150
+ aud = ((2 ** 15) * aud).astype(np.int16)
151
+ wav.write(path, audio_rate, aud)
152
+
153
+
154
+ def save_video(path, video, fps=25, scale=2, audio=None, audio_rate=16000, overlay_pts=None, ffmpeg_experimental=False):
155
+ """
156
+ Saves the video data as an MP4 file. Optionally includes audio and overlay points.
157
+
158
+ Args:
159
+ path (str): The file path where the video will be saved.
160
+ video (Tensor or np.array): The video data.
161
+ fps (int): Frames per second of the video.
162
+ scale (int): Scaling factor for the video dimensions.
163
+ audio (Tensor or np.array, optional): Audio data.
164
+ audio_rate (int, optional): The sampling rate for the audio.
165
+ overlay_pts (list of points, optional): Points to overlay on the video frames.
166
+ ffmpeg_experimental (bool): Whether to use experimental ffmpeg options.
167
+
168
+ Returns:
169
+ bool: Success status.
170
+ """
171
+ if not os.path.exists(os.path.dirname(path)):
172
+ os.makedirs(os.path.dirname(path))
173
+ success = True
174
+ out_size = (scale * video.shape[-1], scale * video.shape[-2])
175
+ video_path = get_temp_path(os.path.split(path)[0], ext=".mp4")
176
+ if torch.is_tensor(video):
177
+ vid = video.squeeze().detach().cpu().numpy()
178
+ else:
179
+ vid = video.copy() # Make a copy so that we don't alter the object
180
+
181
+ if np.min(vid) < 0:
182
+ vid = 127 * vid + 127
183
+ elif np.max(vid) <= 1:
184
+ vid = 255 * vid
185
+
186
+ is_color = True
187
+ if vid.ndim == 3:
188
+ is_color = False
189
+
190
+ writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"mp4v"), float(fps), out_size, isColor=is_color)
191
+ for i, frame in enumerate(vid):
192
+ if is_color:
193
+ frame = cv2.cvtColor(np.rollaxis(frame, 0, 3), cv2.COLOR_RGB2BGR)
194
+
195
+ if scale != 1:
196
+ frame = cv2.resize(frame, out_size)
197
+
198
+ write_frame = frame.astype('uint8')
199
+
200
+ if overlay_pts is not None:
201
+ for pt in overlay_pts[i]:
202
+ cv2.circle(write_frame, (int(scale * pt[0]), int(scale * pt[1])), 2, (0, 0, 0), -1)
203
+
204
+ writer.write(write_frame)
205
+ writer.release()
206
+
207
+ inputs = [ffmpeg.input(video_path)['v']]
208
+
209
+ if audio is not None: # Save the audio file
210
+ audio_path = swp_extension(video_path, ".wav")
211
+ save_audio(audio_path, audio, audio_rate)
212
+ inputs += [ffmpeg.input(audio_path)['a']]
213
+
214
+ try:
215
+ if ffmpeg_experimental:
216
+ out = ffmpeg.output(*inputs, path, strict='-2', loglevel="panic", vcodec='h264').overwrite_output()
217
+ else:
218
+ out = ffmpeg.output(*inputs, path, loglevel="panic", vcodec='h264').overwrite_output()
219
+ out.run(quiet=True)
220
+ except:
221
+ success = False
222
+
223
+ if audio is not None and os.path.isfile(audio_path):
224
+ os.remove(audio_path)
225
+ if os.path.isfile(video_path):
226
+ os.remove(video_path)
227
+
228
+ return success
229
+
230
+
231
+ def load_image_to_torch(dir):
232
+ """
233
+ Load an image from disk and convert it to a PyTorch tensor.
234
+
235
+ Args:
236
+ dir (str): The directory path to the image file.
237
+
238
+ Returns:
239
+ torch.Tensor: A tensor representation of the image.
240
+ """
241
+ img = Image.open(dir).convert('RGB')
242
+ img = np.array(img)
243
+ return torch.from_numpy(img).permute(2, 0, 1)
244
+
245
+
246
+ def get_temp_path(tmp_dir, mode="", ext=""):
247
+ """
248
+ Generate a temporary file path for storing data.
249
+
250
+ Args:
251
+ tmp_dir (str): The directory where the temporary file will be created.
252
+ mode (str, optional): A string to append to the file name.
253
+ ext (str, optional): The file extension.
254
+
255
+ Returns:
256
+ str: The full path to the temporary file.
257
+ """
258
+ file_path = next(tempfile._get_candidate_names()) + mode + ext
259
+ if not os.path.exists(tmp_dir):
260
+ os.makedirs(tmp_dir)
261
+ file_path = os.path.join(tmp_dir, file_path)
262
+ return file_path
263
+
264
+
265
+ def swp_extension(file, ext):
266
+ """
267
+ Swap the extension of a given file name.
268
+
269
+ Args:
270
+ file (str): The original file name.
271
+ ext (str): The new extension.
272
+
273
+ Returns:
274
+ str: The file name with the new extension.
275
+ """
276
+ return os.path.splitext(file)[0] + ext
277
+
278
+
279
+ def pad_both_ends(tensor, left, right, dim=0):
280
+ """
281
+ Pad a tensor on both ends along a specific dimension.
282
+
283
+ Args:
284
+ tensor (torch.Tensor): The tensor to be padded.
285
+ left (int): The padding size for the left side.
286
+ right (int): The padding size for the right side.
287
+ dim (int, optional): The dimension along which to pad.
288
+
289
+ Returns:
290
+ torch.Tensor: The padded tensor.
291
+ """
292
+ no_dims = len(tensor.size())
293
+ if dim == -1:
294
+ dim = no_dims - 1
295
+
296
+ padding = [0] * 2 * no_dims
297
+ padding[2 * (no_dims - dim - 1)] = left
298
+ padding[2 * (no_dims - dim - 1) + 1] = right
299
+ return F.pad(tensor, padding, "constant", 0)
300
+
301
+
302
+ def cut_n_stack(seq, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0):
303
+ """
304
+ Divide a sequence tensor into smaller snips and stack them.
305
+
306
+ Args:
307
+ seq (torch.Tensor): The original sequence tensor.
308
+ snip_length (int): The length of each snip.
309
+ cut_dim (int, optional): The dimension along which to cut.
310
+ cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length.
311
+ pad_samples (int, optional): Number of samples to pad at both ends.
312
+
313
+ Returns:
314
+ torch.Tensor: A tensor containing the stacked snips.
315
+ """
316
+ if cutting_stride is None:
317
+ cutting_stride = snip_length
318
+
319
+ pad_left = pad_samples // 2
320
+ pad_right = pad_samples - pad_samples // 2
321
+
322
+ seq = pad_both_ends(seq, pad_left, pad_right, dim=cut_dim)
323
+
324
+ stacked = seq.narrow(cut_dim, 0, snip_length).unsqueeze(0)
325
+ iterations = (seq.size()[cut_dim] - snip_length) // cutting_stride + 1
326
+ for i in range(1, iterations):
327
+ stacked = torch.cat((stacked, seq.narrow(cut_dim, i * cutting_stride, snip_length).unsqueeze(0)))
328
+ return stacked
329
+
330
+
331
+ def create_windowed_sequence(seqs, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0):
332
+ """
333
+ Create a windowed sequence from a list of sequences.
334
+
335
+ Args:
336
+ seqs (list of torch.Tensor): List of sequence tensors.
337
+ snip_length (int): The length of each snip.
338
+ cut_dim (int, optional): The dimension along which to cut.
339
+ cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length.
340
+ pad_samples (int, optional): Number of samples to pad at both ends.
341
+
342
+ Returns:
343
+ torch.Tensor: A tensor containing the windowed sequences.
344
+ """
345
+ windowed_seqs = []
346
+ for seq in seqs:
347
+ windowed_seqs.append(cut_n_stack(seq, snip_length, cut_dim, cutting_stride, pad_samples).unsqueeze(0))
348
+
349
+ return torch.cat(windowed_seqs)