Spaces:
Sleeping
Sleeping
from typing import Dict, Callable, Union, List | |
import random | |
import math | |
import sys | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |
from torchaudio import transforms | |
from efficientnet_pytorch import EfficientNet | |
from efficientnet_pytorch import utils as efficientnet_utils | |
from einops import rearrange, reduce | |
from transformers import PretrainedConfig, PreTrainedModel | |
def sort_pack_padded_sequence(input, lengths): | |
sorted_lengths, indices = torch.sort(lengths, descending=True) | |
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) | |
inv_ix = indices.clone() | |
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) | |
return tmp, inv_ix | |
def pad_unsort_packed_sequence(input, inv_ix): | |
tmp, _ = pad_packed_sequence(input, batch_first=True) | |
tmp = tmp[inv_ix] | |
return tmp | |
def pack_wrapper(module, attn_feats, attn_feat_lens): | |
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens) | |
if isinstance(module, torch.nn.RNNBase): | |
return pad_unsort_packed_sequence(module(packed)[0], inv_ix) | |
else: | |
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) | |
def embedding_pooling(x, lens, pooling="mean"): | |
if pooling == "max": | |
fc_embs = max_with_lens(x, lens) | |
elif pooling == "mean": | |
fc_embs = mean_with_lens(x, lens) | |
elif pooling == "mean+max": | |
x_mean = mean_with_lens(x, lens) | |
x_max = max_with_lens(x, lens) | |
fc_embs = x_mean + x_max | |
elif pooling == "last": | |
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1)) | |
# indices: [N, 1, hidden] | |
fc_embs = torch.gather(x, 1, indices).squeeze(1) | |
else: | |
raise Exception(f"pooling method {pooling} not support") | |
return fc_embs | |
def interpolate(x, ratio): | |
"""Interpolate data in time domain. This is used to compensate the | |
resolution reduction in downsampling of a CNN. | |
Args: | |
x: (batch_size, time_steps, classes_num) | |
ratio: int, ratio to interpolate | |
Returns: | |
upsampled: (batch_size, time_steps * ratio, classes_num) | |
""" | |
(batch_size, time_steps, classes_num) = x.shape | |
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) | |
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) | |
return upsampled | |
def pad_framewise_output(framewise_output, frames_num): | |
"""Pad framewise_output to the same length as input frames. The pad value | |
is the same as the value of the last frame. | |
Args: | |
framewise_output: (batch_size, frames_num, classes_num) | |
frames_num: int, number of frames to pad | |
Outputs: | |
output: (batch_size, frames_num, classes_num) | |
""" | |
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) | |
"""tensor for padding""" | |
output = torch.cat((framewise_output, pad), dim=1) | |
"""(batch_size, frames_num, classes_num)""" | |
return output | |
def find_contiguous_regions(activity_array): | |
"""Find contiguous regions from bool valued numpy.array. | |
Copy of https://dcase-repo.github.io/dcase_util/_modules/dcase_util/data/decisions.html#DecisionEncoder | |
Reason is: | |
1. This does not belong to a class necessarily | |
2. Import DecisionEncoder requires sndfile over some other imports..which causes some problems on clusters | |
""" | |
# Find the changes in the activity_array | |
change_indices = np.logical_xor(activity_array[1:], | |
activity_array[:-1]).nonzero()[0] | |
# Shift change_index with one, focus on frame after the change. | |
change_indices += 1 | |
if activity_array[0]: | |
# If the first element of activity_array is True add 0 at the beginning | |
change_indices = np.r_[0, change_indices] | |
if activity_array[-1]: | |
# If the last element of activity_array is True, add the length of the array | |
change_indices = np.r_[change_indices, activity_array.size] | |
# Reshape the result into two columns | |
return change_indices.reshape((-1, 2)) | |
def double_threshold(x, high_thres, low_thres, n_connect=1): | |
"""double_threshold | |
Helper function to calculate double threshold for n-dim arrays | |
:param x: input array | |
:param high_thres: high threshold value | |
:param low_thres: Low threshold value | |
:param n_connect: Distance of <= n clusters will be merged | |
""" | |
assert x.ndim <= 3, "Whoops something went wrong with the input ({}), check if its <= 3 dims".format( | |
x.shape) | |
if x.ndim == 3: | |
apply_dim = 1 | |
elif x.ndim < 3: | |
apply_dim = 0 | |
# x is assumed to be 3d: (batch, time, dim) | |
# Assumed to be 2d : (time, dim) | |
# Assumed to be 1d : (time) | |
# time axis is therefore at 1 for 3d and 0 for 2d ( | |
return np.apply_along_axis(lambda x: _double_threshold( | |
x, high_thres, low_thres, n_connect=n_connect), | |
axis=apply_dim, | |
arr=x) | |
def _double_threshold(x, high_thres, low_thres, n_connect=1, return_arr=True): | |
"""_double_threshold | |
Computes a double threshold over the input array | |
:param x: input array, needs to be 1d | |
:param high_thres: High threshold over the array | |
:param low_thres: Low threshold over the array | |
:param n_connect: Postprocessing, maximal distance between clusters to connect | |
:param return_arr: By default this function returns the filtered indiced, but if return_arr = True it returns an array of tsame size as x filled with ones and zeros. | |
""" | |
assert x.ndim == 1, "Input needs to be 1d" | |
high_locations = np.where(x > high_thres)[0] | |
locations = x > low_thres | |
encoded_pairs = find_contiguous_regions(locations) | |
filtered_list = list( | |
filter( | |
lambda pair: | |
((pair[0] <= high_locations) & (high_locations <= pair[1])).any(), | |
encoded_pairs)) | |
filtered_list = connect_(filtered_list, n_connect) | |
if return_arr: | |
zero_one_arr = np.zeros_like(x, dtype=int) | |
for sl in filtered_list: | |
zero_one_arr[sl[0]:sl[1]] = 1 | |
return zero_one_arr | |
return filtered_list | |
def connect_(pairs, n=1): | |
"""connect_ | |
Connects two adjacent clusters if their distance is <= n | |
:param pairs: Clusters of iterateables e.g., [(1,5),(7,10)] | |
:param n: distance between two clusters | |
""" | |
if len(pairs) == 0: | |
return [] | |
start_, end_ = pairs[0] | |
new_pairs = [] | |
for i, (next_item, cur_item) in enumerate(zip(pairs[1:], pairs[0:])): | |
end_ = next_item[1] | |
if next_item[0] - cur_item[1] <= n: | |
pass | |
else: | |
new_pairs.append((start_, cur_item[1])) | |
start_ = next_item[0] | |
new_pairs.append((start_, end_)) | |
return new_pairs | |
def segments_to_temporal_tag(segments, thre=0.5): | |
after_flag, while_flag = 0, 0 | |
for j in range(len(segments)): | |
for k in range(len(segments)): | |
if segments[j][0] == segments[k][0]: | |
continue | |
min_duration = min(segments[j][2] - segments[j][1], segments[k][2] - segments[k][1]) | |
overlap = segments[j][2] - segments[k][1] | |
if overlap < thre * min_duration: | |
after_flag = 2 | |
if segments[j][1] < segments[k][1] and overlap > thre * min_duration: | |
while_flag = 1 | |
return after_flag + while_flag | |
def decode_with_timestamps(labels, time_resolution): | |
batch_results = [] | |
for lab in labels: | |
segments = [] | |
for i, label_column in enumerate(lab.T): | |
change_indices = find_contiguous_regions(label_column) | |
# append [onset, offset] in the result list | |
for row in change_indices: | |
segments.append((i, row[0] * time_resolution, row[1] * time_resolution)) | |
temporal_tag = segments_to_temporal_tag(segments) | |
batch_results.append(temporal_tag) | |
return batch_results | |
class _EffiNet(nn.Module): | |
"""A proxy for efficient net models""" | |
def __init__(self, | |
blocks_args=None, | |
global_params=None, | |
) -> None: | |
super().__init__() | |
self.eff_net = EfficientNet(blocks_args=blocks_args, | |
global_params=global_params) | |
def forward(self, x: torch.Tensor): | |
x = rearrange(x, 'b f t -> b 1 f t') | |
x = self.eff_net.extract_features(x) | |
return reduce(x, 'b c f t -> b t c', 'mean') | |
def get_effb2_model() -> _EffiNet: | |
blocks_args, global_params = efficientnet_utils.get_model_params( | |
'efficientnet-b2', {'include_top': False}) | |
model = _EffiNet(blocks_args=blocks_args, | |
global_params=global_params) | |
model.eff_net._change_in_channels(1) | |
return model | |
def merge_load_state_dict(state_dict, | |
model: torch.nn.Module, | |
output_fn: Callable = sys.stdout.write): | |
model_dict = model.state_dict() | |
pretrained_dict = {} | |
mismatch_keys = [] | |
for key, value in state_dict.items(): | |
if key in model_dict and model_dict[key].shape == value.shape: | |
pretrained_dict[key] = value | |
else: | |
mismatch_keys.append(key) | |
output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n") | |
model_dict.update(pretrained_dict) | |
model.load_state_dict(model_dict, strict=True) | |
return pretrained_dict.keys() | |
class EfficientNetB2(nn.Module): | |
def __init__(self, | |
n_mels: int = 64, | |
win_length: int = 32, | |
hop_length: int = 10, | |
f_min: int = 0, | |
freeze: bool = False,): | |
super().__init__() | |
sample_rate = 16000 | |
self.melspec_extractor = transforms.MelSpectrogram( | |
sample_rate=sample_rate, | |
n_fft=win_length * sample_rate // 1000, | |
win_length=win_length * sample_rate // 1000, | |
hop_length=hop_length * sample_rate // 1000, | |
f_min=f_min, | |
n_mels=n_mels, | |
) | |
self.hop_length = 10 * sample_rate // 1000 | |
self.db_transform = transforms.AmplitudeToDB(top_db=120) | |
self.backbone = get_effb2_model() | |
self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels | |
self.downsample_ratio = 32 | |
if freeze: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, input_dict): | |
waveform = input_dict["wav"] | |
wave_length = input_dict["wav_len"] | |
specaug = input_dict["specaug"] | |
x = self.melspec_extractor(waveform) | |
x = self.db_transform(x) # (batch_size, mel_bins, time_steps) | |
x = rearrange(x, 'b f t -> b 1 t f') | |
if self.training and specaug: | |
x = self.spec_augmenter(x) | |
x = rearrange(x, 'b 1 t f -> b f t') | |
x = self.backbone(x) | |
attn_emb = x | |
wave_length = torch.as_tensor(wave_length) | |
feat_length = torch.div(wave_length, self.hop_length, | |
rounding_mode="floor") + 1 | |
feat_length = torch.div(feat_length, self.downsample_ratio, | |
rounding_mode="floor") | |
fc_emb = mean_with_lens(attn_emb, feat_length) | |
output_dict = { | |
'fc_emb': fc_emb, | |
'attn_emb': attn_emb, | |
'attn_emb_len': feat_length | |
} | |
return output_dict | |
def generate_length_mask(lens, max_length=None): | |
lens = torch.as_tensor(lens) | |
N = lens.size(0) | |
if max_length is None: | |
max_length = max(lens) | |
if isinstance(max_length, torch.Tensor): | |
max_length = max_length.item() | |
idxs = torch.arange(max_length).repeat(N).view(N, max_length) | |
idxs = idxs.to(lens.device) | |
mask = (idxs < lens.view(-1, 1)) | |
return mask | |
def mean_with_lens(features, lens): | |
""" | |
features: [N, T, ...] (assume the second dimension represents length) | |
lens: [N,] | |
""" | |
lens = torch.as_tensor(lens) | |
if max(lens) != features.size(1): | |
max_length = features.size(1) | |
mask = generate_length_mask(lens, max_length) | |
else: | |
mask = generate_length_mask(lens) | |
mask = mask.to(features.device) # [N, T] | |
while mask.ndim < features.ndim: | |
mask = mask.unsqueeze(-1) | |
feature_mean = features * mask | |
feature_mean = feature_mean.sum(1) | |
while lens.ndim < feature_mean.ndim: | |
lens = lens.unsqueeze(1) | |
feature_mean = feature_mean / lens.to(features.device) | |
# feature_mean = features * mask.unsqueeze(-1) | |
# feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device) | |
return feature_mean | |
def max_with_lens(features, lens): | |
""" | |
features: [N, T, ...] (assume the second dimension represents length) | |
lens: [N,] | |
""" | |
lens = torch.as_tensor(lens) | |
if max(lens) != features.size(1): | |
max_length = features.size(1) | |
mask = generate_length_mask(lens, max_length) | |
else: | |
mask = generate_length_mask(lens) | |
mask = mask.to(features.device) # [N, T] | |
feature_max = features.clone() | |
feature_max[~mask] = float("-inf") | |
feature_max, _ = feature_max.max(1) | |
return feature_max | |
def repeat_tensor(x, n): | |
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape))) | |
class CaptionMetaMixin: | |
pad_idx = 0 | |
start_idx = 1 | |
end_idx = 2 | |
max_length = 20 | |
def set_index(cls, start_idx, end_idx, pad_idx): | |
cls.start_idx = start_idx | |
cls.end_idx = end_idx | |
cls.pad_idx = pad_idx | |
class CaptionModel(nn.Module, CaptionMetaMixin): | |
""" | |
Encoder-decoder captioning model. | |
""" | |
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.vocab_size = decoder.vocab_size | |
self.train_forward_keys = ["cap", "cap_len", "ss_ratio"] | |
self.inference_forward_keys = ["sample_method", "max_length", "temp"] | |
freeze_encoder = kwargs.get("freeze_encoder", False) | |
if freeze_encoder: | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
self.check_decoder_compatibility() | |
def check_decoder_compatibility(self): | |
compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders] | |
assert isinstance(self.decoder, self.compatible_decoders), \ | |
f"{self.decoder.__class__.__name__} is incompatible with " \ | |
f"{self.__class__.__name__}, please use decoder in {compatible_decoders} " | |
def forward(self, input_dict: Dict): | |
""" | |
input_dict: { | |
(required) | |
mode: train/inference, | |
[spec, spec_len], | |
[fc], | |
[attn, attn_len], | |
[wav, wav_len], | |
[sample_method: greedy], | |
[temp: 1.0] (in case of no teacher forcing) | |
(optional, mode=train) | |
cap, | |
cap_len, | |
ss_ratio, | |
(optional, mode=inference) | |
sample_method: greedy/beam, | |
max_length, | |
temp, | |
beam_size (optional, sample_method=beam), | |
n_best (optional, sample_method=beam), | |
} | |
""" | |
encoder_output_dict = self.encoder(input_dict) | |
output = self.forward_decoder(input_dict, encoder_output_dict) | |
return output | |
def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict): | |
if input_dict["mode"] == "train": | |
forward_dict = { | |
"mode": "train", "sample_method": "greedy", "temp": 1.0 | |
} | |
for key in self.train_forward_keys: | |
forward_dict[key] = input_dict[key] | |
forward_dict.update(encoder_output_dict) | |
output = self.train_forward(forward_dict) | |
elif input_dict["mode"] == "inference": | |
forward_dict = {"mode": "inference"} | |
default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 } | |
for key in self.inference_forward_keys: | |
if key in input_dict: | |
forward_dict[key] = input_dict[key] | |
else: | |
forward_dict[key] = default_args[key] | |
if forward_dict["sample_method"] == "beam": | |
forward_dict["beam_size"] = input_dict.get("beam_size", 3) | |
forward_dict["n_best"] = input_dict.get("n_best", False) | |
forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"]) | |
elif forward_dict["sample_method"] == "dbs": | |
forward_dict["beam_size"] = input_dict.get("beam_size", 6) | |
forward_dict["group_size"] = input_dict.get("group_size", 3) | |
forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5) | |
forward_dict["group_nbest"] = input_dict.get("group_nbest", True) | |
forward_dict.update(encoder_output_dict) | |
output = self.inference_forward(forward_dict) | |
else: | |
raise Exception("mode should be either 'train' or 'inference'") | |
output.update(encoder_output_dict) | |
return output | |
def prepare_output(self, input_dict): | |
output = {} | |
batch_size = input_dict["fc_emb"].size(0) | |
if input_dict["mode"] == "train": | |
max_length = input_dict["cap"].size(1) - 1 | |
elif input_dict["mode"] == "inference": | |
max_length = input_dict["max_length"] | |
else: | |
raise Exception("mode should be either 'train' or 'inference'") | |
device = input_dict["fc_emb"].device | |
output["seq"] = torch.full((batch_size, max_length), self.end_idx, | |
dtype=torch.long) | |
output["logit"] = torch.empty(batch_size, max_length, | |
self.vocab_size).to(device) | |
output["sampled_logprob"] = torch.zeros(batch_size, max_length) | |
output["embed"] = torch.empty(batch_size, max_length, | |
self.decoder.d_model).to(device) | |
return output | |
def train_forward(self, input_dict): | |
if input_dict["ss_ratio"] != 1: # scheduled sampling training | |
input_dict["mode"] = "train" | |
return self.stepwise_forward(input_dict) | |
output = self.seq_forward(input_dict) | |
self.train_process(output, input_dict) | |
return output | |
def seq_forward(self, input_dict): | |
raise NotImplementedError | |
def train_process(self, output, input_dict): | |
pass | |
def inference_forward(self, input_dict): | |
if input_dict["sample_method"] == "beam": | |
return self.beam_search(input_dict) | |
elif input_dict["sample_method"] == "dbs": | |
return self.diverse_beam_search(input_dict) | |
return self.stepwise_forward(input_dict) | |
def stepwise_forward(self, input_dict): | |
"""Step-by-step decoding""" | |
output = self.prepare_output(input_dict) | |
max_length = output["seq"].size(1) | |
# start sampling | |
for t in range(max_length): | |
input_dict["t"] = t | |
self.decode_step(input_dict, output) | |
if input_dict["mode"] == "inference": # decide whether to stop when sampling | |
unfinished_t = output["seq"][:, t] != self.end_idx | |
if t == 0: | |
unfinished = unfinished_t | |
else: | |
unfinished *= unfinished_t | |
output["seq"][:, t][~unfinished] = self.end_idx | |
if unfinished.sum() == 0: | |
break | |
self.stepwise_process(output) | |
return output | |
def decode_step(self, input_dict, output): | |
"""Decoding operation of timestep t""" | |
decoder_input = self.prepare_decoder_input(input_dict, output) | |
# feed to the decoder to get logit | |
output_t = self.decoder(decoder_input) | |
logit_t = output_t["logit"] | |
# assert logit_t.ndim == 3 | |
if logit_t.size(1) == 1: | |
logit_t = logit_t.squeeze(1) | |
embed_t = output_t["embed"].squeeze(1) | |
elif logit_t.size(1) > 1: | |
logit_t = logit_t[:, -1, :] | |
embed_t = output_t["embed"][:, -1, :] | |
else: | |
raise Exception("no logit output") | |
# sample the next input word and get the corresponding logit | |
sampled = self.sample_next_word(logit_t, | |
method=input_dict["sample_method"], | |
temp=input_dict["temp"]) | |
output_t.update(sampled) | |
output_t["t"] = input_dict["t"] | |
output_t["logit"] = logit_t | |
output_t["embed"] = embed_t | |
self.stepwise_process_step(output, output_t) | |
def prepare_decoder_input(self, input_dict, output): | |
"""Prepare the inp ut dict for the decoder""" | |
raise NotImplementedError | |
def stepwise_process_step(self, output, output_t): | |
"""Postprocessing (save output values) after each timestep t""" | |
t = output_t["t"] | |
output["logit"][:, t, :] = output_t["logit"] | |
output["seq"][:, t] = output_t["word"] | |
output["sampled_logprob"][:, t] = output_t["probs"] | |
output["embed"][:, t, :] = output_t["embed"] | |
def stepwise_process(self, output): | |
"""Postprocessing after the whole step-by-step autoregressive decoding""" | |
pass | |
def sample_next_word(self, logit, method, temp): | |
"""Sample the next word, given probs output by the decoder""" | |
logprob = torch.log_softmax(logit, dim=1) | |
if method == "greedy": | |
sampled_logprob, word = torch.max(logprob.detach(), 1) | |
elif method == "gumbel": | |
def sample_gumbel(shape, eps=1e-20): | |
U = torch.rand(shape).to(logprob.device) | |
return -torch.log(-torch.log(U + eps) + eps) | |
def gumbel_softmax_sample(logit, temperature): | |
y = logit + sample_gumbel(logit.size()) | |
return torch.log_softmax(y / temperature, dim=-1) | |
_logprob = gumbel_softmax_sample(logprob, temp) | |
_, word = torch.max(_logprob.data, 1) | |
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)) | |
else: | |
logprob = logprob / temp | |
if method.startswith("top"): | |
top_num = float(method[3:]) | |
if 0 < top_num < 1: # top-p sampling | |
probs = torch.softmax(logit, dim=1) | |
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) | |
_cumsum = sorted_probs.cumsum(1) | |
mask = _cumsum < top_num | |
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) | |
sorted_probs = sorted_probs * mask.to(sorted_probs) | |
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) | |
logprob.scatter_(1, sorted_indices, sorted_probs.log()) | |
else: # top-k sampling | |
k = int(top_num) | |
tmp = torch.empty_like(logprob).fill_(float('-inf')) | |
topk, indices = torch.topk(logprob, k, dim=1) | |
tmp = tmp.scatter(1, indices, topk) | |
logprob = tmp | |
word = torch.distributions.Categorical(logits=logprob.detach()).sample() | |
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1) | |
word = word.detach().long() | |
# sampled_logprob: [N,], word: [N,] | |
return {"word": word, "probs": sampled_logprob} | |
def beam_search(self, input_dict): | |
output = self.prepare_output(input_dict) | |
max_length = input_dict["max_length"] | |
beam_size = input_dict["beam_size"] | |
if input_dict["n_best"]: | |
n_best_size = input_dict["n_best_size"] | |
batch_size, max_length = output["seq"].size() | |
output["seq"] = torch.full((batch_size, n_best_size, max_length), | |
self.end_idx, dtype=torch.long) | |
temp = input_dict["temp"] | |
# instance by instance beam seach | |
for i in range(output["seq"].size(0)): | |
output_i = self.prepare_beamsearch_output(input_dict) | |
input_dict["sample_idx"] = i | |
for t in range(max_length): | |
input_dict["t"] = t | |
output_t = self.beamsearch_step(input_dict, output_i) | |
####################################### | |
# merge with previous beam and select the current max prob beam | |
####################################### | |
logit_t = output_t["logit"] | |
if logit_t.size(1) == 1: | |
logit_t = logit_t.squeeze(1) | |
elif logit_t.size(1) > 1: | |
logit_t = logit_t[:, -1, :] | |
else: | |
raise Exception("no logit output") | |
logprob_t = torch.log_softmax(logit_t, dim=1) | |
logprob_t = torch.log_softmax(logprob_t / temp, dim=1) | |
logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t | |
if t == 0: # for the first step, all k seq will have the same probs | |
topk_logprob, topk_words = logprob_t[0].topk( | |
beam_size, 0, True, True) | |
else: # unroll and find top logprob, and their unrolled indices | |
topk_logprob, topk_words = logprob_t.view(-1).topk( | |
beam_size, 0, True, True) | |
topk_words = topk_words.cpu() | |
output_i["topk_logprob"] = topk_logprob | |
# output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,] | |
output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size, | |
rounding_mode='trunc') | |
output_i["next_word"] = topk_words % self.vocab_size # [beam_size,] | |
if t == 0: | |
output_i["seq"] = output_i["next_word"].unsqueeze(1) | |
else: | |
output_i["seq"] = torch.cat([ | |
output_i["seq"][output_i["prev_words_beam"]], | |
output_i["next_word"].unsqueeze(1)], dim=1) | |
# add finished beams to results | |
is_end = output_i["next_word"] == self.end_idx | |
if t == max_length - 1: | |
is_end.fill_(1) | |
for beam_idx in range(beam_size): | |
if is_end[beam_idx]: | |
final_beam = { | |
"seq": output_i["seq"][beam_idx].clone(), | |
"score": output_i["topk_logprob"][beam_idx].item() | |
} | |
final_beam["score"] = final_beam["score"] / (t + 1) | |
output_i["done_beams"].append(final_beam) | |
output_i["topk_logprob"][is_end] -= 1000 | |
self.beamsearch_process_step(output_i, output_t) | |
if len(output_i["done_beams"]) == beam_size: | |
break | |
self.beamsearch_process(output, output_i, input_dict) | |
return output | |
def prepare_beamsearch_output(self, input_dict): | |
beam_size = input_dict["beam_size"] | |
device = input_dict["fc_emb"].device | |
output = { | |
"topk_logprob": torch.zeros(beam_size).to(device), | |
"seq": None, | |
"prev_words_beam": None, | |
"next_word": None, | |
"done_beams": [], | |
} | |
return output | |
def beamsearch_step(self, input_dict, output_i): | |
decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i) | |
output_t = self.decoder(decoder_input) | |
output_t["t"] = input_dict["t"] | |
return output_t | |
def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
raise NotImplementedError | |
def beamsearch_process_step(self, output_i, output_t): | |
pass | |
def beamsearch_process(self, output, output_i, input_dict): | |
i = input_dict["sample_idx"] | |
done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"]) | |
if input_dict["n_best"]: | |
done_beams = done_beams[:input_dict["n_best_size"]] | |
for out_idx, done_beam in enumerate(done_beams): | |
seq = done_beam["seq"] | |
output["seq"][i][out_idx, :len(seq)] = seq | |
else: | |
seq = done_beams[0]["seq"] | |
output["seq"][i][:len(seq)] = seq | |
def diverse_beam_search(self, input_dict): | |
def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash): | |
local_time = t - divm | |
unaug_logprob = logprob.clone() | |
if divm > 0: | |
change = torch.zeros(logprob.size(-1)) | |
for prev_choice in range(divm): | |
prev_decisions = seq_table[prev_choice][..., local_time] | |
for prev_labels in range(bdash): | |
change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1)) | |
change = change.to(logprob.device) | |
logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda | |
return logprob, unaug_logprob | |
output = self.prepare_output(input_dict) | |
group_size = input_dict["group_size"] | |
batch_size = output["seq"].size(0) | |
beam_size = input_dict["beam_size"] | |
bdash = beam_size // group_size | |
input_dict["bdash"] = bdash | |
diversity_lambda = input_dict["diversity_lambda"] | |
device = input_dict["fc_emb"].device | |
max_length = input_dict["max_length"] | |
temp = input_dict["temp"] | |
group_nbest = input_dict["group_nbest"] | |
batch_size, max_length = output["seq"].size() | |
if group_nbest: | |
output["seq"] = torch.full((batch_size, beam_size, max_length), | |
self.end_idx, dtype=torch.long) | |
else: | |
output["seq"] = torch.full((batch_size, group_size, max_length), | |
self.end_idx, dtype=torch.long) | |
for i in range(batch_size): | |
input_dict["sample_idx"] = i | |
seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0] | |
logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)] | |
done_beams_table = [[] for _ in range(group_size)] | |
output_i = { | |
"prev_words_beam": [None for _ in range(group_size)], | |
"next_word": [None for _ in range(group_size)], | |
"state": [None for _ in range(group_size)] | |
} | |
for t in range(max_length + group_size - 1): | |
input_dict["t"] = t | |
for divm in range(group_size): | |
input_dict["divm"] = divm | |
if t >= divm and t <= max_length + divm - 1: | |
local_time = t - divm | |
decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i) | |
output_t = self.decoder(decoder_input) | |
output_t["divm"] = divm | |
logit_t = output_t["logit"] | |
if logit_t.size(1) == 1: | |
logit_t = logit_t.squeeze(1) | |
elif logit_t.size(1) > 1: | |
logit_t = logit_t[:, -1, :] | |
else: | |
raise Exception("no logit output") | |
logprob_t = torch.log_softmax(logit_t, dim=1) | |
logprob_t = torch.log_softmax(logprob_t / temp, dim=1) | |
logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash) | |
logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t | |
if local_time == 0: # for the first step, all k seq will have the same probs | |
topk_logprob, topk_words = logprob_t[0].topk( | |
bdash, 0, True, True) | |
else: # unroll and find top logprob, and their unrolled indices | |
topk_logprob, topk_words = logprob_t.view(-1).topk( | |
bdash, 0, True, True) | |
topk_words = topk_words.cpu() | |
logprob_table[divm] = topk_logprob | |
output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,] | |
output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,] | |
if local_time > 0: | |
seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]] | |
seq_table[divm] = torch.cat([ | |
seq_table[divm], | |
output_i["next_word"][divm].unsqueeze(-1)], -1) | |
is_end = seq_table[divm][:, t-divm] == self.end_idx | |
assert seq_table[divm].shape[-1] == t - divm + 1 | |
if t == max_length + divm - 1: | |
is_end.fill_(1) | |
for beam_idx in range(bdash): | |
if is_end[beam_idx]: | |
final_beam = { | |
"seq": seq_table[divm][beam_idx].clone(), | |
"score": logprob_table[divm][beam_idx].item() | |
} | |
final_beam["score"] = final_beam["score"] / (t - divm + 1) | |
done_beams_table[divm].append(final_beam) | |
logprob_table[divm][is_end] -= 1000 | |
self.dbs_process_step(output_i, output_t) | |
done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)] | |
if group_nbest: | |
done_beams = sum(done_beams_table, []) | |
else: | |
done_beams = [group_beam[0] for group_beam in done_beams_table] | |
for _, done_beam in enumerate(done_beams): | |
output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"] | |
return output | |
def prepare_dbs_decoder_input(self, input_dict, output_i): | |
raise NotImplementedError | |
def dbs_process_step(self, output_i, output_t): | |
pass | |
class TransformerModel(CaptionModel): | |
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
if not hasattr(self, "compatible_decoders"): | |
self.compatible_decoders = ( | |
TransformerDecoder, | |
) | |
super().__init__(encoder, decoder, **kwargs) | |
def seq_forward(self, input_dict): | |
cap = input_dict["cap"] | |
cap_padding_mask = (cap == self.pad_idx).to(cap.device) | |
cap_padding_mask = cap_padding_mask[:, :-1] | |
output = self.decoder( | |
{ | |
"word": cap[:, :-1], | |
"attn_emb": input_dict["attn_emb"], | |
"attn_emb_len": input_dict["attn_emb_len"], | |
"cap_padding_mask": cap_padding_mask | |
} | |
) | |
return output | |
def prepare_decoder_input(self, input_dict, output): | |
decoder_input = { | |
"attn_emb": input_dict["attn_emb"], | |
"attn_emb_len": input_dict["attn_emb_len"] | |
} | |
t = input_dict["t"] | |
############### | |
# determine input word | |
################ | |
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling | |
word = input_dict["cap"][:, :t+1] | |
else: | |
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() | |
if t == 0: | |
word = start_word | |
else: | |
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) | |
# word: [N, T] | |
decoder_input["word"] = word | |
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) | |
decoder_input["cap_padding_mask"] = cap_padding_mask | |
return decoder_input | |
def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
decoder_input = {} | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
beam_size = input_dict["beam_size"] | |
############### | |
# prepare attn embeds | |
################ | |
if t == 0: | |
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) | |
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) | |
output_i["attn_emb"] = attn_emb | |
output_i["attn_emb_len"] = attn_emb_len | |
decoder_input["attn_emb"] = output_i["attn_emb"] | |
decoder_input["attn_emb_len"] = output_i["attn_emb_len"] | |
############### | |
# determine input word | |
################ | |
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() | |
if t == 0: | |
word = start_word | |
else: | |
word = torch.cat((start_word, output_i["seq"]), dim=-1) | |
decoder_input["word"] = word | |
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) | |
decoder_input["cap_padding_mask"] = cap_padding_mask | |
return decoder_input | |
class BaseDecoder(nn.Module): | |
""" | |
Take word/audio embeddings and output the next word probs | |
""" | |
def __init__(self, emb_dim, vocab_size, fc_emb_dim, | |
attn_emb_dim, dropout=0.2, tie_weights=False): | |
super().__init__() | |
self.emb_dim = emb_dim | |
self.vocab_size = vocab_size | |
self.fc_emb_dim = fc_emb_dim | |
self.attn_emb_dim = attn_emb_dim | |
self.tie_weights = tie_weights | |
self.word_embedding = nn.Embedding(vocab_size, emb_dim) | |
self.in_dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
raise NotImplementedError | |
def load_word_embedding(self, weight, freeze=True): | |
embedding = np.load(weight) | |
assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch" | |
assert embedding.shape[1] == self.emb_dim, "embed size mismatch" | |
# embeddings = torch.as_tensor(embeddings).float() | |
# self.word_embeddings.weight = nn.Parameter(embeddings) | |
# for para in self.word_embeddings.parameters(): | |
# para.requires_grad = tune | |
self.word_embedding = nn.Embedding.from_pretrained(embedding, | |
freeze=freeze) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=100): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \ | |
(-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
# self.register_buffer("pe", pe) | |
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False)) | |
def forward(self, x): | |
# x: [T, N, E] | |
x = x + self.pe[:x.size(0), :] | |
return self.dropout(x) | |
class TransformerDecoder(BaseDecoder): | |
def __init__(self, | |
emb_dim, | |
vocab_size, | |
fc_emb_dim, | |
attn_emb_dim, | |
dropout, | |
freeze=False, | |
tie_weights=False, | |
**kwargs): | |
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
dropout=dropout, tie_weights=tie_weights) | |
self.d_model = emb_dim | |
self.nhead = kwargs.get("nhead", self.d_model // 64) | |
self.nlayers = kwargs.get("nlayers", 2) | |
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) | |
self.pos_encoder = PositionalEncoding(self.d_model, dropout) | |
layer = nn.TransformerDecoderLayer(d_model=self.d_model, | |
nhead=self.nhead, | |
dim_feedforward=self.dim_feedforward, | |
dropout=dropout) | |
self.model = nn.TransformerDecoder(layer, self.nlayers) | |
self.classifier = nn.Linear(self.d_model, vocab_size, bias=False) | |
if tie_weights: | |
self.classifier.weight = self.word_embedding.weight | |
self.attn_proj = nn.Sequential( | |
nn.Linear(self.attn_emb_dim, self.d_model), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
nn.LayerNorm(self.d_model) | |
) | |
self.init_params() | |
self.freeze = freeze | |
if freeze: | |
for p in self.parameters(): | |
p.requires_grad = False | |
def init_params(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def load_pretrained(self, pretrained, output_fn): | |
checkpoint = torch.load(pretrained, map_location="cpu") | |
if "model" in checkpoint: | |
checkpoint = checkpoint["model"] | |
if next(iter(checkpoint)).startswith("decoder."): | |
state_dict = {} | |
for k, v in checkpoint.items(): | |
state_dict[k[8:]] = v | |
loaded_keys = merge_load_state_dict(state_dict, self, output_fn) | |
if self.freeze: | |
for name, param in self.named_parameters(): | |
if name in loaded_keys: | |
param.requires_grad = False | |
else: | |
param.requires_grad = True | |
def generate_square_subsequent_mask(self, max_length): | |
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1) | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |
def forward(self, input_dict): | |
word = input_dict["word"] | |
attn_emb = input_dict["attn_emb"] | |
attn_emb_len = input_dict["attn_emb_len"] | |
cap_padding_mask = input_dict["cap_padding_mask"] | |
p_attn_emb = self.attn_proj(attn_emb) | |
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim] | |
word = word.to(attn_emb.device) | |
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim] | |
embed = embed.transpose(0, 1) # [T, N, emb_dim] | |
embed = self.pos_encoder(embed) | |
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device) | |
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device) | |
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask, | |
tgt_key_padding_mask=cap_padding_mask, | |
memory_key_padding_mask=memory_key_padding_mask) | |
output = output.transpose(0, 1) | |
output = { | |
"embed": output, | |
"logit": self.classifier(output), | |
} | |
return output | |
class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
def __init__(self, | |
model: nn.Module, | |
shared_dim: int, | |
tchr_dim: int, | |
): | |
super().__init__() | |
self.model = model | |
self.tchr_dim = tchr_dim | |
if hasattr(model, "encoder"): | |
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, | |
shared_dim) | |
else: | |
self.stdnt_proj = nn.Linear(model.fc_emb_size, | |
shared_dim) | |
self.tchr_proj = nn.Linear(tchr_dim, shared_dim) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
def forward(self, input_dict: Dict): | |
unsup = input_dict.get("unsup", False) | |
if unsup is False: | |
output_dict = self.model(input_dict) | |
else: | |
output_dict = self.model.encoder(input_dict) | |
if "tchr_output" in input_dict: | |
stdnt_emb = output_dict["fc_emb"] | |
stdnt_emb = self.stdnt_proj(stdnt_emb) | |
tchr_emb = input_dict["tchr_output"]["embedding"] | |
thcr_emb = self.tchr_proj(tchr_emb) | |
stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
thcr_emb = F.normalize(thcr_emb, dim=-1) | |
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) | |
logit = self.logit_scale * unscaled_logit | |
label = torch.arange(logit.shape[0]).to(logit.device) | |
loss1 = F.cross_entropy(logit, label) | |
loss2 = F.cross_entropy(logit.transpose(0, 1), label) | |
loss = (loss1 + loss2) / 2 | |
output_dict["enc_kd_loss"] = loss | |
return output_dict | |
class Effb2TrmConfig(PretrainedConfig): | |
def __init__( | |
self, | |
sample_rate: int = 16000, | |
tchr_dim: int = 768, | |
shared_dim: int = 1024, | |
fc_emb_dim: int = 1408, | |
attn_emb_dim: int = 1408, | |
decoder_n_layers: int = 2, | |
decoder_we_tie_weights: bool = True, | |
decoder_emb_dim: int = 256, | |
decoder_dropout: float = 0.2, | |
vocab_size: int = 4981, | |
**kwargs | |
): | |
self.sample_rate = sample_rate | |
self.tchr_dim = tchr_dim | |
self.shared_dim = shared_dim | |
self.fc_emb_dim = fc_emb_dim | |
self.attn_emb_dim = attn_emb_dim | |
self.decoder_n_layers = decoder_n_layers | |
self.decoder_we_tie_weights = decoder_we_tie_weights | |
self.decoder_emb_dim = decoder_emb_dim | |
self.decoder_dropout = decoder_dropout | |
self.vocab_size = vocab_size | |
super().__init__(**kwargs) | |
class Effb2TrmCaptioningModel(PreTrainedModel): | |
config_class = Effb2TrmConfig | |
def __init__(self, config): | |
super().__init__(config) | |
encoder = EfficientNetB2() | |
decoder = TransformerDecoder( | |
emb_dim=config.decoder_emb_dim, | |
vocab_size=config.vocab_size, | |
fc_emb_dim=config.fc_emb_dim, | |
attn_emb_dim=config.attn_emb_dim, | |
dropout=config.decoder_dropout, | |
nlayers=config.decoder_n_layers, | |
tie_weights=config.decoder_we_tie_weights | |
) | |
model = TransformerModel(encoder, decoder) | |
self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim) | |
def forward(self, | |
audio: torch.Tensor, | |
audio_length: Union[List, np.ndarray, torch.Tensor], | |
sample_method: str = "beam", | |
beam_size: int = 3, | |
max_length: int = 20, | |
temp: float = 1.0,): | |
device = self.device | |
input_dict = { | |
"wav": audio.to(device), | |
"wav_len": audio_length, | |
"specaug": False, | |
"mode": "inference", | |
"sample_method": sample_method, | |
"max_length": max_length, | |
"temp": temp, | |
} | |
if sample_method == "beam": | |
input_dict["beam_size"] = beam_size | |
return self.model(input_dict)["seq"].cpu() | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(ConvBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=(3, 3), stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.conv2 = nn.Conv2d(in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=(3, 3), stride=(1, 1), | |
padding=(1, 1), bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
def forward(self, input, pool_size=(2, 2), pool_type='avg'): | |
x = input | |
x = F.relu_(self.bn1(self.conv1(x))) | |
x = F.relu_(self.bn2(self.conv2(x))) | |
if pool_type == 'max': | |
x = F.max_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg': | |
x = F.avg_pool2d(x, kernel_size=pool_size) | |
elif pool_type == 'avg+max': | |
x1 = F.avg_pool2d(x, kernel_size=pool_size) | |
x2 = F.max_pool2d(x, kernel_size=pool_size) | |
x = x1 + x2 | |
else: | |
raise Exception('Incorrect argument!') | |
return x | |
class Cnn14Encoder(nn.Module): | |
def __init__(self, sample_rate=32000): | |
super().__init__() | |
sr_to_fmax = { | |
32000: 14000, | |
16000: 8000 | |
} | |
# Logmel spectrogram extractor | |
self.melspec_extractor = transforms.MelSpectrogram( | |
sample_rate=sample_rate, | |
n_fft=32 * sample_rate // 1000, | |
win_length=32 * sample_rate // 1000, | |
hop_length=10 * sample_rate // 1000, | |
f_min=50, | |
f_max=sr_to_fmax[sample_rate], | |
n_mels=64, | |
norm="slaney", | |
mel_scale="slaney" | |
) | |
self.hop_length = 10 * sample_rate // 1000 | |
self.db_transform = transforms.AmplitudeToDB() | |
self.bn0 = nn.BatchNorm2d(64) | |
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) | |
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) | |
self.downsample_ratio = 32 | |
self.fc1 = nn.Linear(2048, 2048, bias=True) | |
self.fc_emb_size = 2048 | |
def forward(self, input_dict): | |
lms = input_dict["lms"] | |
wave_length = input_dict["wav_len"] | |
x = lms # (batch_size, mel_bins, time_steps) | |
x = x.transpose(1, 2) | |
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) | |
x = x.transpose(1, 3) | |
x = self.bn0(x) | |
x = x.transpose(1, 3) | |
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = torch.mean(x, dim=3) | |
attn_emb = x.transpose(1, 2) | |
wave_length = torch.as_tensor(wave_length) | |
feat_length = torch.div(wave_length, self.hop_length, | |
rounding_mode="floor") + 1 | |
feat_length = torch.div(feat_length, self.downsample_ratio, | |
rounding_mode="floor") | |
x_max = max_with_lens(attn_emb, feat_length) | |
x_mean = mean_with_lens(attn_emb, feat_length) | |
x = x_max + x_mean | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = F.relu_(self.fc1(x)) | |
fc_emb = F.dropout(x, p=0.5, training=self.training) | |
output_dict = { | |
'fc_emb': fc_emb, | |
'attn_emb': attn_emb, | |
'attn_emb_len': feat_length | |
} | |
return output_dict | |
class RnnEncoder(nn.Module): | |
def __init__(self, | |
attn_feat_dim, | |
pooling="mean", | |
**kwargs): | |
super().__init__() | |
self.pooling = pooling | |
self.hidden_size = kwargs.get('hidden_size', 512) | |
self.bidirectional = kwargs.get('bidirectional', False) | |
self.num_layers = kwargs.get('num_layers', 1) | |
self.dropout = kwargs.get('dropout', 0.2) | |
self.rnn_type = kwargs.get('rnn_type', "GRU") | |
self.in_bn = kwargs.get('in_bn', False) | |
self.embed_dim = self.hidden_size * (self.bidirectional + 1) | |
self.network = getattr(nn, self.rnn_type)( | |
attn_feat_dim, | |
self.hidden_size, | |
num_layers=self.num_layers, | |
bidirectional=self.bidirectional, | |
dropout=self.dropout, | |
batch_first=True) | |
if self.in_bn: | |
self.bn = nn.BatchNorm1d(self.embed_dim) | |
def forward(self, input_dict): | |
x = input_dict["attn"] | |
lens = input_dict["attn_len"] | |
lens = torch.as_tensor(lens) | |
# x: [N, T, E] | |
if self.in_bn: | |
x = pack_wrapper(self.bn, x, lens) | |
out = pack_wrapper(self.network, x, lens) | |
# out: [N, T, hidden] | |
attn_emb = out | |
fc_emb = embedding_pooling(out, lens, self.pooling) | |
return { | |
"attn_emb": attn_emb, | |
"fc_emb": fc_emb, | |
"attn_emb_len": lens | |
} | |
class Cnn14RnnEncoder(nn.Module): | |
def __init__(self, | |
sample_rate, | |
rnn_bidirectional, | |
rnn_hidden_size, | |
rnn_dropout, | |
rnn_num_layers): | |
super().__init__() | |
self.cnn = Cnn14Encoder(sample_rate=sample_rate) | |
self.rnn = RnnEncoder( | |
2048, | |
bidirectional=rnn_bidirectional, | |
hidden_size=rnn_hidden_size, | |
dropout=rnn_dropout, | |
num_layers=rnn_num_layers, | |
) | |
def forward(self, input_dict): | |
output_dict = self.cnn(input_dict) | |
output_dict["attn"] = output_dict["attn_emb"] | |
output_dict["attn_len"] = output_dict["attn_emb_len"] | |
del output_dict["attn_emb"], output_dict["attn_emb_len"] | |
output_dict = self.rnn(output_dict) | |
return output_dict | |
class Seq2SeqAttention(nn.Module): | |
def __init__(self, hs_enc, hs_dec, attn_size): | |
""" | |
Args: | |
hs_enc: encoder hidden size | |
hs_dec: decoder hidden size | |
attn_size: attention vector size | |
""" | |
super(Seq2SeqAttention, self).__init__() | |
self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size) | |
self.v = nn.Parameter(torch.randn(attn_size)) | |
def forward(self, h_dec, h_enc, src_lens): | |
""" | |
Args: | |
h_dec: decoder hidden (query), [N, hs_dec] | |
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc] | |
src_lens: source (encoder memory) lengths, [N, ] | |
""" | |
N = h_enc.size(0) | |
src_max_len = h_enc.size(1) | |
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec] | |
attn_input = torch.cat((h_dec, h_enc), dim=-1) | |
attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size] | |
v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size] | |
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len] | |
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len) | |
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device) | |
score = score.masked_fill(mask == 0, -1e10) | |
weights = torch.softmax(score, dim=-1) # [N, src_max_len] | |
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc] | |
return ctx, weights | |
class RnnDecoder(BaseDecoder): | |
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
dropout, d_model, **kwargs): | |
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
dropout,) | |
self.d_model = d_model | |
self.num_layers = kwargs.get('num_layers', 1) | |
self.bidirectional = kwargs.get('bidirectional', False) | |
self.rnn_type = kwargs.get('rnn_type', "GRU") | |
self.classifier = nn.Linear( | |
self.d_model * (self.bidirectional + 1), vocab_size) | |
def forward(self, x): | |
raise NotImplementedError | |
def init_hidden(self, bs, device): | |
num_dire = self.bidirectional + 1 | |
n_layer = self.num_layers | |
hid_dim = self.d_model | |
if self.rnn_type == "LSTM": | |
return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device), | |
torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)) | |
else: | |
return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device) | |
class BahAttnCatFcDecoder(RnnDecoder): | |
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
dropout, d_model, **kwargs): | |
""" | |
concatenate fc, attn, word to feed to the rnn | |
""" | |
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
dropout, d_model, **kwargs) | |
attn_size = kwargs.get("attn_size", self.d_model) | |
self.model = getattr(nn, self.rnn_type)( | |
input_size=self.emb_dim * 3, | |
hidden_size=self.d_model, | |
batch_first=True, | |
num_layers=self.num_layers, | |
bidirectional=self.bidirectional) | |
self.attn = Seq2SeqAttention(self.attn_emb_dim, | |
self.d_model * (self.bidirectional + 1) * \ | |
self.num_layers, | |
attn_size) | |
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim) | |
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim) | |
def forward(self, input_dict): | |
word = input_dict["word"] | |
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] | |
fc_emb = input_dict["fc_emb"] | |
attn_emb = input_dict["attn_emb"] | |
attn_emb_len = input_dict["attn_emb_len"] | |
word = word.to(fc_emb.device) | |
embed = self.in_dropout(self.word_embedding(word)) | |
# embed: [N, 1, embed_size] | |
if state is None: | |
state = self.init_hidden(word.size(0), fc_emb.device) | |
if self.rnn_type == "LSTM": | |
query = state[0].transpose(0, 1).flatten(1) | |
else: | |
query = state.transpose(0, 1).flatten(1) | |
c, attn_weight = self.attn(query, attn_emb, attn_emb_len) | |
p_fc_emb = self.fc_proj(fc_emb) | |
p_ctx = self.ctx_proj(c) | |
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)), | |
dim=-1) | |
out, state = self.model(rnn_input, state) | |
output = { | |
"state": state, | |
"embed": out, | |
"logit": self.classifier(out), | |
"attn_weight": attn_weight | |
} | |
return output | |
class TemporalBahAttnDecoder(BahAttnCatFcDecoder): | |
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
dropout, d_model, **kwargs): | |
""" | |
concatenate fc, attn, word to feed to the rnn | |
""" | |
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, | |
dropout, d_model, **kwargs) | |
self.temporal_embedding = nn.Embedding(4, emb_dim) | |
def forward(self, input_dict): | |
word = input_dict["word"] | |
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model] | |
fc_embs = input_dict["fc_emb"] | |
attn_embs = input_dict["attn_emb"] | |
attn_emb_lens = input_dict["attn_emb_len"] | |
temporal_tag = input_dict["temporal_tag"] | |
if input_dict["t"] == 0: | |
embed = self.in_dropout( | |
self.temporal_embedding(temporal_tag)).unsqueeze(1) | |
elif word.size(-1) == self.fc_emb_dim: # fc_embs | |
embed = word.unsqueeze(1) | |
elif word.size(-1) == 1: # word | |
word = word.to(fc_embs.device) | |
embed = self.in_dropout(self.word_embedding(word)) | |
else: | |
raise Exception(f"problem with word input size {word.size()}") | |
# embed: [N, 1, embed_size] | |
if state is None: | |
state = self.init_hidden(word.size(0), fc_embs.device) | |
if self.rnn_type == "LSTM": | |
query = state[0].transpose(0, 1).flatten(1) | |
else: | |
query = state.transpose(0, 1).flatten(1) | |
c, attn_weight = self.attn(query, attn_embs, attn_emb_lens) | |
p_ctx = self.ctx_proj(c) | |
p_fc_embs = self.fc_proj(fc_embs) | |
p_ctx = self.ctx_proj(c) | |
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_embs.unsqueeze(1)), dim=-1) | |
out, state = self.model(rnn_input, state) | |
output = { | |
"state": state, | |
"embed": out, | |
"logit": self.classifier(out), | |
"attn_weight": attn_weight | |
} | |
return output | |
class Seq2SeqAttnModel(CaptionModel): | |
def __init__(self, encoder, decoder, **kwargs): | |
if not hasattr(self, "compatible_decoders"): | |
self.compatible_decoders = ( | |
BahAttnCatFcDecoder, | |
) | |
super().__init__(encoder, decoder, **kwargs) | |
def seq_forward(self, input_dict): | |
# Bahdanau attention only supports step-by-step implementation, so we implement forward in | |
# step-by-step manner whether in training or evaluation | |
return self.stepwise_forward(input_dict) | |
def prepare_output(self, input_dict): | |
output = super().prepare_output(input_dict) | |
attn_weight = torch.empty(output["seq"].size(0), | |
input_dict["attn_emb"].size(1), output["seq"].size(1)) | |
output["attn_weight"] = attn_weight | |
return output | |
def prepare_decoder_input(self, input_dict, output): | |
decoder_input = { | |
"fc_emb": input_dict["fc_emb"], | |
"attn_emb": input_dict["attn_emb"], | |
"attn_emb_len": input_dict["attn_emb_len"] | |
} | |
t = input_dict["t"] | |
############### | |
# determine input word | |
################ | |
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling | |
word = input_dict["cap"][:, t] | |
else: | |
if t == 0: | |
word = torch.tensor([self.start_idx,] * input_dict["fc_emb"].size(0)).long() | |
else: | |
word = output["seq"][:, t-1] | |
# word: [N,] | |
decoder_input["word"] = word.unsqueeze(1) | |
################ | |
# prepare rnn state | |
################ | |
if t > 0: | |
decoder_input["state"] = output["state"] | |
return decoder_input | |
def stepwise_process_step(self, output, output_t): | |
super().stepwise_process_step(output, output_t) | |
output["state"] = output_t["state"] | |
t = output_t["t"] | |
output["attn_weight"][:, :, t] = output_t["attn_weight"] | |
def prepare_beamsearch_output(self, input_dict): | |
output = super().prepare_beamsearch_output(input_dict) | |
beam_size = input_dict["beam_size"] | |
max_length = input_dict["max_length"] | |
output["attn_weight"] = torch.empty(beam_size, | |
max(input_dict["attn_emb_len"]), max_length) | |
return output | |
def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
decoder_input = {} | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
beam_size = input_dict["beam_size"] | |
############### | |
# prepare fc embeds | |
################ | |
if t == 0: | |
fc_emb = repeat_tensor(input_dict["fc_emb"][i], beam_size) | |
output_i["fc_emb"] = fc_emb | |
decoder_input["fc_emb"] = output_i["fc_emb"] | |
############### | |
# prepare attn embeds | |
################ | |
if t == 0: | |
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) | |
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) | |
output_i["attn_emb"] = attn_emb | |
output_i["attn_emb_len"] = attn_emb_len | |
decoder_input["attn_emb"] = output_i["attn_emb"] | |
decoder_input["attn_emb_len"] = output_i["attn_emb_len"] | |
############### | |
# determine input word | |
################ | |
if t == 0: | |
word = torch.tensor([self.start_idx,] * beam_size).long() | |
else: | |
word = output_i["next_word"] | |
decoder_input["word"] = word.unsqueeze(1) | |
################ | |
# prepare rnn state | |
################ | |
if t > 0: | |
if self.decoder.rnn_type == "LSTM": | |
decoder_input["state"] = (output_i["state"][0][:, output_i["prev_words_beam"], :].contiguous(), | |
output_i["state"][1][:, output_i["prev_words_beam"], :].contiguous()) | |
else: | |
decoder_input["state"] = output_i["state"][:, output_i["prev_words_beam"], :].contiguous() | |
return decoder_input | |
def beamsearch_process_step(self, output_i, output_t): | |
t = output_t["t"] | |
output_i["state"] = output_t["state"] | |
output_i["attn_weight"][..., t] = output_t["attn_weight"] | |
output_i["attn_weight"] = output_i["attn_weight"][output_i["prev_words_beam"], ...] | |
def beamsearch_process(self, output, output_i, input_dict): | |
super().beamsearch_process(output, output_i, input_dict) | |
i = input_dict["sample_idx"] | |
output["attn_weight"][i] = output_i["attn_weight"][0] | |
def prepare_dbs_decoder_input(self, input_dict, output_i): | |
decoder_input = {} | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
bdash = input_dict["bdash"] | |
divm = input_dict["divm"] | |
local_time = t - divm | |
############### | |
# prepare fc embeds | |
################ | |
# repeat only at the first timestep to save consumption | |
if t == 0: | |
fc_emb = repeat_tensor(input_dict["fc_emb"][i], bdash).unsqueeze(1) | |
output_i["fc_emb"] = fc_emb | |
decoder_input["fc_emb"] = output_i["fc_emb"] | |
############### | |
# prepare attn embeds | |
################ | |
if t == 0: | |
attn_emb = repeat_tensor(input_dict["attn_emb"][i], bdash) | |
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], bdash) | |
output_i["attn_emb"] = attn_emb | |
output_i["attn_emb_len"] = attn_emb_len | |
decoder_input["attn_emb"] = output_i["attn_emb"] | |
decoder_input["attn_emb_len"] = output_i["attn_emb_len"] | |
############### | |
# determine input word | |
################ | |
if local_time == 0: | |
word = torch.tensor([self.start_idx,] * bdash).long() | |
else: | |
word = output_i["next_word"][divm] | |
decoder_input["word"] = word.unsqueeze(1) | |
################ | |
# prepare rnn state | |
################ | |
if local_time > 0: | |
if self.decoder.rnn_type == "LSTM": | |
decoder_input["state"] = ( | |
output_i["state"][0][divm][ | |
:, output_i["prev_words_beam"][divm], :].contiguous(), | |
output_i["state"][1][divm][ | |
:, output_i["prev_words_beam"][divm], :].contiguous() | |
) | |
else: | |
decoder_input["state"] = output_i["state"][divm][ | |
:, output_i["prev_words_beam"][divm], :].contiguous() | |
return decoder_input | |
def dbs_process_step(self, output_i, output_t): | |
divm = output_t["divm"] | |
output_i["state"][divm] = output_t["state"] | |
# TODO attention weight | |
class TemporalSeq2SeqAttnModel(Seq2SeqAttnModel): | |
def __init__(self, encoder, decoder, **kwargs): | |
if not hasattr(self, "compatible_decoders"): | |
self.compatible_decoders = ( | |
TemporalBahAttnDecoder, | |
) | |
super().__init__(encoder, decoder, **kwargs) | |
self.train_forward_keys = ["cap", "cap_len", "ss_ratio", "temporal_tag"] | |
self.inference_forward_keys = ["sample_method", "max_length", "temp", "temporal_tag"] | |
def prepare_decoder_input(self, input_dict, output): | |
decoder_input = super().prepare_decoder_input(input_dict, output) | |
decoder_input["temporal_tag"] = input_dict["temporal_tag"] | |
decoder_input["t"] = input_dict["t"] | |
return decoder_input | |
def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
beam_size = input_dict["beam_size"] | |
############### | |
# prepare temporal_tag | |
################ | |
if t == 0: | |
temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], beam_size) | |
output_i["temporal_tag"] = temporal_tag | |
decoder_input["temporal_tag"] = output_i["temporal_tag"] | |
decoder_input["t"] = input_dict["t"] | |
return decoder_input | |
def prepare_dbs_decoder_input(self, input_dict, output_i): | |
decoder_input = super.prepare_dbs_decoder_input(input_dict, output_i) | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
bdash = input_dict["bdash"] | |
############### | |
# prepare temporal tag | |
################ | |
# repeat only at the first timestep to save consumption | |
if t == 0: | |
temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], bdash) | |
output_i["temporal_tag"] = temporal_tag | |
decoder_input["temporal_tag"] = output_i["temporal_tag"] | |
decoder_input["t"] = input_dict["t"] | |
return decoder_input | |
class Cnn8rnnSedModel(nn.Module): | |
def __init__(self, classes_num): | |
super().__init__() | |
self.time_resolution = 0.01 | |
self.interpolate_ratio = 4 # Downsampled ratio | |
self.bn0 = nn.BatchNorm2d(64) | |
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
self.fc1 = nn.Linear(512, 512, bias=True) | |
self.rnn = nn.GRU(512, 256, bidirectional=True, batch_first=True) | |
self.fc_audioset = nn.Linear(512, classes_num, bias=True) | |
def forward(self, lms): | |
output = self.forward_prob(lms) | |
framewise_output = output["framewise_output"].cpu().numpy() | |
thresholded_predictions = double_threshold( | |
framewise_output, 0.75, 0.25) | |
decoded_tags = decode_with_timestamps( | |
thresholded_predictions, self.time_resolution | |
) | |
return decoded_tags | |
def forward_prob(self, lms): | |
""" | |
lms: (batch_size, mel_bins, time_steps)""" | |
x = lms | |
x = x.transpose(1, 2) | |
x = x.unsqueeze(1) | |
frames_num = x.shape[2] | |
x = x.transpose(1, 3) | |
x = self.bn0(x) | |
x = x.transpose(1, 3) | |
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg+max') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg+max') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg+max') | |
x = F.dropout(x, p=0.2, training=self.training) | |
x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg+max') | |
x = F.dropout(x, p=0.2, training=self.training) # (batch_size, 256, time_steps / 4, mel_bins / 16) | |
x = torch.mean(x, dim=3) | |
x = x.transpose(1, 2) | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = F.relu_(self.fc1(x)) | |
x, _ = self.rnn(x) | |
segmentwise_output = torch.sigmoid(self.fc_audioset(x)).clamp(1e-7, 1.) | |
framewise_output = interpolate(segmentwise_output, | |
self.interpolate_ratio) | |
framewise_output = pad_framewise_output(framewise_output, frames_num) | |
output_dict = { | |
"segmentwise_output": segmentwise_output, | |
'framewise_output': framewise_output, | |
} | |
return output_dict | |
class Cnn14RnnTempAttnGruConfig(PretrainedConfig): | |
def __init__( | |
self, | |
sample_rate: int = 32000, | |
encoder_rnn_bidirectional: bool = True, | |
encoder_rnn_hidden_size: int = 256, | |
encoder_rnn_dropout: float = 0.5, | |
encoder_rnn_num_layers: int = 3, | |
decoder_emb_dim: int = 512, | |
vocab_size: int = 4981, | |
fc_emb_dim: int = 512, | |
attn_emb_dim: int = 512, | |
decoder_rnn_type: str = "GRU", | |
decoder_num_layers: int = 1, | |
decoder_d_model: int = 512, | |
decoder_dropout: float = 0.5, | |
**kwargs | |
): | |
self.sample_rate = sample_rate | |
self.encoder_rnn_bidirectional = encoder_rnn_bidirectional | |
self.encoder_rnn_hidden_size = encoder_rnn_hidden_size | |
self.encoder_rnn_dropout = encoder_rnn_dropout | |
self.encoder_rnn_num_layers = encoder_rnn_num_layers | |
self.decoder_emb_dim = decoder_emb_dim | |
self.vocab_size = vocab_size | |
self.fc_emb_dim = fc_emb_dim | |
self.attn_emb_dim = attn_emb_dim | |
self.decoder_rnn_type = decoder_rnn_type | |
self.decoder_num_layers = decoder_num_layers | |
self.decoder_d_model = decoder_d_model | |
self.decoder_dropout = decoder_dropout | |
super().__init__(**kwargs) | |
class Cnn14RnnTempAttnGruModel(PreTrainedModel): | |
config_class = Cnn14RnnTempAttnGruConfig | |
def __init__(self, config): | |
super().__init__(config) | |
sample_rate = config.sample_rate | |
sr_to_fmax = { | |
32000: 14000, | |
16000: 8000 | |
} | |
self.melspec_extractor = transforms.MelSpectrogram( | |
sample_rate=sample_rate, | |
n_fft=32 * sample_rate // 1000, | |
win_length=32 * sample_rate // 1000, | |
hop_length=10 * sample_rate // 1000, | |
f_min=50, | |
f_max=sr_to_fmax[sample_rate], | |
n_mels=64, | |
norm="slaney", | |
mel_scale="slaney" | |
) | |
self.db_transform = transforms.AmplitudeToDB() | |
encoder = Cnn14RnnEncoder( | |
sample_rate=config.sample_rate, | |
rnn_bidirectional=config.encoder_rnn_bidirectional, | |
rnn_hidden_size=config.encoder_rnn_hidden_size, | |
rnn_dropout=config.encoder_rnn_dropout, | |
rnn_num_layers=config.encoder_rnn_num_layers | |
) | |
decoder = TemporalBahAttnDecoder( | |
emb_dim=config.decoder_emb_dim, | |
vocab_size=config.vocab_size, | |
fc_emb_dim=config.fc_emb_dim, | |
attn_emb_dim=config.attn_emb_dim, | |
rnn_type=config.decoder_rnn_type, | |
num_layers=config.decoder_num_layers, | |
d_model=config.decoder_d_model, | |
dropout=config.decoder_dropout, | |
) | |
cap_model = TemporalSeq2SeqAttnModel(encoder, decoder) | |
sed_model = Cnn8rnnSedModel(classes_num=447) | |
self.cap_model = cap_model | |
self.sed_model = sed_model | |
def forward(self, | |
audio: torch.Tensor, | |
audio_length: Union[List, np.ndarray, torch.Tensor], | |
temporal_tag: Union[List, np.ndarray, torch.Tensor] = None, | |
sample_method: str = "beam", | |
beam_size: int = 3, | |
max_length: int = 20, | |
temp: float = 1.0,): | |
device = self.device | |
mel_spec = self.melspec_extractor(audio.to(device)) | |
log_mel_spec = self.db_transform(mel_spec) | |
sed_tag = self.sed_model(log_mel_spec) | |
sed_tag = torch.as_tensor(sed_tag).to(device) | |
if temporal_tag is not None: | |
temporal_tag = torch.as_tensor(temporal_tag).to(device) | |
temporal_tag = torch.stack([temporal_tag, sed_tag], dim=0) | |
temporal_tag = torch.min(temporal_tag, dim=0).values | |
else: | |
temporal_tag = sed_tag | |
input_dict = { | |
"lms": log_mel_spec, | |
"wav_len": audio_length, | |
"temporal_tag": temporal_tag, | |
"mode": "inference", | |
"sample_method": sample_method, | |
"max_length": max_length, | |
"temp": temp, | |
} | |
if sample_method == "beam": | |
input_dict["beam_size"] = beam_size | |
return self.cap_model(input_dict)["seq"].cpu() |