Wav2lip / audio_processor.py
Eraldo123897's picture
Upload 46 files
a593b7a verified
# pylint: disable=C0301
'''
This module contains the AudioProcessor class and related functions for processing audio data.
It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
and audio separation. The class is initialized with configuration parameters and can process
audio files using the provided models.
'''
import math
import os
import librosa
import numpy as np
import torch
from audio_separator.separator import Separator
from einops import rearrange
from transformers import Wav2Vec2FeatureExtractor
from hallo.models.wav2vec import Wav2VecModel
from hallo.utils.util import resample_audio
class AudioProcessor:
"""
AudioProcessor is a class that handles the processing of audio files.
It takes care of preprocessing the audio files, extracting features
using wav2vec models, and separating audio signals if needed.
:param sample_rate: Sampling rate of the audio file
:param fps: Frames per second for the extracted features
:param wav2vec_model_path: Path to the wav2vec model
:param only_last_features: Whether to only use the last features
:param audio_separator_model_path: Path to the audio separator model
:param audio_separator_model_name: Name of the audio separator model
:param cache_dir: Directory to cache the intermediate results
:param device: Device to run the processing on
"""
def __init__(
self,
sample_rate,
fps,
wav2vec_model_path,
only_last_features,
audio_separator_model_path:str=None,
audio_separator_model_name:str=None,
cache_dir:str='',
device="cuda:0",
) -> None:
self.sample_rate = sample_rate
self.fps = fps
self.device = device
self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device)
self.audio_encoder.feature_extractor._freeze_parameters()
self.only_last_features = only_last_features
if audio_separator_model_name is not None:
try:
os.makedirs(cache_dir, exist_ok=True)
except OSError as _:
print("Fail to create the output cache dir.")
self.audio_separator = Separator(
output_dir=cache_dir,
output_single_stem="vocals",
model_file_dir=audio_separator_model_path,
)
self.audio_separator.load_model(audio_separator_model_name)
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
else:
self.audio_separator=None
print("Use audio directly without vocals seperator.")
self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
def preprocess(self, wav_file: str, clip_length: int):
"""
Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
The separated vocal track is then converted into wav2vec2 for further processing or analysis.
Args:
wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
Raises:
RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues
such as file not found, unsupported file format, or errors during the audio processing steps.
Returns:
torch.tensor: Returns an audio embedding as a torch.tensor
"""
if self.audio_separator is not None:
# 1. separate vocals
# TODO: process in memory
outputs = self.audio_separator.separate(wav_file)
if len(outputs) <= 0:
raise RuntimeError("Audio separate failed.")
vocal_audio_file = outputs[0]
vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
else:
vocal_audio_file=wav_file
# 2. extract wav2vec features
speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate)
audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
audio_length = seq_len
audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
if seq_len % clip_length != 0:
audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
seq_len += clip_length - seq_len % clip_length
audio_feature = audio_feature.unsqueeze(0)
with torch.no_grad():
embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True)
assert len(embeddings) > 0, "Fail to extract audio embedding"
if self.only_last_features:
audio_emb = embeddings.last_hidden_state.squeeze()
else:
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
audio_emb = rearrange(audio_emb, "b s d -> s b d")
audio_emb = audio_emb.cpu().detach()
return audio_emb, audio_length
def get_embedding(self, wav_file: str):
"""preprocess wav audio file convert to embeddings
Args:
wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
Returns:
torch.tensor: Returns an audio embedding as a torch.tensor
"""
speech_array, sampling_rate = librosa.load(
wav_file, sr=self.sample_rate)
assert sampling_rate == 16000, "The audio sample rate must be 16000"
audio_feature = np.squeeze(self.wav2vec_feature_extractor(
speech_array, sampling_rate=sampling_rate).input_values)
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
audio_feature = torch.from_numpy(
audio_feature).float().to(device=self.device)
audio_feature = audio_feature.unsqueeze(0)
with torch.no_grad():
embeddings = self.audio_encoder(
audio_feature, seq_len=seq_len, output_hidden_states=True)
assert len(embeddings) > 0, "Fail to extract audio embedding"
if self.only_last_features:
audio_emb = embeddings.last_hidden_state.squeeze()
else:
audio_emb = torch.stack(
embeddings.hidden_states[1:], dim=1).squeeze(0)
audio_emb = rearrange(audio_emb, "b s d -> s b d")
audio_emb = audio_emb.cpu().detach()
return audio_emb
def close(self):
"""
TODO: to be implemented
"""
return self
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()