Gijs Wijngaard
Fix
ab3f8fd
raw
history blame
2.2 kB
"""
Audio Captioning Model
This script implements an audio captioning model based on the Effb2-Trm architecture.
It uses a pre-trained model to generate captions for audio inputs.
The original implementation is based on:
https://github.com/wsntxxn/Effb2-Trm-AudioCaps-Captioning/
"""
from functools import partial
import gradio as gr
import spaces
import torch
from torchaudio.functional import resample
from transformers import AutoModel, PreTrainedTokenizerFast
from hf_wrapper import Effb2TrmConfig, Effb2TrmCaptioningModel
# Load the configuration
config = Effb2TrmConfig.from_pretrained("config.json")
# Load the model
model = Effb2TrmCaptioningModel(config)
# Load the state dict from the local pytorch_model.bin file
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move the model to the appropriate device
model = model.to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"wsntxxn/audiocaps-simple-tokenizer"
)
target_sr = model.config.sample_rate
@spaces.GPU
def infer(input_audio):
sr, wav = input_audio
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0)
with torch.no_grad():
word_idx = model(
audio=wav,
audio_length=[wav_len]
)[0]
cap = tokenizer.decode(word_idx, skip_special_tokens=True)
return cap
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("# Lightweight Audio Captioning")
with gr.Row():
gr.Markdown("""
Audio Captioning Demo
""")
with gr.Row():
with gr.Column():
file = gr.Audio(label="Input", visible=True)
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer),
inputs=[file,],
outputs=output
)
demo.launch()