""" Multi Step Attention for CNN """ import torch import torch.nn as nn import torch.nn.functional as F SCALE_WEIGHT = 0.5**0.5 def seq_linear(linear, x): """linear transform for 3-d tensor""" batch, hidden_size, length, _ = x.size() h = linear(torch.transpose(x, 1, 2).contiguous().view(batch * length, hidden_size)) return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) class ConvMultiStepAttention(nn.Module): """ Conv attention takes a key matrix, a value matrix and a query vector. Attention weight is calculated by key matrix with the query vector and sum on the value matrix. And the same operation is applied in each decode conv layer. """ def __init__(self, input_size): super(ConvMultiStepAttention, self).__init__() self.linear_in = nn.Linear(input_size, input_size) self.mask = None def apply_mask(self, mask): """Apply mask""" self.mask = mask def forward( self, base_target_emb, input_from_dec, encoder_out_top, encoder_out_combine ): """ Args: base_target_emb: target emb tensor ``(batch, channel, height, width)`` input_from_dec: output of dec conv ``(batch, channel, height, width)`` encoder_out_top: the key matrix for calc of attention weight, which is the top output of encode conv encoder_out_combine: the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode """ preatt = seq_linear(self.linear_in, input_from_dec) target = (base_target_emb + preatt) * SCALE_WEIGHT target = torch.squeeze(target, 3) target = torch.transpose(target, 1, 2) pre_attn = torch.bmm(target, encoder_out_top) if self.mask is not None: pre_attn.data.masked_fill_(self.mask, -float("inf")) attn = F.softmax(pre_attn, dim=2) context_output = torch.bmm(attn, torch.transpose(encoder_out_combine, 1, 2)) context_output = torch.transpose(torch.unsqueeze(context_output, 3), 1, 2) return context_output, attn