Spaces:
Build error
Build error
# Copyright (c) 2024 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
import os | |
import yaml | |
import gradio as gr | |
import librosa | |
from pydub import AudioSegment | |
import soundfile as sf | |
import numpy as np | |
import torch | |
import laion_clap | |
from inference_utils import prepare_tokenizer, prepare_model, inference | |
from data import AudioTextDataProcessor | |
def load_laionclap(): | |
model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').cuda() | |
model.load_ckpt(ckpt='630k-audioset-fusion-best.pt') | |
model.eval() | |
return model | |
def int16_to_float32(x): | |
return (x / 32767.0).astype(np.float32) | |
def float32_to_int16(x): | |
x = np.clip(x, a_min=-1., a_max=1.) | |
return (x * 32767.).astype(np.int16) | |
def load_audio(file_path, target_sr=44100, duration=33.25, start=0.0): | |
if file_path.endswith('.mp3'): | |
audio = AudioSegment.from_file(file_path) | |
if len(audio) > (start + duration) * 1000: | |
audio = audio[start * 1000:(start + duration) * 1000] | |
if audio.frame_rate != target_sr: | |
audio = audio.set_frame_rate(target_sr) | |
if audio.channels > 1: | |
audio = audio.set_channels(1) | |
data = np.array(audio.get_array_of_samples()) | |
if audio.sample_width == 2: | |
data = data.astype(np.float32) / np.iinfo(np.int16).max | |
elif audio.sample_width == 4: | |
data = data.astype(np.float32) / np.iinfo(np.int32).max | |
else: | |
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
else: | |
with sf.SoundFile(file_path) as audio: | |
original_sr = audio.samplerate | |
channels = audio.channels | |
max_frames = int((start + duration) * original_sr) | |
audio.seek(int(start * original_sr)) | |
frames_to_read = min(max_frames, len(audio)) | |
data = audio.read(frames_to_read) | |
if data.max() > 1 or data.min() < -1: | |
data = data / max(abs(data.max()), abs(data.min())) | |
if original_sr != target_sr: | |
if channels == 1: | |
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
else: | |
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
else: | |
if channels != 1: | |
data = data.T[0] | |
if data.min() >= 0: | |
data = 2 * data / abs(data.max()) - 1.0 | |
else: | |
data = data / max(abs(data.max()), abs(data.min())) | |
return data | |
def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs): | |
try: | |
data = load_audio(audio_file, target_sr=48000) | |
except Exception as e: | |
print(audio_file, 'unsuccessful due to', e) | |
return [0.0] * len(outputs) | |
audio_data = data.reshape(1, -1) | |
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda() | |
audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True) | |
text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True) | |
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) | |
cos_similarity = cos(audio_embed.repeat(text_embed.shape[0], 1), text_embed) | |
return cos_similarity.squeeze().cpu().numpy() | |
inference_kwargs = { | |
"do_sample": True, | |
"top_k": 50, | |
"top_p": 0.95, | |
"num_return_sequences": 10 | |
} | |
config = yaml.load(open('chat.yaml'), Loader=yaml.FullLoader) | |
clap_config = config['clap_config'] | |
model_config = config['model_config'] | |
text_tokenizer = prepare_tokenizer(model_config) | |
DataProcessor = AudioTextDataProcessor( | |
data_root='./', | |
clap_config=clap_config, | |
tokenizer=text_tokenizer, | |
max_tokens=512, | |
) | |
laionclap_model = load_laionclap() | |
model = prepare_model( | |
model_config=model_config, | |
clap_config=clap_config, | |
checkpoint_path='chat.pt' | |
) | |
def inference_item(name, prompt): | |
item = { | |
'name': str(name), | |
'prefix': 'The task is dialog.', | |
'prompt': str(prompt) | |
} | |
processed_item = DataProcessor.process(item) | |
outputs = inference( | |
model, text_tokenizer, item, processed_item, | |
inference_kwargs, | |
) | |
laionclap_scores = compute_laionclap_text_audio_sim( | |
item["name"], | |
laionclap_model, | |
outputs | |
) | |
outputs_joint = [(output, score) for (output, score) in zip(outputs, laionclap_scores)] | |
outputs_joint.sort(key=lambda x: -x[1]) | |
return outputs_joint[0][0] | |
with gr.Blocks(title="Audio Flamingo - Demo") as ui: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 900px; margin: 0 auto;"> | |
<div | |
style=" | |
display: inline-flex; | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1.5rem; | |
" | |
> | |
<h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;"> | |
Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities | |
</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 125%"> | |
<a href="https://arxiv.org/abs/2402.01831">[Paper]</a> <a href="https://github.com/NVIDIA/audio-flamingo">[Code]</a> <a href="https://audioflamingo.github.io/">[Demo]</a> | |
</p> | |
</div> | |
""" | |
) | |
gr.HTML( | |
""" | |
<div> | |
<h3>Model Overview</h3> | |
Audio Flamingo is an audio language model that can understand sounds beyond speech. | |
It can also answer questions about the sound in natural language. | |
Examples of questions include: | |
"Can you briefly describe what you hear in this audio?", | |
"What is the emotion conveyed in this music?", | |
"Where is this audio usually heard?", | |
or "What place is this music usually played at?". | |
</div> | |
""" | |
) | |
name = gr.Textbox( | |
label="Audio file path (choose one from: audio/wav{1--6}.wav)", | |
value="audio/wav5.wav" | |
) | |
prompt = gr.Textbox( | |
label="Instruction", | |
value='Can you briefly describe what you hear in this audio?' | |
) | |
with gr.Row(): | |
play_audio_button = gr.Button("Play Audio") | |
audio_output = gr.Audio(label="Playback") | |
play_audio_button.click(fn=lambda x: x, inputs=name, outputs=audio_output) | |
inference_button = gr.Button("Inference") | |
output_text = gr.Textbox(label="Audio Flamingo output") | |
inference_button.click( | |
fn=inference_item, | |
inputs=[name, prompt], | |
outputs=output_text | |
) | |
ui.queue() | |
ui.launch() | |