audio-captioning-small / hf_wrapper.py
Gijs Wijngaard
Fix
ab3f8fd
raw
history blame
77.2 kB
from typing import Dict, Callable, Union, List
import random
import math
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
from torchaudio import transforms
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch import utils as efficientnet_utils
from einops import rearrange, reduce
from transformers import PretrainedConfig, PreTrainedModel
def sort_pack_padded_sequence(input, lengths):
sorted_lengths, indices = torch.sort(lengths, descending=True)
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
inv_ix = indices.clone()
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
return tmp, inv_ix
def pad_unsort_packed_sequence(input, inv_ix):
tmp, _ = pad_packed_sequence(input, batch_first=True)
tmp = tmp[inv_ix]
return tmp
def pack_wrapper(module, attn_feats, attn_feat_lens):
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
if isinstance(module, torch.nn.RNNBase):
return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
else:
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
def embedding_pooling(x, lens, pooling="mean"):
if pooling == "max":
fc_embs = max_with_lens(x, lens)
elif pooling == "mean":
fc_embs = mean_with_lens(x, lens)
elif pooling == "mean+max":
x_mean = mean_with_lens(x, lens)
x_max = max_with_lens(x, lens)
fc_embs = x_mean + x_max
elif pooling == "last":
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
# indices: [N, 1, hidden]
fc_embs = torch.gather(x, 1, indices).squeeze(1)
else:
raise Exception(f"pooling method {pooling} not support")
return fc_embs
def interpolate(x, ratio):
"""Interpolate data in time domain. This is used to compensate the
resolution reduction in downsampling of a CNN.
Args:
x: (batch_size, time_steps, classes_num)
ratio: int, ratio to interpolate
Returns:
upsampled: (batch_size, time_steps * ratio, classes_num)
"""
(batch_size, time_steps, classes_num) = x.shape
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
return upsampled
def pad_framewise_output(framewise_output, frames_num):
"""Pad framewise_output to the same length as input frames. The pad value
is the same as the value of the last frame.
Args:
framewise_output: (batch_size, frames_num, classes_num)
frames_num: int, number of frames to pad
Outputs:
output: (batch_size, frames_num, classes_num)
"""
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
"""tensor for padding"""
output = torch.cat((framewise_output, pad), dim=1)
"""(batch_size, frames_num, classes_num)"""
return output
def find_contiguous_regions(activity_array):
"""Find contiguous regions from bool valued numpy.array.
Copy of https://dcase-repo.github.io/dcase_util/_modules/dcase_util/data/decisions.html#DecisionEncoder
Reason is:
1. This does not belong to a class necessarily
2. Import DecisionEncoder requires sndfile over some other imports..which causes some problems on clusters
"""
# Find the changes in the activity_array
change_indices = np.logical_xor(activity_array[1:],
activity_array[:-1]).nonzero()[0]
# Shift change_index with one, focus on frame after the change.
change_indices += 1
if activity_array[0]:
# If the first element of activity_array is True add 0 at the beginning
change_indices = np.r_[0, change_indices]
if activity_array[-1]:
# If the last element of activity_array is True, add the length of the array
change_indices = np.r_[change_indices, activity_array.size]
# Reshape the result into two columns
return change_indices.reshape((-1, 2))
def double_threshold(x, high_thres, low_thres, n_connect=1):
"""double_threshold
Helper function to calculate double threshold for n-dim arrays
:param x: input array
:param high_thres: high threshold value
:param low_thres: Low threshold value
:param n_connect: Distance of <= n clusters will be merged
"""
assert x.ndim <= 3, "Whoops something went wrong with the input ({}), check if its <= 3 dims".format(
x.shape)
if x.ndim == 3:
apply_dim = 1
elif x.ndim < 3:
apply_dim = 0
# x is assumed to be 3d: (batch, time, dim)
# Assumed to be 2d : (time, dim)
# Assumed to be 1d : (time)
# time axis is therefore at 1 for 3d and 0 for 2d (
return np.apply_along_axis(lambda x: _double_threshold(
x, high_thres, low_thres, n_connect=n_connect),
axis=apply_dim,
arr=x)
def _double_threshold(x, high_thres, low_thres, n_connect=1, return_arr=True):
"""_double_threshold
Computes a double threshold over the input array
:param x: input array, needs to be 1d
:param high_thres: High threshold over the array
:param low_thres: Low threshold over the array
:param n_connect: Postprocessing, maximal distance between clusters to connect
:param return_arr: By default this function returns the filtered indiced, but if return_arr = True it returns an array of tsame size as x filled with ones and zeros.
"""
assert x.ndim == 1, "Input needs to be 1d"
high_locations = np.where(x > high_thres)[0]
locations = x > low_thres
encoded_pairs = find_contiguous_regions(locations)
filtered_list = list(
filter(
lambda pair:
((pair[0] <= high_locations) & (high_locations <= pair[1])).any(),
encoded_pairs))
filtered_list = connect_(filtered_list, n_connect)
if return_arr:
zero_one_arr = np.zeros_like(x, dtype=int)
for sl in filtered_list:
zero_one_arr[sl[0]:sl[1]] = 1
return zero_one_arr
return filtered_list
def connect_(pairs, n=1):
"""connect_
Connects two adjacent clusters if their distance is <= n
:param pairs: Clusters of iterateables e.g., [(1,5),(7,10)]
:param n: distance between two clusters
"""
if len(pairs) == 0:
return []
start_, end_ = pairs[0]
new_pairs = []
for i, (next_item, cur_item) in enumerate(zip(pairs[1:], pairs[0:])):
end_ = next_item[1]
if next_item[0] - cur_item[1] <= n:
pass
else:
new_pairs.append((start_, cur_item[1]))
start_ = next_item[0]
new_pairs.append((start_, end_))
return new_pairs
def segments_to_temporal_tag(segments, thre=0.5):
after_flag, while_flag = 0, 0
for j in range(len(segments)):
for k in range(len(segments)):
if segments[j][0] == segments[k][0]:
continue
min_duration = min(segments[j][2] - segments[j][1], segments[k][2] - segments[k][1])
overlap = segments[j][2] - segments[k][1]
if overlap < thre * min_duration:
after_flag = 2
if segments[j][1] < segments[k][1] and overlap > thre * min_duration:
while_flag = 1
return after_flag + while_flag
def decode_with_timestamps(labels, time_resolution):
batch_results = []
for lab in labels:
segments = []
for i, label_column in enumerate(lab.T):
change_indices = find_contiguous_regions(label_column)
# append [onset, offset] in the result list
for row in change_indices:
segments.append((i, row[0] * time_resolution, row[1] * time_resolution))
temporal_tag = segments_to_temporal_tag(segments)
batch_results.append(temporal_tag)
return batch_results
class _EffiNet(nn.Module):
"""A proxy for efficient net models"""
def __init__(self,
blocks_args=None,
global_params=None,
) -> None:
super().__init__()
self.eff_net = EfficientNet(blocks_args=blocks_args,
global_params=global_params)
def forward(self, x: torch.Tensor):
x = rearrange(x, 'b f t -> b 1 f t')
x = self.eff_net.extract_features(x)
return reduce(x, 'b c f t -> b t c', 'mean')
def get_effb2_model() -> _EffiNet:
blocks_args, global_params = efficientnet_utils.get_model_params(
'efficientnet-b2', {'include_top': False})
model = _EffiNet(blocks_args=blocks_args,
global_params=global_params)
model.eff_net._change_in_channels(1)
return model
def merge_load_state_dict(state_dict,
model: torch.nn.Module,
output_fn: Callable = sys.stdout.write):
model_dict = model.state_dict()
pretrained_dict = {}
mismatch_keys = []
for key, value in state_dict.items():
if key in model_dict and model_dict[key].shape == value.shape:
pretrained_dict[key] = value
else:
mismatch_keys.append(key)
output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=True)
return pretrained_dict.keys()
class EfficientNetB2(nn.Module):
def __init__(self,
n_mels: int = 64,
win_length: int = 32,
hop_length: int = 10,
f_min: int = 0,
freeze: bool = False,):
super().__init__()
sample_rate = 16000
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=win_length * sample_rate // 1000,
win_length=win_length * sample_rate // 1000,
hop_length=hop_length * sample_rate // 1000,
f_min=f_min,
n_mels=n_mels,
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB(top_db=120)
self.backbone = get_effb2_model()
self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels
self.downsample_ratio = 32
if freeze:
for param in self.parameters():
param.requires_grad = False
def forward(self, input_dict):
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = rearrange(x, 'b f t -> b 1 t f')
if self.training and specaug:
x = self.spec_augmenter(x)
x = rearrange(x, 'b 1 t f -> b f t')
x = self.backbone(x)
attn_emb = x
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
fc_emb = mean_with_lens(attn_emb, feat_length)
output_dict = {
'fc_emb': fc_emb,
'attn_emb': attn_emb,
'attn_emb_len': feat_length
}
return output_dict
def generate_length_mask(lens, max_length=None):
lens = torch.as_tensor(lens)
N = lens.size(0)
if max_length is None:
max_length = max(lens)
if isinstance(max_length, torch.Tensor):
max_length = max_length.item()
idxs = torch.arange(max_length).repeat(N).view(N, max_length)
idxs = idxs.to(lens.device)
mask = (idxs < lens.view(-1, 1))
return mask
def mean_with_lens(features, lens):
"""
features: [N, T, ...] (assume the second dimension represents length)
lens: [N,]
"""
lens = torch.as_tensor(lens)
if max(lens) != features.size(1):
max_length = features.size(1)
mask = generate_length_mask(lens, max_length)
else:
mask = generate_length_mask(lens)
mask = mask.to(features.device) # [N, T]
while mask.ndim < features.ndim:
mask = mask.unsqueeze(-1)
feature_mean = features * mask
feature_mean = feature_mean.sum(1)
while lens.ndim < feature_mean.ndim:
lens = lens.unsqueeze(1)
feature_mean = feature_mean / lens.to(features.device)
# feature_mean = features * mask.unsqueeze(-1)
# feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
return feature_mean
def max_with_lens(features, lens):
"""
features: [N, T, ...] (assume the second dimension represents length)
lens: [N,]
"""
lens = torch.as_tensor(lens)
if max(lens) != features.size(1):
max_length = features.size(1)
mask = generate_length_mask(lens, max_length)
else:
mask = generate_length_mask(lens)
mask = mask.to(features.device) # [N, T]
feature_max = features.clone()
feature_max[~mask] = float("-inf")
feature_max, _ = feature_max.max(1)
return feature_max
def repeat_tensor(x, n):
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
class CaptionMetaMixin:
pad_idx = 0
start_idx = 1
end_idx = 2
max_length = 20
@classmethod
def set_index(cls, start_idx, end_idx, pad_idx):
cls.start_idx = start_idx
cls.end_idx = end_idx
cls.pad_idx = pad_idx
class CaptionModel(nn.Module, CaptionMetaMixin):
"""
Encoder-decoder captioning model.
"""
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.vocab_size = decoder.vocab_size
self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
self.inference_forward_keys = ["sample_method", "max_length", "temp"]
freeze_encoder = kwargs.get("freeze_encoder", False)
if freeze_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
self.check_decoder_compatibility()
def check_decoder_compatibility(self):
compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
assert isinstance(self.decoder, self.compatible_decoders), \
f"{self.decoder.__class__.__name__} is incompatible with " \
f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
def forward(self, input_dict: Dict):
"""
input_dict: {
(required)
mode: train/inference,
[spec, spec_len],
[fc],
[attn, attn_len],
[wav, wav_len],
[sample_method: greedy],
[temp: 1.0] (in case of no teacher forcing)
(optional, mode=train)
cap,
cap_len,
ss_ratio,
(optional, mode=inference)
sample_method: greedy/beam,
max_length,
temp,
beam_size (optional, sample_method=beam),
n_best (optional, sample_method=beam),
}
"""
encoder_output_dict = self.encoder(input_dict)
output = self.forward_decoder(input_dict, encoder_output_dict)
return output
def forward_decoder(self, input_dict: Dict, encoder_output_dict: Dict):
if input_dict["mode"] == "train":
forward_dict = {
"mode": "train", "sample_method": "greedy", "temp": 1.0
}
for key in self.train_forward_keys:
forward_dict[key] = input_dict[key]
forward_dict.update(encoder_output_dict)
output = self.train_forward(forward_dict)
elif input_dict["mode"] == "inference":
forward_dict = {"mode": "inference"}
default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
for key in self.inference_forward_keys:
if key in input_dict:
forward_dict[key] = input_dict[key]
else:
forward_dict[key] = default_args[key]
if forward_dict["sample_method"] == "beam":
forward_dict["beam_size"] = input_dict.get("beam_size", 3)
forward_dict["n_best"] = input_dict.get("n_best", False)
forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
elif forward_dict["sample_method"] == "dbs":
forward_dict["beam_size"] = input_dict.get("beam_size", 6)
forward_dict["group_size"] = input_dict.get("group_size", 3)
forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
forward_dict.update(encoder_output_dict)
output = self.inference_forward(forward_dict)
else:
raise Exception("mode should be either 'train' or 'inference'")
output.update(encoder_output_dict)
return output
def prepare_output(self, input_dict):
output = {}
batch_size = input_dict["fc_emb"].size(0)
if input_dict["mode"] == "train":
max_length = input_dict["cap"].size(1) - 1
elif input_dict["mode"] == "inference":
max_length = input_dict["max_length"]
else:
raise Exception("mode should be either 'train' or 'inference'")
device = input_dict["fc_emb"].device
output["seq"] = torch.full((batch_size, max_length), self.end_idx,
dtype=torch.long)
output["logit"] = torch.empty(batch_size, max_length,
self.vocab_size).to(device)
output["sampled_logprob"] = torch.zeros(batch_size, max_length)
output["embed"] = torch.empty(batch_size, max_length,
self.decoder.d_model).to(device)
return output
def train_forward(self, input_dict):
if input_dict["ss_ratio"] != 1: # scheduled sampling training
input_dict["mode"] = "train"
return self.stepwise_forward(input_dict)
output = self.seq_forward(input_dict)
self.train_process(output, input_dict)
return output
def seq_forward(self, input_dict):
raise NotImplementedError
def train_process(self, output, input_dict):
pass
def inference_forward(self, input_dict):
if input_dict["sample_method"] == "beam":
return self.beam_search(input_dict)
elif input_dict["sample_method"] == "dbs":
return self.diverse_beam_search(input_dict)
return self.stepwise_forward(input_dict)
def stepwise_forward(self, input_dict):
"""Step-by-step decoding"""
output = self.prepare_output(input_dict)
max_length = output["seq"].size(1)
# start sampling
for t in range(max_length):
input_dict["t"] = t
self.decode_step(input_dict, output)
if input_dict["mode"] == "inference": # decide whether to stop when sampling
unfinished_t = output["seq"][:, t] != self.end_idx
if t == 0:
unfinished = unfinished_t
else:
unfinished *= unfinished_t
output["seq"][:, t][~unfinished] = self.end_idx
if unfinished.sum() == 0:
break
self.stepwise_process(output)
return output
def decode_step(self, input_dict, output):
"""Decoding operation of timestep t"""
decoder_input = self.prepare_decoder_input(input_dict, output)
# feed to the decoder to get logit
output_t = self.decoder(decoder_input)
logit_t = output_t["logit"]
# assert logit_t.ndim == 3
if logit_t.size(1) == 1:
logit_t = logit_t.squeeze(1)
embed_t = output_t["embed"].squeeze(1)
elif logit_t.size(1) > 1:
logit_t = logit_t[:, -1, :]
embed_t = output_t["embed"][:, -1, :]
else:
raise Exception("no logit output")
# sample the next input word and get the corresponding logit
sampled = self.sample_next_word(logit_t,
method=input_dict["sample_method"],
temp=input_dict["temp"])
output_t.update(sampled)
output_t["t"] = input_dict["t"]
output_t["logit"] = logit_t
output_t["embed"] = embed_t
self.stepwise_process_step(output, output_t)
def prepare_decoder_input(self, input_dict, output):
"""Prepare the inp ut dict for the decoder"""
raise NotImplementedError
def stepwise_process_step(self, output, output_t):
"""Postprocessing (save output values) after each timestep t"""
t = output_t["t"]
output["logit"][:, t, :] = output_t["logit"]
output["seq"][:, t] = output_t["word"]
output["sampled_logprob"][:, t] = output_t["probs"]
output["embed"][:, t, :] = output_t["embed"]
def stepwise_process(self, output):
"""Postprocessing after the whole step-by-step autoregressive decoding"""
pass
def sample_next_word(self, logit, method, temp):
"""Sample the next word, given probs output by the decoder"""
logprob = torch.log_softmax(logit, dim=1)
if method == "greedy":
sampled_logprob, word = torch.max(logprob.detach(), 1)
elif method == "gumbel":
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).to(logprob.device)
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logit, temperature):
y = logit + sample_gumbel(logit.size())
return torch.log_softmax(y / temperature, dim=-1)
_logprob = gumbel_softmax_sample(logprob, temp)
_, word = torch.max(_logprob.data, 1)
sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
else:
logprob = logprob / temp
if method.startswith("top"):
top_num = float(method[3:])
if 0 < top_num < 1: # top-p sampling
probs = torch.softmax(logit, dim=1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
_cumsum = sorted_probs.cumsum(1)
mask = _cumsum < top_num
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
sorted_probs = sorted_probs * mask.to(sorted_probs)
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
logprob.scatter_(1, sorted_indices, sorted_probs.log())
else: # top-k sampling
k = int(top_num)
tmp = torch.empty_like(logprob).fill_(float('-inf'))
topk, indices = torch.topk(logprob, k, dim=1)
tmp = tmp.scatter(1, indices, topk)
logprob = tmp
word = torch.distributions.Categorical(logits=logprob.detach()).sample()
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
word = word.detach().long()
# sampled_logprob: [N,], word: [N,]
return {"word": word, "probs": sampled_logprob}
def beam_search(self, input_dict):
output = self.prepare_output(input_dict)
max_length = input_dict["max_length"]
beam_size = input_dict["beam_size"]
if input_dict["n_best"]:
n_best_size = input_dict["n_best_size"]
batch_size, max_length = output["seq"].size()
output["seq"] = torch.full((batch_size, n_best_size, max_length),
self.end_idx, dtype=torch.long)
temp = input_dict["temp"]
# instance by instance beam seach
for i in range(output["seq"].size(0)):
output_i = self.prepare_beamsearch_output(input_dict)
input_dict["sample_idx"] = i
for t in range(max_length):
input_dict["t"] = t
output_t = self.beamsearch_step(input_dict, output_i)
#######################################
# merge with previous beam and select the current max prob beam
#######################################
logit_t = output_t["logit"]
if logit_t.size(1) == 1:
logit_t = logit_t.squeeze(1)
elif logit_t.size(1) > 1:
logit_t = logit_t[:, -1, :]
else:
raise Exception("no logit output")
logprob_t = torch.log_softmax(logit_t, dim=1)
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
if t == 0: # for the first step, all k seq will have the same probs
topk_logprob, topk_words = logprob_t[0].topk(
beam_size, 0, True, True)
else: # unroll and find top logprob, and their unrolled indices
topk_logprob, topk_words = logprob_t.view(-1).topk(
beam_size, 0, True, True)
topk_words = topk_words.cpu()
output_i["topk_logprob"] = topk_logprob
# output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
rounding_mode='trunc')
output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
if t == 0:
output_i["seq"] = output_i["next_word"].unsqueeze(1)
else:
output_i["seq"] = torch.cat([
output_i["seq"][output_i["prev_words_beam"]],
output_i["next_word"].unsqueeze(1)], dim=1)
# add finished beams to results
is_end = output_i["next_word"] == self.end_idx
if t == max_length - 1:
is_end.fill_(1)
for beam_idx in range(beam_size):
if is_end[beam_idx]:
final_beam = {
"seq": output_i["seq"][beam_idx].clone(),
"score": output_i["topk_logprob"][beam_idx].item()
}
final_beam["score"] = final_beam["score"] / (t + 1)
output_i["done_beams"].append(final_beam)
output_i["topk_logprob"][is_end] -= 1000
self.beamsearch_process_step(output_i, output_t)
if len(output_i["done_beams"]) == beam_size:
break
self.beamsearch_process(output, output_i, input_dict)
return output
def prepare_beamsearch_output(self, input_dict):
beam_size = input_dict["beam_size"]
device = input_dict["fc_emb"].device
output = {
"topk_logprob": torch.zeros(beam_size).to(device),
"seq": None,
"prev_words_beam": None,
"next_word": None,
"done_beams": [],
}
return output
def beamsearch_step(self, input_dict, output_i):
decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
output_t = self.decoder(decoder_input)
output_t["t"] = input_dict["t"]
return output_t
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
raise NotImplementedError
def beamsearch_process_step(self, output_i, output_t):
pass
def beamsearch_process(self, output, output_i, input_dict):
i = input_dict["sample_idx"]
done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
if input_dict["n_best"]:
done_beams = done_beams[:input_dict["n_best_size"]]
for out_idx, done_beam in enumerate(done_beams):
seq = done_beam["seq"]
output["seq"][i][out_idx, :len(seq)] = seq
else:
seq = done_beams[0]["seq"]
output["seq"][i][:len(seq)] = seq
def diverse_beam_search(self, input_dict):
def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
local_time = t - divm
unaug_logprob = logprob.clone()
if divm > 0:
change = torch.zeros(logprob.size(-1))
for prev_choice in range(divm):
prev_decisions = seq_table[prev_choice][..., local_time]
for prev_labels in range(bdash):
change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
change = change.to(logprob.device)
logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
return logprob, unaug_logprob
output = self.prepare_output(input_dict)
group_size = input_dict["group_size"]
batch_size = output["seq"].size(0)
beam_size = input_dict["beam_size"]
bdash = beam_size // group_size
input_dict["bdash"] = bdash
diversity_lambda = input_dict["diversity_lambda"]
device = input_dict["fc_emb"].device
max_length = input_dict["max_length"]
temp = input_dict["temp"]
group_nbest = input_dict["group_nbest"]
batch_size, max_length = output["seq"].size()
if group_nbest:
output["seq"] = torch.full((batch_size, beam_size, max_length),
self.end_idx, dtype=torch.long)
else:
output["seq"] = torch.full((batch_size, group_size, max_length),
self.end_idx, dtype=torch.long)
for i in range(batch_size):
input_dict["sample_idx"] = i
seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
done_beams_table = [[] for _ in range(group_size)]
output_i = {
"prev_words_beam": [None for _ in range(group_size)],
"next_word": [None for _ in range(group_size)],
"state": [None for _ in range(group_size)]
}
for t in range(max_length + group_size - 1):
input_dict["t"] = t
for divm in range(group_size):
input_dict["divm"] = divm
if t >= divm and t <= max_length + divm - 1:
local_time = t - divm
decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
output_t = self.decoder(decoder_input)
output_t["divm"] = divm
logit_t = output_t["logit"]
if logit_t.size(1) == 1:
logit_t = logit_t.squeeze(1)
elif logit_t.size(1) > 1:
logit_t = logit_t[:, -1, :]
else:
raise Exception("no logit output")
logprob_t = torch.log_softmax(logit_t, dim=1)
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
if local_time == 0: # for the first step, all k seq will have the same probs
topk_logprob, topk_words = logprob_t[0].topk(
bdash, 0, True, True)
else: # unroll and find top logprob, and their unrolled indices
topk_logprob, topk_words = logprob_t.view(-1).topk(
bdash, 0, True, True)
topk_words = topk_words.cpu()
logprob_table[divm] = topk_logprob
output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
if local_time > 0:
seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
seq_table[divm] = torch.cat([
seq_table[divm],
output_i["next_word"][divm].unsqueeze(-1)], -1)
is_end = seq_table[divm][:, t-divm] == self.end_idx
assert seq_table[divm].shape[-1] == t - divm + 1
if t == max_length + divm - 1:
is_end.fill_(1)
for beam_idx in range(bdash):
if is_end[beam_idx]:
final_beam = {
"seq": seq_table[divm][beam_idx].clone(),
"score": logprob_table[divm][beam_idx].item()
}
final_beam["score"] = final_beam["score"] / (t - divm + 1)
done_beams_table[divm].append(final_beam)
logprob_table[divm][is_end] -= 1000
self.dbs_process_step(output_i, output_t)
done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
if group_nbest:
done_beams = sum(done_beams_table, [])
else:
done_beams = [group_beam[0] for group_beam in done_beams_table]
for _, done_beam in enumerate(done_beams):
output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
return output
def prepare_dbs_decoder_input(self, input_dict, output_i):
raise NotImplementedError
def dbs_process_step(self, output_i, output_t):
pass
class TransformerModel(CaptionModel):
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
if not hasattr(self, "compatible_decoders"):
self.compatible_decoders = (
TransformerDecoder,
)
super().__init__(encoder, decoder, **kwargs)
def seq_forward(self, input_dict):
cap = input_dict["cap"]
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
cap_padding_mask = cap_padding_mask[:, :-1]
output = self.decoder(
{
"word": cap[:, :-1],
"attn_emb": input_dict["attn_emb"],
"attn_emb_len": input_dict["attn_emb_len"],
"cap_padding_mask": cap_padding_mask
}
)
return output
def prepare_decoder_input(self, input_dict, output):
decoder_input = {
"attn_emb": input_dict["attn_emb"],
"attn_emb_len": input_dict["attn_emb_len"]
}
t = input_dict["t"]
###############
# determine input word
################
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
word = input_dict["cap"][:, :t+1]
else:
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
if t == 0:
word = start_word
else:
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
# word: [N, T]
decoder_input["word"] = word
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
decoder_input["cap_padding_mask"] = cap_padding_mask
return decoder_input
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
decoder_input = {}
t = input_dict["t"]
i = input_dict["sample_idx"]
beam_size = input_dict["beam_size"]
###############
# prepare attn embeds
################
if t == 0:
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
output_i["attn_emb"] = attn_emb
output_i["attn_emb_len"] = attn_emb_len
decoder_input["attn_emb"] = output_i["attn_emb"]
decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
###############
# determine input word
################
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
if t == 0:
word = start_word
else:
word = torch.cat((start_word, output_i["seq"]), dim=-1)
decoder_input["word"] = word
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
decoder_input["cap_padding_mask"] = cap_padding_mask
return decoder_input
class BaseDecoder(nn.Module):
"""
Take word/audio embeddings and output the next word probs
"""
def __init__(self, emb_dim, vocab_size, fc_emb_dim,
attn_emb_dim, dropout=0.2, tie_weights=False):
super().__init__()
self.emb_dim = emb_dim
self.vocab_size = vocab_size
self.fc_emb_dim = fc_emb_dim
self.attn_emb_dim = attn_emb_dim
self.tie_weights = tie_weights
self.word_embedding = nn.Embedding(vocab_size, emb_dim)
self.in_dropout = nn.Dropout(dropout)
def forward(self, x):
raise NotImplementedError
def load_word_embedding(self, weight, freeze=True):
embedding = np.load(weight)
assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
# embeddings = torch.as_tensor(embeddings).float()
# self.word_embeddings.weight = nn.Parameter(embeddings)
# for para in self.word_embeddings.parameters():
# para.requires_grad = tune
self.word_embedding = nn.Embedding.from_pretrained(embedding,
freeze=freeze)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=100):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
# self.register_buffer("pe", pe)
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
def forward(self, x):
# x: [T, N, E]
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class TransformerDecoder(BaseDecoder):
def __init__(self,
emb_dim,
vocab_size,
fc_emb_dim,
attn_emb_dim,
dropout,
freeze=False,
tie_weights=False,
**kwargs):
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout=dropout, tie_weights=tie_weights)
self.d_model = emb_dim
self.nhead = kwargs.get("nhead", self.d_model // 64)
self.nlayers = kwargs.get("nlayers", 2)
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
self.pos_encoder = PositionalEncoding(self.d_model, dropout)
layer = nn.TransformerDecoderLayer(d_model=self.d_model,
nhead=self.nhead,
dim_feedforward=self.dim_feedforward,
dropout=dropout)
self.model = nn.TransformerDecoder(layer, self.nlayers)
self.classifier = nn.Linear(self.d_model, vocab_size, bias=False)
if tie_weights:
self.classifier.weight = self.word_embedding.weight
self.attn_proj = nn.Sequential(
nn.Linear(self.attn_emb_dim, self.d_model),
nn.ReLU(),
nn.Dropout(dropout),
nn.LayerNorm(self.d_model)
)
self.init_params()
self.freeze = freeze
if freeze:
for p in self.parameters():
p.requires_grad = False
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def load_pretrained(self, pretrained, output_fn):
checkpoint = torch.load(pretrained, map_location="cpu")
if "model" in checkpoint:
checkpoint = checkpoint["model"]
if next(iter(checkpoint)).startswith("decoder."):
state_dict = {}
for k, v in checkpoint.items():
state_dict[k[8:]] = v
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
if self.freeze:
for name, param in self.named_parameters():
if name in loaded_keys:
param.requires_grad = False
else:
param.requires_grad = True
def generate_square_subsequent_mask(self, max_length):
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, input_dict):
word = input_dict["word"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
cap_padding_mask = input_dict["cap_padding_mask"]
p_attn_emb = self.attn_proj(attn_emb)
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
word = word.to(attn_emb.device)
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
embed = embed.transpose(0, 1) # [T, N, emb_dim]
embed = self.pos_encoder(embed)
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
tgt_key_padding_mask=cap_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
output = output.transpose(0, 1)
output = {
"embed": output,
"logit": self.classifier(output),
}
return output
class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin):
def __init__(self,
model: nn.Module,
shared_dim: int,
tchr_dim: int,
):
super().__init__()
self.model = model
self.tchr_dim = tchr_dim
if hasattr(model, "encoder"):
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size,
shared_dim)
else:
self.stdnt_proj = nn.Linear(model.fc_emb_size,
shared_dim)
self.tchr_proj = nn.Linear(tchr_dim, shared_dim)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, input_dict: Dict):
unsup = input_dict.get("unsup", False)
if unsup is False:
output_dict = self.model(input_dict)
else:
output_dict = self.model.encoder(input_dict)
if "tchr_output" in input_dict:
stdnt_emb = output_dict["fc_emb"]
stdnt_emb = self.stdnt_proj(stdnt_emb)
tchr_emb = input_dict["tchr_output"]["embedding"]
thcr_emb = self.tchr_proj(tchr_emb)
stdnt_emb = F.normalize(stdnt_emb, dim=-1)
thcr_emb = F.normalize(thcr_emb, dim=-1)
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1)
logit = self.logit_scale * unscaled_logit
label = torch.arange(logit.shape[0]).to(logit.device)
loss1 = F.cross_entropy(logit, label)
loss2 = F.cross_entropy(logit.transpose(0, 1), label)
loss = (loss1 + loss2) / 2
output_dict["enc_kd_loss"] = loss
return output_dict
class Effb2TrmConfig(PretrainedConfig):
def __init__(
self,
sample_rate: int = 16000,
tchr_dim: int = 768,
shared_dim: int = 1024,
fc_emb_dim: int = 1408,
attn_emb_dim: int = 1408,
decoder_n_layers: int = 2,
decoder_we_tie_weights: bool = True,
decoder_emb_dim: int = 256,
decoder_dropout: float = 0.2,
vocab_size: int = 4981,
**kwargs
):
self.sample_rate = sample_rate
self.tchr_dim = tchr_dim
self.shared_dim = shared_dim
self.fc_emb_dim = fc_emb_dim
self.attn_emb_dim = attn_emb_dim
self.decoder_n_layers = decoder_n_layers
self.decoder_we_tie_weights = decoder_we_tie_weights
self.decoder_emb_dim = decoder_emb_dim
self.decoder_dropout = decoder_dropout
self.vocab_size = vocab_size
super().__init__(**kwargs)
class Effb2TrmCaptioningModel(PreTrainedModel):
config_class = Effb2TrmConfig
def __init__(self, config):
super().__init__(config)
encoder = EfficientNetB2()
decoder = TransformerDecoder(
emb_dim=config.decoder_emb_dim,
vocab_size=config.vocab_size,
fc_emb_dim=config.fc_emb_dim,
attn_emb_dim=config.attn_emb_dim,
dropout=config.decoder_dropout,
nlayers=config.decoder_n_layers,
tie_weights=config.decoder_we_tie_weights
)
model = TransformerModel(encoder, decoder)
self.model = ContraEncoderKdWrapper(model, config.shared_dim, config.tchr_dim)
def forward(self,
audio: torch.Tensor,
audio_length: Union[List, np.ndarray, torch.Tensor],
sample_method: str = "beam",
beam_size: int = 3,
max_length: int = 20,
temp: float = 1.0,):
device = self.device
input_dict = {
"wav": audio.to(device),
"wav_len": audio_length,
"specaug": False,
"mode": "inference",
"sample_method": sample_method,
"max_length": max_length,
"temp": temp,
}
if sample_method == "beam":
input_dict["beam_size"] = beam_size
return self.model(input_dict)["seq"].cpu()
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception('Incorrect argument!')
return x
class Cnn14Encoder(nn.Module):
def __init__(self, sample_rate=32000):
super().__init__()
sr_to_fmax = {
32000: 14000,
16000: 8000
}
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=32 * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB()
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.downsample_ratio = 32
self.fc1 = nn.Linear(2048, 2048, bias=True)
self.fc_emb_size = 2048
def forward(self, input_dict):
lms = input_dict["lms"]
wave_length = input_dict["wav_len"]
x = lms # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
attn_emb = x.transpose(1, 2)
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
x_max = max_with_lens(attn_emb, feat_length)
x_mean = mean_with_lens(attn_emb, feat_length)
x = x_max + x_mean
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
fc_emb = F.dropout(x, p=0.5, training=self.training)
output_dict = {
'fc_emb': fc_emb,
'attn_emb': attn_emb,
'attn_emb_len': feat_length
}
return output_dict
class RnnEncoder(nn.Module):
def __init__(self,
attn_feat_dim,
pooling="mean",
**kwargs):
super().__init__()
self.pooling = pooling
self.hidden_size = kwargs.get('hidden_size', 512)
self.bidirectional = kwargs.get('bidirectional', False)
self.num_layers = kwargs.get('num_layers', 1)
self.dropout = kwargs.get('dropout', 0.2)
self.rnn_type = kwargs.get('rnn_type', "GRU")
self.in_bn = kwargs.get('in_bn', False)
self.embed_dim = self.hidden_size * (self.bidirectional + 1)
self.network = getattr(nn, self.rnn_type)(
attn_feat_dim,
self.hidden_size,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
dropout=self.dropout,
batch_first=True)
if self.in_bn:
self.bn = nn.BatchNorm1d(self.embed_dim)
def forward(self, input_dict):
x = input_dict["attn"]
lens = input_dict["attn_len"]
lens = torch.as_tensor(lens)
# x: [N, T, E]
if self.in_bn:
x = pack_wrapper(self.bn, x, lens)
out = pack_wrapper(self.network, x, lens)
# out: [N, T, hidden]
attn_emb = out
fc_emb = embedding_pooling(out, lens, self.pooling)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": lens
}
class Cnn14RnnEncoder(nn.Module):
def __init__(self,
sample_rate,
rnn_bidirectional,
rnn_hidden_size,
rnn_dropout,
rnn_num_layers):
super().__init__()
self.cnn = Cnn14Encoder(sample_rate=sample_rate)
self.rnn = RnnEncoder(
2048,
bidirectional=rnn_bidirectional,
hidden_size=rnn_hidden_size,
dropout=rnn_dropout,
num_layers=rnn_num_layers,
)
def forward(self, input_dict):
output_dict = self.cnn(input_dict)
output_dict["attn"] = output_dict["attn_emb"]
output_dict["attn_len"] = output_dict["attn_emb_len"]
del output_dict["attn_emb"], output_dict["attn_emb_len"]
output_dict = self.rnn(output_dict)
return output_dict
class Seq2SeqAttention(nn.Module):
def __init__(self, hs_enc, hs_dec, attn_size):
"""
Args:
hs_enc: encoder hidden size
hs_dec: decoder hidden size
attn_size: attention vector size
"""
super(Seq2SeqAttention, self).__init__()
self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size)
self.v = nn.Parameter(torch.randn(attn_size))
def forward(self, h_dec, h_enc, src_lens):
"""
Args:
h_dec: decoder hidden (query), [N, hs_dec]
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
src_lens: source (encoder memory) lengths, [N, ]
"""
N = h_enc.size(0)
src_max_len = h_enc.size(1)
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
attn_input = torch.cat((h_dec, h_enc), dim=-1)
attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
score = score.masked_fill(mask == 0, -1e10)
weights = torch.softmax(score, dim=-1) # [N, src_max_len]
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
return ctx, weights
class RnnDecoder(BaseDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout,)
self.d_model = d_model
self.num_layers = kwargs.get('num_layers', 1)
self.bidirectional = kwargs.get('bidirectional', False)
self.rnn_type = kwargs.get('rnn_type', "GRU")
self.classifier = nn.Linear(
self.d_model * (self.bidirectional + 1), vocab_size)
def forward(self, x):
raise NotImplementedError
def init_hidden(self, bs, device):
num_dire = self.bidirectional + 1
n_layer = self.num_layers
hid_dim = self.d_model
if self.rnn_type == "LSTM":
return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device),
torch.zeros(num_dire * n_layer, bs, hid_dim).to(device))
else:
return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)
class BahAttnCatFcDecoder(RnnDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
"""
concatenate fc, attn, word to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
attn_size = kwargs.get("attn_size", self.d_model)
self.model = getattr(nn, self.rnn_type)(
input_size=self.emb_dim * 3,
hidden_size=self.d_model,
batch_first=True,
num_layers=self.num_layers,
bidirectional=self.bidirectional)
self.attn = Seq2SeqAttention(self.attn_emb_dim,
self.d_model * (self.bidirectional + 1) * \
self.num_layers,
attn_size)
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_emb = input_dict["fc_emb"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
word = word.to(fc_emb.device)
embed = self.in_dropout(self.word_embedding(word))
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_emb.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
p_fc_emb = self.fc_proj(fc_emb)
p_ctx = self.ctx_proj(c)
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)),
dim=-1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class TemporalBahAttnDecoder(BahAttnCatFcDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs):
"""
concatenate fc, attn, word to feed to the rnn
"""
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, d_model, **kwargs)
self.temporal_embedding = nn.Embedding(4, emb_dim)
def forward(self, input_dict):
word = input_dict["word"]
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
fc_embs = input_dict["fc_emb"]
attn_embs = input_dict["attn_emb"]
attn_emb_lens = input_dict["attn_emb_len"]
temporal_tag = input_dict["temporal_tag"]
if input_dict["t"] == 0:
embed = self.in_dropout(
self.temporal_embedding(temporal_tag)).unsqueeze(1)
elif word.size(-1) == self.fc_emb_dim: # fc_embs
embed = word.unsqueeze(1)
elif word.size(-1) == 1: # word
word = word.to(fc_embs.device)
embed = self.in_dropout(self.word_embedding(word))
else:
raise Exception(f"problem with word input size {word.size()}")
# embed: [N, 1, embed_size]
if state is None:
state = self.init_hidden(word.size(0), fc_embs.device)
if self.rnn_type == "LSTM":
query = state[0].transpose(0, 1).flatten(1)
else:
query = state.transpose(0, 1).flatten(1)
c, attn_weight = self.attn(query, attn_embs, attn_emb_lens)
p_ctx = self.ctx_proj(c)
p_fc_embs = self.fc_proj(fc_embs)
p_ctx = self.ctx_proj(c)
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_embs.unsqueeze(1)), dim=-1)
out, state = self.model(rnn_input, state)
output = {
"state": state,
"embed": out,
"logit": self.classifier(out),
"attn_weight": attn_weight
}
return output
class Seq2SeqAttnModel(CaptionModel):
def __init__(self, encoder, decoder, **kwargs):
if not hasattr(self, "compatible_decoders"):
self.compatible_decoders = (
BahAttnCatFcDecoder,
)
super().__init__(encoder, decoder, **kwargs)
def seq_forward(self, input_dict):
# Bahdanau attention only supports step-by-step implementation, so we implement forward in
# step-by-step manner whether in training or evaluation
return self.stepwise_forward(input_dict)
def prepare_output(self, input_dict):
output = super().prepare_output(input_dict)
attn_weight = torch.empty(output["seq"].size(0),
input_dict["attn_emb"].size(1), output["seq"].size(1))
output["attn_weight"] = attn_weight
return output
def prepare_decoder_input(self, input_dict, output):
decoder_input = {
"fc_emb": input_dict["fc_emb"],
"attn_emb": input_dict["attn_emb"],
"attn_emb_len": input_dict["attn_emb_len"]
}
t = input_dict["t"]
###############
# determine input word
################
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
word = input_dict["cap"][:, t]
else:
if t == 0:
word = torch.tensor([self.start_idx,] * input_dict["fc_emb"].size(0)).long()
else:
word = output["seq"][:, t-1]
# word: [N,]
decoder_input["word"] = word.unsqueeze(1)
################
# prepare rnn state
################
if t > 0:
decoder_input["state"] = output["state"]
return decoder_input
def stepwise_process_step(self, output, output_t):
super().stepwise_process_step(output, output_t)
output["state"] = output_t["state"]
t = output_t["t"]
output["attn_weight"][:, :, t] = output_t["attn_weight"]
def prepare_beamsearch_output(self, input_dict):
output = super().prepare_beamsearch_output(input_dict)
beam_size = input_dict["beam_size"]
max_length = input_dict["max_length"]
output["attn_weight"] = torch.empty(beam_size,
max(input_dict["attn_emb_len"]), max_length)
return output
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
decoder_input = {}
t = input_dict["t"]
i = input_dict["sample_idx"]
beam_size = input_dict["beam_size"]
###############
# prepare fc embeds
################
if t == 0:
fc_emb = repeat_tensor(input_dict["fc_emb"][i], beam_size)
output_i["fc_emb"] = fc_emb
decoder_input["fc_emb"] = output_i["fc_emb"]
###############
# prepare attn embeds
################
if t == 0:
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
output_i["attn_emb"] = attn_emb
output_i["attn_emb_len"] = attn_emb_len
decoder_input["attn_emb"] = output_i["attn_emb"]
decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
###############
# determine input word
################
if t == 0:
word = torch.tensor([self.start_idx,] * beam_size).long()
else:
word = output_i["next_word"]
decoder_input["word"] = word.unsqueeze(1)
################
# prepare rnn state
################
if t > 0:
if self.decoder.rnn_type == "LSTM":
decoder_input["state"] = (output_i["state"][0][:, output_i["prev_words_beam"], :].contiguous(),
output_i["state"][1][:, output_i["prev_words_beam"], :].contiguous())
else:
decoder_input["state"] = output_i["state"][:, output_i["prev_words_beam"], :].contiguous()
return decoder_input
def beamsearch_process_step(self, output_i, output_t):
t = output_t["t"]
output_i["state"] = output_t["state"]
output_i["attn_weight"][..., t] = output_t["attn_weight"]
output_i["attn_weight"] = output_i["attn_weight"][output_i["prev_words_beam"], ...]
def beamsearch_process(self, output, output_i, input_dict):
super().beamsearch_process(output, output_i, input_dict)
i = input_dict["sample_idx"]
output["attn_weight"][i] = output_i["attn_weight"][0]
def prepare_dbs_decoder_input(self, input_dict, output_i):
decoder_input = {}
t = input_dict["t"]
i = input_dict["sample_idx"]
bdash = input_dict["bdash"]
divm = input_dict["divm"]
local_time = t - divm
###############
# prepare fc embeds
################
# repeat only at the first timestep to save consumption
if t == 0:
fc_emb = repeat_tensor(input_dict["fc_emb"][i], bdash).unsqueeze(1)
output_i["fc_emb"] = fc_emb
decoder_input["fc_emb"] = output_i["fc_emb"]
###############
# prepare attn embeds
################
if t == 0:
attn_emb = repeat_tensor(input_dict["attn_emb"][i], bdash)
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], bdash)
output_i["attn_emb"] = attn_emb
output_i["attn_emb_len"] = attn_emb_len
decoder_input["attn_emb"] = output_i["attn_emb"]
decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
###############
# determine input word
################
if local_time == 0:
word = torch.tensor([self.start_idx,] * bdash).long()
else:
word = output_i["next_word"][divm]
decoder_input["word"] = word.unsqueeze(1)
################
# prepare rnn state
################
if local_time > 0:
if self.decoder.rnn_type == "LSTM":
decoder_input["state"] = (
output_i["state"][0][divm][
:, output_i["prev_words_beam"][divm], :].contiguous(),
output_i["state"][1][divm][
:, output_i["prev_words_beam"][divm], :].contiguous()
)
else:
decoder_input["state"] = output_i["state"][divm][
:, output_i["prev_words_beam"][divm], :].contiguous()
return decoder_input
def dbs_process_step(self, output_i, output_t):
divm = output_t["divm"]
output_i["state"][divm] = output_t["state"]
# TODO attention weight
class TemporalSeq2SeqAttnModel(Seq2SeqAttnModel):
def __init__(self, encoder, decoder, **kwargs):
if not hasattr(self, "compatible_decoders"):
self.compatible_decoders = (
TemporalBahAttnDecoder,
)
super().__init__(encoder, decoder, **kwargs)
self.train_forward_keys = ["cap", "cap_len", "ss_ratio", "temporal_tag"]
self.inference_forward_keys = ["sample_method", "max_length", "temp", "temporal_tag"]
def prepare_decoder_input(self, input_dict, output):
decoder_input = super().prepare_decoder_input(input_dict, output)
decoder_input["temporal_tag"] = input_dict["temporal_tag"]
decoder_input["t"] = input_dict["t"]
return decoder_input
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
t = input_dict["t"]
i = input_dict["sample_idx"]
beam_size = input_dict["beam_size"]
###############
# prepare temporal_tag
################
if t == 0:
temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], beam_size)
output_i["temporal_tag"] = temporal_tag
decoder_input["temporal_tag"] = output_i["temporal_tag"]
decoder_input["t"] = input_dict["t"]
return decoder_input
def prepare_dbs_decoder_input(self, input_dict, output_i):
decoder_input = super.prepare_dbs_decoder_input(input_dict, output_i)
t = input_dict["t"]
i = input_dict["sample_idx"]
bdash = input_dict["bdash"]
###############
# prepare temporal tag
################
# repeat only at the first timestep to save consumption
if t == 0:
temporal_tag = repeat_tensor(input_dict["temporal_tag"][i], bdash)
output_i["temporal_tag"] = temporal_tag
decoder_input["temporal_tag"] = output_i["temporal_tag"]
decoder_input["t"] = input_dict["t"]
return decoder_input
class Cnn8rnnSedModel(nn.Module):
def __init__(self, classes_num):
super().__init__()
self.time_resolution = 0.01
self.interpolate_ratio = 4 # Downsampled ratio
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, 512, bias=True)
self.rnn = nn.GRU(512, 256, bidirectional=True, batch_first=True)
self.fc_audioset = nn.Linear(512, classes_num, bias=True)
def forward(self, lms):
output = self.forward_prob(lms)
framewise_output = output["framewise_output"].cpu().numpy()
thresholded_predictions = double_threshold(
framewise_output, 0.75, 0.25)
decoded_tags = decode_with_timestamps(
thresholded_predictions, self.time_resolution
)
return decoded_tags
def forward_prob(self, lms):
"""
lms: (batch_size, mel_bins, time_steps)"""
x = lms
x = x.transpose(1, 2)
x = x.unsqueeze(1)
frames_num = x.shape[2]
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg+max')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg+max')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg+max')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg+max')
x = F.dropout(x, p=0.2, training=self.training) # (batch_size, 256, time_steps / 4, mel_bins / 16)
x = torch.mean(x, dim=3)
x = x.transpose(1, 2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
x, _ = self.rnn(x)
segmentwise_output = torch.sigmoid(self.fc_audioset(x)).clamp(1e-7, 1.)
framewise_output = interpolate(segmentwise_output,
self.interpolate_ratio)
framewise_output = pad_framewise_output(framewise_output, frames_num)
output_dict = {
"segmentwise_output": segmentwise_output,
'framewise_output': framewise_output,
}
return output_dict
class Cnn14RnnTempAttnGruConfig(PretrainedConfig):
def __init__(
self,
sample_rate: int = 32000,
encoder_rnn_bidirectional: bool = True,
encoder_rnn_hidden_size: int = 256,
encoder_rnn_dropout: float = 0.5,
encoder_rnn_num_layers: int = 3,
decoder_emb_dim: int = 512,
vocab_size: int = 4981,
fc_emb_dim: int = 512,
attn_emb_dim: int = 512,
decoder_rnn_type: str = "GRU",
decoder_num_layers: int = 1,
decoder_d_model: int = 512,
decoder_dropout: float = 0.5,
**kwargs
):
self.sample_rate = sample_rate
self.encoder_rnn_bidirectional = encoder_rnn_bidirectional
self.encoder_rnn_hidden_size = encoder_rnn_hidden_size
self.encoder_rnn_dropout = encoder_rnn_dropout
self.encoder_rnn_num_layers = encoder_rnn_num_layers
self.decoder_emb_dim = decoder_emb_dim
self.vocab_size = vocab_size
self.fc_emb_dim = fc_emb_dim
self.attn_emb_dim = attn_emb_dim
self.decoder_rnn_type = decoder_rnn_type
self.decoder_num_layers = decoder_num_layers
self.decoder_d_model = decoder_d_model
self.decoder_dropout = decoder_dropout
super().__init__(**kwargs)
class Cnn14RnnTempAttnGruModel(PreTrainedModel):
config_class = Cnn14RnnTempAttnGruConfig
def __init__(self, config):
super().__init__(config)
sample_rate = config.sample_rate
sr_to_fmax = {
32000: 14000,
16000: 8000
}
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=32 * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.db_transform = transforms.AmplitudeToDB()
encoder = Cnn14RnnEncoder(
sample_rate=config.sample_rate,
rnn_bidirectional=config.encoder_rnn_bidirectional,
rnn_hidden_size=config.encoder_rnn_hidden_size,
rnn_dropout=config.encoder_rnn_dropout,
rnn_num_layers=config.encoder_rnn_num_layers
)
decoder = TemporalBahAttnDecoder(
emb_dim=config.decoder_emb_dim,
vocab_size=config.vocab_size,
fc_emb_dim=config.fc_emb_dim,
attn_emb_dim=config.attn_emb_dim,
rnn_type=config.decoder_rnn_type,
num_layers=config.decoder_num_layers,
d_model=config.decoder_d_model,
dropout=config.decoder_dropout,
)
cap_model = TemporalSeq2SeqAttnModel(encoder, decoder)
sed_model = Cnn8rnnSedModel(classes_num=447)
self.cap_model = cap_model
self.sed_model = sed_model
def forward(self,
audio: torch.Tensor,
audio_length: Union[List, np.ndarray, torch.Tensor],
temporal_tag: Union[List, np.ndarray, torch.Tensor] = None,
sample_method: str = "beam",
beam_size: int = 3,
max_length: int = 20,
temp: float = 1.0,):
device = self.device
mel_spec = self.melspec_extractor(audio.to(device))
log_mel_spec = self.db_transform(mel_spec)
sed_tag = self.sed_model(log_mel_spec)
sed_tag = torch.as_tensor(sed_tag).to(device)
if temporal_tag is not None:
temporal_tag = torch.as_tensor(temporal_tag).to(device)
temporal_tag = torch.stack([temporal_tag, sed_tag], dim=0)
temporal_tag = torch.min(temporal_tag, dim=0).values
else:
temporal_tag = sed_tag
input_dict = {
"lms": log_mel_spec,
"wav_len": audio_length,
"temporal_tag": temporal_tag,
"mode": "inference",
"sample_method": sample_method,
"max_length": max_length,
"temp": temp,
}
if sample_method == "beam":
input_dict["beam_size"] = beam_size
return self.cap_model(input_dict)["seq"].cpu()