Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
import scipy.io.wavfile as wav | |
import ffmpeg | |
import cv2 | |
from PIL import Image | |
import decord | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.transforms import Compose, GaussianBlur, Grayscale, Resize | |
import torchaudio | |
decord.bridge.set_bridge('torch') | |
torchaudio.set_audio_backend("sox_io") | |
class AudioEncoder(nn.Module): | |
""" | |
A PyTorch Module to encode audio data into a fixed-size vector | |
(also known as an "embedding"). This can be useful for various machine | |
learning tasks such as classification, similarity matching, etc. | |
""" | |
def __init__(self, path): | |
""" | |
Initialize the AudioEncoder object. | |
Args: | |
path (str): The file path where the pre-trained model is stored. | |
""" | |
super().__init__() | |
self.model = torch.jit.load(path) | |
self.register_buffer('hidden', torch.zeros(2, 1, 256)) | |
def forward(self, audio): | |
""" | |
The forward method is where the actual encoding happens. Given an | |
audio sample, this function returns its corresponding embedding. | |
Args: | |
audio (Tensor): A PyTorch tensor containing the audio data. | |
Returns: | |
Tensor: The embedding of the given audio. | |
""" | |
self.reset() | |
x = create_windowed_sequence(audio, 3200, cutting_stride=640, pad_samples=3200-640, cut_dim=1) | |
embs = [] | |
for i in range(x.shape[1]): | |
emb, _, self.hidden = self.model(x[:, i], torch.LongTensor([3200]), init_state=self.hidden) | |
embs.append(emb) | |
return torch.vstack(embs) | |
def reset(self): | |
""" | |
Resets the hidden states in the model. Call this function | |
before processing a new audio sample to ensure that there is | |
no state carried over from the previous sample. | |
""" | |
self.hidden = torch.zeros(2, 1, 256).to(self.hidden.device) | |
def get_audio_emb(audio_path, checkpoint, device): | |
""" | |
This function takes the path of an audio file, loads it into a | |
PyTorch tensor, and returns its embedding. | |
Args: | |
audio_path (str): The file path of the audio to be loaded. | |
checkpoint (str): The file path of the pre-trained model. | |
device (str): The computing device ('cpu' or 'cuda'). | |
Returns: | |
Tensor, Tensor: The original audio as a tensor and its corresponding embedding. | |
""" | |
audio, audio_rate = torchaudio.load(audio_path, channels_first=False) | |
assert audio_rate == 16000, 'Only 16 kHZ audio is supported.' | |
audio = audio[None, None, :, 0].to(device) | |
audio_encoder = AudioEncoder(checkpoint).to(device) | |
emb = audio_encoder(audio) | |
return audio, emb | |
def get_id_frame(path, random=False, resize=128): | |
""" | |
Retrieves a frame from either a video or image file. This frame can | |
serve as an identifier or reference for the video or image. | |
Args: | |
path (str): File path to the video or image. | |
random (bool): Whether to randomly select a frame from the video. | |
resize (int): The dimensions to which the frame should be resized. | |
Returns: | |
Tensor: The image frame as a tensor. | |
""" | |
if path.endswith('.mp4'): | |
vr = decord.VideoReader(path) | |
if random: | |
idx = [np.random.randint(len(vr))] | |
else: | |
idx = [0] | |
frame = vr.get_batch(idx).permute(0, 3, 1, 2) | |
else: | |
frame = load_image_to_torch(path).unsqueeze(0) | |
frame = (frame / 255) * 2 - 1 | |
frame = Resize((resize, resize), antialias=True)(frame).float() | |
return frame | |
def get_motion_transforms(args): | |
""" | |
Applies a series of transformations like Gaussian blur and grayscale | |
conversion based on the provided arguments. This is commonly used for | |
data augmentation or preprocessing. | |
Args: | |
args (Namespace): Arguments containing options for motion transformations. | |
Returns: | |
Compose: A composed function of transforms. | |
""" | |
motion_transforms = [] | |
if args.motion_blur: | |
motion_transforms.append(GaussianBlur(5, sigma=2.0)) | |
if args.grayscale_motion: | |
motion_transforms.append(Grayscale(1)) | |
return Compose(motion_transforms) | |
def save_audio(path, audio, audio_rate=16000): | |
""" | |
Saves the audio data as a WAV file. | |
Args: | |
path (str): The file path where the audio will be saved. | |
audio (Tensor or np.array): The audio data. | |
audio_rate (int): The sampling rate of the audio, defaults to 16000Hz. | |
""" | |
if torch.is_tensor(audio): | |
aud = audio.squeeze().detach().cpu().numpy() | |
else: | |
aud = audio.copy() # Make a copy so that we don't alter the object | |
aud = ((2 ** 15) * aud).astype(np.int16) | |
wav.write(path, audio_rate, aud) | |
def save_video(path, video, fps=25, scale=2, audio=None, audio_rate=16000, overlay_pts=None, ffmpeg_experimental=False): | |
""" | |
Saves the video data as an MP4 file. Optionally includes audio and overlay points. | |
Args: | |
path (str): The file path where the video will be saved. | |
video (Tensor or np.array): The video data. | |
fps (int): Frames per second of the video. | |
scale (int): Scaling factor for the video dimensions. | |
audio (Tensor or np.array, optional): Audio data. | |
audio_rate (int, optional): The sampling rate for the audio. | |
overlay_pts (list of points, optional): Points to overlay on the video frames. | |
ffmpeg_experimental (bool): Whether to use experimental ffmpeg options. | |
Returns: | |
bool: Success status. | |
""" | |
if not os.path.exists(os.path.dirname(path)): | |
os.makedirs(os.path.dirname(path)) | |
success = True | |
out_size = (scale * video.shape[-1], scale * video.shape[-2]) | |
video_path = get_temp_path(os.path.split(path)[0], ext=".mp4") | |
if torch.is_tensor(video): | |
vid = video.squeeze().detach().cpu().numpy() | |
else: | |
vid = video.copy() # Make a copy so that we don't alter the object | |
if np.min(vid) < 0: | |
vid = 127 * vid + 127 | |
elif np.max(vid) <= 1: | |
vid = 255 * vid | |
is_color = True | |
if vid.ndim == 3: | |
is_color = False | |
writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"mp4v"), float(fps), out_size, isColor=is_color) | |
for i, frame in enumerate(vid): | |
if is_color: | |
frame = cv2.cvtColor(np.rollaxis(frame, 0, 3), cv2.COLOR_RGB2BGR) | |
if scale != 1: | |
frame = cv2.resize(frame, out_size) | |
write_frame = frame.astype('uint8') | |
if overlay_pts is not None: | |
for pt in overlay_pts[i]: | |
cv2.circle(write_frame, (int(scale * pt[0]), int(scale * pt[1])), 2, (0, 0, 0), -1) | |
writer.write(write_frame) | |
writer.release() | |
inputs = [ffmpeg.input(video_path)['v']] | |
if audio is not None: # Save the audio file | |
audio_path = swp_extension(video_path, ".wav") | |
save_audio(audio_path, audio, audio_rate) | |
inputs += [ffmpeg.input(audio_path)['a']] | |
try: | |
if ffmpeg_experimental: | |
out = ffmpeg.output(*inputs, path, strict='-2', loglevel="panic", vcodec='h264').overwrite_output() | |
else: | |
out = ffmpeg.output(*inputs, path, loglevel="panic", vcodec='h264').overwrite_output() | |
out.run(quiet=True) | |
except: | |
success = False | |
if audio is not None and os.path.isfile(audio_path): | |
os.remove(audio_path) | |
if os.path.isfile(video_path): | |
os.remove(video_path) | |
return success | |
def load_image_to_torch(dir): | |
""" | |
Load an image from disk and convert it to a PyTorch tensor. | |
Args: | |
dir (str): The directory path to the image file. | |
Returns: | |
torch.Tensor: A tensor representation of the image. | |
""" | |
img = Image.open(dir).convert('RGB') | |
img = np.array(img) | |
return torch.from_numpy(img).permute(2, 0, 1) | |
def get_temp_path(tmp_dir, mode="", ext=""): | |
""" | |
Generate a temporary file path for storing data. | |
Args: | |
tmp_dir (str): The directory where the temporary file will be created. | |
mode (str, optional): A string to append to the file name. | |
ext (str, optional): The file extension. | |
Returns: | |
str: The full path to the temporary file. | |
""" | |
file_path = next(tempfile._get_candidate_names()) + mode + ext | |
if not os.path.exists(tmp_dir): | |
os.makedirs(tmp_dir) | |
file_path = os.path.join(tmp_dir, file_path) | |
return file_path | |
def swp_extension(file, ext): | |
""" | |
Swap the extension of a given file name. | |
Args: | |
file (str): The original file name. | |
ext (str): The new extension. | |
Returns: | |
str: The file name with the new extension. | |
""" | |
return os.path.splitext(file)[0] + ext | |
def pad_both_ends(tensor, left, right, dim=0): | |
""" | |
Pad a tensor on both ends along a specific dimension. | |
Args: | |
tensor (torch.Tensor): The tensor to be padded. | |
left (int): The padding size for the left side. | |
right (int): The padding size for the right side. | |
dim (int, optional): The dimension along which to pad. | |
Returns: | |
torch.Tensor: The padded tensor. | |
""" | |
no_dims = len(tensor.size()) | |
if dim == -1: | |
dim = no_dims - 1 | |
padding = [0] * 2 * no_dims | |
padding[2 * (no_dims - dim - 1)] = left | |
padding[2 * (no_dims - dim - 1) + 1] = right | |
return F.pad(tensor, padding, "constant", 0) | |
def cut_n_stack(seq, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0): | |
""" | |
Divide a sequence tensor into smaller snips and stack them. | |
Args: | |
seq (torch.Tensor): The original sequence tensor. | |
snip_length (int): The length of each snip. | |
cut_dim (int, optional): The dimension along which to cut. | |
cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length. | |
pad_samples (int, optional): Number of samples to pad at both ends. | |
Returns: | |
torch.Tensor: A tensor containing the stacked snips. | |
""" | |
if cutting_stride is None: | |
cutting_stride = snip_length | |
pad_left = pad_samples // 2 | |
pad_right = pad_samples - pad_samples // 2 | |
seq = pad_both_ends(seq, pad_left, pad_right, dim=cut_dim) | |
stacked = seq.narrow(cut_dim, 0, snip_length).unsqueeze(0) | |
iterations = (seq.size()[cut_dim] - snip_length) // cutting_stride + 1 | |
for i in range(1, iterations): | |
stacked = torch.cat((stacked, seq.narrow(cut_dim, i * cutting_stride, snip_length).unsqueeze(0))) | |
return stacked | |
def create_windowed_sequence(seqs, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0): | |
""" | |
Create a windowed sequence from a list of sequences. | |
Args: | |
seqs (list of torch.Tensor): List of sequence tensors. | |
snip_length (int): The length of each snip. | |
cut_dim (int, optional): The dimension along which to cut. | |
cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length. | |
pad_samples (int, optional): Number of samples to pad at both ends. | |
Returns: | |
torch.Tensor: A tensor containing the windowed sequences. | |
""" | |
windowed_seqs = [] | |
for seq in seqs: | |
windowed_seqs.append(cut_n_stack(seq, snip_length, cut_dim, cutting_stride, pad_samples).unsqueeze(0)) | |
return torch.cat(windowed_seqs) |