wsntxxn commited on
Commit
9372475
·
verified ·
1 Parent(s): a6295f4

Upload model

Browse files
Files changed (3) hide show
  1. config.json +5 -0
  2. hf_wrapper.py +1974 -0
  3. pytorch_model.bin +2 -2
config.json CHANGED
@@ -1,8 +1,13 @@
1
  {
 
2
  "architectures": [
3
  "Effb2TrmCaptioningModel"
4
  ],
5
  "attn_emb_dim": 1408,
 
 
 
 
6
  "decoder_dropout": 0.2,
7
  "decoder_emb_dim": 256,
8
  "decoder_n_layers": 2,
 
1
  {
2
+ "_name_or_path": "/mnt/cloudstorfs/sjtu_home/xuenan.xu/hf_cache/hub/models--wsntxxn--effb2-trm-clotho-captioning/snapshots/a6295f4d7e0f2314bd3fc02b03512814c63fc132/",
3
  "architectures": [
4
  "Effb2TrmCaptioningModel"
5
  ],
6
  "attn_emb_dim": 1408,
7
+ "auto_map": {
8
+ "AutoConfig": "hf_wrapper.Effb2TrmConfig",
9
+ "AutoModel": "hf_wrapper.Effb2TrmCaptioningModel"
10
+ },
11
  "decoder_dropout": 0.2,
12
  "decoder_emb_dim": 256,
13
  "decoder_n_layers": 2,
hf_wrapper.py ADDED
@@ -0,0 +1,1974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Callable, Union, List
2
+ import random
3
+ import math
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
11
+ from torchaudio import transforms
12
+ from efficientnet_pytorch import EfficientNet
13
+ from efficientnet_pytorch import utils as efficientnet_utils
14
+ from einops import rearrange, reduce
15
+ from transformers import PretrainedConfig, PreTrainedModel
16
+
17
+
18
+ def sort_pack_padded_sequence(input, lengths):
19
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
20
+ tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
21
+ inv_ix = indices.clone()
22
+ inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
23
+ return tmp, inv_ix
24
+
25
+ def pad_unsort_packed_sequence(input, inv_ix):
26
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
27
+ tmp = tmp[inv_ix]
28
+ return tmp
29
+
30
+ def pack_wrapper(module, attn_feats, attn_feat_lens):
31
+ packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
32
+ if isinstance(module, torch.nn.RNNBase):
33
+ return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
34
+ else:
35
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
36
+
37
+ def embedding_pooling(x, lens, pooling="mean"):
38
+ if pooling == "max":
39
+ fc_embs = max_with_lens(x, lens)
40
+ elif pooling == "mean":
41
+ fc_embs = mean_with_lens(x, lens)
42
+ elif pooling == "mean+max":
43
+ x_mean = mean_with_lens(x, lens)
44
+ x_max = max_with_lens(x, lens)
45
+ fc_embs = x_mean + x_max
46
+ elif pooling == "last":
47
+ indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
48
+ # indices: [N, 1, hidden]
49
+ fc_embs = torch.gather(x, 1, indices).squeeze(1)
50
+ else:
51
+ raise Exception(f"pooling method {pooling} not support")
52
+ return fc_embs
53
+
54
+ def interpolate(x, ratio):
55
+ """Interpolate data in time domain. This is used to compensate the
56
+ resolution reduction in downsampling of a CNN.
57
+
58
+ Args:
59
+ x: (batch_size, time_steps, classes_num)
60
+ ratio: int, ratio to interpolate
61
+
62
+ Returns:
63
+ upsampled: (batch_size, time_steps * ratio, classes_num)
64
+ """
65
+ (batch_size, time_steps, classes_num) = x.shape
66
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
67
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
68
+ return upsampled
69
+
70
+ def pad_framewise_output(framewise_output, frames_num):
71
+ """Pad framewise_output to the same length as input frames. The pad value
72
+ is the same as the value of the last frame.
73
+
74
+ Args:
75
+ framewise_output: (batch_size, frames_num, classes_num)
76
+ frames_num: int, number of frames to pad
77
+
78
+ Outputs:
79
+ output: (batch_size, frames_num, classes_num)
80
+ """
81
+ pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
82
+ """tensor for padding"""
83
+
84
+ output = torch.cat((framewise_output, pad), dim=1)
85
+ """(batch_size, frames_num, classes_num)"""
86
+
87
+ return output
88
+
89
+ def find_contiguous_regions(activity_array):
90
+ """Find contiguous regions from bool valued numpy.array.
91
+ Copy of https://dcase-repo.github.io/dcase_util/_modules/dcase_util/data/decisions.html#DecisionEncoder
92
+
93
+ Reason is:
94
+ 1. This does not belong to a class necessarily
95
+ 2. Import DecisionEncoder requires sndfile over some other imports..which causes some problems on clusters
96
+
97
+ """
98
+
99
+ # Find the changes in the activity_array
100
+ change_indices = np.logical_xor(activity_array[1:],
101
+ activity_array[:-1]).nonzero()[0]
102
+
103
+ # Shift change_index with one, focus on frame after the change.
104
+ change_indices += 1
105
+
106
+ if activity_array[0]:
107
+ # If the first element of activity_array is True add 0 at the beginning
108
+ change_indices = np.r_[0, change_indices]
109
+
110
+ if activity_array[-1]:
111
+ # If the last element of activity_array is True, add the length of the array
112
+ change_indices = np.r_[change_indices, activity_array.size]
113
+
114
+ # Reshape the result into two columns
115
+ return change_indices.reshape((-1, 2))
116
+
117
+ def double_threshold(x, high_thres, low_thres, n_connect=1):
118
+ """double_threshold
119
+ Helper function to calculate double threshold for n-dim arrays
120
+
121
+ :param x: input array
122
+ :param high_thres: high threshold value
123
+ :param low_thres: Low threshold value
124
+ :param n_connect: Distance of <= n clusters will be merged
125
+ """
126
+ assert x.ndim <= 3, "Whoops something went wrong with the input ({}), check if its <= 3 dims".format(
127
+ x.shape)
128
+ if x.ndim == 3:
129
+ apply_dim = 1
130
+ elif x.ndim < 3:
131
+ apply_dim = 0
132
+ # x is assumed to be 3d: (batch, time, dim)
133
+ # Assumed to be 2d : (time, dim)
134
+ # Assumed to be 1d : (time)
135
+ # time axis is therefore at 1 for 3d and 0 for 2d (
136
+ return np.apply_along_axis(lambda x: _double_threshold(
137
+ x, high_thres, low_thres, n_connect=n_connect),
138
+ axis=apply_dim,
139
+ arr=x)
140
+
141
+ def _double_threshold(x, high_thres, low_thres, n_connect=1, return_arr=True):
142
+ """_double_threshold
143
+ Computes a double threshold over the input array
144
+
145
+ :param x: input array, needs to be 1d
146
+ :param high_thres: High threshold over the array
147
+ :param low_thres: Low threshold over the array
148
+ :param n_connect: Postprocessing, maximal distance between clusters to connect
149
+ :param return_arr: By default this function returns the filtered indiced, but if return_arr = True it returns an array of tsame size as x filled with ones and zeros.
150
+ """
151
+ assert x.ndim == 1, "Input needs to be 1d"
152
+ high_locations = np.where(x > high_thres)[0]
153
+ locations = x > low_thres
154
+ encoded_pairs = find_contiguous_regions(locations)
155
+
156
+ filtered_list = list(
157
+ filter(
158
+ lambda pair:
159
+ ((pair[0] <= high_locations) & (high_locations <= pair[1])).any(),
160
+ encoded_pairs))
161
+
162
+ filtered_list = connect_(filtered_list, n_connect)
163
+ if return_arr:
164
+ zero_one_arr = np.zeros_like(x, dtype=int)
165
+ for sl in filtered_list:
166
+ zero_one_arr[sl[0]:sl[1]] = 1
167
+ return zero_one_arr
168
+ return filtered_list
169
+
170
+ def connect_(pairs, n=1):
171
+ """connect_
172
+ Connects two adjacent clusters if their distance is <= n
173
+
174
+ :param pairs: Clusters of iterateables e.g., [(1,5),(7,10)]
175
+ :param n: distance between two clusters
176
+ """
177
+ if len(pairs) == 0:
178
+ return []
179
+ start_, end_ = pairs[0]
180
+ new_pairs = []
181
+ for i, (next_item, cur_item) in enumerate(zip(pairs[1:], pairs[0:])):
182
+ end_ = next_item[1]
183
+ if next_item[0] - cur_item[1] <= n:
184
+ pass
185
+ else:
186
+ new_pairs.append((start_, cur_item[1]))
187
+ start_ = next_item[0]
188
+ new_pairs.append((start_, end_))
189
+ return new_pairs
190
+
191
+ def segments_to_temporal_tag(segments, thre=0.5):
192
+ after_flag, while_flag = 0, 0
193
+ for j in range(len(segments)):
194
+ for k in range(len(segments)):
195
+ if segments[j][0] == segments[k][0]:
196
+ continue
197
+ min_duration = min(segments[j][2] - segments[j][1], segments[k][2] - segments[k][1])
198
+ overlap = segments[j][2] - segments[k][1]
199
+ if overlap < thre * min_duration:
200
+ after_flag = 2
201
+ if segments[j][1] < segments[k][1] and overlap > thre * min_duration:
202
+ while_flag = 1
203
+ return after_flag + while_flag
204
+
205
+ def decode_with_timestamps(labels, time_resolution):
206
+ batch_results = []
207
+ for lab in labels:
208
+ segments = []
209
+ for i, label_column in enumerate(lab.T):
210
+ change_indices = find_contiguous_regions(label_column)
211
+ # append [onset, offset] in the result list
212
+ for row in change_indices:
213
+ segments.append((i, row[0] * time_resolution, row[1] * time_resolution))
214
+ temporal_tag = segments_to_temporal_tag(segments)
215
+ batch_results.append(temporal_tag)
216
+ return batch_results
217
+
218
+ class _EffiNet(nn.Module):
219
+ """A proxy for efficient net models"""
220
+ def __init__(self,
221
+ blocks_args=None,
222
+ global_params=None,
223
+ ) -> None:
224
+ super().__init__()
225
+ self.eff_net = EfficientNet(blocks_args=blocks_args,
226
+ global_params=global_params)
227
+
228
+
229
+ def forward(self, x: torch.Tensor):
230
+ x = rearrange(x, 'b f t -> b 1 f t')
231
+ x = self.eff_net.extract_features(x)
232
+ return reduce(x, 'b c f t -> b t c', 'mean')
233
+
234
+
235
+ def get_effb2_model() -> _EffiNet:
236
+ blocks_args, global_params = efficientnet_utils.get_model_params(
237
+ 'efficientnet-b2', {'include_top': False})
238
+ model = _EffiNet(blocks_args=blocks_args,
239
+ global_params=global_params)
240
+ model.eff_net._change_in_channels(1)
241
+ return model
242
+
243
+ def merge_load_state_dict(state_dict,
244
+ model: torch.nn.Module,
245
+ output_fn: Callable = sys.stdout.write):
246
+ model_dict = model.state_dict()
247
+ pretrained_dict = {}
248
+ mismatch_keys = []
249
+ for key, value in state_dict.items():
250
+ if key in model_dict and model_dict[key].shape == value.shape:
251
+ pretrained_dict[key] = value
252
+ else:
253
+ mismatch_keys.append(key)
254
+ output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n")
255
+ model_dict.update(pretrained_dict)
256
+ model.load_state_dict(model_dict, strict=True)
257
+ return pretrained_dict.keys()
258
+
259
+
260
+ class EfficientNetB2(nn.Module):
261
+
262
+ def __init__(self,
263
+ n_mels: int = 64,
264
+ win_length: int = 32,
265
+ hop_length: int = 10,
266
+ f_min: int = 0,
267
+ freeze: bool = False,):
268
+ super().__init__()
269
+ sample_rate = 16000
270
+ self.melspec_extractor = transforms.MelSpectrogram(
271
+ sample_rate=sample_rate,
272
+ n_fft=win_length * sample_rate // 1000,
273
+ win_length=win_length * sample_rate // 1000,
274
+ hop_length=hop_length * sample_rate // 1000,
275
+ f_min=f_min,
276
+ n_mels=n_mels,
277
+ )
278
+ self.hop_length = 10 * sample_rate // 1000
279
+ self.db_transform = transforms.AmplitudeToDB(top_db=120)
280
+ self.backbone = get_effb2_model()
281
+ self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels
282
+ self.downsample_ratio = 32
283
+ if freeze:
284
+ for param in self.parameters():
285
+ param.requires_grad = False
286
+
287
+ def forward(self, input_dict):
288
+
289
+ waveform = input_dict["wav"]
290
+ wave_length = input_dict["wav_len"]
291
+ specaug = input_dict["specaug"]
292
+ x = self.melspec_extractor(waveform)
293
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
294
+
295
+ x = rearrange(x, 'b f t -> b 1 t f')
296
+ if self.training and specaug:
297
+ x = self.spec_augmenter(x)
298
+ x = rearrange(x, 'b 1 t f -> b f t')
299
+
300
+ x = self.backbone(x)
301
+ attn_emb = x
302
+
303
+ wave_length = torch.as_tensor(wave_length)
304
+ feat_length = torch.div(wave_length, self.hop_length,
305
+ rounding_mode="floor") + 1
306
+ feat_length = torch.div(feat_length, self.downsample_ratio,
307
+ rounding_mode="floor")
308
+ fc_emb = mean_with_lens(attn_emb, feat_length)
309
+
310
+ output_dict = {
311
+ 'fc_emb': fc_emb,
312
+ 'attn_emb': attn_emb,
313
+ 'attn_emb_len': feat_length
314
+ }
315
+ return output_dict
316
+
317
+
318
+ def generate_length_mask(lens, max_length=None):
319
+ lens = torch.as_tensor(lens)
320
+ N = lens.size(0)
321
+ if max_length is None:
322
+ max_length = max(lens)
323
+ if isinstance(max_length, torch.Tensor):
324
+ max_length = max_length.item()
325
+ idxs = torch.arange(max_length).repeat(N).view(N, max_length)
326
+ idxs = idxs.to(lens.device)
327
+ mask = (idxs < lens.view(-1, 1))
328
+ return mask
329
+
330
+ def mean_with_lens(features, lens):
331
+ """
332
+ features: [N, T, ...] (assume the second dimension represents length)
333
+ lens: [N,]
334
+ """
335
+ lens = torch.as_tensor(lens)
336
+ if max(lens) != features.size(1):
337
+ max_length = features.size(1)
338
+ mask = generate_length_mask(lens, max_length)
339
+ else:
340
+ mask = generate_length_mask(lens)
341
+ mask = mask.to(features.device) # [N, T]
342
+
343
+ while mask.ndim < features.ndim:
344
+ mask = mask.unsqueeze(-1)
345
+ feature_mean = features * mask
346
+ feature_mean = feature_mean.sum(1)
347
+ while lens.ndim < feature_mean.ndim:
348
+ lens = lens.unsqueeze(1)
349
+ feature_mean = feature_mean / lens.to(features.device)
350
+ # feature_mean = features * mask.unsqueeze(-1)
351
+ # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
352
+ return feature_mean
353
+
354
+ def max_with_lens(features, lens):
355
+ """
356
+ features: [N, T, ...] (assume the second dimension represents length)
357
+ lens: [N,]
358
+ """
359
+ lens = torch.as_tensor(lens)
360
+ if max(lens) != features.size(1):
361
+ max_length = features.size(1)
362
+ mask = generate_length_mask(lens, max_length)
363
+ else:
364
+ mask = generate_length_mask(lens)
365
+ mask = mask.to(features.device) # [N, T]
366
+
367
+ feature_max = features.clone()
368
+ feature_max[~mask] = float("-inf")
369
+ feature_max, _ = feature_max.max(1)
370
+ return feature_max
371
+
372
+ def repeat_tensor(x, n):
373
+ return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
374
+
375
+
376
+ class CaptionMetaMixin:
377
+ pad_idx = 0
378
+ start_idx = 1
379
+ end_idx = 2
380
+ max_length = 20
381
+
382
+ @classmethod
383
+ def set_index(cls, start_idx, end_idx, pad_idx):
384
+ cls.start_idx = start_idx
385
+ cls.end_idx = end_idx
386
+ cls.pad_idx = pad_idx
387
+
388
+
389
+ class CaptionModel(nn.Module, CaptionMetaMixin):
390
+ """
391
+ Encoder-decoder captioning model.
392
+ """
393
+
394
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
395
+ super().__init__()
396
+ self.encoder = encoder
397
+ self.decoder = decoder
398
+ self.vocab_size = decoder.vocab_size
399
+ self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
400
+ self.inference_forward_keys = ["sample_method", "max_length", "temp"]
401
+ freeze_encoder = kwargs.get("freeze_encoder", False)
402
+ if freeze_encoder:
403
+ for param in self.encoder.parameters():
404
+ param.requires_grad = False
405
+ self.check_decoder_compatibility()
406
+
407
+ def check_decoder_compatibility(self):
408
+ compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
409
+ assert isinstance(self.decoder, self.compatible_decoders), \
410
+ f"{self.decoder.__class__.__name__} is incompatible with " \
411
+ f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
412
+
413
+ def forward(self, input_dict: Dict):
414
+ """
415
+ input_dict: {
416
+ (required)
417
+ mode: train/inference,
418
+ [spec, spec_len],
419
+ [fc],
420
+ [attn, attn_len],
421
+ [wav, wav_len],
422
+ [sample_method: greedy],
423
+ [temp: 1.0] (in case of no teacher forcing)
424
+
425
+ (optional, mode=train)
426
+ cap,
427
+ cap_len,
428
+ ss_ratio,
429
+
430
+ (optional, mode=inference)
431
+ sample_method: greedy/beam,
432
+ max_length,
433
+ temp,
434
+ beam_size (optional, sample_method=beam),
435
+ n_best (optional, sample_method=beam),
436
+ }
437
+ """
438
+ encoder_output_dict = self.encoder(input_dict)
439
+ output = self.forward_decoder(input_dict, encoder_output_dict)
440
+ return output
441
+
442
+ def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict):
443
+ if input_dict["mode"] == "train":
444
+ forward_dict = {
445
+ "mode": "train", "sample_method": "greedy", "temp": 1.0
446
+ }
447
+ for key in self.train_forward_keys:
448
+ forward_dict[key] = input_dict[key]
449
+ forward_dict.update(encoder_output_dict)
450
+ output = self.train_forward(forward_dict)
451
+ elif input_dict["mode"] == "inference":
452
+ forward_dict = {"mode": "inference"}
453
+ default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
454
+ for key in self.inference_forward_keys:
455
+ if key in input_dict:
456
+ forward_dict[key] = input_dict[key]
457
+ else:
458
+ forward_dict[key] = default_args[key]
459
+
460
+ if forward_dict["sample_method"] == "beam":
461
+ forward_dict["beam_size"] = input_dict.get("beam_size", 3)
462
+ forward_dict["n_best"] = input_dict.get("n_best", False)
463
+ forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
464
+ elif forward_dict["sample_method"] == "dbs":
465
+ forward_dict["beam_size"] = input_dict.get("beam_size", 6)
466
+ forward_dict["group_size"] = input_dict.get("group_size", 3)
467
+ forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
468
+ forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
469
+
470
+ forward_dict.update(encoder_output_dict)
471
+ output = self.inference_forward(forward_dict)
472
+ else:
473
+ raise Exception("mode should be either 'train' or 'inference'")
474
+ output.update(encoder_output_dict)
475
+ return output
476
+
477
+ def prepare_output(self, input_dict):
478
+ output = {}
479
+ batch_size = input_dict["fc_emb"].size(0)
480
+ if input_dict["mode"] == "train":
481
+ max_length = input_dict["cap"].size(1) - 1
482
+ elif input_dict["mode"] == "inference":
483
+ max_length = input_dict["max_length"]
484
+ else:
485
+ raise Exception("mode should be either 'train' or 'inference'")
486
+ device = input_dict["fc_emb"].device
487
+ output["seq"] = torch.full((batch_size, max_length), self.end_idx,
488
+ dtype=torch.long)
489
+ output["logit"] = torch.empty(batch_size, max_length,
490
+ self.vocab_size).to(device)
491
+ output["sampled_logprob"] = torch.zeros(batch_size, max_length)
492
+ output["embed"] = torch.empty(batch_size, max_length,
493
+ self.decoder.d_model).to(device)
494
+ return output
495
+
496
+ def train_forward(self, input_dict):
497
+ if input_dict["ss_ratio"] != 1: # scheduled sampling training
498
+ input_dict["mode"] = "train"
499
+ return self.stepwise_forward(input_dict)
500
+ output = self.seq_forward(input_dict)
501
+ self.train_process(output, input_dict)
502
+ return output
503
+
504
+ def seq_forward(self, input_dict):
505
+ raise NotImplementedError
506
+
507
+ def train_process(self, output, input_dict):
508
+ pass
509
+
510
+ def inference_forward(self, input_dict):
511
+ if input_dict["sample_method"] == "beam":
512
+ return self.beam_search(input_dict)
513
+ elif input_dict["sample_method"] == "dbs":
514
+ return self.diverse_beam_search(input_dict)
515
+ return self.stepwise_forward(input_dict)
516
+
517
+ def stepwise_forward(self, input_dict):
518
+ """Step-by-step decoding"""
519
+ output = self.prepare_output(input_dict)
520
+ max_length = output["seq"].size(1)
521
+ # start sampling
522
+ for t in range(max_length):
523
+ input_dict["t"] = t
524
+ self.decode_step(input_dict, output)
525
+ if input_dict["mode"] == "inference": # decide whether to stop when sampling
526
+ unfinished_t = output["seq"][:, t] != self.end_idx
527
+ if t == 0:
528
+ unfinished = unfinished_t
529
+ else:
530
+ unfinished *= unfinished_t
531
+ output["seq"][:, t][~unfinished] = self.end_idx
532
+ if unfinished.sum() == 0:
533
+ break
534
+ self.stepwise_process(output)
535
+ return output
536
+
537
+ def decode_step(self, input_dict, output):
538
+ """Decoding operation of timestep t"""
539
+ decoder_input = self.prepare_decoder_input(input_dict, output)
540
+ # feed to the decoder to get logit
541
+ output_t = self.decoder(decoder_input)
542
+ logit_t = output_t["logit"]
543
+ # assert logit_t.ndim == 3
544
+ if logit_t.size(1) == 1:
545
+ logit_t = logit_t.squeeze(1)
546
+ embed_t = output_t["embed"].squeeze(1)
547
+ elif logit_t.size(1) > 1:
548
+ logit_t = logit_t[:, -1, :]
549
+ embed_t = output_t["embed"][:, -1, :]
550
+ else:
551
+ raise Exception("no logit output")
552
+ # sample the next input word and get the corresponding logit
553
+ sampled = self.sample_next_word(logit_t,
554
+ method=input_dict["sample_method"],
555
+ temp=input_dict["temp"])
556
+
557
+ output_t.update(sampled)
558
+ output_t["t"] = input_dict["t"]
559
+ output_t["logit"] = logit_t
560
+ output_t["embed"] = embed_t
561
+ self.stepwise_process_step(output, output_t)
562
+
563
+ def prepare_decoder_input(self, input_dict, output):
564
+ """Prepare the inp ut dict for the decoder"""
565
+ raise NotImplementedError
566
+
567
+ def stepwise_process_step(self, output, output_t):
568
+ """Postprocessing (save output values) after each timestep t"""
569
+ t = output_t["t"]
570
+ output["logit"][:, t, :] = output_t["logit"]
571
+ output["seq"][:, t] = output_t["word"]
572
+ output["sampled_logprob"][:, t] = output_t["probs"]
573
+ output["embed"][:, t, :] = output_t["embed"]
574
+
575
+ def stepwise_process(self, output):
576
+ """Postprocessing after the whole step-by-step autoregressive decoding"""
577
+ pass
578
+
579
+ def sample_next_word(self, logit, method, temp):
580
+ """Sample the next word, given probs output by the decoder"""
581
+ logprob = torch.log_softmax(logit, dim=1)
582
+ if method == "greedy":
583
+ sampled_logprob, word = torch.max(logprob.detach(), 1)
584
+ elif method == "gumbel":
585
+ def sample_gumbel(shape, eps=1e-20):
586
+ U = torch.rand(shape).to(logprob.device)
587
+ return -torch.log(-torch.log(U + eps) + eps)
588
+ def gumbel_softmax_sample(logit, temperature):
589
+ y = logit + sample_gumbel(logit.size())
590
+ return torch.log_softmax(y / temperature, dim=-1)
591
+ _logprob = gumbel_softmax_sample(logprob, temp)
592
+ _, word = torch.max(_logprob.data, 1)
593
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
594
+ else:
595
+ logprob = logprob / temp
596
+ if method.startswith("top"):
597
+ top_num = float(method[3:])
598
+ if 0 < top_num < 1: # top-p sampling
599
+ probs = torch.softmax(logit, dim=1)
600
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
601
+ _cumsum = sorted_probs.cumsum(1)
602
+ mask = _cumsum < top_num
603
+ mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
604
+ sorted_probs = sorted_probs * mask.to(sorted_probs)
605
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
606
+ logprob.scatter_(1, sorted_indices, sorted_probs.log())
607
+ else: # top-k sampling
608
+ k = int(top_num)
609
+ tmp = torch.empty_like(logprob).fill_(float('-inf'))
610
+ topk, indices = torch.topk(logprob, k, dim=1)
611
+ tmp = tmp.scatter(1, indices, topk)
612
+ logprob = tmp
613
+ word = torch.distributions.Categorical(logits=logprob.detach()).sample()
614
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
615
+ word = word.detach().long()
616
+ # sampled_logprob: [N,], word: [N,]
617
+ return {"word": word, "probs": sampled_logprob}
618
+
619
+ def beam_search(self, input_dict):
620
+ output = self.prepare_output(input_dict)
621
+ max_length = input_dict["max_length"]
622
+ beam_size = input_dict["beam_size"]
623
+ if input_dict["n_best"]:
624
+ n_best_size = input_dict["n_best_size"]
625
+ batch_size, max_length = output["seq"].size()
626
+ output["seq"] = torch.full((batch_size, n_best_size, max_length),
627
+ self.end_idx, dtype=torch.long)
628
+
629
+ temp = input_dict["temp"]
630
+ # instance by instance beam seach
631
+ for i in range(output["seq"].size(0)):
632
+ output_i = self.prepare_beamsearch_output(input_dict)
633
+ input_dict["sample_idx"] = i
634
+ for t in range(max_length):
635
+ input_dict["t"] = t
636
+ output_t = self.beamsearch_step(input_dict, output_i)
637
+ #######################################
638
+ # merge with previous beam and select the current max prob beam
639
+ #######################################
640
+ logit_t = output_t["logit"]
641
+ if logit_t.size(1) == 1:
642
+ logit_t = logit_t.squeeze(1)
643
+ elif logit_t.size(1) > 1:
644
+ logit_t = logit_t[:, -1, :]
645
+ else:
646
+ raise Exception("no logit output")
647
+ logprob_t = torch.log_softmax(logit_t, dim=1)
648
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
649
+ logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
650
+ if t == 0: # for the first step, all k seq will have the same probs
651
+ topk_logprob, topk_words = logprob_t[0].topk(
652
+ beam_size, 0, True, True)
653
+ else: # unroll and find top logprob, and their unrolled indices
654
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
655
+ beam_size, 0, True, True)
656
+ topk_words = topk_words.cpu()
657
+ output_i["topk_logprob"] = topk_logprob
658
+ # output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
659
+ output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
660
+ rounding_mode='trunc')
661
+ output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
662
+ if t == 0:
663
+ output_i["seq"] = output_i["next_word"].unsqueeze(1)
664
+ else:
665
+ output_i["seq"] = torch.cat([
666
+ output_i["seq"][output_i["prev_words_beam"]],
667
+ output_i["next_word"].unsqueeze(1)], dim=1)
668
+
669
+ # add finished beams to results
670
+ is_end = output_i["next_word"] == self.end_idx
671
+ if t == max_length - 1:
672
+ is_end.fill_(1)
673
+
674
+ for beam_idx in range(beam_size):
675
+ if is_end[beam_idx]:
676
+ final_beam = {
677
+ "seq": output_i["seq"][beam_idx].clone(),
678
+ "score": output_i["topk_logprob"][beam_idx].item()
679
+ }
680
+ final_beam["score"] = final_beam["score"] / (t + 1)
681
+ output_i["done_beams"].append(final_beam)
682
+ output_i["topk_logprob"][is_end] -= 1000
683
+
684
+ self.beamsearch_process_step(output_i, output_t)
685
+
686
+ if len(output_i["done_beams"]) == beam_size:
687
+ break
688
+
689
+ self.beamsearch_process(output, output_i, input_dict)
690
+ return output
691
+
692
+ def prepare_beamsearch_output(self, input_dict):
693
+ beam_size = input_dict["beam_size"]
694
+ device = input_dict["fc_emb"].device
695
+ output = {
696
+ "topk_logprob": torch.zeros(beam_size).to(device),
697
+ "seq": None,
698
+ "prev_words_beam": None,
699
+ "next_word": None,
700
+ "done_beams": [],
701
+ }
702
+ return output
703
+
704
+ def beamsearch_step(self, input_dict, output_i):
705
+ decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
706
+ output_t = self.decoder(decoder_input)
707
+ output_t["t"] = input_dict["t"]
708
+ return output_t
709
+
710
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
711
+ raise NotImplementedError
712
+
713
+ def beamsearch_process_step(self, output_i, output_t):
714
+ pass
715
+
716
+ def beamsearch_process(self, output, output_i, input_dict):
717
+ i = input_dict["sample_idx"]
718
+ done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
719
+ if input_dict["n_best"]:
720
+ done_beams = done_beams[:input_dict["n_best_size"]]
721
+ for out_idx, done_beam in enumerate(done_beams):
722
+ seq = done_beam["seq"]
723
+ output["seq"][i][out_idx, :len(seq)] = seq
724
+ else:
725
+ seq = done_beams[0]["seq"]
726
+ output["seq"][i][:len(seq)] = seq
727
+
728
+ def diverse_beam_search(self, input_dict):
729
+
730
+ def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
731
+ local_time = t - divm
732
+ unaug_logprob = logprob.clone()
733
+
734
+ if divm > 0:
735
+ change = torch.zeros(logprob.size(-1))
736
+ for prev_choice in range(divm):
737
+ prev_decisions = seq_table[prev_choice][..., local_time]
738
+ for prev_labels in range(bdash):
739
+ change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
740
+
741
+ change = change.to(logprob.device)
742
+ logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
743
+
744
+ return logprob, unaug_logprob
745
+
746
+ output = self.prepare_output(input_dict)
747
+ group_size = input_dict["group_size"]
748
+ batch_size = output["seq"].size(0)
749
+ beam_size = input_dict["beam_size"]
750
+ bdash = beam_size // group_size
751
+ input_dict["bdash"] = bdash
752
+ diversity_lambda = input_dict["diversity_lambda"]
753
+ device = input_dict["fc_emb"].device
754
+ max_length = input_dict["max_length"]
755
+ temp = input_dict["temp"]
756
+ group_nbest = input_dict["group_nbest"]
757
+ batch_size, max_length = output["seq"].size()
758
+ if group_nbest:
759
+ output["seq"] = torch.full((batch_size, beam_size, max_length),
760
+ self.end_idx, dtype=torch.long)
761
+ else:
762
+ output["seq"] = torch.full((batch_size, group_size, max_length),
763
+ self.end_idx, dtype=torch.long)
764
+
765
+
766
+ for i in range(batch_size):
767
+ input_dict["sample_idx"] = i
768
+ seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
769
+ logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
770
+ done_beams_table = [[] for _ in range(group_size)]
771
+
772
+ output_i = {
773
+ "prev_words_beam": [None for _ in range(group_size)],
774
+ "next_word": [None for _ in range(group_size)],
775
+ "state": [None for _ in range(group_size)]
776
+ }
777
+
778
+ for t in range(max_length + group_size - 1):
779
+ input_dict["t"] = t
780
+ for divm in range(group_size):
781
+ input_dict["divm"] = divm
782
+ if t >= divm and t <= max_length + divm - 1:
783
+ local_time = t - divm
784
+ decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
785
+ output_t = self.decoder(decoder_input)
786
+ output_t["divm"] = divm
787
+ logit_t = output_t["logit"]
788
+ if logit_t.size(1) == 1:
789
+ logit_t = logit_t.squeeze(1)
790
+ elif logit_t.size(1) > 1:
791
+ logit_t = logit_t[:, -1, :]
792
+ else:
793
+ raise Exception("no logit output")
794
+ logprob_t = torch.log_softmax(logit_t, dim=1)
795
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
796
+ logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
797
+ logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
798
+ if local_time == 0: # for the first step, all k seq will have the same probs
799
+ topk_logprob, topk_words = logprob_t[0].topk(
800
+ bdash, 0, True, True)
801
+ else: # unroll and find top logprob, and their unrolled indices
802
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
803
+ bdash, 0, True, True)
804
+ topk_words = topk_words.cpu()
805
+ logprob_table[divm] = topk_logprob
806
+ output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
807
+ output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
808
+ if local_time > 0:
809
+ seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
810
+ seq_table[divm] = torch.cat([
811
+ seq_table[divm],
812
+ output_i["next_word"][divm].unsqueeze(-1)], -1)
813
+
814
+ is_end = seq_table[divm][:, t-divm] == self.end_idx
815
+ assert seq_table[divm].shape[-1] == t - divm + 1
816
+ if t == max_length + divm - 1:
817
+ is_end.fill_(1)
818
+ for beam_idx in range(bdash):
819
+ if is_end[beam_idx]:
820
+ final_beam = {
821
+ "seq": seq_table[divm][beam_idx].clone(),
822
+ "score": logprob_table[divm][beam_idx].item()
823
+ }
824
+ final_beam["score"] = final_beam["score"] / (t - divm + 1)
825
+ done_beams_table[divm].append(final_beam)
826
+ logprob_table[divm][is_end] -= 1000
827
+ self.dbs_process_step(output_i, output_t)
828
+ done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
829
+ if group_nbest:
830
+ done_beams = sum(done_beams_table, [])
831
+ else:
832
+ done_beams = [group_beam[0] for group_beam in done_beams_table]
833
+ for _, done_beam in enumerate(done_beams):
834
+ output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
835
+
836
+ return output
837
+
838
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
839
+ raise NotImplementedError
840
+
841
+ def dbs_process_step(self, output_i, output_t):
842
+ pass
843
+
844
+
845
+ class TransformerModel(CaptionModel):
846
+
847
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
848
+ if not hasattr(self, "compatible_decoders"):
849
+ self.compatible_decoders = (
850
+ TransformerDecoder,
851
+ )
852
+ super().__init__(encoder, decoder, **kwargs)
853
+
854
+ def seq_forward(self, input_dict):
855
+ cap = input_dict["cap"]
856
+ cap_padding_mask = (cap == self.pad_idx).to(cap.device)
857
+ cap_padding_mask = cap_padding_mask[:, :-1]
858
+ output = self.decoder(
859
+ {
860
+ "word": cap[:, :-1],
861
+ "attn_emb": input_dict["attn_emb"],
862
+ "attn_emb_len": input_dict["attn_emb_len"],
863
+ "cap_padding_mask": cap_padding_mask
864
+ }
865
+ )
866
+ return output
867
+
868
+ def prepare_decoder_input(self, input_dict, output):
869
+ decoder_input = {
870
+ "attn_emb": input_dict["attn_emb"],
871
+ "attn_emb_len": input_dict["attn_emb_len"]
872
+ }
873
+ t = input_dict["t"]
874
+
875
+ ###############
876
+ # determine input word
877
+ ################
878
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
879
+ word = input_dict["cap"][:, :t+1]
880
+ else:
881
+ start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
882
+ if t == 0:
883
+ word = start_word
884
+ else:
885
+ word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
886
+ # word: [N, T]
887
+ decoder_input["word"] = word
888
+
889
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
890
+ decoder_input["cap_padding_mask"] = cap_padding_mask
891
+ return decoder_input
892
+
893
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
894
+ decoder_input = {}
895
+ t = input_dict["t"]
896
+ i = input_dict["sample_idx"]
897
+ beam_size = input_dict["beam_size"]
898
+ ###############
899
+ # prepare attn embeds
900
+ ################
901
+ if t == 0:
902
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
903
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
904
+ output_i["attn_emb"] = attn_emb
905
+ output_i["attn_emb_len"] = attn_emb_len
906
+ decoder_input["attn_emb"] = output_i["attn_emb"]
907
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
908
+ ###############
909
+ # determine input word
910
+ ################
911
+ start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
912
+ if t == 0:
913
+ word = start_word
914
+ else:
915
+ word = torch.cat((start_word, output_i["seq"]), dim=-1)
916
+ decoder_input["word"] = word
917
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
918
+ decoder_input["cap_padding_mask"] = cap_padding_mask
919
+
920
+ return decoder_input
921
+
922
+
923
+ class BaseDecoder(nn.Module):
924
+ """
925
+ Take word/audio embeddings and output the next word probs
926
+ """
927
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim,
928
+ attn_emb_dim, dropout=0.2, tie_weights=False):
929
+ super().__init__()
930
+ self.emb_dim = emb_dim
931
+ self.vocab_size = vocab_size
932
+ self.fc_emb_dim = fc_emb_dim
933
+ self.attn_emb_dim = attn_emb_dim
934
+ self.tie_weights = tie_weights
935
+ self.word_embedding = nn.Embedding(vocab_size, emb_dim)
936
+ self.in_dropout = nn.Dropout(dropout)
937
+
938
+ def forward(self, x):
939
+ raise NotImplementedError
940
+
941
+ def load_word_embedding(self, weight, freeze=True):
942
+ embedding = np.load(weight)
943
+ assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
944
+ assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
945
+
946
+ # embeddings = torch.as_tensor(embeddings).float()
947
+ # self.word_embeddings.weight = nn.Parameter(embeddings)
948
+ # for para in self.word_embeddings.parameters():
949
+ # para.requires_grad = tune
950
+ self.word_embedding = nn.Embedding.from_pretrained(embedding,
951
+ freeze=freeze)
952
+
953
+
954
+ class PositionalEncoding(nn.Module):
955
+
956
+ def __init__(self, d_model, dropout=0.1, max_len=100):
957
+ super(PositionalEncoding, self).__init__()
958
+ self.dropout = nn.Dropout(p=dropout)
959
+
960
+ pe = torch.zeros(max_len, d_model)
961
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
962
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
963
+ (-math.log(10000.0) / d_model))
964
+ pe[:, 0::2] = torch.sin(position * div_term)
965
+ pe[:, 1::2] = torch.cos(position * div_term)
966
+ pe = pe.unsqueeze(0).transpose(0, 1)
967
+ # self.register_buffer("pe", pe)
968
+ self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
969
+
970
+ def forward(self, x):
971
+ # x: [T, N, E]
972
+ x = x + self.pe[:x.size(0), :]
973
+ return self.dropout(x)
974
+
975
+
976
+ class TransformerDecoder(BaseDecoder):
977
+
978
+ def __init__(self,
979
+ emb_dim,
980
+ vocab_size,
981
+ fc_emb_dim,
982
+ attn_emb_dim,
983
+ dropout,
984
+ freeze=False,
985
+ tie_weights=False,
986
+ **kwargs):
987
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
988
+ dropout=dropout, tie_weights=tie_weights)
989
+ self.d_model = emb_dim
990
+ self.nhead = kwargs.get("nhead", self.d_model // 64)
991
+ self.nlayers = kwargs.get("nlayers", 2)
992
+ self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
993
+
994
+ self.pos_encoder = PositionalEncoding(self.d_model, dropout)
995
+ layer = nn.TransformerDecoderLayer(d_model=self.d_model,
996
+ nhead=self.nhead,
997
+ dim_feedforward=self.dim_feedforward,
998
+ dropout=dropout)
999
+ self.model = nn.TransformerDecoder(layer, self.nlayers)
1000
+ self.classifier = nn.Linear(self.d_model, vocab_size, bias=False)
1001
+ if tie_weights:
1002
+ self.classifier.weight = self.word_embedding.weight
1003
+ self.attn_proj = nn.Sequential(
1004
+ nn.Linear(self.attn_emb_dim, self.d_model),
1005
+ nn.ReLU(),
1006
+ nn.Dropout(dropout),
1007
+ nn.LayerNorm(self.d_model)
1008
+ )
1009
+ self.init_params()
1010
+
1011
+ self.freeze = freeze
1012
+ if freeze:
1013
+ for p in self.parameters():
1014
+ p.requires_grad = False
1015
+
1016
+ def init_params(self):
1017
+ for p in self.parameters():
1018
+ if p.dim() > 1:
1019
+ nn.init.xavier_uniform_(p)
1020
+
1021
+ def load_pretrained(self, pretrained, output_fn):
1022
+ checkpoint = torch.load(pretrained, map_location="cpu")
1023
+
1024
+ if "model" in checkpoint:
1025
+ checkpoint = checkpoint["model"]
1026
+ if next(iter(checkpoint)).startswith("decoder."):
1027
+ state_dict = {}
1028
+ for k, v in checkpoint.items():
1029
+ state_dict[k[8:]] = v
1030
+
1031
+ loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
1032
+ if self.freeze:
1033
+ for name, param in self.named_parameters():
1034
+ if name in loaded_keys:
1035
+ param.requires_grad = False
1036
+ else:
1037
+ param.requires_grad = True
1038
+
1039
+
1040
+ def generate_square_subsequent_mask(self, max_length):
1041
+ mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
1042
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
1043
+ return mask
1044
+
1045
+ def forward(self, input_dict):
1046
+ word = input_dict["word"]
1047
+ attn_emb = input_dict["attn_emb"]
1048
+ attn_emb_len = input_dict["attn_emb_len"]
1049
+ cap_padding_mask = input_dict["cap_padding_mask"]
1050
+
1051
+ p_attn_emb = self.attn_proj(attn_emb)
1052
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
1053
+ word = word.to(attn_emb.device)
1054
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
1055
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
1056
+ embed = self.pos_encoder(embed)
1057
+
1058
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
1059
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
1060
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
1061
+ tgt_key_padding_mask=cap_padding_mask,
1062
+ memory_key_padding_mask=memory_key_padding_mask)
1063
+ output = output.transpose(0, 1)
1064
+ output = {
1065
+ "embed": output,
1066
+ "logit": self.classifier(output),
1067
+ }
1068
+ return output
1069
+
1070
+
1071
+ class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin):
1072
+
1073
+ def __init__(self,
1074
+ model: nn.Module,
1075
+ shared_dim: int,
1076
+ tchr_dim: int,
1077
+ ):
1078
+ super().__init__()
1079
+ self.model = model
1080
+ self.tchr_dim = tchr_dim
1081
+ if hasattr(model, "encoder"):
1082
+ self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
1083
+ shared_dim)
1084
+ else:
1085
+ self.stdnt_proj = nn.Linear(model.fc_emb_size,
1086
+ shared_dim)
1087
+ self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
1088
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
1089
+
1090
+ def forward(self, input_dict: Dict):
1091
+ unsup = input_dict.get("unsup", False)
1092
+ if unsup is False:
1093
+ output_dict = self.model(input_dict)
1094
+ else:
1095
+ output_dict = self.model.encoder(input_dict)
1096
+ if "tchr_output" in input_dict:
1097
+ stdnt_emb = output_dict["fc_emb"]
1098
+ stdnt_emb = self.stdnt_proj(stdnt_emb)
1099
+ tchr_emb = input_dict["tchr_output"]["embedding"]
1100
+ thcr_emb = self.tchr_proj(tchr_emb)
1101
+
1102
+ stdnt_emb = F.normalize(stdnt_emb, dim=-1)
1103
+ thcr_emb = F.normalize(thcr_emb, dim=-1)
1104
+
1105
+ unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
1106
+ logit = self.logit_scale * unscaled_logit
1107
+ label = torch.arange(logit.shape[0]).to(logit.device)
1108
+ loss1 = F.cross_entropy(logit, label)
1109
+ loss2 = F.cross_entropy(logit.transpose(0, 1), label)
1110
+ loss = (loss1 + loss2) / 2
1111
+ output_dict["enc_kd_loss"] = loss
1112
+ return output_dict
1113
+
1114
+
1115
+ class Effb2TrmConfig(PretrainedConfig):
1116
+
1117
+ def __init__(
1118
+ self,
1119
+ sample_rate: int = 16000,
1120
+ tchr_dim: int = 768,
1121
+ shared_dim: int = 1024,
1122
+ fc_emb_dim: int = 1408,
1123
+ attn_emb_dim: int = 1408,
1124
+ decoder_n_layers: int = 2,
1125
+ decoder_we_tie_weights: bool = True,
1126
+ decoder_emb_dim: int = 256,
1127
+ decoder_dropout: float = 0.2,
1128
+ vocab_size: int = 4981,
1129
+ **kwargs
1130
+ ):
1131
+ self.sample_rate = sample_rate
1132
+ self.tchr_dim = tchr_dim
1133
+ self.shared_dim = shared_dim
1134
+ self.fc_emb_dim = fc_emb_dim
1135
+ self.attn_emb_dim = attn_emb_dim
1136
+ self.decoder_n_layers = decoder_n_layers
1137
+ self.decoder_we_tie_weights = decoder_we_tie_weights
1138
+ self.decoder_emb_dim = decoder_emb_dim
1139
+ self.decoder_dropout = decoder_dropout
1140
+ self.vocab_size = vocab_size
1141
+ super().__init__(**kwargs)
1142
+
1143
+
1144
+ class Effb2TrmCaptioningModel(PreTrainedModel):
1145
+ config_class = Effb2TrmConfig
1146
+
1147
+ def __init__(self, config):
1148
+ super().__init__(config)
1149
+ encoder = EfficientNetB2()
1150
+ decoder = TransformerDecoder(
1151
+ emb_dim=config.decoder_emb_dim,
1152
+ vocab_size=config.vocab_size,
1153
+ fc_emb_dim=config.fc_emb_dim,
1154
+ attn_emb_dim=config.attn_emb_dim,
1155
+ dropout=config.decoder_dropout,
1156
+ nlayers=config.decoder_n_layers,
1157
+ tie_weights=config.decoder_we_tie_weights
1158
+ )
1159
+ model = TransformerModel(encoder, decoder)
1160
+ self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim)
1161
+
1162
+ def forward(self,
1163
+ audio: torch.Tensor,
1164
+ audio_length: Union[List, np.ndarray, torch.Tensor],
1165
+ sample_method: str = "beam",
1166
+ beam_size: int = 3,
1167
+ max_length: int = 20,
1168
+ temp: float = 1.0,):
1169
+ device = self.device
1170
+ input_dict = {
1171
+ "wav": audio.to(device),
1172
+ "wav_len": audio_length,
1173
+ "specaug": False,
1174
+ "mode": "inference",
1175
+ "sample_method": sample_method,
1176
+ "max_length": max_length,
1177
+ "temp": temp,
1178
+ }
1179
+ if sample_method == "beam":
1180
+ input_dict["beam_size"] = beam_size
1181
+ return self.model(input_dict)["seq"].cpu()
1182
+
1183
+
1184
+ class ConvBlock(nn.Module):
1185
+
1186
+ def __init__(self, in_channels, out_channels):
1187
+
1188
+ super(ConvBlock, self).__init__()
1189
+
1190
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
1191
+ out_channels=out_channels,
1192
+ kernel_size=(3, 3), stride=(1, 1),
1193
+ padding=(1, 1), bias=False)
1194
+
1195
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
1196
+ out_channels=out_channels,
1197
+ kernel_size=(3, 3), stride=(1, 1),
1198
+ padding=(1, 1), bias=False)
1199
+
1200
+ self.bn1 = nn.BatchNorm2d(out_channels)
1201
+ self.bn2 = nn.BatchNorm2d(out_channels)
1202
+
1203
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
1204
+
1205
+ x = input
1206
+ x = F.relu_(self.bn1(self.conv1(x)))
1207
+ x = F.relu_(self.bn2(self.conv2(x)))
1208
+ if pool_type == 'max':
1209
+ x = F.max_pool2d(x, kernel_size=pool_size)
1210
+ elif pool_type == 'avg':
1211
+ x = F.avg_pool2d(x, kernel_size=pool_size)
1212
+ elif pool_type == 'avg+max':
1213
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
1214
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
1215
+ x = x1 + x2
1216
+ else:
1217
+ raise Exception('Incorrect argument!')
1218
+
1219
+ return x
1220
+
1221
+
1222
+ class Cnn14Encoder(nn.Module):
1223
+
1224
+ def __init__(self, sample_rate=32000):
1225
+ super().__init__()
1226
+ sr_to_fmax = {
1227
+ 32000: 14000,
1228
+ 16000: 8000
1229
+ }
1230
+ # Logmel spectrogram extractor
1231
+ self.melspec_extractor = transforms.MelSpectrogram(
1232
+ sample_rate=sample_rate,
1233
+ n_fft=32 * sample_rate // 1000,
1234
+ win_length=32 * sample_rate // 1000,
1235
+ hop_length=10 * sample_rate // 1000,
1236
+ f_min=50,
1237
+ f_max=sr_to_fmax[sample_rate],
1238
+ n_mels=64,
1239
+ norm="slaney",
1240
+ mel_scale="slaney"
1241
+ )
1242
+ self.hop_length = 10 * sample_rate // 1000
1243
+ self.db_transform = transforms.AmplitudeToDB()
1244
+
1245
+ self.bn0 = nn.BatchNorm2d(64)
1246
+
1247
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
1248
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
1249
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
1250
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
1251
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
1252
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
1253
+
1254
+ self.downsample_ratio = 32
1255
+
1256
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
1257
+ self.fc_emb_size = 2048
1258
+
1259
+ def forward(self, input_dict):
1260
+ lms = input_dict["lms"]
1261
+ wave_length = input_dict["wav_len"]
1262
+
1263
+ x = lms # (batch_size, mel_bins, time_steps)
1264
+ x = x.transpose(1, 2)
1265
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
1266
+
1267
+ x = x.transpose(1, 3)
1268
+ x = self.bn0(x)
1269
+ x = x.transpose(1, 3)
1270
+
1271
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
1272
+ x = F.dropout(x, p=0.2, training=self.training)
1273
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
1274
+ x = F.dropout(x, p=0.2, training=self.training)
1275
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
1276
+ x = F.dropout(x, p=0.2, training=self.training)
1277
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
1278
+ x = F.dropout(x, p=0.2, training=self.training)
1279
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
1280
+ x = F.dropout(x, p=0.2, training=self.training)
1281
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
1282
+ x = F.dropout(x, p=0.2, training=self.training)
1283
+ x = torch.mean(x, dim=3)
1284
+ attn_emb = x.transpose(1, 2)
1285
+
1286
+ wave_length = torch.as_tensor(wave_length)
1287
+ feat_length = torch.div(wave_length, self.hop_length,
1288
+ rounding_mode="floor") + 1
1289
+ feat_length = torch.div(feat_length, self.downsample_ratio,
1290
+ rounding_mode="floor")
1291
+ x_max = max_with_lens(attn_emb, feat_length)
1292
+ x_mean = mean_with_lens(attn_emb, feat_length)
1293
+ x = x_max + x_mean
1294
+ x = F.dropout(x, p=0.5, training=self.training)
1295
+ x = F.relu_(self.fc1(x))
1296
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
1297
+
1298
+ output_dict = {
1299
+ 'fc_emb': fc_emb,
1300
+ 'attn_emb': attn_emb,
1301
+ 'attn_emb_len': feat_length
1302
+ }
1303
+
1304
+ return output_dict
1305
+
1306
+
1307
+ class RnnEncoder(nn.Module):
1308
+
1309
+ def __init__(self,
1310
+ attn_feat_dim,
1311
+ pooling="mean",
1312
+ **kwargs):
1313
+ super().__init__()
1314
+ self.pooling = pooling
1315
+ self.hidden_size = kwargs.get('hidden_size', 512)
1316
+ self.bidirectional = kwargs.get('bidirectional', False)
1317
+ self.num_layers = kwargs.get('num_layers', 1)
1318
+ self.dropout = kwargs.get('dropout', 0.2)
1319
+ self.rnn_type = kwargs.get('rnn_type', "GRU")
1320
+ self.in_bn = kwargs.get('in_bn', False)
1321
+ self.embed_dim = self.hidden_size * (self.bidirectional + 1)
1322
+ self.network = getattr(nn, self.rnn_type)(
1323
+ attn_feat_dim,
1324
+ self.hidden_size,
1325
+ num_layers=self.num_layers,
1326
+ bidirectional=self.bidirectional,
1327
+ dropout=self.dropout,
1328
+ batch_first=True)
1329
+ if self.in_bn:
1330
+ self.bn = nn.BatchNorm1d(self.embed_dim)
1331
+
1332
+ def forward(self, input_dict):
1333
+ x = input_dict["attn"]
1334
+ lens = input_dict["attn_len"]
1335
+ lens = torch.as_tensor(lens)
1336
+ # x: [N, T, E]
1337
+ if self.in_bn:
1338
+ x = pack_wrapper(self.bn, x, lens)
1339
+ out = pack_wrapper(self.network, x, lens)
1340
+ # out: [N, T, hidden]
1341
+ attn_emb = out
1342
+ fc_emb = embedding_pooling(out, lens, self.pooling)
1343
+ return {
1344
+ "attn_emb": attn_emb,
1345
+ "fc_emb": fc_emb,
1346
+ "attn_emb_len": lens
1347
+ }
1348
+
1349
+
1350
+ class Cnn14RnnEncoder(nn.Module):
1351
+
1352
+ def __init__(self,
1353
+ sample_rate,
1354
+ rnn_bidirectional,
1355
+ rnn_hidden_size,
1356
+ rnn_dropout,
1357
+ rnn_num_layers):
1358
+ super().__init__()
1359
+ self.cnn = Cnn14Encoder(sample_rate=sample_rate)
1360
+ self.rnn = RnnEncoder(
1361
+ 2048,
1362
+ bidirectional=rnn_bidirectional,
1363
+ hidden_size=rnn_hidden_size,
1364
+ dropout=rnn_dropout,
1365
+ num_layers=rnn_num_layers,
1366
+ )
1367
+
1368
+ def forward(self, input_dict):
1369
+ output_dict = self.cnn(input_dict)
1370
+ output_dict["attn"] = output_dict["attn_emb"]
1371
+ output_dict["attn_len"] = output_dict["attn_emb_len"]
1372
+ del output_dict["attn_emb"], output_dict["attn_emb_len"]
1373
+ output_dict = self.rnn(output_dict)
1374
+ return output_dict
1375
+
1376
+
1377
+ class Seq2SeqAttention(nn.Module):
1378
+
1379
+ def __init__(self, hs_enc, hs_dec, attn_size):
1380
+ """
1381
+ Args:
1382
+ hs_enc: encoder hidden size
1383
+ hs_dec: decoder hidden size
1384
+ attn_size: attention vector size
1385
+ """
1386
+ super(Seq2SeqAttention, self).__init__()
1387
+ self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size)
1388
+ self.v = nn.Parameter(torch.randn(attn_size))
1389
+
1390
+ def forward(self, h_dec, h_enc, src_lens):
1391
+ """
1392
+ Args:
1393
+ h_dec: decoder hidden (query), [N, hs_dec]
1394
+ h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
1395
+ src_lens: source (encoder memory) lengths, [N, ]
1396
+ """
1397
+ N = h_enc.size(0)
1398
+ src_max_len = h_enc.size(1)
1399
+ h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
1400
+
1401
+ attn_input = torch.cat((h_dec, h_enc), dim=-1)
1402
+ attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
1403
+
1404
+ v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
1405
+ score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
1406
+
1407
+ idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
1408
+ mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
1409
+
1410
+ score = score.masked_fill(mask == 0, -1e10)
1411
+ weights = torch.softmax(score, dim=-1) # [N, src_max_len]
1412
+ ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
1413
+
1414
+ return ctx, weights
1415
+
1416
+
1417
+ class RnnDecoder(BaseDecoder):
1418
+
1419
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1420
+ dropout, d_model, **kwargs):
1421
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1422
+ dropout,)
1423
+ self.d_model = d_model
1424
+ self.num_layers = kwargs.get('num_layers', 1)
1425
+ self.bidirectional = kwargs.get('bidirectional', False)
1426
+ self.rnn_type = kwargs.get('rnn_type', "GRU")
1427
+ self.classifier = nn.Linear(
1428
+ self.d_model * (self.bidirectional + 1), vocab_size)
1429
+
1430
+ def forward(self, x):
1431
+ raise NotImplementedError
1432
+
1433
+ def init_hidden(self, bs, device):
1434
+ num_dire = self.bidirectional + 1
1435
+ n_layer = self.num_layers
1436
+ hid_dim = self.d_model
1437
+ if self.rnn_type == "LSTM":
1438
+ return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device),
1439
+ torch.zeros(num_dire * n_layer, bs, hid_dim).to(device))
1440
+ else:
1441
+ return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)
1442
+
1443
+
1444
+ class BahAttnCatFcDecoder(RnnDecoder):
1445
+
1446
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1447
+ dropout, d_model, **kwargs):
1448
+ """
1449
+ concatenate fc, attn, word to feed to the rnn
1450
+ """
1451
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1452
+ dropout, d_model, **kwargs)
1453
+ attn_size = kwargs.get("attn_size", self.d_model)
1454
+ self.model = getattr(nn, self.rnn_type)(
1455
+ input_size=self.emb_dim * 3,
1456
+ hidden_size=self.d_model,
1457
+ batch_first=True,
1458
+ num_layers=self.num_layers,
1459
+ bidirectional=self.bidirectional)
1460
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
1461
+ self.d_model * (self.bidirectional + 1) * \
1462
+ self.num_layers,
1463
+ attn_size)
1464
+ self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
1465
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
1466
+
1467
+ def forward(self, input_dict):
1468
+ word = input_dict["word"]
1469
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
1470
+ fc_emb = input_dict["fc_emb"]
1471
+ attn_emb = input_dict["attn_emb"]
1472
+ attn_emb_len = input_dict["attn_emb_len"]
1473
+
1474
+ word = word.to(fc_emb.device)
1475
+ embed = self.in_dropout(self.word_embedding(word))
1476
+
1477
+ # embed: [N, 1, embed_size]
1478
+ if state is None:
1479
+ state = self.init_hidden(word.size(0), fc_emb.device)
1480
+ if self.rnn_type == "LSTM":
1481
+ query = state[0].transpose(0, 1).flatten(1)
1482
+ else:
1483
+ query = state.transpose(0, 1).flatten(1)
1484
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
1485
+
1486
+ p_fc_emb = self.fc_proj(fc_emb)
1487
+ p_ctx = self.ctx_proj(c)
1488
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)),
1489
+ dim=-1)
1490
+
1491
+ out, state = self.model(rnn_input, state)
1492
+
1493
+ output = {
1494
+ "state": state,
1495
+ "embed": out,
1496
+ "logit": self.classifier(out),
1497
+ "attn_weight": attn_weight
1498
+ }
1499
+ return output
1500
+
1501
+
1502
+ class TemporalBahAttnDecoder(BahAttnCatFcDecoder):
1503
+
1504
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1505
+ dropout, d_model, **kwargs):
1506
+ """
1507
+ concatenate fc, attn, word to feed to the rnn
1508
+ """
1509
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
1510
+ dropout, d_model, **kwargs)
1511
+ self.temporal_embedding = nn.Embedding(4, emb_dim)
1512
+
1513
+ def forward(self, input_dict):
1514
+ word = input_dict["word"]
1515
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
1516
+ fc_embs = input_dict["fc_emb"]
1517
+ attn_embs = input_dict["attn_emb"]
1518
+ attn_emb_lens = input_dict["attn_emb_len"]
1519
+ temporal_tag = input_dict["temporal_tag"]
1520
+
1521
+ if input_dict["t"] == 0:
1522
+ embed = self.in_dropout(
1523
+ self.temporal_embedding(temporal_tag)).unsqueeze(1)
1524
+ elif word.size(-1) == self.fc_emb_dim: # fc_embs
1525
+ embed = word.unsqueeze(1)
1526
+ elif word.size(-1) == 1: # word
1527
+ word = word.to(fc_embs.device)
1528
+ embed = self.in_dropout(self.word_embedding(word))
1529
+ else:
1530
+ raise Exception(f"problem with word input size {word.size()}")
1531
+
1532
+ # embed: [N, 1, embed_size]
1533
+ if state is None:
1534
+ state = self.init_hidden(word.size(0), fc_embs.device)
1535
+ if self.rnn_type == "LSTM":
1536
+ query = state[0].transpose(0, 1).flatten(1)
1537
+ else:
1538
+ query = state.transpose(0, 1).flatten(1)
1539
+ c, attn_weight = self.attn(query, attn_embs, attn_emb_lens)
1540
+
1541
+ p_ctx = self.ctx_proj(c)
1542
+ p_fc_embs = self.fc_proj(fc_embs)
1543
+ p_ctx = self.ctx_proj(c)
1544
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_embs.unsqueeze(1)), dim=-1)
1545
+
1546
+ out, state = self.model(rnn_input, state)
1547
+
1548
+ output = {
1549
+ "state": state,
1550
+ "embed": out,
1551
+ "logit": self.classifier(out),
1552
+ "attn_weight": attn_weight
1553
+ }
1554
+ return output
1555
+
1556
+
1557
+ class Seq2SeqAttnModel(CaptionModel):
1558
+
1559
+ def __init__(self, encoder, decoder, **kwargs):
1560
+ if not hasattr(self, "compatible_decoders"):
1561
+ self.compatible_decoders = (
1562
+ BahAttnCatFcDecoder,
1563
+ )
1564
+ super().__init__(encoder, decoder, **kwargs)
1565
+
1566
+
1567
+ def seq_forward(self, input_dict):
1568
+ # Bahdanau attention only supports step-by-step implementation, so we implement forward in
1569
+ # step-by-step manner whether in training or evaluation
1570
+ return self.stepwise_forward(input_dict)
1571
+
1572
+ def prepare_output(self, input_dict):
1573
+ output = super().prepare_output(input_dict)
1574
+ attn_weight = torch.empty(output["seq"].size(0),
1575
+ input_dict["attn_emb"].size(1), output["seq"].size(1))
1576
+ output["attn_weight"] = attn_weight
1577
+ return output
1578
+
1579
+ def prepare_decoder_input(self, input_dict, output):
1580
+ decoder_input = {
1581
+ "fc_emb": input_dict["fc_emb"],
1582
+ "attn_emb": input_dict["attn_emb"],
1583
+ "attn_emb_len": input_dict["attn_emb_len"]
1584
+ }
1585
+ t = input_dict["t"]
1586
+ ###############
1587
+ # determine input word
1588
+ ################
1589
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
1590
+ word = input_dict["cap"][:, t]
1591
+ else:
1592
+ if t == 0:
1593
+ word = torch.tensor([self.start_idx,] * input_dict["fc_emb"].size(0)).long()
1594
+ else:
1595
+ word = output["seq"][:, t-1]
1596
+ # word: [N,]
1597
+ decoder_input["word"] = word.unsqueeze(1)
1598
+
1599
+ ################
1600
+ # prepare rnn state
1601
+ ################
1602
+ if t > 0:
1603
+ decoder_input["state"] = output["state"]
1604
+ return decoder_input
1605
+
1606
+ def stepwise_process_step(self, output, output_t):
1607
+ super().stepwise_process_step(output, output_t)
1608
+ output["state"] = output_t["state"]
1609
+ t = output_t["t"]
1610
+ output["attn_weight"][:, :, t] = output_t["attn_weight"]
1611
+
1612
+ def prepare_beamsearch_output(self, input_dict):
1613
+ output = super().prepare_beamsearch_output(input_dict)
1614
+ beam_size = input_dict["beam_size"]
1615
+ max_length = input_dict["max_length"]
1616
+ output["attn_weight"] = torch.empty(beam_size,
1617
+ max(input_dict["attn_emb_len"]), max_length)
1618
+ return output
1619
+
1620
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
1621
+ decoder_input = {}
1622
+ t = input_dict["t"]
1623
+ i = input_dict["sample_idx"]
1624
+ beam_size = input_dict["beam_size"]
1625
+ ###############
1626
+ # prepare fc embeds
1627
+ ################
1628
+ if t == 0:
1629
+ fc_emb = repeat_tensor(input_dict["fc_emb"][i], beam_size)
1630
+ output_i["fc_emb"] = fc_emb
1631
+ decoder_input["fc_emb"] = output_i["fc_emb"]
1632
+
1633
+ ###############
1634
+ # prepare attn embeds
1635
+ ################
1636
+ if t == 0:
1637
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
1638
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
1639
+ output_i["attn_emb"] = attn_emb
1640
+ output_i["attn_emb_len"] = attn_emb_len
1641
+ decoder_input["attn_emb"] = output_i["attn_emb"]
1642
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
1643
+
1644
+ ###############
1645
+ # determine input word
1646
+ ################
1647
+ if t == 0:
1648
+ word = torch.tensor([self.start_idx,] * beam_size).long()
1649
+ else:
1650
+ word = output_i["next_word"]
1651
+ decoder_input["word"] = word.unsqueeze(1)
1652
+
1653
+ ################
1654
+ # prepare rnn state
1655
+ ################
1656
+ if t > 0:
1657
+ if self.decoder.rnn_type == "LSTM":
1658
+ decoder_input["state"] = (output_i["state"][0][:, output_i["prev_words_beam"], :].contiguous(),
1659
+ output_i["state"][1][:, output_i["prev_words_beam"], :].contiguous())
1660
+ else:
1661
+ decoder_input["state"] = output_i["state"][:, output_i["prev_words_beam"], :].contiguous()
1662
+
1663
+ return decoder_input
1664
+
1665
+ def beamsearch_process_step(self, output_i, output_t):
1666
+ t = output_t["t"]
1667
+ output_i["state"] = output_t["state"]
1668
+ output_i["attn_weight"][..., t] = output_t["attn_weight"]
1669
+ output_i["attn_weight"] = output_i["attn_weight"][output_i["prev_words_beam"], ...]
1670
+
1671
+ def beamsearch_process(self, output, output_i, input_dict):
1672
+ super().beamsearch_process(output, output_i, input_dict)
1673
+ i = input_dict["sample_idx"]
1674
+ output["attn_weight"][i] = output_i["attn_weight"][0]
1675
+
1676
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
1677
+ decoder_input = {}
1678
+ t = input_dict["t"]
1679
+ i = input_dict["sample_idx"]
1680
+ bdash = input_dict["bdash"]
1681
+ divm = input_dict["divm"]
1682
+
1683
+ local_time = t - divm
1684
+ ###############
1685
+ # prepare fc embeds
1686
+ ################
1687
+ # repeat only at the first timestep to save consumption
1688
+ if t == 0:
1689
+ fc_emb = repeat_tensor(input_dict["fc_emb"][i], bdash).unsqueeze(1)
1690
+ output_i["fc_emb"] = fc_emb
1691
+ decoder_input["fc_emb"] = output_i["fc_emb"]
1692
+
1693
+ ###############
1694
+ # prepare attn embeds
1695
+ ################
1696
+ if t == 0:
1697
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], bdash)
1698
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], bdash)
1699
+ output_i["attn_emb"] = attn_emb
1700
+ output_i["attn_emb_len"] = attn_emb_len
1701
+ decoder_input["attn_emb"] = output_i["attn_emb"]
1702
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
1703
+
1704
+ ###############
1705
+ # determine input word
1706
+ ################
1707
+ if local_time == 0:
1708
+ word = torch.tensor([self.start_idx,] * bdash).long()
1709
+ else:
1710
+ word = output_i["next_word"][divm]
1711
+ decoder_input["word"] = word.unsqueeze(1)
1712
+
1713
+ ################
1714
+ # prepare rnn state
1715
+ ################
1716
+ if local_time > 0:
1717
+ if self.decoder.rnn_type == "LSTM":
1718
+ decoder_input["state"] = (
1719
+ output_i["state"][0][divm][
1720
+ :, output_i["prev_words_beam"][divm], :].contiguous(),
1721
+ output_i["state"][1][divm][
1722
+ :, output_i["prev_words_beam"][divm], :].contiguous()
1723
+ )
1724
+ else:
1725
+ decoder_input["state"] = output_i["state"][divm][
1726
+ :, output_i["prev_words_beam"][divm], :].contiguous()
1727
+
1728
+ return decoder_input
1729
+
1730
+ def dbs_process_step(self, output_i, output_t):
1731
+ divm = output_t["divm"]
1732
+ output_i["state"][divm] = output_t["state"]
1733
+ # TODO attention weight
1734
+
1735
+
1736
+ class TemporalSeq2SeqAttnModel(Seq2SeqAttnModel):
1737
+
1738
+ def __init__(self, encoder, decoder, **kwargs):
1739
+ if not hasattr(self, "compatible_decoders"):
1740
+ self.compatible_decoders = (
1741
+ TemporalBahAttnDecoder,
1742
+ )
1743
+ super().__init__(encoder, decoder, **kwargs)
1744
+ self.train_forward_keys = ["cap", "cap_len", "ss_ratio", "temporal_tag"]
1745
+ self.inference_forward_keys = ["sample_method", "max_length", "temp", "temporal_tag"]
1746
+
1747
+
1748
+ def prepare_decoder_input(self, input_dict, output):
1749
+ decoder_input = super().prepare_decoder_input(input_dict, output)
1750
+ decoder_input["temporal_tag"] = input_dict["temporal_tag"]
1751
+ decoder_input["t"] = input_dict["t"]
1752
+
1753
+ return decoder_input
1754
+
1755
+
1756
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
1757
+ decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
1758
+ t = input_dict["t"]
1759
+ i = input_dict["sample_idx"]
1760
+ beam_size = input_dict["beam_size"]
1761
+ ###############
1762
+ # prepare temporal_tag
1763
+ ################
1764
+ if t == 0:
1765
+ temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], beam_size)
1766
+ output_i["temporal_tag"] = temporal_tag
1767
+ decoder_input["temporal_tag"] = output_i["temporal_tag"]
1768
+ decoder_input["t"] = input_dict["t"]
1769
+
1770
+ return decoder_input
1771
+
1772
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
1773
+ decoder_input = super.prepare_dbs_decoder_input(input_dict, output_i)
1774
+ t = input_dict["t"]
1775
+ i = input_dict["sample_idx"]
1776
+ bdash = input_dict["bdash"]
1777
+
1778
+ ###############
1779
+ # prepare temporal tag
1780
+ ################
1781
+ # repeat only at the first timestep to save consumption
1782
+ if t == 0:
1783
+ temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], bdash)
1784
+ output_i["temporal_tag"] = temporal_tag
1785
+ decoder_input["temporal_tag"] = output_i["temporal_tag"]
1786
+ decoder_input["t"] = input_dict["t"]
1787
+
1788
+ return decoder_input
1789
+
1790
+
1791
+ class Cnn8rnnSedModel(nn.Module):
1792
+ def __init__(self, classes_num):
1793
+
1794
+ super().__init__()
1795
+
1796
+ self.time_resolution = 0.01
1797
+ self.interpolate_ratio = 4 # Downsampled ratio
1798
+
1799
+ self.bn0 = nn.BatchNorm2d(64)
1800
+
1801
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
1802
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
1803
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
1804
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
1805
+
1806
+ self.fc1 = nn.Linear(512, 512, bias=True)
1807
+ self.rnn = nn.GRU(512, 256, bidirectional=True, batch_first=True)
1808
+ self.fc_audioset = nn.Linear(512, classes_num, bias=True)
1809
+
1810
+ def forward(self, lms):
1811
+ output = self.forward_prob(lms)
1812
+ framewise_output = output["framewise_output"].cpu().numpy()
1813
+ thresholded_predictions = double_threshold(
1814
+ framewise_output, 0.75, 0.25)
1815
+ decoded_tags = decode_with_timestamps(
1816
+ thresholded_predictions, self.time_resolution
1817
+ )
1818
+ return decoded_tags
1819
+
1820
+ def forward_prob(self, lms):
1821
+ """
1822
+ lms: (batch_size, mel_bins, time_steps)"""
1823
+
1824
+ x = lms
1825
+ x = x.transpose(1, 2)
1826
+ x = x.unsqueeze(1)
1827
+
1828
+ frames_num = x.shape[2]
1829
+
1830
+ x = x.transpose(1, 3)
1831
+ x = self.bn0(x)
1832
+ x = x.transpose(1, 3)
1833
+
1834
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg+max')
1835
+ x = F.dropout(x, p=0.2, training=self.training)
1836
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg+max')
1837
+ x = F.dropout(x, p=0.2, training=self.training)
1838
+ x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg+max')
1839
+ x = F.dropout(x, p=0.2, training=self.training)
1840
+ x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg+max')
1841
+ x = F.dropout(x, p=0.2, training=self.training) # (batch_size, 256, time_steps / 4, mel_bins / 16)
1842
+ x = torch.mean(x, dim=3)
1843
+
1844
+ x = x.transpose(1, 2)
1845
+ x = F.dropout(x, p=0.5, training=self.training)
1846
+ x = F.relu_(self.fc1(x))
1847
+ x, _ = self.rnn(x)
1848
+ segmentwise_output = torch.sigmoid(self.fc_audioset(x)).clamp(1e-7, 1.)
1849
+
1850
+ framewise_output = interpolate(segmentwise_output,
1851
+ self.interpolate_ratio)
1852
+ framewise_output = pad_framewise_output(framewise_output, frames_num)
1853
+
1854
+ output_dict = {
1855
+ "segmentwise_output": segmentwise_output,
1856
+ 'framewise_output': framewise_output,
1857
+ }
1858
+
1859
+ return output_dict
1860
+
1861
+
1862
+ class Cnn14RnnTempAttnGruConfig(PretrainedConfig):
1863
+
1864
+ def __init__(
1865
+ self,
1866
+ sample_rate: int = 32000,
1867
+ encoder_rnn_bidirectional: bool = True,
1868
+ encoder_rnn_hidden_size: int = 256,
1869
+ encoder_rnn_dropout: float = 0.5,
1870
+ encoder_rnn_num_layers: int = 3,
1871
+ decoder_emb_dim: int = 512,
1872
+ vocab_size: int = 4981,
1873
+ fc_emb_dim: int = 512,
1874
+ attn_emb_dim: int = 512,
1875
+ decoder_rnn_type: str = "GRU",
1876
+ decoder_num_layers: int = 1,
1877
+ decoder_d_model: int = 512,
1878
+ decoder_dropout: float = 0.5,
1879
+ **kwargs
1880
+ ):
1881
+ self.sample_rate = sample_rate
1882
+ self.encoder_rnn_bidirectional = encoder_rnn_bidirectional
1883
+ self.encoder_rnn_hidden_size = encoder_rnn_hidden_size
1884
+ self.encoder_rnn_dropout = encoder_rnn_dropout
1885
+ self.encoder_rnn_num_layers = encoder_rnn_num_layers
1886
+ self.decoder_emb_dim = decoder_emb_dim
1887
+ self.vocab_size = vocab_size
1888
+ self.fc_emb_dim = fc_emb_dim
1889
+ self.attn_emb_dim = attn_emb_dim
1890
+ self.decoder_rnn_type = decoder_rnn_type
1891
+ self.decoder_num_layers = decoder_num_layers
1892
+ self.decoder_d_model = decoder_d_model
1893
+ self.decoder_dropout = decoder_dropout
1894
+ super().__init__(**kwargs)
1895
+
1896
+
1897
+ class Cnn14RnnTempAttnGruModel(PreTrainedModel):
1898
+ config_class = Cnn14RnnTempAttnGruConfig
1899
+
1900
+ def __init__(self, config):
1901
+ super().__init__(config)
1902
+ sample_rate = config.sample_rate
1903
+ sr_to_fmax = {
1904
+ 32000: 14000,
1905
+ 16000: 8000
1906
+ }
1907
+ self.melspec_extractor = transforms.MelSpectrogram(
1908
+ sample_rate=sample_rate,
1909
+ n_fft=32 * sample_rate // 1000,
1910
+ win_length=32 * sample_rate // 1000,
1911
+ hop_length=10 * sample_rate // 1000,
1912
+ f_min=50,
1913
+ f_max=sr_to_fmax[sample_rate],
1914
+ n_mels=64,
1915
+ norm="slaney",
1916
+ mel_scale="slaney"
1917
+ )
1918
+ self.db_transform = transforms.AmplitudeToDB()
1919
+
1920
+ encoder = Cnn14RnnEncoder(
1921
+ sample_rate=config.sample_rate,
1922
+ rnn_bidirectional=config.encoder_rnn_bidirectional,
1923
+ rnn_hidden_size=config.encoder_rnn_hidden_size,
1924
+ rnn_dropout=config.encoder_rnn_dropout,
1925
+ rnn_num_layers=config.encoder_rnn_num_layers
1926
+ )
1927
+ decoder = TemporalBahAttnDecoder(
1928
+ emb_dim=config.decoder_emb_dim,
1929
+ vocab_size=config.vocab_size,
1930
+ fc_emb_dim=config.fc_emb_dim,
1931
+ attn_emb_dim=config.attn_emb_dim,
1932
+ rnn_type=config.decoder_rnn_type,
1933
+ num_layers=config.decoder_num_layers,
1934
+ d_model=config.decoder_d_model,
1935
+ dropout=config.decoder_dropout,
1936
+ )
1937
+ cap_model = TemporalSeq2SeqAttnModel(encoder, decoder)
1938
+ sed_model = Cnn8rnnSedModel(classes_num=447)
1939
+ self.cap_model = cap_model
1940
+ self.sed_model = sed_model
1941
+
1942
+ def forward(self,
1943
+ audio: torch.Tensor,
1944
+ audio_length: Union[List, np.ndarray, torch.Tensor],
1945
+ temporal_tag: Union[List, np.ndarray, torch.Tensor] = None,
1946
+ sample_method: str = "beam",
1947
+ beam_size: int = 3,
1948
+ max_length: int = 20,
1949
+ temp: float = 1.0,):
1950
+ device = self.device
1951
+ mel_spec = self.melspec_extractor(audio.to(device))
1952
+ log_mel_spec = self.db_transform(mel_spec)
1953
+
1954
+ sed_tag = self.sed_model(log_mel_spec)
1955
+ sed_tag = torch.as_tensor(sed_tag).to(device)
1956
+ if temporal_tag is not None:
1957
+ temporal_tag = torch.as_tensor(temporal_tag).to(device)
1958
+ temporal_tag = torch.stack([temporal_tag, sed_tag], dim=0)
1959
+ temporal_tag = torch.min(temporal_tag, dim=0).values
1960
+ else:
1961
+ temporal_tag = sed_tag
1962
+
1963
+ input_dict = {
1964
+ "lms": log_mel_spec,
1965
+ "wav_len": audio_length,
1966
+ "temporal_tag": temporal_tag,
1967
+ "mode": "inference",
1968
+ "sample_method": sample_method,
1969
+ "max_length": max_length,
1970
+ "temp": temp,
1971
+ }
1972
+ if sample_method == "beam":
1973
+ input_dict["beam_size"] = beam_size
1974
+ return self.cap_model(input_dict)["seq"].cpu()
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b24b19128b094a9fc5589df15486289a9d7812bd3e103b02ec48b7cce0f57a65
3
- size 54766265
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49356c485fbafb022c03b351009a4634d91bd89eee01baf3c31e82e836b0fa3a
3
+ size 54696313