""" Audio Captioning Model This script implements an audio captioning model based on the Effb2-Trm architecture. It uses a pre-trained model to generate captions for audio inputs. The original implementation is based on: https://github.com/wsntxxn/Effb2-Trm-AudioCaps-Captioning/ """ from functools import partial import gradio as gr import spaces import torch from torchaudio.functional import resample from transformers import AutoModel, PreTrainedTokenizerFast from hf_wrapper import Effb2TrmConfig, Effb2TrmCaptioningModel # Load the configuration config = Effb2TrmConfig.from_pretrained("config.json") # Load the model model = Effb2TrmCaptioningModel(config) # Load the state dict from the local pytorch_model.bin file state_dict = torch.load("pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Move the model to the appropriate device model = model.to(device) tokenizer = PreTrainedTokenizerFast.from_pretrained( "wsntxxn/audiocaps-simple-tokenizer" ) target_sr = model.config.sample_rate @spaces.GPU def infer(input_audio): sr, wav = input_audio wav = torch.as_tensor(wav) if wav.dtype == torch.short: wav = wav / 2 ** 15 elif wav.dtype == torch.int: wav = wav / 2 ** 31 if wav.ndim > 1: wav = wav.mean(1) wav = resample(wav, sr, target_sr) wav_len = len(wav) wav = wav.float().unsqueeze(0) with torch.no_grad(): word_idx = model( audio=wav, audio_length=[wav_len] )[0] cap = tokenizer.decode(word_idx, skip_special_tokens=True) return cap with gr.Blocks() as demo: with gr.Row(): gr.Markdown("# Lightweight Audio Captioning") with gr.Row(): gr.Markdown(""" Audio Captioning Demo """) with gr.Row(): with gr.Column(): file = gr.Audio(label="Input", visible=True) btn = gr.Button("Run") with gr.Column(): output = gr.Textbox(label="Output") btn.click( fn=partial(infer), inputs=[file,], outputs=output ) demo.launch()