Spaces:
Runtime error
Runtime error
import torch | |
from torchaudio.pipelines import SQUIM_OBJECTIVE | |
import torchaudio | |
import evaluate | |
from transformers import ( | |
AutoModel, | |
AutoProcessor, | |
pipeline, | |
WhisperForConditionalGeneration, | |
WhisperTokenizer, | |
WhisperTokenizerFast, | |
) | |
from accelerate.utils.memory import release_memory | |
import numpy as np | |
def clap_similarity(clap_model_name_or_path, texts, audios, device, input_sampling_rate=44100): | |
clap = AutoModel.from_pretrained(clap_model_name_or_path) | |
clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path) | |
output_sampling_rate = clap_processor.feature_extractor.sampling_rate | |
if input_sampling_rate != output_sampling_rate: | |
audios = [ | |
torchaudio.functional.resample(torch.from_numpy(audio), input_sampling_rate, output_sampling_rate).numpy() | |
for audio in audios | |
] | |
clap_inputs = clap_processor( | |
text=texts, audios=audios, padding=True, return_tensors="pt", sampling_rate=output_sampling_rate | |
).to(device) | |
clap.to(device) | |
with torch.no_grad(): | |
text_features = clap.get_text_features( | |
clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None) | |
) | |
audio_features = clap.get_audio_features(clap_inputs["input_features"]) | |
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8).mean() | |
cosine_sim = cosine_sim.to("cpu") | |
clap.to("cpu") | |
clap, clap_inputs, audio_features, text_features = release_memory(clap, clap_inputs, audio_features, text_features) | |
return cosine_sim | |
def si_sdr(audios, device, input_sampling_rate=44100): | |
max_audio_length = 15 * SQUIM_OBJECTIVE.sample_rate | |
model = SQUIM_OBJECTIVE.get_model().to((device)) | |
output_sampling_rate = SQUIM_OBJECTIVE.sample_rate | |
if input_sampling_rate != output_sampling_rate: | |
audios = [ | |
torchaudio.functional.resample( | |
torch.tensor(audio)[None, :].to(device).float(), input_sampling_rate, output_sampling_rate | |
) | |
for audio in audios | |
] | |
def apply_squim(waveform): | |
with torch.no_grad(): | |
waveform = waveform[:, : min(max_audio_length, waveform.shape[1])] | |
_, _, sdr_sample = model(waveform) | |
sdr_sample = sdr_sample.cpu()[0] | |
return sdr_sample | |
si_sdrs = [apply_squim(audio) for audio in audios] | |
audios, model = release_memory(audios, model) | |
return si_sdrs | |
def wer( | |
asr_model_name_or_path, | |
prompts, | |
audios, | |
device, | |
per_device_eval_batch_size, | |
sampling_rate, | |
noise_level_to_compute_clean_wer, | |
si_sdr_measures, | |
): | |
metric = evaluate.load("wer") | |
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device, chunk_length_s=25.0) | |
return_language = None | |
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration): | |
return_language = True | |
transcriptions = asr_pipeline( | |
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios], | |
batch_size=int(per_device_eval_batch_size), | |
return_language=return_language, | |
) | |
if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)): | |
tokenizer = asr_pipeline.tokenizer | |
else: | |
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3") | |
english_normalizer = tokenizer.normalize | |
basic_normalizer = tokenizer.basic_normalize | |
normalized_predictions = [] | |
normalized_references = [] | |
for pred, ref in zip(transcriptions, prompts): | |
normalizer = ( | |
english_normalizer | |
if isinstance(pred.get("chunks", None), list) and pred["chunks"][0].get("language", None) == "english" | |
else basic_normalizer | |
) | |
norm_ref = normalizer(ref) | |
if len(norm_ref) > 0: | |
norm_pred = normalizer(pred["text"]) | |
normalized_predictions.append(norm_pred) | |
normalized_references.append(norm_ref) | |
word_error = 100 | |
clean_word_error = None | |
noisy_word_error = None | |
percent_clean_samples = 0 | |
if len(normalized_references) > 0: | |
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references) | |
if noise_level_to_compute_clean_wer and si_sdr_measures: | |
si_sdr_measures = np.array(si_sdr_measures) | |
mask = si_sdr_measures >= noise_level_to_compute_clean_wer | |
if mask.any(): | |
clean_word_error = 100 * metric.compute( | |
predictions=np.array(normalized_predictions)[mask], references=np.array(normalized_references)[mask] | |
) | |
if not mask.all(): | |
noisy_word_error = 100 * metric.compute( | |
predictions=np.array(normalized_predictions)[~mask], references=np.array(normalized_references)[~mask] | |
) | |
else: | |
noisy_word_error = 0 | |
percent_clean_samples = mask.sum() / len(mask) | |
asr_pipeline.model.to("cpu") | |
asr_pipeline = release_memory(asr_pipeline) | |
return word_error, [t["text"] for t in transcriptions], clean_word_error, noisy_word_error, percent_clean_samples | |