Gijs Wijngaard commited on
Commit
47bcf45
Β·
1 Parent(s): f43e4f8
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +5 -4
  3. app.py +101 -0
  4. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: Audio Captioning Small
3
- emoji: πŸ”₯
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Audio Captioning Small
3
+ emoji: πŸ”Š
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import gradio as gr
3
+ import spaces
4
+ import torch
5
+ from torchaudio.functional import resample
6
+ from transformers import AutoModel, PreTrainedTokenizerFast
7
+
8
+
9
+ def load_model(model_name,
10
+ device):
11
+ if model_name == "AudioCaps":
12
+ model = AutoModel.from_pretrained(
13
+ "wsntxxn/effb2-trm-audiocaps-captioning",
14
+ trust_remote_code=True
15
+ ).to(device)
16
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
17
+ "wsntxxn/audiocaps-simple-tokenizer"
18
+ )
19
+ elif model_name == "Clotho":
20
+ model = AutoModel.from_pretrained(
21
+ "wsntxxn/effb2-trm-clotho-captioning",
22
+ trust_remote_code=True
23
+ ).to(device)
24
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
25
+ "wsntxxn/clotho-simple-tokenizer"
26
+ )
27
+ return model, tokenizer
28
+
29
+ @spaces.GPU
30
+ def infer(file, runner):
31
+ sr, wav = file
32
+ wav = torch.as_tensor(wav)
33
+ if wav.dtype == torch.short:
34
+ wav = wav / 2 ** 15
35
+ elif wav.dtype == torch.int:
36
+ wav = wav / 2 ** 31
37
+ if wav.ndim > 1:
38
+ wav = wav.mean(1)
39
+ wav = resample(wav, sr, runner.target_sr)
40
+ wav_len = len(wav)
41
+ wav = wav.float().unsqueeze(0)
42
+ with torch.no_grad():
43
+ word_idx = runner.model(
44
+ audio=wav,
45
+ audio_length=[wav_len]
46
+ )[0]
47
+ cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True)
48
+ return cap
49
+
50
+ # def input_toggle(input_type):
51
+ # if input_type == "file":
52
+ # return gr.update(visible=True), gr.update(visible=False)
53
+ # elif input_type == "mic":
54
+ # return gr.update(visible=False), gr.update(visible=True)
55
+
56
+ class InferRunner:
57
+ def __init__(self, model_name):
58
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ self.model, self.tokenizer = load_model(model_name, self.device)
60
+ self.target_sr = self.model.config.sample_rate
61
+
62
+ def change_model(self, model_name):
63
+ self.model, self.tokenizer = load_model(model_name, self.device)
64
+ self.target_sr = self.model.config.sample_rate
65
+
66
+
67
+ def change_model(radio):
68
+ global infer_runner
69
+ infer_runner.change_model(radio)
70
+
71
+
72
+ with gr.Blocks() as demo:
73
+ with gr.Row():
74
+ gr.Markdown("# Lightweight Audio Captioning")
75
+
76
+ with gr.Row():
77
+ gr.Markdown("""
78
+ Audio Captioning Demo
79
+ """)
80
+ with gr.Row():
81
+ with gr.Column():
82
+ radio = gr.Radio(
83
+ ["AudioCaps", "Clotho"],
84
+ value="AudioCaps",
85
+ label="Select model"
86
+ )
87
+ infer_runner = InferRunner(radio.value)
88
+ file = gr.Audio(label="Input", visible=True)
89
+ radio.change(fn=change_model, inputs=[radio,],)
90
+ btn = gr.Button("Run")
91
+ with gr.Column():
92
+ output = gr.Textbox(label="Output")
93
+ btn.click(
94
+ fn=partial(infer,
95
+ runner=infer_runner),
96
+ inputs=[file,],
97
+ outputs=output
98
+ )
99
+
100
+ demo.launch()
101
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ efficientnet_pytorch
3
+ torchaudio
4
+ einops