Spaces:
Sleeping
Sleeping
""" | |
Audio Captioning Model | |
This script implements an audio captioning model based on the Effb2-Trm architecture. | |
It uses a pre-trained model to generate captions for audio inputs. | |
The original implementation is based on: | |
https://github.com/wsntxxn/Effb2-Trm-AudioCaps-Captioning/ | |
""" | |
from functools import partial | |
import gradio as gr | |
import spaces | |
import torch | |
from torchaudio.functional import resample | |
from transformers import AutoModel, PreTrainedTokenizerFast | |
from hf_wrapper import Effb2TrmConfig, Effb2TrmCaptioningModel | |
# Load the configuration | |
config = Effb2TrmConfig.from_pretrained("config.json") | |
# Load the model | |
model = Effb2TrmCaptioningModel(config) | |
# Load the state dict from the local pytorch_model.bin file | |
state_dict = torch.load("pytorch_model.bin", map_location="cpu") | |
model.load_state_dict(state_dict) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Move the model to the appropriate device | |
model = model.to(device) | |
tokenizer = PreTrainedTokenizerFast.from_pretrained( | |
"wsntxxn/audiocaps-simple-tokenizer" | |
) | |
target_sr = model.config.sample_rate | |
def infer(input_audio): | |
sr, wav = input_audio | |
wav = torch.as_tensor(wav) | |
if wav.dtype == torch.short: | |
wav = wav / 2 ** 15 | |
elif wav.dtype == torch.int: | |
wav = wav / 2 ** 31 | |
if wav.ndim > 1: | |
wav = wav.mean(1) | |
wav = resample(wav, sr, target_sr) | |
wav_len = len(wav) | |
wav = wav.float().unsqueeze(0) | |
with torch.no_grad(): | |
word_idx = model( | |
audio=wav, | |
audio_length=[wav_len] | |
)[0] | |
cap = tokenizer.decode(word_idx, skip_special_tokens=True) | |
return cap | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown("# Lightweight Audio Captioning") | |
with gr.Row(): | |
gr.Markdown(""" | |
Audio Captioning Demo | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
file = gr.Audio(label="Input", visible=True) | |
btn = gr.Button("Run") | |
with gr.Column(): | |
output = gr.Textbox(label="Output") | |
btn.click( | |
fn=partial(infer), | |
inputs=[file,], | |
outputs=output | |
) | |
demo.launch() | |