Gijs Wijngaard commited on
Commit
ab3f8fd
Β·
1 Parent(s): 47bcf45
Files changed (5) hide show
  1. README.md +4 -0
  2. app.py +35 -56
  3. config.json +22 -0
  4. hf_wrapper.py +1964 -0
  5. pytorch_model.bin +3 -0
README.md CHANGED
@@ -1,5 +1,9 @@
1
  ---
 
2
  title: Audio Captioning Small
 
 
 
3
  emoji: πŸ”Š
4
  colorFrom: blue
5
  colorTo: pink
 
1
  ---
2
+ <<<<<<< HEAD
3
  title: Audio Captioning Small
4
+ =======
5
+ title: Efficient Audio Captioning
6
+ >>>>>>> 901f564 (Test)
7
  emoji: πŸ”Š
8
  colorFrom: blue
9
  colorTo: pink
app.py CHANGED
@@ -1,34 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
1
  from functools import partial
2
  import gradio as gr
3
  import spaces
4
  import torch
5
  from torchaudio.functional import resample
6
  from transformers import AutoModel, PreTrainedTokenizerFast
 
7
 
 
 
8
 
9
- def load_model(model_name,
10
- device):
11
- if model_name == "AudioCaps":
12
- model = AutoModel.from_pretrained(
13
- "wsntxxn/effb2-trm-audiocaps-captioning",
14
- trust_remote_code=True
15
- ).to(device)
16
- tokenizer = PreTrainedTokenizerFast.from_pretrained(
17
- "wsntxxn/audiocaps-simple-tokenizer"
18
- )
19
- elif model_name == "Clotho":
20
- model = AutoModel.from_pretrained(
21
- "wsntxxn/effb2-trm-clotho-captioning",
22
- trust_remote_code=True
23
- ).to(device)
24
- tokenizer = PreTrainedTokenizerFast.from_pretrained(
25
- "wsntxxn/clotho-simple-tokenizer"
26
- )
27
- return model, tokenizer
28
 
29
  @spaces.GPU
30
- def infer(file, runner):
31
- sr, wav = file
32
  wav = torch.as_tensor(wav)
33
  if wav.dtype == torch.short:
34
  wav = wav / 2 ** 15
@@ -36,38 +46,17 @@ def infer(file, runner):
36
  wav = wav / 2 ** 31
37
  if wav.ndim > 1:
38
  wav = wav.mean(1)
39
- wav = resample(wav, sr, runner.target_sr)
40
  wav_len = len(wav)
41
  wav = wav.float().unsqueeze(0)
42
  with torch.no_grad():
43
- word_idx = runner.model(
44
  audio=wav,
45
  audio_length=[wav_len]
46
  )[0]
47
- cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True)
48
  return cap
49
 
50
- # def input_toggle(input_type):
51
- # if input_type == "file":
52
- # return gr.update(visible=True), gr.update(visible=False)
53
- # elif input_type == "mic":
54
- # return gr.update(visible=False), gr.update(visible=True)
55
-
56
- class InferRunner:
57
- def __init__(self, model_name):
58
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
- self.model, self.tokenizer = load_model(model_name, self.device)
60
- self.target_sr = self.model.config.sample_rate
61
-
62
- def change_model(self, model_name):
63
- self.model, self.tokenizer = load_model(model_name, self.device)
64
- self.target_sr = self.model.config.sample_rate
65
-
66
-
67
- def change_model(radio):
68
- global infer_runner
69
- infer_runner.change_model(radio)
70
-
71
 
72
  with gr.Blocks() as demo:
73
  with gr.Row():
@@ -79,23 +68,13 @@ with gr.Blocks() as demo:
79
  """)
80
  with gr.Row():
81
  with gr.Column():
82
- radio = gr.Radio(
83
- ["AudioCaps", "Clotho"],
84
- value="AudioCaps",
85
- label="Select model"
86
- )
87
- infer_runner = InferRunner(radio.value)
88
  file = gr.Audio(label="Input", visible=True)
89
- radio.change(fn=change_model, inputs=[radio,],)
90
  btn = gr.Button("Run")
91
  with gr.Column():
92
  output = gr.Textbox(label="Output")
93
  btn.click(
94
- fn=partial(infer,
95
- runner=infer_runner),
96
  inputs=[file,],
97
  outputs=output
98
  )
99
-
100
  demo.launch()
101
-
 
1
+ """
2
+ Audio Captioning Model
3
+
4
+ This script implements an audio captioning model based on the Effb2-Trm architecture.
5
+ It uses a pre-trained model to generate captions for audio inputs.
6
+
7
+ The original implementation is based on:
8
+ https://github.com/wsntxxn/Effb2-Trm-AudioCaps-Captioning/
9
+
10
+ """
11
+
12
  from functools import partial
13
  import gradio as gr
14
  import spaces
15
  import torch
16
  from torchaudio.functional import resample
17
  from transformers import AutoModel, PreTrainedTokenizerFast
18
+ from hf_wrapper import Effb2TrmConfig, Effb2TrmCaptioningModel
19
 
20
+ # Load the configuration
21
+ config = Effb2TrmConfig.from_pretrained("config.json")
22
 
23
+ # Load the model
24
+ model = Effb2TrmCaptioningModel(config)
25
+
26
+ # Load the state dict from the local pytorch_model.bin file
27
+ state_dict = torch.load("pytorch_model.bin", map_location="cpu")
28
+ model.load_state_dict(state_dict)
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ # Move the model to the appropriate device
33
+ model = model.to(device)
34
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
35
+ "wsntxxn/audiocaps-simple-tokenizer"
36
+ )
37
+ target_sr = model.config.sample_rate
 
 
 
 
38
 
39
  @spaces.GPU
40
+ def infer(input_audio):
41
+ sr, wav = input_audio
42
  wav = torch.as_tensor(wav)
43
  if wav.dtype == torch.short:
44
  wav = wav / 2 ** 15
 
46
  wav = wav / 2 ** 31
47
  if wav.ndim > 1:
48
  wav = wav.mean(1)
49
+ wav = resample(wav, sr, target_sr)
50
  wav_len = len(wav)
51
  wav = wav.float().unsqueeze(0)
52
  with torch.no_grad():
53
+ word_idx = model(
54
  audio=wav,
55
  audio_length=[wav_len]
56
  )[0]
57
+ cap = tokenizer.decode(word_idx, skip_special_tokens=True)
58
  return cap
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  with gr.Blocks() as demo:
62
  with gr.Row():
 
68
  """)
69
  with gr.Row():
70
  with gr.Column():
 
 
 
 
 
 
71
  file = gr.Audio(label="Input", visible=True)
 
72
  btn = gr.Button("Run")
73
  with gr.Column():
74
  output = gr.Textbox(label="Output")
75
  btn.click(
76
+ fn=partial(infer),
 
77
  inputs=[file,],
78
  outputs=output
79
  )
 
80
  demo.launch()
 
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "gijs/audio-captioning-small",
3
+ "architectures": [
4
+ "Effb2TrmCaptioningModel"
5
+ ],
6
+ "attn_emb_dim": 1408,
7
+ "auto_map": {
8
+ "AutoConfig": "hf_wrapper.Effb2TrmConfig",
9
+ "AutoModel": "hf_wrapper.Effb2TrmCaptioningModel"
10
+ },
11
+ "decoder_dropout": 0.2,
12
+ "decoder_emb_dim": 256,
13
+ "decoder_n_layers": 2,
14
+ "decoder_we_tie_weights": true,
15
+ "fc_emb_dim": 1408,
16
+ "sample_rate": 16000,
17
+ "shared_dim": 1024,
18
+ "tchr_dim": 768,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.30.2",
21
+ "vocab_size": 4981
22
+ }
hf_wrapper.py ADDED
@@ -0,0 +1,1964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Callable, Union, List
2
+ import random
3
+ import math
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
11
+ from torchaudio import transforms
12
+ from efficientnet_pytorch import EfficientNet
13
+ from efficientnet_pytorch import utils as efficientnet_utils
14
+ from einops import rearrange, reduce
15
+ from transformers import PretrainedConfig, PreTrainedModel
16
+
17
+
18
+ def sort_pack_padded_sequence(input, lengths):
19
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
20
+ tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
21
+ inv_ix = indices.clone()
22
+ inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
23
+ return tmp, inv_ix
24
+
25
+ def pad_unsort_packed_sequence(input, inv_ix):
26
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
27
+ tmp = tmp[inv_ix]
28
+ return tmp
29
+
30
+ def pack_wrapper(module, attn_feats, attn_feat_lens):
31
+ packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
32
+ if isinstance(module, torch.nn.RNNBase):
33
+ return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
34
+ else:
35
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
36
+
37
+ def embedding_pooling(x, lens, pooling="mean"):
38
+ if pooling == "max":
39
+ fc_embs = max_with_lens(x, lens)
40
+ elif pooling == "mean":
41
+ fc_embs = mean_with_lens(x, lens)
42
+ elif pooling == "mean+max":
43
+ x_mean = mean_with_lens(x, lens)
44
+ x_max = max_with_lens(x, lens)
45
+ fc_embs = x_mean + x_max
46
+ elif pooling == "last":
47
+ indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
48
+ # indices: [N, 1, hidden]
49
+ fc_embs = torch.gather(x, 1, indices).squeeze(1)
50
+ else:
51
+ raise Exception(f"pooling method {pooling} not support")
52
+ return fc_embs
53
+
54
+ def interpolate(x, ratio):
55
+ """Interpolate data in time domain. This is used to compensate the
56
+ resolution reduction in downsampling of a CNN.
57
+
58
+ Args:
59
+ x: (batch_size, time_steps, classes_num)
60
+ ratio: int, ratio to interpolate
61
+ Returns:
62
+ upsampled: (batch_size, time_steps * ratio, classes_num)
63
+ """
64
+ (batch_size, time_steps, classes_num) = x.shape
65
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
66
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
67
+ return upsampled
68
+
69
+ def pad_framewise_output(framewise_output, frames_num):
70
+ """Pad framewise_output to the same length as input frames. The pad value
71
+ is the same as the value of the last frame.
72
+ Args:
73
+ framewise_output: (batch_size, frames_num, classes_num)
74
+ frames_num: int, number of frames to pad
75
+ Outputs:
76
+ output: (batch_size, frames_num, classes_num)
77
+ """
78
+ pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
79
+ """tensor for padding"""
80
+
81
+ output = torch.cat((framewise_output, pad), dim=1)
82
+ """(batch_size, frames_num, classes_num)"""
83
+
84
+ return output
85
+
86
+ def find_contiguous_regions(activity_array):
87
+ """Find contiguous regions from bool valued numpy.array.
88
+ Copy of https://dcase-repo.github.io/dcase_util/_modules/dcase_util/data/decisions.html#DecisionEncoder
89
+ Reason is:
90
+ 1. This does not belong to a class necessarily
91
+ 2. Import DecisionEncoder requires sndfile over some other imports..which causes some problems on clusters
92
+ """
93
+
94
+ # Find the changes in the activity_array
95
+ change_indices = np.logical_xor(activity_array[1:],
96
+ activity_array[:-1]).nonzero()[0]
97
+
98
+ # Shift change_index with one, focus on frame after the change.
99
+ change_indices += 1
100
+
101
+ if activity_array[0]:
102
+ # If the first element of activity_array is True add 0 at the beginning
103
+ change_indices = np.r_[0, change_indices]
104
+
105
+ if activity_array[-1]:
106
+ # If the last element of activity_array is True, add the length of the array
107
+ change_indices = np.r_[change_indices, activity_array.size]
108
+
109
+ # Reshape the result into two columns
110
+ return change_indices.reshape((-1, 2))
111
+
112
+ def double_threshold(x, high_thres, low_thres, n_connect=1):
113
+ """double_threshold
114
+ Helper function to calculate double threshold for n-dim arrays
115
+ :param x: input array
116
+ :param high_thres: high threshold value
117
+ :param low_thres: Low threshold value
118
+ :param n_connect: Distance of <= n clusters will be merged
119
+ """
120
+ assert x.ndim <= 3, "Whoops something went wrong with the input ({}), check if its <= 3 dims".format(
121
+ x.shape)
122
+ if x.ndim == 3:
123
+ apply_dim = 1
124
+ elif x.ndim < 3:
125
+ apply_dim = 0
126
+ # x is assumed to be 3d: (batch, time, dim)
127
+ # Assumed to be 2d : (time, dim)
128
+ # Assumed to be 1d : (time)
129
+ # time axis is therefore at 1 for 3d and 0 for 2d (
130
+ return np.apply_along_axis(lambda x: _double_threshold(
131
+ x, high_thres, low_thres, n_connect=n_connect),
132
+ axis=apply_dim,
133
+ arr=x)
134
+
135
+ def _double_threshold(x, high_thres, low_thres, n_connect=1, return_arr=True):
136
+ """_double_threshold
137
+ Computes a double threshold over the input array
138
+ :param x: input array, needs to be 1d
139
+ :param high_thres: High threshold over the array
140
+ :param low_thres: Low threshold over the array
141
+ :param n_connect: Postprocessing, maximal distance between clusters to connect
142
+ :param return_arr: By default this function returns the filtered indiced, but if return_arr = True it returns an array of tsame size as x filled with ones and zeros.
143
+ """
144
+ assert x.ndim == 1, "Input needs to be 1d"
145
+ high_locations = np.where(x > high_thres)[0]
146
+ locations = x > low_thres
147
+ encoded_pairs = find_contiguous_regions(locations)
148
+
149
+ filtered_list = list(
150
+ filter(
151
+ lambda pair:
152
+ ((pair[0] <= high_locations) & (high_locations <= pair[1])).any(),
153
+ encoded_pairs))
154
+
155
+ filtered_list = connect_(filtered_list, n_connect)
156
+ if return_arr:
157
+ zero_one_arr = np.zeros_like(x, dtype=int)
158
+ for sl in filtered_list:
159
+ zero_one_arr[sl[0]:sl[1]] = 1
160
+ return zero_one_arr
161
+ return filtered_list
162
+
163
+ def connect_(pairs, n=1):
164
+ """connect_
165
+ Connects two adjacent clusters if their distance is <= n
166
+ :param pairs: Clusters of iterateables e.g., [(1,5),(7,10)]
167
+ :param n: distance between two clusters
168
+ """
169
+ if len(pairs) == 0:
170
+ return []
171
+ start_, end_ = pairs[0]
172
+ new_pairs = []
173
+ for i, (next_item, cur_item) in enumerate(zip(pairs[1:], pairs[0:])):
174
+ end_ = next_item[1]
175
+ if next_item[0] - cur_item[1] <= n:
176
+ pass
177
+ else:
178
+ new_pairs.append((start_, cur_item[1]))
179
+ start_ = next_item[0]
180
+ new_pairs.append((start_, end_))
181
+ return new_pairs
182
+
183
+ def segments_to_temporal_tag(segments, thre=0.5):
184
+ after_flag, while_flag = 0, 0
185
+ for j in range(len(segments)):
186
+ for k in range(len(segments)):
187
+ if segments[j][0] == segments[k][0]:
188
+ continue
189
+ min_duration = min(segments[j][2] - segments[j][1], segments[k][2] - segments[k][1])
190
+ overlap = segments[j][2] - segments[k][1]
191
+ if overlap < thre * min_duration:
192
+ after_flag = 2
193
+ if segments[j][1] < segments[k][1] and overlap > thre * min_duration:
194
+ while_flag = 1
195
+ return after_flag + while_flag
196
+
197
+ def decode_with_timestamps(labels, time_resolution):
198
+ batch_results = []
199
+ for lab in labels:
200
+ segments = []
201
+ for i, label_column in enumerate(lab.T):
202
+ change_indices = find_contiguous_regions(label_column)
203
+ # append [onset, offset] in the result list
204
+ for row in change_indices:
205
+ segments.append((i, row[0] * time_resolution, row[1] * time_resolution))
206
+ temporal_tag = segments_to_temporal_tag(segments)
207
+ batch_results.append(temporal_tag)
208
+ return batch_results
209
+
210
+ class _EffiNet(nn.Module):
211
+ """A proxy for efficient net models"""
212
+ def __init__(self,
213
+ blocks_args=None,
214
+ global_params=None,
215
+ ) -> None:
216
+ super().__init__()
217
+ self.eff_net = EfficientNet(blocks_args=blocks_args,
218
+ global_params=global_params)
219
+
220
+
221
+ def forward(self, x: torch.Tensor):
222
+ x = rearrange(x, 'b f t -> b 1 f t')
223
+ x = self.eff_net.extract_features(x)
224
+ return reduce(x, 'b c f t -> b t c', 'mean')
225
+
226
+
227
+ def get_effb2_model() -> _EffiNet:
228
+ blocks_args, global_params = efficientnet_utils.get_model_params(
229
+ 'efficientnet-b2', {'include_top': False})
230
+ model = _EffiNet(blocks_args=blocks_args,
231
+ global_params=global_params)
232
+ model.eff_net._change_in_channels(1)
233
+ return model
234
+
235
+ def merge_load_state_dict(state_dict,
236
+ model: torch.nn.Module,
237
+ output_fn: Callable = sys.stdout.write):
238
+ model_dict = model.state_dict()
239
+ pretrained_dict = {}
240
+ mismatch_keys = []
241
+ for key, value in state_dict.items():
242
+ if key in model_dict and model_dict[key].shape == value.shape:
243
+ pretrained_dict[key] = value
244
+ else:
245
+ mismatch_keys.append(key)
246
+ output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n")
247
+ model_dict.update(pretrained_dict)
248
+ model.load_state_dict(model_dict, strict=True)
249
+ return pretrained_dict.keys()
250
+
251
+
252
+ class EfficientNetB2(nn.Module):
253
+
254
+ def __init__(self,
255
+ n_mels: int = 64,
256
+ win_length: int = 32,
257
+ hop_length: int = 10,
258
+ f_min: int = 0,
259
+ freeze: bool = False,):
260
+ super().__init__()
261
+ sample_rate = 16000
262
+ self.melspec_extractor = transforms.MelSpectrogram(
263
+ sample_rate=sample_rate,
264
+ n_fft=win_length * sample_rate // 1000,
265
+ win_length=win_length * sample_rate // 1000,
266
+ hop_length=hop_length * sample_rate // 1000,
267
+ f_min=f_min,
268
+ n_mels=n_mels,
269
+ )
270
+ self.hop_length = 10 * sample_rate // 1000
271
+ self.db_transform = transforms.AmplitudeToDB(top_db=120)
272
+ self.backbone = get_effb2_model()
273
+ self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels
274
+ self.downsample_ratio = 32
275
+ if freeze:
276
+ for param in self.parameters():
277
+ param.requires_grad = False
278
+
279
+ def forward(self, input_dict):
280
+
281
+ waveform = input_dict["wav"]
282
+ wave_length = input_dict["wav_len"]
283
+ specaug = input_dict["specaug"]
284
+ x = self.melspec_extractor(waveform)
285
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
286
+
287
+ x = rearrange(x, 'b f t -> b 1 t f')
288
+ if self.training and specaug:
289
+ x = self.spec_augmenter(x)
290
+ x = rearrange(x, 'b 1 t f -> b f t')
291
+
292
+ x = self.backbone(x)
293
+ attn_emb = x
294
+
295
+ wave_length = torch.as_tensor(wave_length)
296
+ feat_length = torch.div(wave_length, self.hop_length,
297
+ rounding_mode="floor") + 1
298
+ feat_length = torch.div(feat_length, self.downsample_ratio,
299
+ rounding_mode="floor")
300
+ fc_emb = mean_with_lens(attn_emb, feat_length)
301
+
302
+ output_dict = {
303
+ 'fc_emb': fc_emb,
304
+ 'attn_emb': attn_emb,
305
+ 'attn_emb_len': feat_length
306
+ }
307
+ return output_dict
308
+
309
+
310
+ def generate_length_mask(lens, max_length=None):
311
+ lens = torch.as_tensor(lens)
312
+ N = lens.size(0)
313
+ if max_length is None:
314
+ max_length = max(lens)
315
+ if isinstance(max_length, torch.Tensor):
316
+ max_length = max_length.item()
317
+ idxs = torch.arange(max_length).repeat(N).view(N, max_length)
318
+ idxs = idxs.to(lens.device)
319
+ mask = (idxs < lens.view(-1, 1))
320
+ return mask
321
+
322
+ def mean_with_lens(features, lens):
323
+ """
324
+ features: [N, T, ...] (assume the second dimension represents length)
325
+ lens: [N,]
326
+ """
327
+ lens = torch.as_tensor(lens)
328
+ if max(lens) != features.size(1):
329
+ max_length = features.size(1)
330
+ mask = generate_length_mask(lens, max_length)
331
+ else:
332
+ mask = generate_length_mask(lens)
333
+ mask = mask.to(features.device) # [N, T]
334
+
335
+ while mask.ndim < features.ndim:
336
+ mask = mask.unsqueeze(-1)
337
+ feature_mean = features * mask
338
+ feature_mean = feature_mean.sum(1)
339
+ while lens.ndim < feature_mean.ndim:
340
+ lens = lens.unsqueeze(1)
341
+ feature_mean = feature_mean / lens.to(features.device)
342
+ # feature_mean = features * mask.unsqueeze(-1)
343
+ # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
344
+ return feature_mean
345
+
346
+ def max_with_lens(features, lens):
347
+ """
348
+ features: [N, T, ...] (assume the second dimension represents length)
349
+ lens: [N,]
350
+ """
351
+ lens = torch.as_tensor(lens)
352
+ if max(lens) != features.size(1):
353
+ max_length = features.size(1)
354
+ mask = generate_length_mask(lens, max_length)
355
+ else:
356
+ mask = generate_length_mask(lens)
357
+ mask = mask.to(features.device) # [N, T]
358
+
359
+ feature_max = features.clone()
360
+ feature_max[~mask] = float("-inf")
361
+ feature_max, _ = feature_max.max(1)
362
+ return feature_max
363
+
364
+ def repeat_tensor(x, n):
365
+ return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
366
+
367
+
368
+ class CaptionMetaMixin:
369
+ pad_idx = 0
370
+ start_idx = 1
371
+ end_idx = 2
372
+ max_length = 20
373
+
374
+ @classmethod
375
+ def set_index(cls, start_idx, end_idx, pad_idx):
376
+ cls.start_idx = start_idx
377
+ cls.end_idx = end_idx
378
+ cls.pad_idx = pad_idx
379
+
380
+
381
+ class CaptionModel(nn.Module, CaptionMetaMixin):
382
+ """
383
+ Encoder-decoder captioning model.
384
+ """
385
+
386
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
387
+ super().__init__()
388
+ self.encoder = encoder
389
+ self.decoder = decoder
390
+ self.vocab_size = decoder.vocab_size
391
+ self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
392
+ self.inference_forward_keys = ["sample_method", "max_length", "temp"]
393
+ freeze_encoder = kwargs.get("freeze_encoder", False)
394
+ if freeze_encoder:
395
+ for param in self.encoder.parameters():
396
+ param.requires_grad = False
397
+ self.check_decoder_compatibility()
398
+
399
+ def check_decoder_compatibility(self):
400
+ compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
401
+ assert isinstance(self.decoder, self.compatible_decoders), \
402
+ f"{self.decoder.__class__.__name__} is incompatible with " \
403
+ f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
404
+
405
+ def forward(self, input_dict: Dict):
406
+ """
407
+ input_dict: {
408
+ (required)
409
+ mode: train/inference,
410
+ [spec, spec_len],
411
+ [fc],
412
+ [attn, attn_len],
413
+ [wav, wav_len],
414
+ [sample_method: greedy],
415
+ [temp: 1.0] (in case of no teacher forcing)
416
+ (optional, mode=train)
417
+ cap,
418
+ cap_len,
419
+ ss_ratio,
420
+ (optional, mode=inference)
421
+ sample_method: greedy/beam,
422
+ max_length,
423
+ temp,
424
+ beam_size (optional, sample_method=beam),
425
+ n_best (optional, sample_method=beam),
426
+ }
427
+ """
428
+ encoder_output_dict = self.encoder(input_dict)
429
+ output = self.forward_decoder(input_dict, encoder_output_dict)
430
+ return output
431
+
432
+ def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict):
433
+ if input_dict["mode"] == "train":
434
+ forward_dict = {
435
+ "mode": "train", "sample_method": "greedy", "temp": 1.0
436
+ }
437
+ for key in self.train_forward_keys:
438
+ forward_dict[key] = input_dict[key]
439
+ forward_dict.update(encoder_output_dict)
440
+ output = self.train_forward(forward_dict)
441
+ elif input_dict["mode"] == "inference":
442
+ forward_dict = {"mode": "inference"}
443
+ default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
444
+ for key in self.inference_forward_keys:
445
+ if key in input_dict:
446
+ forward_dict[key] = input_dict[key]
447
+ else:
448
+ forward_dict[key] = default_args[key]
449
+
450
+ if forward_dict["sample_method"] == "beam":
451
+ forward_dict["beam_size"] = input_dict.get("beam_size", 3)
452
+ forward_dict["n_best"] = input_dict.get("n_best", False)
453
+ forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
454
+ elif forward_dict["sample_method"] == "dbs":
455
+ forward_dict["beam_size"] = input_dict.get("beam_size", 6)
456
+ forward_dict["group_size"] = input_dict.get("group_size", 3)
457
+ forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
458
+ forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
459
+
460
+ forward_dict.update(encoder_output_dict)
461
+ output = self.inference_forward(forward_dict)
462
+ else:
463
+ raise Exception("mode should be either 'train' or 'inference'")
464
+ output.update(encoder_output_dict)
465
+ return output
466
+
467
+ def prepare_output(self, input_dict):
468
+ output = {}
469
+ batch_size = input_dict["fc_emb"].size(0)
470
+ if input_dict["mode"] == "train":
471
+ max_length = input_dict["cap"].size(1) - 1
472
+ elif input_dict["mode"] == "inference":
473
+ max_length = input_dict["max_length"]
474
+ else:
475
+ raise Exception("mode should be either 'train' or 'inference'")
476
+ device = input_dict["fc_emb"].device
477
+ output["seq"] = torch.full((batch_size, max_length), self.end_idx,
478
+ dtype=torch.long)
479
+ output["logit"] = torch.empty(batch_size, max_length,
480
+ self.vocab_size).to(device)
481
+ output["sampled_logprob"] = torch.zeros(batch_size, max_length)
482
+ output["embed"] = torch.empty(batch_size, max_length,
483
+ self.decoder.d_model).to(device)
484
+ return output
485
+
486
+ def train_forward(self, input_dict):
487
+ if input_dict["ss_ratio"] != 1: # scheduled sampling training
488
+ input_dict["mode"] = "train"
489
+ return self.stepwise_forward(input_dict)
490
+ output = self.seq_forward(input_dict)
491
+ self.train_process(output, input_dict)
492
+ return output
493
+
494
+ def seq_forward(self, input_dict):
495
+ raise NotImplementedError
496
+
497
+ def train_process(self, output, input_dict):
498
+ pass
499
+
500
+ def inference_forward(self, input_dict):
501
+ if input_dict["sample_method"] == "beam":
502
+ return self.beam_search(input_dict)
503
+ elif input_dict["sample_method"] == "dbs":
504
+ return self.diverse_beam_search(input_dict)
505
+ return self.stepwise_forward(input_dict)
506
+
507
+ def stepwise_forward(self, input_dict):
508
+ """Step-by-step decoding"""
509
+ output = self.prepare_output(input_dict)
510
+ max_length = output["seq"].size(1)
511
+ # start sampling
512
+ for t in range(max_length):
513
+ input_dict["t"] = t
514
+ self.decode_step(input_dict, output)
515
+ if input_dict["mode"] == "inference": # decide whether to stop when sampling
516
+ unfinished_t = output["seq"][:, t] != self.end_idx
517
+ if t == 0:
518
+ unfinished = unfinished_t
519
+ else:
520
+ unfinished *= unfinished_t
521
+ output["seq"][:, t][~unfinished] = self.end_idx
522
+ if unfinished.sum() == 0:
523
+ break
524
+ self.stepwise_process(output)
525
+ return output
526
+
527
+ def decode_step(self, input_dict, output):
528
+ """Decoding operation of timestep t"""
529
+ decoder_input = self.prepare_decoder_input(input_dict, output)
530
+ # feed to the decoder to get logit
531
+ output_t = self.decoder(decoder_input)
532
+ logit_t = output_t["logit"]
533
+ # assert logit_t.ndim == 3
534
+ if logit_t.size(1) == 1:
535
+ logit_t = logit_t.squeeze(1)
536
+ embed_t = output_t["embed"].squeeze(1)
537
+ elif logit_t.size(1) > 1:
538
+ logit_t = logit_t[:, -1, :]
539
+ embed_t = output_t["embed"][:, -1, :]
540
+ else:
541
+ raise Exception("no logit output")
542
+ # sample the next input word and get the corresponding logit
543
+ sampled = self.sample_next_word(logit_t,
544
+ method=input_dict["sample_method"],
545
+ temp=input_dict["temp"])
546
+
547
+ output_t.update(sampled)
548
+ output_t["t"] = input_dict["t"]
549
+ output_t["logit"] = logit_t
550
+ output_t["embed"] = embed_t
551
+ self.stepwise_process_step(output, output_t)
552
+
553
+ def prepare_decoder_input(self, input_dict, output):
554
+ """Prepare the inp ut dict for the decoder"""
555
+ raise NotImplementedError
556
+
557
+ def stepwise_process_step(self, output, output_t):
558
+ """Postprocessing (save output values) after each timestep t"""
559
+ t = output_t["t"]
560
+ output["logit"][:, t, :] = output_t["logit"]
561
+ output["seq"][:, t] = output_t["word"]
562
+ output["sampled_logprob"][:, t] = output_t["probs"]
563
+ output["embed"][:, t, :] = output_t["embed"]
564
+
565
+ def stepwise_process(self, output):
566
+ """Postprocessing after the whole step-by-step autoregressive decoding"""
567
+ pass
568
+
569
+ def sample_next_word(self, logit, method, temp):
570
+ """Sample the next word, given probs output by the decoder"""
571
+ logprob = torch.log_softmax(logit, dim=1)
572
+ if method == "greedy":
573
+ sampled_logprob, word = torch.max(logprob.detach(), 1)
574
+ elif method == "gumbel":
575
+ def sample_gumbel(shape, eps=1e-20):
576
+ U = torch.rand(shape).to(logprob.device)
577
+ return -torch.log(-torch.log(U + eps) + eps)
578
+ def gumbel_softmax_sample(logit, temperature):
579
+ y = logit + sample_gumbel(logit.size())
580
+ return torch.log_softmax(y / temperature, dim=-1)
581
+ _logprob = gumbel_softmax_sample(logprob, temp)
582
+ _, word = torch.max(_logprob.data, 1)
583
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
584
+ else:
585
+ logprob = logprob / temp
586
+ if method.startswith("top"):
587
+ top_num = float(method[3:])
588
+ if 0 < top_num < 1: # top-p sampling
589
+ probs = torch.softmax(logit, dim=1)
590
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
591
+ _cumsum = sorted_probs.cumsum(1)
592
+ mask = _cumsum < top_num
593
+ mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
594
+ sorted_probs = sorted_probs * mask.to(sorted_probs)
595
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
596
+ logprob.scatter_(1, sorted_indices, sorted_probs.log())
597
+ else: # top-k sampling
598
+ k = int(top_num)
599
+ tmp = torch.empty_like(logprob).fill_(float('-inf'))
600
+ topk, indices = torch.topk(logprob, k, dim=1)
601
+ tmp = tmp.scatter(1, indices, topk)
602
+ logprob = tmp
603
+ word = torch.distributions.Categorical(logits=logprob.detach()).sample()
604
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
605
+ word = word.detach().long()
606
+ # sampled_logprob: [N,], word: [N,]
607
+ return {"word": word, "probs": sampled_logprob}
608
+
609
+ def beam_search(self, input_dict):
610
+ output = self.prepare_output(input_dict)
611
+ max_length = input_dict["max_length"]
612
+ beam_size = input_dict["beam_size"]
613
+ if input_dict["n_best"]:
614
+ n_best_size = input_dict["n_best_size"]
615
+ batch_size, max_length = output["seq"].size()
616
+ output["seq"] = torch.full((batch_size, n_best_size, max_length),
617
+ self.end_idx, dtype=torch.long)
618
+
619
+ temp = input_dict["temp"]
620
+ # instance by instance beam seach
621
+ for i in range(output["seq"].size(0)):
622
+ output_i = self.prepare_beamsearch_output(input_dict)
623
+ input_dict["sample_idx"] = i
624
+ for t in range(max_length):
625
+ input_dict["t"] = t
626
+ output_t = self.beamsearch_step(input_dict, output_i)
627
+ #######################################
628
+ # merge with previous beam and select the current max prob beam
629
+ #######################################
630
+ logit_t = output_t["logit"]
631
+ if logit_t.size(1) == 1:
632
+ logit_t = logit_t.squeeze(1)
633
+ elif logit_t.size(1) > 1:
634
+ logit_t = logit_t[:, -1, :]
635
+ else:
636
+ raise Exception("no logit output")
637
+ logprob_t = torch.log_softmax(logit_t, dim=1)
638
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
639
+ logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
640
+ if t == 0: # for the first step, all k seq will have the same probs
641
+ topk_logprob, topk_words = logprob_t[0].topk(
642
+ beam_size, 0, True, True)
643
+ else: # unroll and find top logprob, and their unrolled indices
644
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
645
+ beam_size, 0, True, True)
646
+ topk_words = topk_words.cpu()
647
+ output_i["topk_logprob"] = topk_logprob
648
+ # output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
649
+ output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
650
+ rounding_mode='trunc')
651
+ output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
652
+ if t == 0:
653
+ output_i["seq"] = output_i["next_word"].unsqueeze(1)
654
+ else:
655
+ output_i["seq"] = torch.cat([
656
+ output_i["seq"][output_i["prev_words_beam"]],
657
+ output_i["next_word"].unsqueeze(1)], dim=1)
658
+
659
+ # add finished beams to results
660
+ is_end = output_i["next_word"] == self.end_idx
661
+ if t == max_length - 1:
662
+ is_end.fill_(1)
663
+
664
+ for beam_idx in range(beam_size):
665
+ if is_end[beam_idx]:
666
+ final_beam = {
667
+ "seq": output_i["seq"][beam_idx].clone(),
668
+ "score": output_i["topk_logprob"][beam_idx].item()
669
+ }
670
+ final_beam["score"] = final_beam["score"] / (t + 1)
671
+ output_i["done_beams"].append(final_beam)
672
+ output_i["topk_logprob"][is_end] -= 1000
673
+
674
+ self.beamsearch_process_step(output_i, output_t)
675
+
676
+ if len(output_i["done_beams"]) == beam_size:
677
+ break
678
+
679
+ self.beamsearch_process(output, output_i, input_dict)
680
+ return output
681
+
682
+ def prepare_beamsearch_output(self, input_dict):
683
+ beam_size = input_dict["beam_size"]
684
+ device = input_dict["fc_emb"].device
685
+ output = {
686
+ "topk_logprob": torch.zeros(beam_size).to(device),
687
+ "seq": None,
688
+ "prev_words_beam": None,
689
+ "next_word": None,
690
+ "done_beams": [],
691
+ }
692
+ return output
693
+
694
+ def beamsearch_step(self, input_dict, output_i):
695
+ decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
696
+ output_t = self.decoder(decoder_input)
697
+ output_t["t"] = input_dict["t"]
698
+ return output_t
699
+
700
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
701
+ raise NotImplementedError
702
+
703
+ def beamsearch_process_step(self, output_i, output_t):
704
+ pass
705
+
706
+ def beamsearch_process(self, output, output_i, input_dict):
707
+ i = input_dict["sample_idx"]
708
+ done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
709
+ if input_dict["n_best"]:
710
+ done_beams = done_beams[:input_dict["n_best_size"]]
711
+ for out_idx, done_beam in enumerate(done_beams):
712
+ seq = done_beam["seq"]
713
+ output["seq"][i][out_idx, :len(seq)] = seq
714
+ else:
715
+ seq = done_beams[0]["seq"]
716
+ output["seq"][i][:len(seq)] = seq
717
+
718
+ def diverse_beam_search(self, input_dict):
719
+
720
+ def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
721
+ local_time = t - divm
722
+ unaug_logprob = logprob.clone()
723
+
724
+ if divm > 0:
725
+ change = torch.zeros(logprob.size(-1))
726
+ for prev_choice in range(divm):
727
+ prev_decisions = seq_table[prev_choice][..., local_time]
728
+ for prev_labels in range(bdash):
729
+ change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
730
+
731
+ change = change.to(logprob.device)
732
+ logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
733
+
734
+ return logprob, unaug_logprob
735
+
736
+ output = self.prepare_output(input_dict)
737
+ group_size = input_dict["group_size"]
738
+ batch_size = output["seq"].size(0)
739
+ beam_size = input_dict["beam_size"]
740
+ bdash = beam_size // group_size
741
+ input_dict["bdash"] = bdash
742
+ diversity_lambda = input_dict["diversity_lambda"]
743
+ device = input_dict["fc_emb"].device
744
+ max_length = input_dict["max_length"]
745
+ temp = input_dict["temp"]
746
+ group_nbest = input_dict["group_nbest"]
747
+ batch_size, max_length = output["seq"].size()
748
+ if group_nbest:
749
+ output["seq"] = torch.full((batch_size, beam_size, max_length),
750
+ self.end_idx, dtype=torch.long)
751
+ else:
752
+ output["seq"] = torch.full((batch_size, group_size, max_length),
753
+ self.end_idx, dtype=torch.long)
754
+
755
+
756
+ for i in range(batch_size):
757
+ input_dict["sample_idx"] = i
758
+ seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
759
+ logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
760
+ done_beams_table = [[] for _ in range(group_size)]
761
+
762
+ output_i = {
763
+ "prev_words_beam": [None for _ in range(group_size)],
764
+ "next_word": [None for _ in range(group_size)],
765
+ "state": [None for _ in range(group_size)]
766
+ }
767
+
768
+ for t in range(max_length + group_size - 1):
769
+ input_dict["t"] = t
770
+ for divm in range(group_size):
771
+ input_dict["divm"] = divm
772
+ if t >= divm and t <= max_length + divm - 1:
773
+ local_time = t - divm
774
+ decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
775
+ output_t = self.decoder(decoder_input)
776
+ output_t["divm"] = divm
777
+ logit_t = output_t["logit"]
778
+ if logit_t.size(1) == 1:
779
+ logit_t = logit_t.squeeze(1)
780
+ elif logit_t.size(1) > 1:
781
+ logit_t = logit_t[:, -1, :]
782
+ else:
783
+ raise Exception("no logit output")
784
+ logprob_t = torch.log_softmax(logit_t, dim=1)
785
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
786
+ logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
787
+ logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
788
+ if local_time == 0: # for the first step, all k seq will have the same probs
789
+ topk_logprob, topk_words = logprob_t[0].topk(
790
+ bdash, 0, True, True)
791
+ else: # unroll and find top logprob, and their unrolled indices
792
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
793
+ bdash, 0, True, True)
794
+ topk_words = topk_words.cpu()
795
+ logprob_table[divm] = topk_logprob
796
+ output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
797
+ output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
798
+ if local_time > 0:
799
+ seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
800
+ seq_table[divm] = torch.cat([
801
+ seq_table[divm],
802
+ output_i["next_word"][divm].unsqueeze(-1)], -1)
803
+
804
+ is_end = seq_table[divm][:, t-divm] == self.end_idx
805
+ assert seq_table[divm].shape[-1] == t - divm + 1
806
+ if t == max_length + divm - 1:
807
+ is_end.fill_(1)
808
+ for beam_idx in range(bdash):
809
+ if is_end[beam_idx]:
810
+ final_beam = {
811
+ "seq": seq_table[divm][beam_idx].clone(),
812
+ "score": logprob_table[divm][beam_idx].item()
813
+ }
814
+ final_beam["score"] = final_beam["score"] / (t - divm + 1)
815
+ done_beams_table[divm].append(final_beam)
816
+ logprob_table[divm][is_end] -= 1000
817
+ self.dbs_process_step(output_i, output_t)
818
+ done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
819
+ if group_nbest:
820
+ done_beams = sum(done_beams_table, [])
821
+ else:
822
+ done_beams = [group_beam[0] for group_beam in done_beams_table]
823
+ for _, done_beam in enumerate(done_beams):
824
+ output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
825
+
826
+ return output
827
+
828
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
829
+ raise NotImplementedError
830
+
831
+ def dbs_process_step(self, output_i, output_t):
832
+ pass
833
+
834
+
835
+ class TransformerModel(CaptionModel):
836
+
837
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
838
+ if not hasattr(self, "compatible_decoders"):
839
+ self.compatible_decoders = (
840
+ TransformerDecoder,
841
+ )
842
+ super().__init__(encoder, decoder, **kwargs)
843
+
844
+ def seq_forward(self, input_dict):
845
+ cap = input_dict["cap"]
846
+ cap_padding_mask = (cap == self.pad_idx).to(cap.device)
847
+ cap_padding_mask = cap_padding_mask[:, :-1]
848
+ output = self.decoder(
849
+ {
850
+ "word": cap[:, :-1],
851
+ "attn_emb": input_dict["attn_emb"],
852
+ "attn_emb_len": input_dict["attn_emb_len"],
853
+ "cap_padding_mask": cap_padding_mask
854
+ }
855
+ )
856
+ return output
857
+
858
+ def prepare_decoder_input(self, input_dict, output):
859
+ decoder_input = {
860
+ "attn_emb": input_dict["attn_emb"],
861
+ "attn_emb_len": input_dict["attn_emb_len"]
862
+ }
863
+ t = input_dict["t"]
864
+
865
+ ###############
866
+ # determine input word
867
+ ################
868
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
869
+ word = input_dict["cap"][:, :t+1]
870
+ else:
871
+ start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
872
+ if t == 0:
873
+ word = start_word
874
+ else:
875
+ word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
876
+ # word: [N, T]
877
+ decoder_input["word"] = word
878
+
879
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
880
+ decoder_input["cap_padding_mask"] = cap_padding_mask
881
+ return decoder_input
882
+
883
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
884
+ decoder_input = {}
885
+ t = input_dict["t"]
886
+ i = input_dict["sample_idx"]
887
+ beam_size = input_dict["beam_size"]
888
+ ###############
889
+ # prepare attn embeds
890
+ ################
891
+ if t == 0:
892
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
893
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
894
+ output_i["attn_emb"] = attn_emb
895
+ output_i["attn_emb_len"] = attn_emb_len
896
+ decoder_input["attn_emb"] = output_i["attn_emb"]
897
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
898
+ ###############
899
+ # determine input word
900
+ ################
901
+ start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
902
+ if t == 0:
903
+ word = start_word
904
+ else:
905
+ word = torch.cat((start_word, output_i["seq"]), dim=-1)
906
+ decoder_input["word"] = word
907
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
908
+ decoder_input["cap_padding_mask"] = cap_padding_mask
909
+
910
+ return decoder_input
911
+
912
+
913
+ class BaseDecoder(nn.Module):
914
+ """
915
+ Take word/audio embeddings and output the next word probs
916
+ """
917
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim,
918
+ attn_emb_dim, dropout=0.2, tie_weights=False):
919
+ super().__init__()
920
+ self.emb_dim = emb_dim
921
+ self.vocab_size = vocab_size
922
+ self.fc_emb_dim = fc_emb_dim
923
+ self.attn_emb_dim = attn_emb_dim
924
+ self.tie_weights = tie_weights
925
+ self.word_embedding = nn.Embedding(vocab_size, emb_dim)
926
+ self.in_dropout = nn.Dropout(dropout)
927
+
928
+ def forward(self, x):
929
+ raise NotImplementedError
930
+
931
+ def load_word_embedding(self, weight, freeze=True):
932
+ embedding = np.load(weight)
933
+ assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
934
+ assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
935
+
936
+ # embeddings = torch.as_tensor(embeddings).float()
937
+ # self.word_embeddings.weight = nn.Parameter(embeddings)
938
+ # for para in self.word_embeddings.parameters():
939
+ # para.requires_grad = tune
940
+ self.word_embedding = nn.Embedding.from_pretrained(embedding,
941
+ freeze=freeze)
942
+
943
+
944
+ class PositionalEncoding(nn.Module):
945
+
946
+ def __init__(self, d_model, dropout=0.1, max_len=100):
947
+ super(PositionalEncoding, self).__init__()
948
+ self.dropout = nn.Dropout(p=dropout)
949
+
950
+ pe = torch.zeros(max_len, d_model)
951
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
952
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
953
+ (-math.log(10000.0) / d_model))
954
+ pe[:, 0::2] = torch.sin(position * div_term)
955
+ pe[:, 1::2] = torch.cos(position * div_term)
956
+ pe = pe.unsqueeze(0).transpose(0, 1)
957
+ # self.register_buffer("pe", pe)
958
+ self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
959
+
960
+ def forward(self, x):
961
+ # x: [T, N, E]
962
+ x = x + self.pe[:x.size(0), :]
963
+ return self.dropout(x)
964
+
965
+
966
+ class TransformerDecoder(BaseDecoder):
967
+
968
+ def __init__(self,
969
+ emb_dim,
970
+ vocab_size,
971
+ fc_emb_dim,
972
+ attn_emb_dim,
973
+ dropout,
974
+ freeze=False,
975
+ tie_weights=False,
976
+ **kwargs):
977
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
978
+ dropout=dropout, tie_weights=tie_weights)
979
+ self.d_model = emb_dim
980
+ self.nhead = kwargs.get("nhead", self.d_model // 64)
981
+ self.nlayers = kwargs.get("nlayers", 2)
982
+ self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
983
+
984
+ self.pos_encoder = PositionalEncoding(self.d_model, dropout)
985
+ layer = nn.TransformerDecoderLayer(d_model=self.d_model,
986
+ nhead=self.nhead,
987
+ dim_feedforward=self.dim_feedforward,
988
+ dropout=dropout)
989
+ self.model = nn.TransformerDecoder(layer, self.nlayers)
990
+ self.classifier = nn.Linear(self.d_model, vocab_size, bias=False)
991
+ if tie_weights:
992
+ self.classifier.weight = self.word_embedding.weight
993
+ self.attn_proj = nn.Sequential(
994
+ nn.Linear(self.attn_emb_dim, self.d_model),
995
+ nn.ReLU(),
996
+ nn.Dropout(dropout),
997
+ nn.LayerNorm(self.d_model)
998
+ )
999
+ self.init_params()
1000
+
1001
+ self.freeze = freeze
1002
+ if freeze:
1003
+ for p in self.parameters():
1004
+ p.requires_grad = False
1005
+
1006
+ def init_params(self):
1007
+ for p in self.parameters():
1008
+ if p.dim() > 1:
1009
+ nn.init.xavier_uniform_(p)
1010
+
1011
+ def load_pretrained(self, pretrained, output_fn):
1012
+ checkpoint = torch.load(pretrained, map_location="cpu")
1013
+
1014
+ if "model" in checkpoint:
1015
+ checkpoint = checkpoint["model"]
1016
+ if next(iter(checkpoint)).startswith("decoder."):
1017
+ state_dict = {}
1018
+ for k, v in checkpoint.items():
1019
+ state_dict[k[8:]] = v
1020
+
1021
+ loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
1022
+ if self.freeze:
1023
+ for name, param in self.named_parameters():
1024
+ if name in loaded_keys:
1025
+ param.requires_grad = False
1026
+ else:
1027
+ param.requires_grad = True
1028
+
1029
+
1030
+ def generate_square_subsequent_mask(self, max_length):
1031
+ mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
1032
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
1033
+ return mask
1034
+
1035
+ def forward(self, input_dict):
1036
+ word = input_dict["word"]
1037
+ attn_emb = input_dict["attn_emb"]
1038
+ attn_emb_len = input_dict["attn_emb_len"]
1039
+ cap_padding_mask = input_dict["cap_padding_mask"]
1040
+
1041
+ p_attn_emb = self.attn_proj(attn_emb)
1042
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
1043
+ word = word.to(attn_emb.device)
1044
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
1045
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
1046
+ embed = self.pos_encoder(embed)
1047
+
1048
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
1049
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
1050
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
1051
+ tgt_key_padding_mask=cap_padding_mask,
1052
+ memory_key_padding_mask=memory_key_padding_mask)
1053
+ output = output.transpose(0, 1)
1054
+ output = {
1055
+ "embed": output,
1056
+ "logit": self.classifier(output),
1057
+ }
1058
+ return output
1059
+
1060
+
1061
+ class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin):
1062
+
1063
+ def __init__(self,
1064
+ model: nn.Module,
1065
+ shared_dim: int,
1066
+ tchr_dim: int,
1067
+ ):
1068
+ super().__init__()
1069
+ self.model = model
1070
+ self.tchr_dim = tchr_dim
1071
+ if hasattr(model, "encoder"):
1072
+ self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
1073
+ shared_dim)
1074
+ else:
1075
+ self.stdnt_proj = nn.Linear(model.fc_emb_size,
1076
+ shared_dim)
1077
+ self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
1078
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1079
+
1080
+ def forward(self, input_dict: Dict):
1081
+ unsup = input_dict.get("unsup", False)
1082
+ if unsup is False:
1083
+ output_dict = self.model(input_dict)
1084
+ else:
1085
+ output_dict = self.model.encoder(input_dict)
1086
+ if "tchr_output" in input_dict:
1087
+ stdnt_emb = output_dict["fc_emb"]
1088
+ stdnt_emb = self.stdnt_proj(stdnt_emb)
1089
+ tchr_emb = input_dict["tchr_output"]["embedding"]
1090
+ thcr_emb = self.tchr_proj(tchr_emb)
1091
+
1092
+ stdnt_emb = F.normalize(stdnt_emb, dim=-1)
1093
+ thcr_emb = F.normalize(thcr_emb, dim=-1)
1094
+
1095
+ unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
1096
+ logit = self.logit_scale * unscaled_logit
1097
+ label = torch.arange(logit.shape[0]).to(logit.device)
1098
+ loss1 = F.cross_entropy(logit, label)
1099
+ loss2 = F.cross_entropy(logit.transpose(0, 1), label)
1100
+ loss = (loss1 + loss2) / 2
1101
+ output_dict["enc_kd_loss"] = loss
1102
+ return output_dict
1103
+
1104
+
1105
+ class Effb2TrmConfig(PretrainedConfig):
1106
+
1107
+ def __init__(
1108
+ self,
1109
+ sample_rate: int = 16000,
1110
+ tchr_dim: int = 768,
1111
+ shared_dim: int = 1024,
1112
+ fc_emb_dim: int = 1408,
1113
+ attn_emb_dim: int = 1408,
1114
+ decoder_n_layers: int = 2,
1115
+ decoder_we_tie_weights: bool = True,
1116
+ decoder_emb_dim: int = 256,
1117
+ decoder_dropout: float = 0.2,
1118
+ vocab_size: int = 4981,
1119
+ **kwargs
1120
+ ):
1121
+ self.sample_rate = sample_rate
1122
+ self.tchr_dim = tchr_dim
1123
+ self.shared_dim = shared_dim
1124
+ self.fc_emb_dim = fc_emb_dim
1125
+ self.attn_emb_dim = attn_emb_dim
1126
+ self.decoder_n_layers = decoder_n_layers
1127
+ self.decoder_we_tie_weights = decoder_we_tie_weights
1128
+ self.decoder_emb_dim = decoder_emb_dim
1129
+ self.decoder_dropout = decoder_dropout
1130
+ self.vocab_size = vocab_size
1131
+ super().__init__(**kwargs)
1132
+
1133
+
1134
+ class Effb2TrmCaptioningModel(PreTrainedModel):
1135
+ config_class = Effb2TrmConfig
1136
+
1137
+ def __init__(self, config):
1138
+ super().__init__(config)
1139
+ encoder = EfficientNetB2()
1140
+ decoder = TransformerDecoder(
1141
+ emb_dim=config.decoder_emb_dim,
1142
+ vocab_size=config.vocab_size,
1143
+ fc_emb_dim=config.fc_emb_dim,
1144
+ attn_emb_dim=config.attn_emb_dim,
1145
+ dropout=config.decoder_dropout,
1146
+ nlayers=config.decoder_n_layers,
1147
+ tie_weights=config.decoder_we_tie_weights
1148
+ )
1149
+ model = TransformerModel(encoder, decoder)
1150
+ self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim)
1151
+
1152
+ def forward(self,
1153
+ audio: torch.Tensor,
1154
+ audio_length: Union[List, np.ndarray, torch.Tensor],
1155
+ sample_method: str = "beam",
1156
+ beam_size: int = 3,
1157
+ max_length: int = 20,
1158
+ temp: float = 1.0,):
1159
+ device = self.device
1160
+ input_dict = {
1161
+ "wav": audio.to(device),
1162
+ "wav_len": audio_length,
1163
+ "specaug": False,
1164
+ "mode": "inference",
1165
+ "sample_method": sample_method,
1166
+ "max_length": max_length,
1167
+ "temp": temp,
1168
+ }
1169
+ if sample_method == "beam":
1170
+ input_dict["beam_size"] = beam_size
1171
+ return self.model(input_dict)["seq"].cpu()
1172
+
1173
+
1174
+ class ConvBlock(nn.Module):
1175
+
1176
+ def __init__(self, in_channels, out_channels):
1177
+
1178
+ super(ConvBlock, self).__init__()
1179
+
1180
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
1181
+ out_channels=out_channels,
1182
+ kernel_size=(3, 3), stride=(1, 1),
1183
+ padding=(1, 1), bias=False)
1184
+
1185
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
1186
+ out_channels=out_channels,
1187
+ kernel_size=(3, 3), stride=(1, 1),
1188
+ padding=(1, 1), bias=False)
1189
+
1190
+ self.bn1 = nn.BatchNorm2d(out_channels)
1191
+ self.bn2 = nn.BatchNorm2d(out_channels)
1192
+
1193
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
1194
+
1195
+ x = input
1196
+ x = F.relu_(self.bn1(self.conv1(x)))
1197
+ x = F.relu_(self.bn2(self.conv2(x)))
1198
+ if pool_type == 'max':
1199
+ x = F.max_pool2d(x, kernel_size=pool_size)
1200
+ elif pool_type == 'avg':
1201
+ x = F.avg_pool2d(x, kernel_size=pool_size)
1202
+ elif pool_type == 'avg+max':
1203
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
1204
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
1205
+ x = x1 + x2
1206
+ else:
1207
+ raise Exception('Incorrect argument!')
1208
+
1209
+ return x
1210
+
1211
+
1212
+ class Cnn14Encoder(nn.Module):
1213
+
1214
+ def __init__(self, sample_rate=32000):
1215
+ super().__init__()
1216
+ sr_to_fmax = {
1217
+ 32000: 14000,
1218
+ 16000: 8000
1219
+ }
1220
+ # Logmel spectrogram extractor
1221
+ self.melspec_extractor = transforms.MelSpectrogram(
1222
+ sample_rate=sample_rate,
1223
+ n_fft=32 * sample_rate // 1000,
1224
+ win_length=32 * sample_rate // 1000,
1225
+ hop_length=10 * sample_rate // 1000,
1226
+ f_min=50,
1227
+ f_max=sr_to_fmax[sample_rate],
1228
+ n_mels=64,
1229
+ norm="slaney",
1230
+ mel_scale="slaney"
1231
+ )
1232
+ self.hop_length = 10 * sample_rate // 1000
1233
+ self.db_transform = transforms.AmplitudeToDB()
1234
+
1235
+ self.bn0 = nn.BatchNorm2d(64)
1236
+
1237
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
1238
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
1239
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
1240
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
1241
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
1242
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
1243
+
1244
+ self.downsample_ratio = 32
1245
+
1246
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
1247
+ self.fc_emb_size = 2048
1248
+
1249
+ def forward(self, input_dict):
1250
+ lms = input_dict["lms"]
1251
+ wave_length = input_dict["wav_len"]
1252
+
1253
+ x = lms # (batch_size, mel_bins, time_steps)
1254
+ x = x.transpose(1, 2)
1255
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
1256
+
1257
+ x = x.transpose(1, 3)
1258
+ x = self.bn0(x)
1259
+ x = x.transpose(1, 3)
1260
+
1261
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
1262
+ x = F.dropout(x, p=0.2, training=self.training)
1263
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
1264
+ x = F.dropout(x, p=0.2, training=self.training)
1265
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
1266
+ x = F.dropout(x, p=0.2, training=self.training)
1267
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
1268
+ x = F.dropout(x, p=0.2, training=self.training)
1269
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
1270
+ x = F.dropout(x, p=0.2, training=self.training)
1271
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
1272
+ x = F.dropout(x, p=0.2, training=self.training)
1273
+ x = torch.mean(x, dim=3)
1274
+ attn_emb = x.transpose(1, 2)
1275
+
1276
+ wave_length = torch.as_tensor(wave_length)
1277
+ feat_length = torch.div(wave_length, self.hop_length,
1278
+ rounding_mode="floor") + 1
1279
+ feat_length = torch.div(feat_length, self.downsample_ratio,
1280
+ rounding_mode="floor")
1281
+ x_max = max_with_lens(attn_emb, feat_length)
1282
+ x_mean = mean_with_lens(attn_emb, feat_length)
1283
+ x = x_max + x_mean
1284
+ x = F.dropout(x, p=0.5, training=self.training)
1285
+ x = F.relu_(self.fc1(x))
1286
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
1287
+
1288
+ output_dict = {
1289
+ 'fc_emb': fc_emb,
1290
+ 'attn_emb': attn_emb,
1291
+ 'attn_emb_len': feat_length
1292
+ }
1293
+
1294
+ return output_dict
1295
+
1296
+
1297
+ class RnnEncoder(nn.Module):
1298
+
1299
+ def __init__(self,
1300
+ attn_feat_dim,
1301
+ pooling="mean",
1302
+ **kwargs):
1303
+ super().__init__()
1304
+ self.pooling = pooling
1305
+ self.hidden_size = kwargs.get('hidden_size', 512)
1306
+ self.bidirectional = kwargs.get('bidirectional', False)
1307
+ self.num_layers = kwargs.get('num_layers', 1)
1308
+ self.dropout = kwargs.get('dropout', 0.2)
1309
+ self.rnn_type = kwargs.get('rnn_type', "GRU")
1310
+ self.in_bn = kwargs.get('in_bn', False)
1311
+ self.embed_dim = self.hidden_size * (self.bidirectional + 1)
1312
+ self.network = getattr(nn, self.rnn_type)(
1313
+ attn_feat_dim,
1314
+ self.hidden_size,
1315
+ num_layers=self.num_layers,
1316
+ bidirectional=self.bidirectional,
1317
+ dropout=self.dropout,
1318
+ batch_first=True)
1319
+ if self.in_bn:
1320
+ self.bn = nn.BatchNorm1d(self.embed_dim)
1321
+
1322
+ def forward(self, input_dict):
1323
+ x = input_dict["attn"]
1324
+ lens = input_dict["attn_len"]
1325
+ lens = torch.as_tensor(lens)
1326
+ # x: [N, T, E]
1327
+ if self.in_bn:
1328
+ x = pack_wrapper(self.bn, x, lens)
1329
+ out = pack_wrapper(self.network, x, lens)
1330
+ # out: [N, T, hidden]
1331
+ attn_emb = out
1332
+ fc_emb = embedding_pooling(out, lens, self.pooling)
1333
+ return {
1334
+ "attn_emb": attn_emb,
1335
+ "fc_emb": fc_emb,
1336
+ "attn_emb_len": lens
1337
+ }
1338
+
1339
+
1340
+ class Cnn14RnnEncoder(nn.Module):
1341
+
1342
+ def __init__(self,
1343
+ sample_rate,
1344
+ rnn_bidirectional,
1345
+ rnn_hidden_size,
1346
+ rnn_dropout,
1347
+ rnn_num_layers):
1348
+ super().__init__()
1349
+ self.cnn = Cnn14Encoder(sample_rate=sample_rate)
1350
+ self.rnn = RnnEncoder(
1351
+ 2048,
1352
+ bidirectional=rnn_bidirectional,
1353
+ hidden_size=rnn_hidden_size,
1354
+ dropout=rnn_dropout,
1355
+ num_layers=rnn_num_layers,
1356
+ )
1357
+
1358
+ def forward(self, input_dict):
1359
+ output_dict = self.cnn(input_dict)
1360
+ output_dict["attn"] = output_dict["attn_emb"]
1361
+ output_dict["attn_len"] = output_dict["attn_emb_len"]
1362
+ del output_dict["attn_emb"], output_dict["attn_emb_len"]
1363
+ output_dict = self.rnn(output_dict)
1364
+ return output_dict
1365
+
1366
+
1367
+ class Seq2SeqAttention(nn.Module):
1368
+
1369
+ def __init__(self, hs_enc, hs_dec, attn_size):
1370
+ """
1371
+ Args:
1372
+ hs_enc: encoder hidden size
1373
+ hs_dec: decoder hidden size
1374
+ attn_size: attention vector size
1375
+ """
1376
+ super(Seq2SeqAttention, self).__init__()
1377
+ self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size)
1378
+ self.v = nn.Parameter(torch.randn(attn_size))
1379
+
1380
+ def forward(self, h_dec, h_enc, src_lens):
1381
+ """
1382
+ Args:
1383
+ h_dec: decoder hidden (query), [N, hs_dec]
1384
+ h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
1385
+ src_lens: source (encoder memory) lengths, [N, ]
1386
+ """
1387
+ N = h_enc.size(0)
1388
+ src_max_len = h_enc.size(1)
1389
+ h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
1390
+
1391
+ attn_input = torch.cat((h_dec, h_enc), dim=-1)
1392
+ attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
1393
+
1394
+ v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
1395
+ score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
1396
+
1397
+ idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
1398
+ mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
1399
+
1400
+ score = score.masked_fill(mask == 0, -1e10)
1401
+ weights = torch.softmax(score, dim=-1) # [N, src_max_len]
1402
+ ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
1403
+
1404
+ return ctx, weights
1405
+
1406
+
1407
+ class RnnDecoder(BaseDecoder):
1408
+
1409
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1410
+ dropout, d_model, **kwargs):
1411
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1412
+ dropout,)
1413
+ self.d_model = d_model
1414
+ self.num_layers = kwargs.get('num_layers', 1)
1415
+ self.bidirectional = kwargs.get('bidirectional', False)
1416
+ self.rnn_type = kwargs.get('rnn_type', "GRU")
1417
+ self.classifier = nn.Linear(
1418
+ self.d_model * (self.bidirectional + 1), vocab_size)
1419
+
1420
+ def forward(self, x):
1421
+ raise NotImplementedError
1422
+
1423
+ def init_hidden(self, bs, device):
1424
+ num_dire = self.bidirectional + 1
1425
+ n_layer = self.num_layers
1426
+ hid_dim = self.d_model
1427
+ if self.rnn_type == "LSTM":
1428
+ return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device),
1429
+ torch.zeros(num_dire * n_layer, bs, hid_dim).to(device))
1430
+ else:
1431
+ return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)
1432
+
1433
+
1434
+ class BahAttnCatFcDecoder(RnnDecoder):
1435
+
1436
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1437
+ dropout, d_model, **kwargs):
1438
+ """
1439
+ concatenate fc, attn, word to feed to the rnn
1440
+ """
1441
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1442
+ dropout, d_model, **kwargs)
1443
+ attn_size = kwargs.get("attn_size", self.d_model)
1444
+ self.model = getattr(nn, self.rnn_type)(
1445
+ input_size=self.emb_dim * 3,
1446
+ hidden_size=self.d_model,
1447
+ batch_first=True,
1448
+ num_layers=self.num_layers,
1449
+ bidirectional=self.bidirectional)
1450
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
1451
+ self.d_model * (self.bidirectional + 1) * \
1452
+ self.num_layers,
1453
+ attn_size)
1454
+ self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
1455
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
1456
+
1457
+ def forward(self, input_dict):
1458
+ word = input_dict["word"]
1459
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
1460
+ fc_emb = input_dict["fc_emb"]
1461
+ attn_emb = input_dict["attn_emb"]
1462
+ attn_emb_len = input_dict["attn_emb_len"]
1463
+
1464
+ word = word.to(fc_emb.device)
1465
+ embed = self.in_dropout(self.word_embedding(word))
1466
+
1467
+ # embed: [N, 1, embed_size]
1468
+ if state is None:
1469
+ state = self.init_hidden(word.size(0), fc_emb.device)
1470
+ if self.rnn_type == "LSTM":
1471
+ query = state[0].transpose(0, 1).flatten(1)
1472
+ else:
1473
+ query = state.transpose(0, 1).flatten(1)
1474
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
1475
+
1476
+ p_fc_emb = self.fc_proj(fc_emb)
1477
+ p_ctx = self.ctx_proj(c)
1478
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)),
1479
+ dim=-1)
1480
+
1481
+ out, state = self.model(rnn_input, state)
1482
+
1483
+ output = {
1484
+ "state": state,
1485
+ "embed": out,
1486
+ "logit": self.classifier(out),
1487
+ "attn_weight": attn_weight
1488
+ }
1489
+ return output
1490
+
1491
+
1492
+ class TemporalBahAttnDecoder(BahAttnCatFcDecoder):
1493
+
1494
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1495
+ dropout, d_model, **kwargs):
1496
+ """
1497
+ concatenate fc, attn, word to feed to the rnn
1498
+ """
1499
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1500
+ dropout, d_model, **kwargs)
1501
+ self.temporal_embedding = nn.Embedding(4, emb_dim)
1502
+
1503
+ def forward(self, input_dict):
1504
+ word = input_dict["word"]
1505
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
1506
+ fc_embs = input_dict["fc_emb"]
1507
+ attn_embs = input_dict["attn_emb"]
1508
+ attn_emb_lens = input_dict["attn_emb_len"]
1509
+ temporal_tag = input_dict["temporal_tag"]
1510
+
1511
+ if input_dict["t"] == 0:
1512
+ embed = self.in_dropout(
1513
+ self.temporal_embedding(temporal_tag)).unsqueeze(1)
1514
+ elif word.size(-1) == self.fc_emb_dim: # fc_embs
1515
+ embed = word.unsqueeze(1)
1516
+ elif word.size(-1) == 1: # word
1517
+ word = word.to(fc_embs.device)
1518
+ embed = self.in_dropout(self.word_embedding(word))
1519
+ else:
1520
+ raise Exception(f"problem with word input size {word.size()}")
1521
+
1522
+ # embed: [N, 1, embed_size]
1523
+ if state is None:
1524
+ state = self.init_hidden(word.size(0), fc_embs.device)
1525
+ if self.rnn_type == "LSTM":
1526
+ query = state[0].transpose(0, 1).flatten(1)
1527
+ else:
1528
+ query = state.transpose(0, 1).flatten(1)
1529
+ c, attn_weight = self.attn(query, attn_embs, attn_emb_lens)
1530
+
1531
+ p_ctx = self.ctx_proj(c)
1532
+ p_fc_embs = self.fc_proj(fc_embs)
1533
+ p_ctx = self.ctx_proj(c)
1534
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_embs.unsqueeze(1)), dim=-1)
1535
+
1536
+ out, state = self.model(rnn_input, state)
1537
+
1538
+ output = {
1539
+ "state": state,
1540
+ "embed": out,
1541
+ "logit": self.classifier(out),
1542
+ "attn_weight": attn_weight
1543
+ }
1544
+ return output
1545
+
1546
+
1547
+ class Seq2SeqAttnModel(CaptionModel):
1548
+
1549
+ def __init__(self, encoder, decoder, **kwargs):
1550
+ if not hasattr(self, "compatible_decoders"):
1551
+ self.compatible_decoders = (
1552
+ BahAttnCatFcDecoder,
1553
+ )
1554
+ super().__init__(encoder, decoder, **kwargs)
1555
+
1556
+
1557
+ def seq_forward(self, input_dict):
1558
+ # Bahdanau attention only supports step-by-step implementation, so we implement forward in
1559
+ # step-by-step manner whether in training or evaluation
1560
+ return self.stepwise_forward(input_dict)
1561
+
1562
+ def prepare_output(self, input_dict):
1563
+ output = super().prepare_output(input_dict)
1564
+ attn_weight = torch.empty(output["seq"].size(0),
1565
+ input_dict["attn_emb"].size(1), output["seq"].size(1))
1566
+ output["attn_weight"] = attn_weight
1567
+ return output
1568
+
1569
+ def prepare_decoder_input(self, input_dict, output):
1570
+ decoder_input = {
1571
+ "fc_emb": input_dict["fc_emb"],
1572
+ "attn_emb": input_dict["attn_emb"],
1573
+ "attn_emb_len": input_dict["attn_emb_len"]
1574
+ }
1575
+ t = input_dict["t"]
1576
+ ###############
1577
+ # determine input word
1578
+ ################
1579
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
1580
+ word = input_dict["cap"][:, t]
1581
+ else:
1582
+ if t == 0:
1583
+ word = torch.tensor([self.start_idx,] * input_dict["fc_emb"].size(0)).long()
1584
+ else:
1585
+ word = output["seq"][:, t-1]
1586
+ # word: [N,]
1587
+ decoder_input["word"] = word.unsqueeze(1)
1588
+
1589
+ ################
1590
+ # prepare rnn state
1591
+ ################
1592
+ if t > 0:
1593
+ decoder_input["state"] = output["state"]
1594
+ return decoder_input
1595
+
1596
+ def stepwise_process_step(self, output, output_t):
1597
+ super().stepwise_process_step(output, output_t)
1598
+ output["state"] = output_t["state"]
1599
+ t = output_t["t"]
1600
+ output["attn_weight"][:, :, t] = output_t["attn_weight"]
1601
+
1602
+ def prepare_beamsearch_output(self, input_dict):
1603
+ output = super().prepare_beamsearch_output(input_dict)
1604
+ beam_size = input_dict["beam_size"]
1605
+ max_length = input_dict["max_length"]
1606
+ output["attn_weight"] = torch.empty(beam_size,
1607
+ max(input_dict["attn_emb_len"]), max_length)
1608
+ return output
1609
+
1610
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
1611
+ decoder_input = {}
1612
+ t = input_dict["t"]
1613
+ i = input_dict["sample_idx"]
1614
+ beam_size = input_dict["beam_size"]
1615
+ ###############
1616
+ # prepare fc embeds
1617
+ ################
1618
+ if t == 0:
1619
+ fc_emb = repeat_tensor(input_dict["fc_emb"][i], beam_size)
1620
+ output_i["fc_emb"] = fc_emb
1621
+ decoder_input["fc_emb"] = output_i["fc_emb"]
1622
+
1623
+ ###############
1624
+ # prepare attn embeds
1625
+ ################
1626
+ if t == 0:
1627
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
1628
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
1629
+ output_i["attn_emb"] = attn_emb
1630
+ output_i["attn_emb_len"] = attn_emb_len
1631
+ decoder_input["attn_emb"] = output_i["attn_emb"]
1632
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
1633
+
1634
+ ###############
1635
+ # determine input word
1636
+ ################
1637
+ if t == 0:
1638
+ word = torch.tensor([self.start_idx,] * beam_size).long()
1639
+ else:
1640
+ word = output_i["next_word"]
1641
+ decoder_input["word"] = word.unsqueeze(1)
1642
+
1643
+ ################
1644
+ # prepare rnn state
1645
+ ################
1646
+ if t > 0:
1647
+ if self.decoder.rnn_type == "LSTM":
1648
+ decoder_input["state"] = (output_i["state"][0][:, output_i["prev_words_beam"], :].contiguous(),
1649
+ output_i["state"][1][:, output_i["prev_words_beam"], :].contiguous())
1650
+ else:
1651
+ decoder_input["state"] = output_i["state"][:, output_i["prev_words_beam"], :].contiguous()
1652
+
1653
+ return decoder_input
1654
+
1655
+ def beamsearch_process_step(self, output_i, output_t):
1656
+ t = output_t["t"]
1657
+ output_i["state"] = output_t["state"]
1658
+ output_i["attn_weight"][..., t] = output_t["attn_weight"]
1659
+ output_i["attn_weight"] = output_i["attn_weight"][output_i["prev_words_beam"], ...]
1660
+
1661
+ def beamsearch_process(self, output, output_i, input_dict):
1662
+ super().beamsearch_process(output, output_i, input_dict)
1663
+ i = input_dict["sample_idx"]
1664
+ output["attn_weight"][i] = output_i["attn_weight"][0]
1665
+
1666
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
1667
+ decoder_input = {}
1668
+ t = input_dict["t"]
1669
+ i = input_dict["sample_idx"]
1670
+ bdash = input_dict["bdash"]
1671
+ divm = input_dict["divm"]
1672
+
1673
+ local_time = t - divm
1674
+ ###############
1675
+ # prepare fc embeds
1676
+ ################
1677
+ # repeat only at the first timestep to save consumption
1678
+ if t == 0:
1679
+ fc_emb = repeat_tensor(input_dict["fc_emb"][i], bdash).unsqueeze(1)
1680
+ output_i["fc_emb"] = fc_emb
1681
+ decoder_input["fc_emb"] = output_i["fc_emb"]
1682
+
1683
+ ###############
1684
+ # prepare attn embeds
1685
+ ################
1686
+ if t == 0:
1687
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], bdash)
1688
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], bdash)
1689
+ output_i["attn_emb"] = attn_emb
1690
+ output_i["attn_emb_len"] = attn_emb_len
1691
+ decoder_input["attn_emb"] = output_i["attn_emb"]
1692
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
1693
+
1694
+ ###############
1695
+ # determine input word
1696
+ ################
1697
+ if local_time == 0:
1698
+ word = torch.tensor([self.start_idx,] * bdash).long()
1699
+ else:
1700
+ word = output_i["next_word"][divm]
1701
+ decoder_input["word"] = word.unsqueeze(1)
1702
+
1703
+ ################
1704
+ # prepare rnn state
1705
+ ################
1706
+ if local_time > 0:
1707
+ if self.decoder.rnn_type == "LSTM":
1708
+ decoder_input["state"] = (
1709
+ output_i["state"][0][divm][
1710
+ :, output_i["prev_words_beam"][divm], :].contiguous(),
1711
+ output_i["state"][1][divm][
1712
+ :, output_i["prev_words_beam"][divm], :].contiguous()
1713
+ )
1714
+ else:
1715
+ decoder_input["state"] = output_i["state"][divm][
1716
+ :, output_i["prev_words_beam"][divm], :].contiguous()
1717
+
1718
+ return decoder_input
1719
+
1720
+ def dbs_process_step(self, output_i, output_t):
1721
+ divm = output_t["divm"]
1722
+ output_i["state"][divm] = output_t["state"]
1723
+ # TODO attention weight
1724
+
1725
+
1726
+ class TemporalSeq2SeqAttnModel(Seq2SeqAttnModel):
1727
+
1728
+ def __init__(self, encoder, decoder, **kwargs):
1729
+ if not hasattr(self, "compatible_decoders"):
1730
+ self.compatible_decoders = (
1731
+ TemporalBahAttnDecoder,
1732
+ )
1733
+ super().__init__(encoder, decoder, **kwargs)
1734
+ self.train_forward_keys = ["cap", "cap_len", "ss_ratio", "temporal_tag"]
1735
+ self.inference_forward_keys = ["sample_method", "max_length", "temp", "temporal_tag"]
1736
+
1737
+
1738
+ def prepare_decoder_input(self, input_dict, output):
1739
+ decoder_input = super().prepare_decoder_input(input_dict, output)
1740
+ decoder_input["temporal_tag"] = input_dict["temporal_tag"]
1741
+ decoder_input["t"] = input_dict["t"]
1742
+
1743
+ return decoder_input
1744
+
1745
+
1746
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
1747
+ decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
1748
+ t = input_dict["t"]
1749
+ i = input_dict["sample_idx"]
1750
+ beam_size = input_dict["beam_size"]
1751
+ ###############
1752
+ # prepare temporal_tag
1753
+ ################
1754
+ if t == 0:
1755
+ temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], beam_size)
1756
+ output_i["temporal_tag"] = temporal_tag
1757
+ decoder_input["temporal_tag"] = output_i["temporal_tag"]
1758
+ decoder_input["t"] = input_dict["t"]
1759
+
1760
+ return decoder_input
1761
+
1762
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
1763
+ decoder_input = super.prepare_dbs_decoder_input(input_dict, output_i)
1764
+ t = input_dict["t"]
1765
+ i = input_dict["sample_idx"]
1766
+ bdash = input_dict["bdash"]
1767
+
1768
+ ###############
1769
+ # prepare temporal tag
1770
+ ################
1771
+ # repeat only at the first timestep to save consumption
1772
+ if t == 0:
1773
+ temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], bdash)
1774
+ output_i["temporal_tag"] = temporal_tag
1775
+ decoder_input["temporal_tag"] = output_i["temporal_tag"]
1776
+ decoder_input["t"] = input_dict["t"]
1777
+
1778
+ return decoder_input
1779
+
1780
+
1781
+ class Cnn8rnnSedModel(nn.Module):
1782
+ def __init__(self, classes_num):
1783
+
1784
+ super().__init__()
1785
+
1786
+ self.time_resolution = 0.01
1787
+ self.interpolate_ratio = 4 # Downsampled ratio
1788
+
1789
+ self.bn0 = nn.BatchNorm2d(64)
1790
+
1791
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
1792
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
1793
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
1794
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
1795
+
1796
+ self.fc1 = nn.Linear(512, 512, bias=True)
1797
+ self.rnn = nn.GRU(512, 256, bidirectional=True, batch_first=True)
1798
+ self.fc_audioset = nn.Linear(512, classes_num, bias=True)
1799
+
1800
+ def forward(self, lms):
1801
+ output = self.forward_prob(lms)
1802
+ framewise_output = output["framewise_output"].cpu().numpy()
1803
+ thresholded_predictions = double_threshold(
1804
+ framewise_output, 0.75, 0.25)
1805
+ decoded_tags = decode_with_timestamps(
1806
+ thresholded_predictions, self.time_resolution
1807
+ )
1808
+ return decoded_tags
1809
+
1810
+ def forward_prob(self, lms):
1811
+ """
1812
+ lms: (batch_size, mel_bins, time_steps)"""
1813
+
1814
+ x = lms
1815
+ x = x.transpose(1, 2)
1816
+ x = x.unsqueeze(1)
1817
+
1818
+ frames_num = x.shape[2]
1819
+
1820
+ x = x.transpose(1, 3)
1821
+ x = self.bn0(x)
1822
+ x = x.transpose(1, 3)
1823
+
1824
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg+max')
1825
+ x = F.dropout(x, p=0.2, training=self.training)
1826
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg+max')
1827
+ x = F.dropout(x, p=0.2, training=self.training)
1828
+ x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg+max')
1829
+ x = F.dropout(x, p=0.2, training=self.training)
1830
+ x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg+max')
1831
+ x = F.dropout(x, p=0.2, training=self.training) # (batch_size, 256, time_steps / 4, mel_bins / 16)
1832
+ x = torch.mean(x, dim=3)
1833
+
1834
+ x = x.transpose(1, 2)
1835
+ x = F.dropout(x, p=0.5, training=self.training)
1836
+ x = F.relu_(self.fc1(x))
1837
+ x, _ = self.rnn(x)
1838
+ segmentwise_output = torch.sigmoid(self.fc_audioset(x)).clamp(1e-7, 1.)
1839
+
1840
+ framewise_output = interpolate(segmentwise_output,
1841
+ self.interpolate_ratio)
1842
+ framewise_output = pad_framewise_output(framewise_output, frames_num)
1843
+
1844
+ output_dict = {
1845
+ "segmentwise_output": segmentwise_output,
1846
+ 'framewise_output': framewise_output,
1847
+ }
1848
+
1849
+ return output_dict
1850
+
1851
+
1852
+ class Cnn14RnnTempAttnGruConfig(PretrainedConfig):
1853
+
1854
+ def __init__(
1855
+ self,
1856
+ sample_rate: int = 32000,
1857
+ encoder_rnn_bidirectional: bool = True,
1858
+ encoder_rnn_hidden_size: int = 256,
1859
+ encoder_rnn_dropout: float = 0.5,
1860
+ encoder_rnn_num_layers: int = 3,
1861
+ decoder_emb_dim: int = 512,
1862
+ vocab_size: int = 4981,
1863
+ fc_emb_dim: int = 512,
1864
+ attn_emb_dim: int = 512,
1865
+ decoder_rnn_type: str = "GRU",
1866
+ decoder_num_layers: int = 1,
1867
+ decoder_d_model: int = 512,
1868
+ decoder_dropout: float = 0.5,
1869
+ **kwargs
1870
+ ):
1871
+ self.sample_rate = sample_rate
1872
+ self.encoder_rnn_bidirectional = encoder_rnn_bidirectional
1873
+ self.encoder_rnn_hidden_size = encoder_rnn_hidden_size
1874
+ self.encoder_rnn_dropout = encoder_rnn_dropout
1875
+ self.encoder_rnn_num_layers = encoder_rnn_num_layers
1876
+ self.decoder_emb_dim = decoder_emb_dim
1877
+ self.vocab_size = vocab_size
1878
+ self.fc_emb_dim = fc_emb_dim
1879
+ self.attn_emb_dim = attn_emb_dim
1880
+ self.decoder_rnn_type = decoder_rnn_type
1881
+ self.decoder_num_layers = decoder_num_layers
1882
+ self.decoder_d_model = decoder_d_model
1883
+ self.decoder_dropout = decoder_dropout
1884
+ super().__init__(**kwargs)
1885
+
1886
+
1887
+ class Cnn14RnnTempAttnGruModel(PreTrainedModel):
1888
+ config_class = Cnn14RnnTempAttnGruConfig
1889
+
1890
+ def __init__(self, config):
1891
+ super().__init__(config)
1892
+ sample_rate = config.sample_rate
1893
+ sr_to_fmax = {
1894
+ 32000: 14000,
1895
+ 16000: 8000
1896
+ }
1897
+ self.melspec_extractor = transforms.MelSpectrogram(
1898
+ sample_rate=sample_rate,
1899
+ n_fft=32 * sample_rate // 1000,
1900
+ win_length=32 * sample_rate // 1000,
1901
+ hop_length=10 * sample_rate // 1000,
1902
+ f_min=50,
1903
+ f_max=sr_to_fmax[sample_rate],
1904
+ n_mels=64,
1905
+ norm="slaney",
1906
+ mel_scale="slaney"
1907
+ )
1908
+ self.db_transform = transforms.AmplitudeToDB()
1909
+
1910
+ encoder = Cnn14RnnEncoder(
1911
+ sample_rate=config.sample_rate,
1912
+ rnn_bidirectional=config.encoder_rnn_bidirectional,
1913
+ rnn_hidden_size=config.encoder_rnn_hidden_size,
1914
+ rnn_dropout=config.encoder_rnn_dropout,
1915
+ rnn_num_layers=config.encoder_rnn_num_layers
1916
+ )
1917
+ decoder = TemporalBahAttnDecoder(
1918
+ emb_dim=config.decoder_emb_dim,
1919
+ vocab_size=config.vocab_size,
1920
+ fc_emb_dim=config.fc_emb_dim,
1921
+ attn_emb_dim=config.attn_emb_dim,
1922
+ rnn_type=config.decoder_rnn_type,
1923
+ num_layers=config.decoder_num_layers,
1924
+ d_model=config.decoder_d_model,
1925
+ dropout=config.decoder_dropout,
1926
+ )
1927
+ cap_model = TemporalSeq2SeqAttnModel(encoder, decoder)
1928
+ sed_model = Cnn8rnnSedModel(classes_num=447)
1929
+ self.cap_model = cap_model
1930
+ self.sed_model = sed_model
1931
+
1932
+ def forward(self,
1933
+ audio: torch.Tensor,
1934
+ audio_length: Union[List, np.ndarray, torch.Tensor],
1935
+ temporal_tag: Union[List, np.ndarray, torch.Tensor] = None,
1936
+ sample_method: str = "beam",
1937
+ beam_size: int = 3,
1938
+ max_length: int = 20,
1939
+ temp: float = 1.0,):
1940
+ device = self.device
1941
+ mel_spec = self.melspec_extractor(audio.to(device))
1942
+ log_mel_spec = self.db_transform(mel_spec)
1943
+
1944
+ sed_tag = self.sed_model(log_mel_spec)
1945
+ sed_tag = torch.as_tensor(sed_tag).to(device)
1946
+ if temporal_tag is not None:
1947
+ temporal_tag = torch.as_tensor(temporal_tag).to(device)
1948
+ temporal_tag = torch.stack([temporal_tag, sed_tag], dim=0)
1949
+ temporal_tag = torch.min(temporal_tag, dim=0).values
1950
+ else:
1951
+ temporal_tag = sed_tag
1952
+
1953
+ input_dict = {
1954
+ "lms": log_mel_spec,
1955
+ "wav_len": audio_length,
1956
+ "temporal_tag": temporal_tag,
1957
+ "mode": "inference",
1958
+ "sample_method": sample_method,
1959
+ "max_length": max_length,
1960
+ "temp": temp,
1961
+ }
1962
+ if sample_method == "beam":
1963
+ input_dict["beam_size"] = beam_size
1964
+ return self.cap_model(input_dict)["seq"].cpu()
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d5d51984ac220288b130d04f53652a2aec21e7f2cd275c0a48ed1648f6ace16
3
+ size 55324025