|
from typing import Dict, Callable, Union, List |
|
import random |
|
import math |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence |
|
from torchaudio import transforms |
|
from efficientnet_pytorch import EfficientNet |
|
from efficientnet_pytorch import utils as efficientnet_utils |
|
from einops import rearrange, reduce |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
def sort_pack_padded_sequence(input, lengths): |
|
sorted_lengths, indices = torch.sort(lengths, descending=True) |
|
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) |
|
inv_ix = indices.clone() |
|
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) |
|
return tmp, inv_ix |
|
|
|
def pad_unsort_packed_sequence(input, inv_ix): |
|
tmp, _ = pad_packed_sequence(input, batch_first=True) |
|
tmp = tmp[inv_ix] |
|
return tmp |
|
|
|
def pack_wrapper(module, attn_feats, attn_feat_lens): |
|
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens) |
|
if isinstance(module, torch.nn.RNNBase): |
|
return pad_unsort_packed_sequence(module(packed)[0], inv_ix) |
|
else: |
|
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) |
|
|
|
def embedding_pooling(x, lens, pooling="mean"): |
|
if pooling == "max": |
|
fc_embs = max_with_lens(x, lens) |
|
elif pooling == "mean": |
|
fc_embs = mean_with_lens(x, lens) |
|
elif pooling == "mean+max": |
|
x_mean = mean_with_lens(x, lens) |
|
x_max = max_with_lens(x, lens) |
|
fc_embs = x_mean + x_max |
|
elif pooling == "last": |
|
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1)) |
|
|
|
fc_embs = torch.gather(x, 1, indices).squeeze(1) |
|
else: |
|
raise Exception(f"pooling method {pooling} not support") |
|
return fc_embs |
|
|
|
def interpolate(x, ratio): |
|
"""Interpolate data in time domain. This is used to compensate the |
|
resolution reduction in downsampling of a CNN. |
|
|
|
Args: |
|
x: (batch_size, time_steps, classes_num) |
|
ratio: int, ratio to interpolate |
|
|
|
Returns: |
|
upsampled: (batch_size, time_steps * ratio, classes_num) |
|
""" |
|
(batch_size, time_steps, classes_num) = x.shape |
|
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) |
|
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) |
|
return upsampled |
|
|
|
def pad_framewise_output(framewise_output, frames_num): |
|
"""Pad framewise_output to the same length as input frames. The pad value |
|
is the same as the value of the last frame. |
|
|
|
Args: |
|
framewise_output: (batch_size, frames_num, classes_num) |
|
frames_num: int, number of frames to pad |
|
|
|
Outputs: |
|
output: (batch_size, frames_num, classes_num) |
|
""" |
|
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) |
|
"""tensor for padding""" |
|
|
|
output = torch.cat((framewise_output, pad), dim=1) |
|
"""(batch_size, frames_num, classes_num)""" |
|
|
|
return output |
|
|
|
def find_contiguous_regions(activity_array): |
|
"""Find contiguous regions from bool valued numpy.array. |
|
Copy of https://dcase-repo.github.io/dcase_util/_modules/dcase_util/data/decisions.html#DecisionEncoder |
|
|
|
Reason is: |
|
1. This does not belong to a class necessarily |
|
2. Import DecisionEncoder requires sndfile over some other imports..which causes some problems on clusters |
|
|
|
""" |
|
|
|
|
|
change_indices = np.logical_xor(activity_array[1:], |
|
activity_array[:-1]).nonzero()[0] |
|
|
|
|
|
change_indices += 1 |
|
|
|
if activity_array[0]: |
|
|
|
change_indices = np.r_[0, change_indices] |
|
|
|
if activity_array[-1]: |
|
|
|
change_indices = np.r_[change_indices, activity_array.size] |
|
|
|
|
|
return change_indices.reshape((-1, 2)) |
|
|
|
def double_threshold(x, high_thres, low_thres, n_connect=1): |
|
"""double_threshold |
|
Helper function to calculate double threshold for n-dim arrays |
|
|
|
:param x: input array |
|
:param high_thres: high threshold value |
|
:param low_thres: Low threshold value |
|
:param n_connect: Distance of <= n clusters will be merged |
|
""" |
|
assert x.ndim <= 3, "Whoops something went wrong with the input ({}), check if its <= 3 dims".format( |
|
x.shape) |
|
if x.ndim == 3: |
|
apply_dim = 1 |
|
elif x.ndim < 3: |
|
apply_dim = 0 |
|
|
|
|
|
|
|
|
|
return np.apply_along_axis(lambda x: _double_threshold( |
|
x, high_thres, low_thres, n_connect=n_connect), |
|
axis=apply_dim, |
|
arr=x) |
|
|
|
def _double_threshold(x, high_thres, low_thres, n_connect=1, return_arr=True): |
|
"""_double_threshold |
|
Computes a double threshold over the input array |
|
|
|
:param x: input array, needs to be 1d |
|
:param high_thres: High threshold over the array |
|
:param low_thres: Low threshold over the array |
|
:param n_connect: Postprocessing, maximal distance between clusters to connect |
|
:param return_arr: By default this function returns the filtered indiced, but if return_arr = True it returns an array of tsame size as x filled with ones and zeros. |
|
""" |
|
assert x.ndim == 1, "Input needs to be 1d" |
|
high_locations = np.where(x > high_thres)[0] |
|
locations = x > low_thres |
|
encoded_pairs = find_contiguous_regions(locations) |
|
|
|
filtered_list = list( |
|
filter( |
|
lambda pair: |
|
((pair[0] <= high_locations) & (high_locations <= pair[1])).any(), |
|
encoded_pairs)) |
|
|
|
filtered_list = connect_(filtered_list, n_connect) |
|
if return_arr: |
|
zero_one_arr = np.zeros_like(x, dtype=int) |
|
for sl in filtered_list: |
|
zero_one_arr[sl[0]:sl[1]] = 1 |
|
return zero_one_arr |
|
return filtered_list |
|
|
|
def connect_(pairs, n=1): |
|
"""connect_ |
|
Connects two adjacent clusters if their distance is <= n |
|
|
|
:param pairs: Clusters of iterateables e.g., [(1,5),(7,10)] |
|
:param n: distance between two clusters |
|
""" |
|
if len(pairs) == 0: |
|
return [] |
|
start_, end_ = pairs[0] |
|
new_pairs = [] |
|
for i, (next_item, cur_item) in enumerate(zip(pairs[1:], pairs[0:])): |
|
end_ = next_item[1] |
|
if next_item[0] - cur_item[1] <= n: |
|
pass |
|
else: |
|
new_pairs.append((start_, cur_item[1])) |
|
start_ = next_item[0] |
|
new_pairs.append((start_, end_)) |
|
return new_pairs |
|
|
|
def segments_to_temporal_tag(segments, thre=0.5): |
|
after_flag, while_flag = 0, 0 |
|
for j in range(len(segments)): |
|
for k in range(len(segments)): |
|
if segments[j][0] == segments[k][0]: |
|
continue |
|
min_duration = min(segments[j][2] - segments[j][1], segments[k][2] - segments[k][1]) |
|
overlap = segments[j][2] - segments[k][1] |
|
if overlap < thre * min_duration: |
|
after_flag = 2 |
|
if segments[j][1] < segments[k][1] and overlap > thre * min_duration: |
|
while_flag = 1 |
|
return after_flag + while_flag |
|
|
|
def decode_with_timestamps(labels, time_resolution): |
|
batch_results = [] |
|
for lab in labels: |
|
segments = [] |
|
for i, label_column in enumerate(lab.T): |
|
change_indices = find_contiguous_regions(label_column) |
|
|
|
for row in change_indices: |
|
segments.append((i, row[0] * time_resolution, row[1] * time_resolution)) |
|
temporal_tag = segments_to_temporal_tag(segments) |
|
batch_results.append(temporal_tag) |
|
return batch_results |
|
|
|
class _EffiNet(nn.Module): |
|
"""A proxy for efficient net models""" |
|
def __init__(self, |
|
blocks_args=None, |
|
global_params=None, |
|
) -> None: |
|
super().__init__() |
|
self.eff_net = EfficientNet(blocks_args=blocks_args, |
|
global_params=global_params) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
x = rearrange(x, 'b f t -> b 1 f t') |
|
x = self.eff_net.extract_features(x) |
|
return reduce(x, 'b c f t -> b t c', 'mean') |
|
|
|
|
|
def get_effb2_model() -> _EffiNet: |
|
blocks_args, global_params = efficientnet_utils.get_model_params( |
|
'efficientnet-b2', {'include_top': False}) |
|
model = _EffiNet(blocks_args=blocks_args, |
|
global_params=global_params) |
|
model.eff_net._change_in_channels(1) |
|
return model |
|
|
|
def merge_load_state_dict(state_dict, |
|
model: torch.nn.Module, |
|
output_fn: Callable = sys.stdout.write): |
|
model_dict = model.state_dict() |
|
pretrained_dict = {} |
|
mismatch_keys = [] |
|
for key, value in state_dict.items(): |
|
if key in model_dict and model_dict[key].shape == value.shape: |
|
pretrained_dict[key] = value |
|
else: |
|
mismatch_keys.append(key) |
|
output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n") |
|
model_dict.update(pretrained_dict) |
|
model.load_state_dict(model_dict, strict=True) |
|
return pretrained_dict.keys() |
|
|
|
|
|
class EfficientNetB2(nn.Module): |
|
|
|
def __init__(self, |
|
n_mels: int = 64, |
|
win_length: int = 32, |
|
hop_length: int = 10, |
|
f_min: int = 0, |
|
freeze: bool = False,): |
|
super().__init__() |
|
sample_rate = 16000 |
|
self.melspec_extractor = transforms.MelSpectrogram( |
|
sample_rate=sample_rate, |
|
n_fft=win_length * sample_rate // 1000, |
|
win_length=win_length * sample_rate // 1000, |
|
hop_length=hop_length * sample_rate // 1000, |
|
f_min=f_min, |
|
n_mels=n_mels, |
|
) |
|
self.hop_length = 10 * sample_rate // 1000 |
|
self.db_transform = transforms.AmplitudeToDB(top_db=120) |
|
self.backbone = get_effb2_model() |
|
self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels |
|
self.downsample_ratio = 32 |
|
if freeze: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, input_dict): |
|
|
|
waveform = input_dict["wav"] |
|
wave_length = input_dict["wav_len"] |
|
specaug = input_dict["specaug"] |
|
x = self.melspec_extractor(waveform) |
|
x = self.db_transform(x) |
|
|
|
x = rearrange(x, 'b f t -> b 1 t f') |
|
if self.training and specaug: |
|
x = self.spec_augmenter(x) |
|
x = rearrange(x, 'b 1 t f -> b f t') |
|
|
|
x = self.backbone(x) |
|
attn_emb = x |
|
|
|
wave_length = torch.as_tensor(wave_length) |
|
feat_length = torch.div(wave_length, self.hop_length, |
|
rounding_mode="floor") + 1 |
|
feat_length = torch.div(feat_length, self.downsample_ratio, |
|
rounding_mode="floor") |
|
fc_emb = mean_with_lens(attn_emb, feat_length) |
|
|
|
output_dict = { |
|
'fc_emb': fc_emb, |
|
'attn_emb': attn_emb, |
|
'attn_emb_len': feat_length |
|
} |
|
return output_dict |
|
|
|
|
|
def generate_length_mask(lens, max_length=None): |
|
lens = torch.as_tensor(lens) |
|
N = lens.size(0) |
|
if max_length is None: |
|
max_length = max(lens) |
|
if isinstance(max_length, torch.Tensor): |
|
max_length = max_length.item() |
|
idxs = torch.arange(max_length).repeat(N).view(N, max_length) |
|
idxs = idxs.to(lens.device) |
|
mask = (idxs < lens.view(-1, 1)) |
|
return mask |
|
|
|
def mean_with_lens(features, lens): |
|
""" |
|
features: [N, T, ...] (assume the second dimension represents length) |
|
lens: [N,] |
|
""" |
|
lens = torch.as_tensor(lens) |
|
if max(lens) != features.size(1): |
|
max_length = features.size(1) |
|
mask = generate_length_mask(lens, max_length) |
|
else: |
|
mask = generate_length_mask(lens) |
|
mask = mask.to(features.device) |
|
|
|
while mask.ndim < features.ndim: |
|
mask = mask.unsqueeze(-1) |
|
feature_mean = features * mask |
|
feature_mean = feature_mean.sum(1) |
|
while lens.ndim < feature_mean.ndim: |
|
lens = lens.unsqueeze(1) |
|
feature_mean = feature_mean / lens.to(features.device) |
|
|
|
|
|
return feature_mean |
|
|
|
def max_with_lens(features, lens): |
|
""" |
|
features: [N, T, ...] (assume the second dimension represents length) |
|
lens: [N,] |
|
""" |
|
lens = torch.as_tensor(lens) |
|
if max(lens) != features.size(1): |
|
max_length = features.size(1) |
|
mask = generate_length_mask(lens, max_length) |
|
else: |
|
mask = generate_length_mask(lens) |
|
mask = mask.to(features.device) |
|
|
|
feature_max = features.clone() |
|
feature_max[~mask] = float("-inf") |
|
feature_max, _ = feature_max.max(1) |
|
return feature_max |
|
|
|
def repeat_tensor(x, n): |
|
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape))) |
|
|
|
|
|
class CaptionMetaMixin: |
|
pad_idx = 0 |
|
start_idx = 1 |
|
end_idx = 2 |
|
max_length = 20 |
|
|
|
@classmethod |
|
def set_index(cls, start_idx, end_idx, pad_idx): |
|
cls.start_idx = start_idx |
|
cls.end_idx = end_idx |
|
cls.pad_idx = pad_idx |
|
|
|
|
|
class CaptionModel(nn.Module, CaptionMetaMixin): |
|
""" |
|
Encoder-decoder captioning model. |
|
""" |
|
|
|
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.vocab_size = decoder.vocab_size |
|
self.train_forward_keys = ["cap", "cap_len", "ss_ratio"] |
|
self.inference_forward_keys = ["sample_method", "max_length", "temp"] |
|
freeze_encoder = kwargs.get("freeze_encoder", False) |
|
if freeze_encoder: |
|
for param in self.encoder.parameters(): |
|
param.requires_grad = False |
|
self.check_decoder_compatibility() |
|
|
|
def check_decoder_compatibility(self): |
|
compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders] |
|
assert isinstance(self.decoder, self.compatible_decoders), \ |
|
f"{self.decoder.__class__.__name__} is incompatible with " \ |
|
f"{self.__class__.__name__}, please use decoder in {compatible_decoders} " |
|
|
|
def forward(self, input_dict: Dict): |
|
""" |
|
input_dict: { |
|
(required) |
|
mode: train/inference, |
|
[spec, spec_len], |
|
[fc], |
|
[attn, attn_len], |
|
[wav, wav_len], |
|
[sample_method: greedy], |
|
[temp: 1.0] (in case of no teacher forcing) |
|
|
|
(optional, mode=train) |
|
cap, |
|
cap_len, |
|
ss_ratio, |
|
|
|
(optional, mode=inference) |
|
sample_method: greedy/beam, |
|
max_length, |
|
temp, |
|
beam_size (optional, sample_method=beam), |
|
n_best (optional, sample_method=beam), |
|
} |
|
""" |
|
encoder_output_dict = self.encoder(input_dict) |
|
output = self.forward_decoder(input_dict, encoder_output_dict) |
|
return output |
|
|
|
def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict): |
|
if input_dict["mode"] == "train": |
|
forward_dict = { |
|
"mode": "train", "sample_method": "greedy", "temp": 1.0 |
|
} |
|
for key in self.train_forward_keys: |
|
forward_dict[key] = input_dict[key] |
|
forward_dict.update(encoder_output_dict) |
|
output = self.train_forward(forward_dict) |
|
elif input_dict["mode"] == "inference": |
|
forward_dict = {"mode": "inference"} |
|
default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 } |
|
for key in self.inference_forward_keys: |
|
if key in input_dict: |
|
forward_dict[key] = input_dict[key] |
|
else: |
|
forward_dict[key] = default_args[key] |
|
|
|
if forward_dict["sample_method"] == "beam": |
|
forward_dict["beam_size"] = input_dict.get("beam_size", 3) |
|
forward_dict["n_best"] = input_dict.get("n_best", False) |
|
forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"]) |
|
elif forward_dict["sample_method"] == "dbs": |
|
forward_dict["beam_size"] = input_dict.get("beam_size", 6) |
|
forward_dict["group_size"] = input_dict.get("group_size", 3) |
|
forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5) |
|
forward_dict["group_nbest"] = input_dict.get("group_nbest", True) |
|
|
|
forward_dict.update(encoder_output_dict) |
|
output = self.inference_forward(forward_dict) |
|
else: |
|
raise Exception("mode should be either 'train' or 'inference'") |
|
output.update(encoder_output_dict) |
|
return output |
|
|
|
def prepare_output(self, input_dict): |
|
output = {} |
|
batch_size = input_dict["fc_emb"].size(0) |
|
if input_dict["mode"] == "train": |
|
max_length = input_dict["cap"].size(1) - 1 |
|
elif input_dict["mode"] == "inference": |
|
max_length = input_dict["max_length"] |
|
else: |
|
raise Exception("mode should be either 'train' or 'inference'") |
|
device = input_dict["fc_emb"].device |
|
output["seq"] = torch.full((batch_size, max_length), self.end_idx, |
|
dtype=torch.long) |
|
output["logit"] = torch.empty(batch_size, max_length, |
|
self.vocab_size).to(device) |
|
output["sampled_logprob"] = torch.zeros(batch_size, max_length) |
|
output["embed"] = torch.empty(batch_size, max_length, |
|
self.decoder.d_model).to(device) |
|
return output |
|
|
|
def train_forward(self, input_dict): |
|
if input_dict["ss_ratio"] != 1: |
|
input_dict["mode"] = "train" |
|
return self.stepwise_forward(input_dict) |
|
output = self.seq_forward(input_dict) |
|
self.train_process(output, input_dict) |
|
return output |
|
|
|
def seq_forward(self, input_dict): |
|
raise NotImplementedError |
|
|
|
def train_process(self, output, input_dict): |
|
pass |
|
|
|
def inference_forward(self, input_dict): |
|
if input_dict["sample_method"] == "beam": |
|
return self.beam_search(input_dict) |
|
elif input_dict["sample_method"] == "dbs": |
|
return self.diverse_beam_search(input_dict) |
|
return self.stepwise_forward(input_dict) |
|
|
|
def stepwise_forward(self, input_dict): |
|
"""Step-by-step decoding""" |
|
output = self.prepare_output(input_dict) |
|
max_length = output["seq"].size(1) |
|
|
|
for t in range(max_length): |
|
input_dict["t"] = t |
|
self.decode_step(input_dict, output) |
|
if input_dict["mode"] == "inference": |
|
unfinished_t = output["seq"][:, t] != self.end_idx |
|
if t == 0: |
|
unfinished = unfinished_t |
|
else: |
|
unfinished *= unfinished_t |
|
output["seq"][:, t][~unfinished] = self.end_idx |
|
if unfinished.sum() == 0: |
|
break |
|
self.stepwise_process(output) |
|
return output |
|
|
|
def decode_step(self, input_dict, output): |
|
"""Decoding operation of timestep t""" |
|
decoder_input = self.prepare_decoder_input(input_dict, output) |
|
|
|
output_t = self.decoder(decoder_input) |
|
logit_t = output_t["logit"] |
|
|
|
if logit_t.size(1) == 1: |
|
logit_t = logit_t.squeeze(1) |
|
embed_t = output_t["embed"].squeeze(1) |
|
elif logit_t.size(1) > 1: |
|
logit_t = logit_t[:, -1, :] |
|
embed_t = output_t["embed"][:, -1, :] |
|
else: |
|
raise Exception("no logit output") |
|
|
|
sampled = self.sample_next_word(logit_t, |
|
method=input_dict["sample_method"], |
|
temp=input_dict["temp"]) |
|
|
|
output_t.update(sampled) |
|
output_t["t"] = input_dict["t"] |
|
output_t["logit"] = logit_t |
|
output_t["embed"] = embed_t |
|
self.stepwise_process_step(output, output_t) |
|
|
|
def prepare_decoder_input(self, input_dict, output): |
|
"""Prepare the inp ut dict for the decoder""" |
|
raise NotImplementedError |
|
|
|
def stepwise_process_step(self, output, output_t): |
|
"""Postprocessing (save output values) after each timestep t""" |
|
t = output_t["t"] |
|
output["logit"][:, t, :] = output_t["logit"] |
|
output["seq"][:, t] = output_t["word"] |
|
output["sampled_logprob"][:, t] = output_t["probs"] |
|
output["embed"][:, t, :] = output_t["embed"] |
|
|
|
def stepwise_process(self, output): |
|
"""Postprocessing after the whole step-by-step autoregressive decoding""" |
|
pass |
|
|
|
def sample_next_word(self, logit, method, temp): |
|
"""Sample the next word, given probs output by the decoder""" |
|
logprob = torch.log_softmax(logit, dim=1) |
|
if method == "greedy": |
|
sampled_logprob, word = torch.max(logprob.detach(), 1) |
|
elif method == "gumbel": |
|
def sample_gumbel(shape, eps=1e-20): |
|
U = torch.rand(shape).to(logprob.device) |
|
return -torch.log(-torch.log(U + eps) + eps) |
|
def gumbel_softmax_sample(logit, temperature): |
|
y = logit + sample_gumbel(logit.size()) |
|
return torch.log_softmax(y / temperature, dim=-1) |
|
_logprob = gumbel_softmax_sample(logprob, temp) |
|
_, word = torch.max(_logprob.data, 1) |
|
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)) |
|
else: |
|
logprob = logprob / temp |
|
if method.startswith("top"): |
|
top_num = float(method[3:]) |
|
if 0 < top_num < 1: |
|
probs = torch.softmax(logit, dim=1) |
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) |
|
_cumsum = sorted_probs.cumsum(1) |
|
mask = _cumsum < top_num |
|
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) |
|
sorted_probs = sorted_probs * mask.to(sorted_probs) |
|
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) |
|
logprob.scatter_(1, sorted_indices, sorted_probs.log()) |
|
else: |
|
k = int(top_num) |
|
tmp = torch.empty_like(logprob).fill_(float('-inf')) |
|
topk, indices = torch.topk(logprob, k, dim=1) |
|
tmp = tmp.scatter(1, indices, topk) |
|
logprob = tmp |
|
word = torch.distributions.Categorical(logits=logprob.detach()).sample() |
|
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1) |
|
word = word.detach().long() |
|
|
|
return {"word": word, "probs": sampled_logprob} |
|
|
|
def beam_search(self, input_dict): |
|
output = self.prepare_output(input_dict) |
|
max_length = input_dict["max_length"] |
|
beam_size = input_dict["beam_size"] |
|
if input_dict["n_best"]: |
|
n_best_size = input_dict["n_best_size"] |
|
batch_size, max_length = output["seq"].size() |
|
output["seq"] = torch.full((batch_size, n_best_size, max_length), |
|
self.end_idx, dtype=torch.long) |
|
|
|
temp = input_dict["temp"] |
|
|
|
for i in range(output["seq"].size(0)): |
|
output_i = self.prepare_beamsearch_output(input_dict) |
|
input_dict["sample_idx"] = i |
|
for t in range(max_length): |
|
input_dict["t"] = t |
|
output_t = self.beamsearch_step(input_dict, output_i) |
|
|
|
|
|
|
|
logit_t = output_t["logit"] |
|
if logit_t.size(1) == 1: |
|
logit_t = logit_t.squeeze(1) |
|
elif logit_t.size(1) > 1: |
|
logit_t = logit_t[:, -1, :] |
|
else: |
|
raise Exception("no logit output") |
|
logprob_t = torch.log_softmax(logit_t, dim=1) |
|
logprob_t = torch.log_softmax(logprob_t / temp, dim=1) |
|
logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t |
|
if t == 0: |
|
topk_logprob, topk_words = logprob_t[0].topk( |
|
beam_size, 0, True, True) |
|
else: |
|
topk_logprob, topk_words = logprob_t.view(-1).topk( |
|
beam_size, 0, True, True) |
|
topk_words = topk_words.cpu() |
|
output_i["topk_logprob"] = topk_logprob |
|
|
|
output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size, |
|
rounding_mode='trunc') |
|
output_i["next_word"] = topk_words % self.vocab_size |
|
if t == 0: |
|
output_i["seq"] = output_i["next_word"].unsqueeze(1) |
|
else: |
|
output_i["seq"] = torch.cat([ |
|
output_i["seq"][output_i["prev_words_beam"]], |
|
output_i["next_word"].unsqueeze(1)], dim=1) |
|
|
|
|
|
is_end = output_i["next_word"] == self.end_idx |
|
if t == max_length - 1: |
|
is_end.fill_(1) |
|
|
|
for beam_idx in range(beam_size): |
|
if is_end[beam_idx]: |
|
final_beam = { |
|
"seq": output_i["seq"][beam_idx].clone(), |
|
"score": output_i["topk_logprob"][beam_idx].item() |
|
} |
|
final_beam["score"] = final_beam["score"] / (t + 1) |
|
output_i["done_beams"].append(final_beam) |
|
output_i["topk_logprob"][is_end] -= 1000 |
|
|
|
self.beamsearch_process_step(output_i, output_t) |
|
|
|
if len(output_i["done_beams"]) == beam_size: |
|
break |
|
|
|
self.beamsearch_process(output, output_i, input_dict) |
|
return output |
|
|
|
def prepare_beamsearch_output(self, input_dict): |
|
beam_size = input_dict["beam_size"] |
|
device = input_dict["fc_emb"].device |
|
output = { |
|
"topk_logprob": torch.zeros(beam_size).to(device), |
|
"seq": None, |
|
"prev_words_beam": None, |
|
"next_word": None, |
|
"done_beams": [], |
|
} |
|
return output |
|
|
|
def beamsearch_step(self, input_dict, output_i): |
|
decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i) |
|
output_t = self.decoder(decoder_input) |
|
output_t["t"] = input_dict["t"] |
|
return output_t |
|
|
|
def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
|
raise NotImplementedError |
|
|
|
def beamsearch_process_step(self, output_i, output_t): |
|
pass |
|
|
|
def beamsearch_process(self, output, output_i, input_dict): |
|
i = input_dict["sample_idx"] |
|
done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"]) |
|
if input_dict["n_best"]: |
|
done_beams = done_beams[:input_dict["n_best_size"]] |
|
for out_idx, done_beam in enumerate(done_beams): |
|
seq = done_beam["seq"] |
|
output["seq"][i][out_idx, :len(seq)] = seq |
|
else: |
|
seq = done_beams[0]["seq"] |
|
output["seq"][i][:len(seq)] = seq |
|
|
|
def diverse_beam_search(self, input_dict): |
|
|
|
def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash): |
|
local_time = t - divm |
|
unaug_logprob = logprob.clone() |
|
|
|
if divm > 0: |
|
change = torch.zeros(logprob.size(-1)) |
|
for prev_choice in range(divm): |
|
prev_decisions = seq_table[prev_choice][..., local_time] |
|
for prev_labels in range(bdash): |
|
change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1)) |
|
|
|
change = change.to(logprob.device) |
|
logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda |
|
|
|
return logprob, unaug_logprob |
|
|
|
output = self.prepare_output(input_dict) |
|
group_size = input_dict["group_size"] |
|
batch_size = output["seq"].size(0) |
|
beam_size = input_dict["beam_size"] |
|
bdash = beam_size // group_size |
|
input_dict["bdash"] = bdash |
|
diversity_lambda = input_dict["diversity_lambda"] |
|
device = input_dict["fc_emb"].device |
|
max_length = input_dict["max_length"] |
|
temp = input_dict["temp"] |
|
group_nbest = input_dict["group_nbest"] |
|
batch_size, max_length = output["seq"].size() |
|
if group_nbest: |
|
output["seq"] = torch.full((batch_size, beam_size, max_length), |
|
self.end_idx, dtype=torch.long) |
|
else: |
|
output["seq"] = torch.full((batch_size, group_size, max_length), |
|
self.end_idx, dtype=torch.long) |
|
|
|
|
|
for i in range(batch_size): |
|
input_dict["sample_idx"] = i |
|
seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] |
|
logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)] |
|
done_beams_table = [[] for _ in range(group_size)] |
|
|
|
output_i = { |
|
"prev_words_beam": [None for _ in range(group_size)], |
|
"next_word": [None for _ in range(group_size)], |
|
"state": [None for _ in range(group_size)] |
|
} |
|
|
|
for t in range(max_length + group_size - 1): |
|
input_dict["t"] = t |
|
for divm in range(group_size): |
|
input_dict["divm"] = divm |
|
if t >= divm and t <= max_length + divm - 1: |
|
local_time = t - divm |
|
decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i) |
|
output_t = self.decoder(decoder_input) |
|
output_t["divm"] = divm |
|
logit_t = output_t["logit"] |
|
if logit_t.size(1) == 1: |
|
logit_t = logit_t.squeeze(1) |
|
elif logit_t.size(1) > 1: |
|
logit_t = logit_t[:, -1, :] |
|
else: |
|
raise Exception("no logit output") |
|
logprob_t = torch.log_softmax(logit_t, dim=1) |
|
logprob_t = torch.log_softmax(logprob_t / temp, dim=1) |
|
logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash) |
|
logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t |
|
if local_time == 0: |
|
topk_logprob, topk_words = logprob_t[0].topk( |
|
bdash, 0, True, True) |
|
else: |
|
topk_logprob, topk_words = logprob_t.view(-1).topk( |
|
bdash, 0, True, True) |
|
topk_words = topk_words.cpu() |
|
logprob_table[divm] = topk_logprob |
|
output_i["prev_words_beam"][divm] = topk_words // self.vocab_size |
|
output_i["next_word"][divm] = topk_words % self.vocab_size |
|
if local_time > 0: |
|
seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]] |
|
seq_table[divm] = torch.cat([ |
|
seq_table[divm], |
|
output_i["next_word"][divm].unsqueeze(-1)], -1) |
|
|
|
is_end = seq_table[divm][:, t-divm] == self.end_idx |
|
assert seq_table[divm].shape[-1] == t - divm + 1 |
|
if t == max_length + divm - 1: |
|
is_end.fill_(1) |
|
for beam_idx in range(bdash): |
|
if is_end[beam_idx]: |
|
final_beam = { |
|
"seq": seq_table[divm][beam_idx].clone(), |
|
"score": logprob_table[divm][beam_idx].item() |
|
} |
|
final_beam["score"] = final_beam["score"] / (t - divm + 1) |
|
done_beams_table[divm].append(final_beam) |
|
logprob_table[divm][is_end] -= 1000 |
|
self.dbs_process_step(output_i, output_t) |
|
done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)] |
|
if group_nbest: |
|
done_beams = sum(done_beams_table, []) |
|
else: |
|
done_beams = [group_beam[0] for group_beam in done_beams_table] |
|
for _, done_beam in enumerate(done_beams): |
|
output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"] |
|
|
|
return output |
|
|
|
def prepare_dbs_decoder_input(self, input_dict, output_i): |
|
raise NotImplementedError |
|
|
|
def dbs_process_step(self, output_i, output_t): |
|
pass |
|
|
|
|
|
class TransformerModel(CaptionModel): |
|
|
|
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): |
|
if not hasattr(self, "compatible_decoders"): |
|
self.compatible_decoders = ( |
|
TransformerDecoder, |
|
) |
|
super().__init__(encoder, decoder, **kwargs) |
|
|
|
def seq_forward(self, input_dict): |
|
cap = input_dict["cap"] |
|
cap_padding_mask = (cap == self.pad_idx).to(cap.device) |
|
cap_padding_mask = cap_padding_mask[:, :-1] |
|
output = self.decoder( |
|
{ |
|
"word": cap[:, :-1], |
|
"attn_emb": input_dict["attn_emb"], |
|
"attn_emb_len": input_dict["attn_emb_len"], |
|
"cap_padding_mask": cap_padding_mask |
|
} |
|
) |
|
return output |
|
|
|
def prepare_decoder_input(self, input_dict, output): |
|
decoder_input = { |
|
"attn_emb": input_dict["attn_emb"], |
|
"attn_emb_len": input_dict["attn_emb_len"] |
|
} |
|
t = input_dict["t"] |
|
|
|
|
|
|
|
|
|
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: |
|
word = input_dict["cap"][:, :t+1] |
|
else: |
|
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() |
|
if t == 0: |
|
word = start_word |
|
else: |
|
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) |
|
|
|
decoder_input["word"] = word |
|
|
|
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) |
|
decoder_input["cap_padding_mask"] = cap_padding_mask |
|
return decoder_input |
|
|
|
def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
|
decoder_input = {} |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
beam_size = input_dict["beam_size"] |
|
|
|
|
|
|
|
if t == 0: |
|
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) |
|
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) |
|
output_i["attn_emb"] = attn_emb |
|
output_i["attn_emb_len"] = attn_emb_len |
|
decoder_input["attn_emb"] = output_i["attn_emb"] |
|
decoder_input["attn_emb_len"] = output_i["attn_emb_len"] |
|
|
|
|
|
|
|
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() |
|
if t == 0: |
|
word = start_word |
|
else: |
|
word = torch.cat((start_word, output_i["seq"]), dim=-1) |
|
decoder_input["word"] = word |
|
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) |
|
decoder_input["cap_padding_mask"] = cap_padding_mask |
|
|
|
return decoder_input |
|
|
|
|
|
class BaseDecoder(nn.Module): |
|
""" |
|
Take word/audio embeddings and output the next word probs |
|
""" |
|
def __init__(self, emb_dim, vocab_size, fc_emb_dim, |
|
attn_emb_dim, dropout=0.2, tie_weights=False): |
|
super().__init__() |
|
self.emb_dim = emb_dim |
|
self.vocab_size = vocab_size |
|
self.fc_emb_dim = fc_emb_dim |
|
self.attn_emb_dim = attn_emb_dim |
|
self.tie_weights = tie_weights |
|
self.word_embedding = nn.Embedding(vocab_size, emb_dim) |
|
self.in_dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
raise NotImplementedError |
|
|
|
def load_word_embedding(self, weight, freeze=True): |
|
embedding = np.load(weight) |
|
assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch" |
|
assert embedding.shape[1] == self.emb_dim, "embed size mismatch" |
|
|
|
|
|
|
|
|
|
|
|
self.word_embedding = nn.Embedding.from_pretrained(embedding, |
|
freeze=freeze) |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
|
def __init__(self, d_model, dropout=0.1, max_len=100): |
|
super(PositionalEncoding, self).__init__() |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
pe = torch.zeros(max_len, d_model) |
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \ |
|
(-math.log(10000.0) / d_model)) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
pe = pe.unsqueeze(0).transpose(0, 1) |
|
|
|
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False)) |
|
|
|
def forward(self, x): |
|
|
|
x = x + self.pe[:x.size(0), :] |
|
return self.dropout(x) |
|
|
|
|
|
class TransformerDecoder(BaseDecoder): |
|
|
|
def __init__(self, |
|
emb_dim, |
|
vocab_size, |
|
fc_emb_dim, |
|
attn_emb_dim, |
|
dropout, |
|
freeze=False, |
|
tie_weights=False, |
|
**kwargs): |
|
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout=dropout, tie_weights=tie_weights) |
|
self.d_model = emb_dim |
|
self.nhead = kwargs.get("nhead", self.d_model // 64) |
|
self.nlayers = kwargs.get("nlayers", 2) |
|
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) |
|
|
|
self.pos_encoder = PositionalEncoding(self.d_model, dropout) |
|
layer = nn.TransformerDecoderLayer(d_model=self.d_model, |
|
nhead=self.nhead, |
|
dim_feedforward=self.dim_feedforward, |
|
dropout=dropout) |
|
self.model = nn.TransformerDecoder(layer, self.nlayers) |
|
self.classifier = nn.Linear(self.d_model, vocab_size, bias=False) |
|
if tie_weights: |
|
self.classifier.weight = self.word_embedding.weight |
|
self.attn_proj = nn.Sequential( |
|
nn.Linear(self.attn_emb_dim, self.d_model), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.LayerNorm(self.d_model) |
|
) |
|
self.init_params() |
|
|
|
self.freeze = freeze |
|
if freeze: |
|
for p in self.parameters(): |
|
p.requires_grad = False |
|
|
|
def init_params(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def load_pretrained(self, pretrained, output_fn): |
|
checkpoint = torch.load(pretrained, map_location="cpu") |
|
|
|
if "model" in checkpoint: |
|
checkpoint = checkpoint["model"] |
|
if next(iter(checkpoint)).startswith("decoder."): |
|
state_dict = {} |
|
for k, v in checkpoint.items(): |
|
state_dict[k[8:]] = v |
|
|
|
loaded_keys = merge_load_state_dict(state_dict, self, output_fn) |
|
if self.freeze: |
|
for name, param in self.named_parameters(): |
|
if name in loaded_keys: |
|
param.requires_grad = False |
|
else: |
|
param.requires_grad = True |
|
|
|
|
|
def generate_square_subsequent_mask(self, max_length): |
|
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1) |
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) |
|
return mask |
|
|
|
def forward(self, input_dict): |
|
word = input_dict["word"] |
|
attn_emb = input_dict["attn_emb"] |
|
attn_emb_len = input_dict["attn_emb_len"] |
|
cap_padding_mask = input_dict["cap_padding_mask"] |
|
|
|
p_attn_emb = self.attn_proj(attn_emb) |
|
p_attn_emb = p_attn_emb.transpose(0, 1) |
|
word = word.to(attn_emb.device) |
|
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) |
|
embed = embed.transpose(0, 1) |
|
embed = self.pos_encoder(embed) |
|
|
|
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) |
|
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) |
|
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, |
|
tgt_key_padding_mask=cap_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask) |
|
output = output.transpose(0, 1) |
|
output = { |
|
"embed": output, |
|
"logit": self.classifier(output), |
|
} |
|
return output |
|
|
|
|
|
class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin): |
|
|
|
def __init__(self, |
|
model: nn.Module, |
|
shared_dim: int, |
|
tchr_dim: int, |
|
): |
|
super().__init__() |
|
self.model = model |
|
self.tchr_dim = tchr_dim |
|
if hasattr(model, "encoder"): |
|
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, |
|
shared_dim) |
|
else: |
|
self.stdnt_proj = nn.Linear(model.fc_emb_size, |
|
shared_dim) |
|
self.tchr_proj = nn.Linear(tchr_dim, shared_dim) |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
def forward(self, input_dict: Dict): |
|
unsup = input_dict.get("unsup", False) |
|
if unsup is False: |
|
output_dict = self.model(input_dict) |
|
else: |
|
output_dict = self.model.encoder(input_dict) |
|
if "tchr_output" in input_dict: |
|
stdnt_emb = output_dict["fc_emb"] |
|
stdnt_emb = self.stdnt_proj(stdnt_emb) |
|
tchr_emb = input_dict["tchr_output"]["embedding"] |
|
thcr_emb = self.tchr_proj(tchr_emb) |
|
|
|
stdnt_emb = F.normalize(stdnt_emb, dim=-1) |
|
thcr_emb = F.normalize(thcr_emb, dim=-1) |
|
|
|
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) |
|
logit = self.logit_scale * unscaled_logit |
|
label = torch.arange(logit.shape[0]).to(logit.device) |
|
loss1 = F.cross_entropy(logit, label) |
|
loss2 = F.cross_entropy(logit.transpose(0, 1), label) |
|
loss = (loss1 + loss2) / 2 |
|
output_dict["enc_kd_loss"] = loss |
|
return output_dict |
|
|
|
|
|
class Effb2TrmConfig(PretrainedConfig): |
|
|
|
def __init__( |
|
self, |
|
sample_rate: int = 16000, |
|
tchr_dim: int = 768, |
|
shared_dim: int = 1024, |
|
fc_emb_dim: int = 1408, |
|
attn_emb_dim: int = 1408, |
|
decoder_n_layers: int = 2, |
|
decoder_we_tie_weights: bool = True, |
|
decoder_emb_dim: int = 256, |
|
decoder_dropout: float = 0.2, |
|
vocab_size: int = 4981, |
|
**kwargs |
|
): |
|
self.sample_rate = sample_rate |
|
self.tchr_dim = tchr_dim |
|
self.shared_dim = shared_dim |
|
self.fc_emb_dim = fc_emb_dim |
|
self.attn_emb_dim = attn_emb_dim |
|
self.decoder_n_layers = decoder_n_layers |
|
self.decoder_we_tie_weights = decoder_we_tie_weights |
|
self.decoder_emb_dim = decoder_emb_dim |
|
self.decoder_dropout = decoder_dropout |
|
self.vocab_size = vocab_size |
|
super().__init__(**kwargs) |
|
|
|
|
|
class Effb2TrmCaptioningModel(PreTrainedModel): |
|
config_class = Effb2TrmConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
encoder = EfficientNetB2() |
|
decoder = TransformerDecoder( |
|
emb_dim=config.decoder_emb_dim, |
|
vocab_size=config.vocab_size, |
|
fc_emb_dim=config.fc_emb_dim, |
|
attn_emb_dim=config.attn_emb_dim, |
|
dropout=config.decoder_dropout, |
|
nlayers=config.decoder_n_layers, |
|
tie_weights=config.decoder_we_tie_weights |
|
) |
|
model = TransformerModel(encoder, decoder) |
|
self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim) |
|
|
|
def forward(self, |
|
audio: torch.Tensor, |
|
audio_length: Union[List, np.ndarray, torch.Tensor], |
|
sample_method: str = "beam", |
|
beam_size: int = 3, |
|
max_length: int = 20, |
|
temp: float = 1.0,): |
|
device = self.device |
|
input_dict = { |
|
"wav": audio.to(device), |
|
"wav_len": audio_length, |
|
"specaug": False, |
|
"mode": "inference", |
|
"sample_method": sample_method, |
|
"max_length": max_length, |
|
"temp": temp, |
|
} |
|
if sample_method == "beam": |
|
input_dict["beam_size"] = beam_size |
|
return self.model(input_dict)["seq"].cpu() |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels): |
|
|
|
super(ConvBlock, self).__init__() |
|
|
|
self.conv1 = nn.Conv2d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=(3, 3), stride=(1, 1), |
|
padding=(1, 1), bias=False) |
|
|
|
self.conv2 = nn.Conv2d(in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=(3, 3), stride=(1, 1), |
|
padding=(1, 1), bias=False) |
|
|
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
|
def forward(self, input, pool_size=(2, 2), pool_type='avg'): |
|
|
|
x = input |
|
x = F.relu_(self.bn1(self.conv1(x))) |
|
x = F.relu_(self.bn2(self.conv2(x))) |
|
if pool_type == 'max': |
|
x = F.max_pool2d(x, kernel_size=pool_size) |
|
elif pool_type == 'avg': |
|
x = F.avg_pool2d(x, kernel_size=pool_size) |
|
elif pool_type == 'avg+max': |
|
x1 = F.avg_pool2d(x, kernel_size=pool_size) |
|
x2 = F.max_pool2d(x, kernel_size=pool_size) |
|
x = x1 + x2 |
|
else: |
|
raise Exception('Incorrect argument!') |
|
|
|
return x |
|
|
|
|
|
class Cnn14Encoder(nn.Module): |
|
|
|
def __init__(self, sample_rate=32000): |
|
super().__init__() |
|
sr_to_fmax = { |
|
32000: 14000, |
|
16000: 8000 |
|
} |
|
|
|
self.melspec_extractor = transforms.MelSpectrogram( |
|
sample_rate=sample_rate, |
|
n_fft=32 * sample_rate // 1000, |
|
win_length=32 * sample_rate // 1000, |
|
hop_length=10 * sample_rate // 1000, |
|
f_min=50, |
|
f_max=sr_to_fmax[sample_rate], |
|
n_mels=64, |
|
norm="slaney", |
|
mel_scale="slaney" |
|
) |
|
self.hop_length = 10 * sample_rate // 1000 |
|
self.db_transform = transforms.AmplitudeToDB() |
|
|
|
self.bn0 = nn.BatchNorm2d(64) |
|
|
|
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) |
|
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) |
|
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) |
|
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) |
|
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) |
|
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) |
|
|
|
self.downsample_ratio = 32 |
|
|
|
self.fc1 = nn.Linear(2048, 2048, bias=True) |
|
self.fc_emb_size = 2048 |
|
|
|
def forward(self, input_dict): |
|
lms = input_dict["lms"] |
|
wave_length = input_dict["wav_len"] |
|
|
|
x = lms |
|
x = x.transpose(1, 2) |
|
x = x.unsqueeze(1) |
|
|
|
x = x.transpose(1, 3) |
|
x = self.bn0(x) |
|
x = x.transpose(1, 3) |
|
|
|
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = torch.mean(x, dim=3) |
|
attn_emb = x.transpose(1, 2) |
|
|
|
wave_length = torch.as_tensor(wave_length) |
|
feat_length = torch.div(wave_length, self.hop_length, |
|
rounding_mode="floor") + 1 |
|
feat_length = torch.div(feat_length, self.downsample_ratio, |
|
rounding_mode="floor") |
|
x_max = max_with_lens(attn_emb, feat_length) |
|
x_mean = mean_with_lens(attn_emb, feat_length) |
|
x = x_max + x_mean |
|
x = F.dropout(x, p=0.5, training=self.training) |
|
x = F.relu_(self.fc1(x)) |
|
fc_emb = F.dropout(x, p=0.5, training=self.training) |
|
|
|
output_dict = { |
|
'fc_emb': fc_emb, |
|
'attn_emb': attn_emb, |
|
'attn_emb_len': feat_length |
|
} |
|
|
|
return output_dict |
|
|
|
|
|
class RnnEncoder(nn.Module): |
|
|
|
def __init__(self, |
|
attn_feat_dim, |
|
pooling="mean", |
|
**kwargs): |
|
super().__init__() |
|
self.pooling = pooling |
|
self.hidden_size = kwargs.get('hidden_size', 512) |
|
self.bidirectional = kwargs.get('bidirectional', False) |
|
self.num_layers = kwargs.get('num_layers', 1) |
|
self.dropout = kwargs.get('dropout', 0.2) |
|
self.rnn_type = kwargs.get('rnn_type', "GRU") |
|
self.in_bn = kwargs.get('in_bn', False) |
|
self.embed_dim = self.hidden_size * (self.bidirectional + 1) |
|
self.network = getattr(nn, self.rnn_type)( |
|
attn_feat_dim, |
|
self.hidden_size, |
|
num_layers=self.num_layers, |
|
bidirectional=self.bidirectional, |
|
dropout=self.dropout, |
|
batch_first=True) |
|
if self.in_bn: |
|
self.bn = nn.BatchNorm1d(self.embed_dim) |
|
|
|
def forward(self, input_dict): |
|
x = input_dict["attn"] |
|
lens = input_dict["attn_len"] |
|
lens = torch.as_tensor(lens) |
|
|
|
if self.in_bn: |
|
x = pack_wrapper(self.bn, x, lens) |
|
out = pack_wrapper(self.network, x, lens) |
|
|
|
attn_emb = out |
|
fc_emb = embedding_pooling(out, lens, self.pooling) |
|
return { |
|
"attn_emb": attn_emb, |
|
"fc_emb": fc_emb, |
|
"attn_emb_len": lens |
|
} |
|
|
|
|
|
class Cnn14RnnEncoder(nn.Module): |
|
|
|
def __init__(self, |
|
sample_rate, |
|
rnn_bidirectional, |
|
rnn_hidden_size, |
|
rnn_dropout, |
|
rnn_num_layers): |
|
super().__init__() |
|
self.cnn = Cnn14Encoder(sample_rate=sample_rate) |
|
self.rnn = RnnEncoder( |
|
2048, |
|
bidirectional=rnn_bidirectional, |
|
hidden_size=rnn_hidden_size, |
|
dropout=rnn_dropout, |
|
num_layers=rnn_num_layers, |
|
) |
|
|
|
def forward(self, input_dict): |
|
output_dict = self.cnn(input_dict) |
|
output_dict["attn"] = output_dict["attn_emb"] |
|
output_dict["attn_len"] = output_dict["attn_emb_len"] |
|
del output_dict["attn_emb"], output_dict["attn_emb_len"] |
|
output_dict = self.rnn(output_dict) |
|
return output_dict |
|
|
|
|
|
class Seq2SeqAttention(nn.Module): |
|
|
|
def __init__(self, hs_enc, hs_dec, attn_size): |
|
""" |
|
Args: |
|
hs_enc: encoder hidden size |
|
hs_dec: decoder hidden size |
|
attn_size: attention vector size |
|
""" |
|
super(Seq2SeqAttention, self).__init__() |
|
self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size) |
|
self.v = nn.Parameter(torch.randn(attn_size)) |
|
|
|
def forward(self, h_dec, h_enc, src_lens): |
|
""" |
|
Args: |
|
h_dec: decoder hidden (query), [N, hs_dec] |
|
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc] |
|
src_lens: source (encoder memory) lengths, [N, ] |
|
""" |
|
N = h_enc.size(0) |
|
src_max_len = h_enc.size(1) |
|
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) |
|
|
|
attn_input = torch.cat((h_dec, h_enc), dim=-1) |
|
attn_out = torch.tanh(self.h2attn(attn_input)) |
|
|
|
v = self.v.repeat(N, 1).unsqueeze(1) |
|
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) |
|
|
|
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len) |
|
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device) |
|
|
|
score = score.masked_fill(mask == 0, -1e10) |
|
weights = torch.softmax(score, dim=-1) |
|
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) |
|
|
|
return ctx, weights |
|
|
|
|
|
class RnnDecoder(BaseDecoder): |
|
|
|
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout, d_model, **kwargs): |
|
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout,) |
|
self.d_model = d_model |
|
self.num_layers = kwargs.get('num_layers', 1) |
|
self.bidirectional = kwargs.get('bidirectional', False) |
|
self.rnn_type = kwargs.get('rnn_type', "GRU") |
|
self.classifier = nn.Linear( |
|
self.d_model * (self.bidirectional + 1), vocab_size) |
|
|
|
def forward(self, x): |
|
raise NotImplementedError |
|
|
|
def init_hidden(self, bs, device): |
|
num_dire = self.bidirectional + 1 |
|
n_layer = self.num_layers |
|
hid_dim = self.d_model |
|
if self.rnn_type == "LSTM": |
|
return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device), |
|
torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)) |
|
else: |
|
return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device) |
|
|
|
|
|
class BahAttnCatFcDecoder(RnnDecoder): |
|
|
|
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout, d_model, **kwargs): |
|
""" |
|
concatenate fc, attn, word to feed to the rnn |
|
""" |
|
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout, d_model, **kwargs) |
|
attn_size = kwargs.get("attn_size", self.d_model) |
|
self.model = getattr(nn, self.rnn_type)( |
|
input_size=self.emb_dim * 3, |
|
hidden_size=self.d_model, |
|
batch_first=True, |
|
num_layers=self.num_layers, |
|
bidirectional=self.bidirectional) |
|
self.attn = Seq2SeqAttention(self.attn_emb_dim, |
|
self.d_model * (self.bidirectional + 1) * \ |
|
self.num_layers, |
|
attn_size) |
|
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) |
|
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) |
|
|
|
def forward(self, input_dict): |
|
word = input_dict["word"] |
|
state = input_dict.get("state", None) |
|
fc_emb = input_dict["fc_emb"] |
|
attn_emb = input_dict["attn_emb"] |
|
attn_emb_len = input_dict["attn_emb_len"] |
|
|
|
word = word.to(fc_emb.device) |
|
embed = self.in_dropout(self.word_embedding(word)) |
|
|
|
|
|
if state is None: |
|
state = self.init_hidden(word.size(0), fc_emb.device) |
|
if self.rnn_type == "LSTM": |
|
query = state[0].transpose(0, 1).flatten(1) |
|
else: |
|
query = state.transpose(0, 1).flatten(1) |
|
c, attn_weight = self.attn(query, attn_emb, attn_emb_len) |
|
|
|
p_fc_emb = self.fc_proj(fc_emb) |
|
p_ctx = self.ctx_proj(c) |
|
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)), |
|
dim=-1) |
|
|
|
out, state = self.model(rnn_input, state) |
|
|
|
output = { |
|
"state": state, |
|
"embed": out, |
|
"logit": self.classifier(out), |
|
"attn_weight": attn_weight |
|
} |
|
return output |
|
|
|
|
|
class TemporalBahAttnDecoder(BahAttnCatFcDecoder): |
|
|
|
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout, d_model, **kwargs): |
|
""" |
|
concatenate fc, attn, word to feed to the rnn |
|
""" |
|
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, |
|
dropout, d_model, **kwargs) |
|
self.temporal_embedding = nn.Embedding(4, emb_dim) |
|
|
|
def forward(self, input_dict): |
|
word = input_dict["word"] |
|
state = input_dict.get("state", None) |
|
fc_embs = input_dict["fc_emb"] |
|
attn_embs = input_dict["attn_emb"] |
|
attn_emb_lens = input_dict["attn_emb_len"] |
|
temporal_tag = input_dict["temporal_tag"] |
|
|
|
if input_dict["t"] == 0: |
|
embed = self.in_dropout( |
|
self.temporal_embedding(temporal_tag)).unsqueeze(1) |
|
elif word.size(-1) == self.fc_emb_dim: |
|
embed = word.unsqueeze(1) |
|
elif word.size(-1) == 1: |
|
word = word.to(fc_embs.device) |
|
embed = self.in_dropout(self.word_embedding(word)) |
|
else: |
|
raise Exception(f"problem with word input size {word.size()}") |
|
|
|
|
|
if state is None: |
|
state = self.init_hidden(word.size(0), fc_embs.device) |
|
if self.rnn_type == "LSTM": |
|
query = state[0].transpose(0, 1).flatten(1) |
|
else: |
|
query = state.transpose(0, 1).flatten(1) |
|
c, attn_weight = self.attn(query, attn_embs, attn_emb_lens) |
|
|
|
p_ctx = self.ctx_proj(c) |
|
p_fc_embs = self.fc_proj(fc_embs) |
|
p_ctx = self.ctx_proj(c) |
|
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_embs.unsqueeze(1)), dim=-1) |
|
|
|
out, state = self.model(rnn_input, state) |
|
|
|
output = { |
|
"state": state, |
|
"embed": out, |
|
"logit": self.classifier(out), |
|
"attn_weight": attn_weight |
|
} |
|
return output |
|
|
|
|
|
class Seq2SeqAttnModel(CaptionModel): |
|
|
|
def __init__(self, encoder, decoder, **kwargs): |
|
if not hasattr(self, "compatible_decoders"): |
|
self.compatible_decoders = ( |
|
BahAttnCatFcDecoder, |
|
) |
|
super().__init__(encoder, decoder, **kwargs) |
|
|
|
|
|
def seq_forward(self, input_dict): |
|
|
|
|
|
return self.stepwise_forward(input_dict) |
|
|
|
def prepare_output(self, input_dict): |
|
output = super().prepare_output(input_dict) |
|
attn_weight = torch.empty(output["seq"].size(0), |
|
input_dict["attn_emb"].size(1), output["seq"].size(1)) |
|
output["attn_weight"] = attn_weight |
|
return output |
|
|
|
def prepare_decoder_input(self, input_dict, output): |
|
decoder_input = { |
|
"fc_emb": input_dict["fc_emb"], |
|
"attn_emb": input_dict["attn_emb"], |
|
"attn_emb_len": input_dict["attn_emb_len"] |
|
} |
|
t = input_dict["t"] |
|
|
|
|
|
|
|
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: |
|
word = input_dict["cap"][:, t] |
|
else: |
|
if t == 0: |
|
word = torch.tensor([self.start_idx,] * input_dict["fc_emb"].size(0)).long() |
|
else: |
|
word = output["seq"][:, t-1] |
|
|
|
decoder_input["word"] = word.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
if t > 0: |
|
decoder_input["state"] = output["state"] |
|
return decoder_input |
|
|
|
def stepwise_process_step(self, output, output_t): |
|
super().stepwise_process_step(output, output_t) |
|
output["state"] = output_t["state"] |
|
t = output_t["t"] |
|
output["attn_weight"][:, :, t] = output_t["attn_weight"] |
|
|
|
def prepare_beamsearch_output(self, input_dict): |
|
output = super().prepare_beamsearch_output(input_dict) |
|
beam_size = input_dict["beam_size"] |
|
max_length = input_dict["max_length"] |
|
output["attn_weight"] = torch.empty(beam_size, |
|
max(input_dict["attn_emb_len"]), max_length) |
|
return output |
|
|
|
def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
|
decoder_input = {} |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
beam_size = input_dict["beam_size"] |
|
|
|
|
|
|
|
if t == 0: |
|
fc_emb = repeat_tensor(input_dict["fc_emb"][i], beam_size) |
|
output_i["fc_emb"] = fc_emb |
|
decoder_input["fc_emb"] = output_i["fc_emb"] |
|
|
|
|
|
|
|
|
|
if t == 0: |
|
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) |
|
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) |
|
output_i["attn_emb"] = attn_emb |
|
output_i["attn_emb_len"] = attn_emb_len |
|
decoder_input["attn_emb"] = output_i["attn_emb"] |
|
decoder_input["attn_emb_len"] = output_i["attn_emb_len"] |
|
|
|
|
|
|
|
|
|
if t == 0: |
|
word = torch.tensor([self.start_idx,] * beam_size).long() |
|
else: |
|
word = output_i["next_word"] |
|
decoder_input["word"] = word.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
if t > 0: |
|
if self.decoder.rnn_type == "LSTM": |
|
decoder_input["state"] = (output_i["state"][0][:, output_i["prev_words_beam"], :].contiguous(), |
|
output_i["state"][1][:, output_i["prev_words_beam"], :].contiguous()) |
|
else: |
|
decoder_input["state"] = output_i["state"][:, output_i["prev_words_beam"], :].contiguous() |
|
|
|
return decoder_input |
|
|
|
def beamsearch_process_step(self, output_i, output_t): |
|
t = output_t["t"] |
|
output_i["state"] = output_t["state"] |
|
output_i["attn_weight"][..., t] = output_t["attn_weight"] |
|
output_i["attn_weight"] = output_i["attn_weight"][output_i["prev_words_beam"], ...] |
|
|
|
def beamsearch_process(self, output, output_i, input_dict): |
|
super().beamsearch_process(output, output_i, input_dict) |
|
i = input_dict["sample_idx"] |
|
output["attn_weight"][i] = output_i["attn_weight"][0] |
|
|
|
def prepare_dbs_decoder_input(self, input_dict, output_i): |
|
decoder_input = {} |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
bdash = input_dict["bdash"] |
|
divm = input_dict["divm"] |
|
|
|
local_time = t - divm |
|
|
|
|
|
|
|
|
|
if t == 0: |
|
fc_emb = repeat_tensor(input_dict["fc_emb"][i], bdash).unsqueeze(1) |
|
output_i["fc_emb"] = fc_emb |
|
decoder_input["fc_emb"] = output_i["fc_emb"] |
|
|
|
|
|
|
|
|
|
if t == 0: |
|
attn_emb = repeat_tensor(input_dict["attn_emb"][i], bdash) |
|
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], bdash) |
|
output_i["attn_emb"] = attn_emb |
|
output_i["attn_emb_len"] = attn_emb_len |
|
decoder_input["attn_emb"] = output_i["attn_emb"] |
|
decoder_input["attn_emb_len"] = output_i["attn_emb_len"] |
|
|
|
|
|
|
|
|
|
if local_time == 0: |
|
word = torch.tensor([self.start_idx,] * bdash).long() |
|
else: |
|
word = output_i["next_word"][divm] |
|
decoder_input["word"] = word.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
if local_time > 0: |
|
if self.decoder.rnn_type == "LSTM": |
|
decoder_input["state"] = ( |
|
output_i["state"][0][divm][ |
|
:, output_i["prev_words_beam"][divm], :].contiguous(), |
|
output_i["state"][1][divm][ |
|
:, output_i["prev_words_beam"][divm], :].contiguous() |
|
) |
|
else: |
|
decoder_input["state"] = output_i["state"][divm][ |
|
:, output_i["prev_words_beam"][divm], :].contiguous() |
|
|
|
return decoder_input |
|
|
|
def dbs_process_step(self, output_i, output_t): |
|
divm = output_t["divm"] |
|
output_i["state"][divm] = output_t["state"] |
|
|
|
|
|
|
|
class TemporalSeq2SeqAttnModel(Seq2SeqAttnModel): |
|
|
|
def __init__(self, encoder, decoder, **kwargs): |
|
if not hasattr(self, "compatible_decoders"): |
|
self.compatible_decoders = ( |
|
TemporalBahAttnDecoder, |
|
) |
|
super().__init__(encoder, decoder, **kwargs) |
|
self.train_forward_keys = ["cap", "cap_len", "ss_ratio", "temporal_tag"] |
|
self.inference_forward_keys = ["sample_method", "max_length", "temp", "temporal_tag"] |
|
|
|
|
|
def prepare_decoder_input(self, input_dict, output): |
|
decoder_input = super().prepare_decoder_input(input_dict, output) |
|
decoder_input["temporal_tag"] = input_dict["temporal_tag"] |
|
decoder_input["t"] = input_dict["t"] |
|
|
|
return decoder_input |
|
|
|
|
|
def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
|
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
beam_size = input_dict["beam_size"] |
|
|
|
|
|
|
|
if t == 0: |
|
temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], beam_size) |
|
output_i["temporal_tag"] = temporal_tag |
|
decoder_input["temporal_tag"] = output_i["temporal_tag"] |
|
decoder_input["t"] = input_dict["t"] |
|
|
|
return decoder_input |
|
|
|
def prepare_dbs_decoder_input(self, input_dict, output_i): |
|
decoder_input = super.prepare_dbs_decoder_input(input_dict, output_i) |
|
t = input_dict["t"] |
|
i = input_dict["sample_idx"] |
|
bdash = input_dict["bdash"] |
|
|
|
|
|
|
|
|
|
|
|
if t == 0: |
|
temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], bdash) |
|
output_i["temporal_tag"] = temporal_tag |
|
decoder_input["temporal_tag"] = output_i["temporal_tag"] |
|
decoder_input["t"] = input_dict["t"] |
|
|
|
return decoder_input |
|
|
|
|
|
class Cnn8rnnSedModel(nn.Module): |
|
def __init__(self, classes_num): |
|
|
|
super().__init__() |
|
|
|
self.time_resolution = 0.01 |
|
self.interpolate_ratio = 4 |
|
|
|
self.bn0 = nn.BatchNorm2d(64) |
|
|
|
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) |
|
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) |
|
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) |
|
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) |
|
|
|
self.fc1 = nn.Linear(512, 512, bias=True) |
|
self.rnn = nn.GRU(512, 256, bidirectional=True, batch_first=True) |
|
self.fc_audioset = nn.Linear(512, classes_num, bias=True) |
|
|
|
def forward(self, lms): |
|
output = self.forward_prob(lms) |
|
framewise_output = output["framewise_output"].cpu().numpy() |
|
thresholded_predictions = double_threshold( |
|
framewise_output, 0.75, 0.25) |
|
decoded_tags = decode_with_timestamps( |
|
thresholded_predictions, self.time_resolution |
|
) |
|
return decoded_tags |
|
|
|
def forward_prob(self, lms): |
|
""" |
|
lms: (batch_size, mel_bins, time_steps)""" |
|
|
|
x = lms |
|
x = x.transpose(1, 2) |
|
x = x.unsqueeze(1) |
|
|
|
frames_num = x.shape[2] |
|
|
|
x = x.transpose(1, 3) |
|
x = self.bn0(x) |
|
x = x.transpose(1, 3) |
|
|
|
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg+max') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg+max') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg+max') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg+max') |
|
x = F.dropout(x, p=0.2, training=self.training) |
|
x = torch.mean(x, dim=3) |
|
|
|
x = x.transpose(1, 2) |
|
x = F.dropout(x, p=0.5, training=self.training) |
|
x = F.relu_(self.fc1(x)) |
|
x, _ = self.rnn(x) |
|
segmentwise_output = torch.sigmoid(self.fc_audioset(x)).clamp(1e-7, 1.) |
|
|
|
framewise_output = interpolate(segmentwise_output, |
|
self.interpolate_ratio) |
|
framewise_output = pad_framewise_output(framewise_output, frames_num) |
|
|
|
output_dict = { |
|
"segmentwise_output": segmentwise_output, |
|
'framewise_output': framewise_output, |
|
} |
|
|
|
return output_dict |
|
|
|
|
|
class Cnn14RnnTempAttnGruConfig(PretrainedConfig): |
|
|
|
def __init__( |
|
self, |
|
sample_rate: int = 32000, |
|
encoder_rnn_bidirectional: bool = True, |
|
encoder_rnn_hidden_size: int = 256, |
|
encoder_rnn_dropout: float = 0.5, |
|
encoder_rnn_num_layers: int = 3, |
|
decoder_emb_dim: int = 512, |
|
vocab_size: int = 4981, |
|
fc_emb_dim: int = 512, |
|
attn_emb_dim: int = 512, |
|
decoder_rnn_type: str = "GRU", |
|
decoder_num_layers: int = 1, |
|
decoder_d_model: int = 512, |
|
decoder_dropout: float = 0.5, |
|
**kwargs |
|
): |
|
self.sample_rate = sample_rate |
|
self.encoder_rnn_bidirectional = encoder_rnn_bidirectional |
|
self.encoder_rnn_hidden_size = encoder_rnn_hidden_size |
|
self.encoder_rnn_dropout = encoder_rnn_dropout |
|
self.encoder_rnn_num_layers = encoder_rnn_num_layers |
|
self.decoder_emb_dim = decoder_emb_dim |
|
self.vocab_size = vocab_size |
|
self.fc_emb_dim = fc_emb_dim |
|
self.attn_emb_dim = attn_emb_dim |
|
self.decoder_rnn_type = decoder_rnn_type |
|
self.decoder_num_layers = decoder_num_layers |
|
self.decoder_d_model = decoder_d_model |
|
self.decoder_dropout = decoder_dropout |
|
super().__init__(**kwargs) |
|
|
|
|
|
class Cnn14RnnTempAttnGruModel(PreTrainedModel): |
|
config_class = Cnn14RnnTempAttnGruConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
sample_rate = config.sample_rate |
|
sr_to_fmax = { |
|
32000: 14000, |
|
16000: 8000 |
|
} |
|
self.melspec_extractor = transforms.MelSpectrogram( |
|
sample_rate=sample_rate, |
|
n_fft=32 * sample_rate // 1000, |
|
win_length=32 * sample_rate // 1000, |
|
hop_length=10 * sample_rate // 1000, |
|
f_min=50, |
|
f_max=sr_to_fmax[sample_rate], |
|
n_mels=64, |
|
norm="slaney", |
|
mel_scale="slaney" |
|
) |
|
self.db_transform = transforms.AmplitudeToDB() |
|
|
|
encoder = Cnn14RnnEncoder( |
|
sample_rate=config.sample_rate, |
|
rnn_bidirectional=config.encoder_rnn_bidirectional, |
|
rnn_hidden_size=config.encoder_rnn_hidden_size, |
|
rnn_dropout=config.encoder_rnn_dropout, |
|
rnn_num_layers=config.encoder_rnn_num_layers |
|
) |
|
decoder = TemporalBahAttnDecoder( |
|
emb_dim=config.decoder_emb_dim, |
|
vocab_size=config.vocab_size, |
|
fc_emb_dim=config.fc_emb_dim, |
|
attn_emb_dim=config.attn_emb_dim, |
|
rnn_type=config.decoder_rnn_type, |
|
num_layers=config.decoder_num_layers, |
|
d_model=config.decoder_d_model, |
|
dropout=config.decoder_dropout, |
|
) |
|
cap_model = TemporalSeq2SeqAttnModel(encoder, decoder) |
|
sed_model = Cnn8rnnSedModel(classes_num=447) |
|
self.cap_model = cap_model |
|
self.sed_model = sed_model |
|
|
|
def forward(self, |
|
audio: torch.Tensor, |
|
audio_length: Union[List, np.ndarray, torch.Tensor], |
|
temporal_tag: Union[List, np.ndarray, torch.Tensor] = None, |
|
sample_method: str = "beam", |
|
beam_size: int = 3, |
|
max_length: int = 20, |
|
temp: float = 1.0,): |
|
device = self.device |
|
mel_spec = self.melspec_extractor(audio.to(device)) |
|
log_mel_spec = self.db_transform(mel_spec) |
|
|
|
sed_tag = self.sed_model(log_mel_spec) |
|
sed_tag = torch.as_tensor(sed_tag).to(device) |
|
if temporal_tag is not None: |
|
temporal_tag = torch.as_tensor(temporal_tag).to(device) |
|
temporal_tag = torch.stack([temporal_tag, sed_tag], dim=0) |
|
temporal_tag = torch.min(temporal_tag, dim=0).values |
|
else: |
|
temporal_tag = sed_tag |
|
|
|
input_dict = { |
|
"lms": log_mel_spec, |
|
"wav_len": audio_length, |
|
"temporal_tag": temporal_tag, |
|
"mode": "inference", |
|
"sample_method": sample_method, |
|
"max_length": max_length, |
|
"temp": temp, |
|
} |
|
if sample_method == "beam": |
|
input_dict["beam_size"] = beam_size |
|
return self.cap_model(input_dict)["seq"].cpu() |