|
""" |
|
Implementation of "Attention is All You Need" and of |
|
subsequent transformer based architectures |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from onmt.decoders.decoder import DecoderBase |
|
from onmt.modules import MultiHeadedAttention, AverageAttention |
|
from onmt.modules.position_ffn import PositionwiseFeedForward |
|
from onmt.modules.position_ffn import ActivationFunction |
|
from onmt.utils.misc import sequence_mask |
|
from onmt.modules.rmsnorm import RMSNorm |
|
|
|
|
|
class TransformerDecoderLayerBase(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type="scaled-dot", |
|
max_relative_positions=0, |
|
relative_positions_buckets=0, |
|
aan_useffn=False, |
|
full_context_alignment=False, |
|
alignment_heads=0, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
add_qkvbias=False, |
|
num_kv=0, |
|
add_ffnbias=True, |
|
parallel_residual=False, |
|
shared_layer_norm=False, |
|
layer_norm="standard", |
|
norm_eps=1e-6, |
|
use_ckpting=[], |
|
parallel_gpu=1, |
|
): |
|
""" |
|
Args: |
|
d_model (int): the dimension of keys/values/queries in |
|
:class:`MultiHeadedAttention`, also the input size of |
|
the first-layer of the :class:`PositionwiseFeedForward`. |
|
heads (int): the number of heads for MultiHeadedAttention. |
|
d_ff (int): the second-layer of the |
|
:class:`PositionwiseFeedForward`. |
|
dropout (float): dropout in residual, self-attn(dot) and |
|
feed-forward |
|
attention_dropout (float): dropout in context_attn (and |
|
self-attn(avg)) |
|
self_attn_type (string): type of self-attention scaled-dot, |
|
average |
|
max_relative_positions (int): |
|
Max distance between inputs in relative positions |
|
representations |
|
aan_useffn (bool): Turn on the FFN layer in the AAN decoder |
|
full_context_alignment (bool): |
|
whether enable an extra full context decoder forward for |
|
alignment |
|
alignment_heads (int): |
|
N. of cross attention heads to use for alignment guiding |
|
pos_ffn_activation_fn (ActivationFunction): |
|
activation function choice for PositionwiseFeedForward layer |
|
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear |
|
layer_norm (string): type of layer normalization standard/rms |
|
norm_eps (float): layer norm epsilon |
|
|
|
""" |
|
super(TransformerDecoderLayerBase, self).__init__() |
|
|
|
self.self_attn_type = self_attn_type |
|
if self_attn_type == "scaled-dot": |
|
self.self_attn = MultiHeadedAttention( |
|
heads, |
|
d_model, |
|
dropout=attention_dropout, |
|
max_relative_positions=max_relative_positions, |
|
relative_positions_buckets=relative_positions_buckets, |
|
attn_type="self", |
|
add_qkvbias=add_qkvbias, |
|
num_kv=num_kv, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
elif self_attn_type == "average": |
|
self.self_attn = AverageAttention( |
|
d_model, dropout=attention_dropout, aan_useffn=aan_useffn |
|
) |
|
|
|
self.feed_forward = PositionwiseFeedForward( |
|
d_model, |
|
d_ff, |
|
dropout, |
|
pos_ffn_activation_fn, |
|
add_ffnbias, |
|
parallel_residual, |
|
layer_norm, |
|
norm_eps, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
self.parallel_residual = parallel_residual |
|
self.shared_layer_norm = shared_layer_norm |
|
if layer_norm == "standard": |
|
self.layer_norm_1 = nn.LayerNorm(d_model, eps=norm_eps) |
|
if parallel_residual and not shared_layer_norm: |
|
self.layer_norm_res = nn.LayerNorm(d_model, eps=norm_eps) |
|
elif layer_norm == "rms": |
|
self.layer_norm_1 = RMSNorm(d_model, eps=norm_eps) |
|
if parallel_residual and not shared_layer_norm: |
|
self.layer_norm_res = RMSNorm(d_model, eps=norm_eps) |
|
else: |
|
raise ValueError(f"{layer_norm} layer norm type is not supported") |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.full_context_alignment = full_context_alignment |
|
self.alignment_heads = alignment_heads |
|
|
|
def forward(self, *args, **kwargs): |
|
"""Extend `_forward` for (possibly) multiple decoder pass: |
|
Always a default (future masked) decoder forward pass, |
|
Possibly a second future aware decoder pass for joint learn |
|
full context alignement, :cite:`garg2019jointly`. |
|
|
|
Args: |
|
* All arguments of _forward, of which |
|
with_align (bool): needed to compute attn_align |
|
return_attn (bool): to force MHA to return attns |
|
|
|
Returns: |
|
(FloatTensor, FloatTensor, FloatTensor or None): |
|
|
|
* layer_out ``(batch_size, T, model_dim)`` |
|
* top_attn ``(batch_size, T, src_len)`` |
|
* attn_align ``(batch_size, T, src_len)`` or None |
|
""" |
|
with_align = kwargs.pop("with_align", False) |
|
layer_out, attns = self._forward(*args, **kwargs) |
|
top_attn = None if attns is None else attns[:, 0, :, :].contiguous() |
|
attn_align = None |
|
if with_align: |
|
if self.full_context_alignment: |
|
|
|
_, attns = self._forward(*args, **kwargs, future=True) |
|
|
|
if self.alignment_heads > 0: |
|
attns = attns[:, : self.alignment_heads, :, :].contiguous() |
|
|
|
|
|
|
|
|
|
attn_align = attns.mean(dim=1) |
|
return layer_out, top_attn, attn_align |
|
|
|
def update_dropout(self, dropout, attention_dropout): |
|
self.self_attn.update_dropout(attention_dropout) |
|
self.feed_forward.update_dropout(dropout) |
|
self.dropout.p = dropout |
|
|
|
def _forward(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
def _compute_dec_mask(self, tgt_pad_mask, future): |
|
tgt_len = tgt_pad_mask.size(-1) |
|
if not future: |
|
future_mask = torch.ones( |
|
[tgt_len, tgt_len], |
|
device=tgt_pad_mask.device, |
|
dtype=torch.uint8, |
|
) |
|
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) |
|
|
|
try: |
|
future_mask = future_mask.bool() |
|
except AttributeError: |
|
pass |
|
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) |
|
else: |
|
dec_mask = tgt_pad_mask |
|
return dec_mask |
|
|
|
def _forward_self_attn(self, norm_layer_in, dec_mask, step, return_attn=False): |
|
if self.self_attn_type == "scaled-dot": |
|
return self.self_attn( |
|
norm_layer_in, |
|
norm_layer_in, |
|
norm_layer_in, |
|
mask=dec_mask, |
|
step=step, |
|
return_attn=return_attn, |
|
) |
|
elif self.self_attn_type == "average": |
|
return self.self_attn(norm_layer_in, mask=dec_mask, step=step) |
|
else: |
|
raise ValueError(f"self attention {type(self.self_attn)} not supported") |
|
|
|
|
|
class TransformerDecoderLayer(TransformerDecoderLayerBase): |
|
"""Transformer Decoder layer block in Pre-Norm style. |
|
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style, |
|
providing better converge speed and performance. This is also the actual |
|
implementation in tensor2tensor and also avalable in fairseq. |
|
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type="scaled-dot", |
|
max_relative_positions=0, |
|
relative_positions_buckets=0, |
|
aan_useffn=False, |
|
full_context_alignment=False, |
|
alignment_heads=0, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
add_qkvbias=False, |
|
num_kv=0, |
|
add_ffnbias=True, |
|
parallel_residual=False, |
|
shared_layer_norm=False, |
|
layer_norm="standard", |
|
norm_eps=1e-6, |
|
use_ckpting=[], |
|
parallel_gpu=1, |
|
): |
|
""" |
|
Args: |
|
See TransformerDecoderLayerBase |
|
""" |
|
super(TransformerDecoderLayer, self).__init__( |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type, |
|
max_relative_positions, |
|
relative_positions_buckets, |
|
aan_useffn, |
|
full_context_alignment, |
|
alignment_heads, |
|
pos_ffn_activation_fn=pos_ffn_activation_fn, |
|
add_qkvbias=add_qkvbias, |
|
num_kv=num_kv, |
|
add_ffnbias=add_ffnbias, |
|
parallel_residual=parallel_residual, |
|
shared_layer_norm=shared_layer_norm, |
|
layer_norm=layer_norm, |
|
norm_eps=norm_eps, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
self.context_attn = MultiHeadedAttention( |
|
heads, |
|
d_model, |
|
dropout=attention_dropout, |
|
attn_type="context", |
|
add_qkvbias=add_qkvbias, |
|
num_kv=num_kv, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
if layer_norm == "standard": |
|
self.layer_norm_2 = nn.LayerNorm(d_model, eps=norm_eps) |
|
elif layer_norm == "rms": |
|
self.layer_norm_2 = RMSNorm(d_model, eps=norm_eps) |
|
else: |
|
raise ValueError(f"{layer_norm} layer norm type is not supported") |
|
|
|
def update_dropout(self, dropout, attention_dropout): |
|
super(TransformerDecoderLayer, self).update_dropout(dropout, attention_dropout) |
|
self.context_attn.update_dropout(attention_dropout) |
|
|
|
def _forward( |
|
self, |
|
layer_in, |
|
enc_out, |
|
src_pad_mask, |
|
tgt_pad_mask, |
|
step=None, |
|
future=False, |
|
return_attn=False, |
|
): |
|
"""A naive forward pass for transformer decoder. |
|
|
|
# T: could be 1 in the case of stepwise decoding or tgt_len |
|
|
|
Args: |
|
layer_in (FloatTensor): ``(batch_size, T, model_dim)`` |
|
enc_out (FloatTensor): ``(batch_size, src_len, model_dim)`` |
|
src_pad_mask (bool): ``(batch_size, 1, src_len)`` |
|
tgt_pad_mask (bool): ``(batch_size, 1, T)`` |
|
step (int or None): stepwise decoding counter |
|
future (bool): If set True, do not apply future_mask. |
|
return_attn (bool) : if set True requires attns output |
|
|
|
Returns: |
|
(FloatTensor, FloatTensor): |
|
|
|
* layer_out ``(batch_size, T, model_dim)`` |
|
* attns ``(batch_size, head, T, src_len)`` |
|
|
|
""" |
|
dec_mask = None |
|
src_pad_mask = src_pad_mask.unsqueeze(1) |
|
|
|
if layer_in.size(1) > 1: |
|
|
|
dec_mask = self._compute_dec_mask(tgt_pad_mask, future) |
|
dec_mask = dec_mask.unsqueeze(1) |
|
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1) |
|
src_pad_mask = src_pad_mask.expand(-1, -1, dec_mask.size(3), -1) |
|
|
|
|
|
|
|
norm_layer_in = self.layer_norm_1(layer_in) |
|
|
|
self_attn, _ = self._forward_self_attn(norm_layer_in, dec_mask, step) |
|
|
|
if self.parallel_residual: |
|
ctx_attn, attns = self.context_attn( |
|
enc_out, |
|
enc_out, |
|
norm_layer_in, |
|
mask=src_pad_mask, |
|
return_attn=return_attn, |
|
) |
|
|
|
layer_out = ( |
|
self.feed_forward(norm_layer_in) |
|
- norm_layer_in |
|
+ layer_in |
|
+ self.dropout(self_attn) |
|
+ ctx_attn |
|
) |
|
else: |
|
query = self.dropout(self_attn) + layer_in |
|
norm_query = self.layer_norm_2(query) |
|
ctx_attn, attns = self.context_attn( |
|
enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn |
|
) |
|
layer_out = self.feed_forward(self.dropout(ctx_attn) + query) |
|
|
|
return layer_out, attns |
|
|
|
|
|
class TransformerDecoderBase(DecoderBase): |
|
def __init__( |
|
self, d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps |
|
): |
|
super(TransformerDecoderBase, self).__init__() |
|
|
|
self.embeddings = embeddings |
|
|
|
|
|
self.state = {} |
|
|
|
|
|
|
|
|
|
self._copy = copy_attn |
|
if layer_norm == "standard": |
|
self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps) |
|
elif layer_norm == "rms": |
|
self.layer_norm = RMSNorm(d_model, eps=norm_eps) |
|
else: |
|
raise ValueError(f"{layer_norm} layer norm type is not supported") |
|
|
|
self.alignment_layer = alignment_layer |
|
|
|
@classmethod |
|
def from_opt(cls, opt, embeddings): |
|
"""Alternate constructor.""" |
|
return cls( |
|
opt.dec_layers, |
|
opt.dec_hid_size, |
|
opt.heads, |
|
opt.transformer_ff, |
|
opt.copy_attn, |
|
opt.self_attn_type, |
|
opt.dropout[0] if type(opt.dropout) is list else opt.dropout, |
|
opt.attention_dropout[0] |
|
if type(opt.attention_dropout) is list |
|
else opt.attention_dropout, |
|
embeddings, |
|
opt.max_relative_positions, |
|
opt.relative_positions_buckets, |
|
opt.aan_useffn, |
|
opt.full_context_alignment, |
|
opt.alignment_layer, |
|
alignment_heads=opt.alignment_heads, |
|
pos_ffn_activation_fn=opt.pos_ffn_activation_fn, |
|
add_qkvbias=opt.add_qkvbias, |
|
num_kv=opt.num_kv, |
|
add_ffnbias=opt.add_ffnbias, |
|
parallel_residual=opt.parallel_residual, |
|
shared_layer_norm=opt.shared_layer_norm, |
|
layer_norm=opt.layer_norm, |
|
norm_eps=opt.norm_eps, |
|
use_ckpting=opt.use_ckpting, |
|
parallel_gpu=opt.world_size |
|
if opt.parallel_mode == "tensor_parallel" |
|
else 1, |
|
) |
|
|
|
def init_state(self, src, enc_out, enc_final_hs): |
|
"""Initialize decoder state.""" |
|
self.state["src"] = src |
|
|
|
def map_state(self, fn): |
|
if self.state["src"] is not None: |
|
self.state["src"] = fn(self.state["src"], 0) |
|
for layer in self.transformer_layers: |
|
if hasattr(layer, "context_attn"): |
|
if layer.context_attn.layer_cache[1]["keys"].numel() != 0: |
|
x = fn(layer.context_attn.layer_cache[1]["keys"], 0) |
|
y = fn(layer.context_attn.layer_cache[1]["values"], 0) |
|
layer.context_attn.layer_cache = True, {"keys": x, "values": y} |
|
if isinstance(layer.self_attn, AverageAttention): |
|
if layer.self_attn.layer_cache[1]["prev_g"].numel() != 0: |
|
x = fn(layer.self_attn.layer_cache[1]["prev_g"], 0) |
|
layer.self_attn.layer_cache = True, {"prev_g": x} |
|
else: |
|
if layer.self_attn.layer_cache[1]["keys"].numel() != 0: |
|
x = fn(layer.self_attn.layer_cache[1]["keys"], 0) |
|
y = fn(layer.self_attn.layer_cache[1]["values"], 0) |
|
layer.self_attn.layer_cache = True, {"keys": x, "values": y} |
|
|
|
def detach_state(self): |
|
raise NotImplementedError |
|
|
|
def forward(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
def update_dropout(self, dropout, attention_dropout): |
|
self.embeddings.update_dropout(dropout) |
|
for layer in self.transformer_layers: |
|
layer.update_dropout(dropout, attention_dropout) |
|
|
|
|
|
class TransformerDecoder(TransformerDecoderBase): |
|
"""The Transformer decoder from "Attention is All You Need". |
|
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17` |
|
|
|
Args: |
|
num_layers (int): number of decoder layers. |
|
d_model (int): size of the model |
|
heads (int): number of heads |
|
d_ff (int): size of the inner FF layer |
|
copy_attn (bool): if using a separate copy attention |
|
self_attn_type (str): type of self-attention scaled-dot, average |
|
dropout (float): dropout in residual, self-attn(dot) and feed-forward |
|
attention_dropout (float): dropout in context_attn (and self-attn(avg)) |
|
embeddings (onmt.modules.Embeddings): |
|
embeddings to use, should have positional encodings |
|
max_relative_positions (int): |
|
Max distance between inputs in relative positions representations |
|
relative_positions_buckets (int): |
|
Number of buckets when using relative position bias |
|
aan_useffn (bool): Turn on the FFN layer in the AAN decoder |
|
full_context_alignment (bool): |
|
whether enable an extra full context decoder forward for alignment |
|
alignment_layer (int): N° Layer to supervise with for alignment guiding |
|
alignment_heads (int): |
|
N. of cross attention heads to use for alignment guiding |
|
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear |
|
layer_norm (string): type of layer normalization standard/rms |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_layers, |
|
d_model, |
|
heads, |
|
d_ff, |
|
copy_attn, |
|
self_attn_type, |
|
dropout, |
|
attention_dropout, |
|
embeddings, |
|
max_relative_positions, |
|
relative_positions_buckets, |
|
aan_useffn, |
|
full_context_alignment, |
|
alignment_layer, |
|
alignment_heads, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
add_qkvbias=False, |
|
num_kv=0, |
|
add_ffnbias=True, |
|
parallel_residual=False, |
|
shared_layer_norm=False, |
|
layer_norm="standard", |
|
norm_eps=1e-6, |
|
use_ckpting=[], |
|
parallel_gpu=1, |
|
): |
|
super(TransformerDecoder, self).__init__( |
|
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps |
|
) |
|
|
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
TransformerDecoderLayer( |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type=self_attn_type, |
|
max_relative_positions=max_relative_positions, |
|
relative_positions_buckets=relative_positions_buckets, |
|
aan_useffn=aan_useffn, |
|
full_context_alignment=full_context_alignment, |
|
alignment_heads=alignment_heads, |
|
pos_ffn_activation_fn=pos_ffn_activation_fn, |
|
add_qkvbias=add_qkvbias, |
|
num_kv=num_kv, |
|
add_ffnbias=add_ffnbias, |
|
parallel_residual=parallel_residual, |
|
shared_layer_norm=shared_layer_norm, |
|
layer_norm=layer_norm, |
|
norm_eps=norm_eps, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
def detach_state(self): |
|
self.state["src"] = self.state["src"].detach() |
|
|
|
def forward(self, tgt, enc_out=None, step=None, **kwargs): |
|
""" |
|
Decode, possibly stepwise. |
|
when training step is always None, when decoding, step increases |
|
tgt (Tensor): batch x tlen x feats |
|
enc_out (Tensor): encoder output (batch x slen x model_dim) |
|
""" |
|
if enc_out is None: |
|
enc_out = self.embeddings(tgt) |
|
if step == 0: |
|
self._init_cache(enc_out) |
|
elif step is None: |
|
for layer in self.transformer_layers: |
|
if isinstance(layer.self_attn, AverageAttention): |
|
layer.self_attn.layer_cache = False, {"prev_g": torch.tensor([])} |
|
else: |
|
layer.self_attn.layer_cache = ( |
|
False, |
|
{"keys": torch.tensor([]), "values": torch.tensor([])}, |
|
) |
|
layer.context_attn.layer_cache = ( |
|
False, |
|
{"keys": torch.tensor([]), "values": torch.tensor([])}, |
|
) |
|
|
|
emb = self.embeddings(tgt, step=step) |
|
dec_out = emb |
|
assert emb.dim() == 3 |
|
|
|
pad_idx = self.embeddings.word_padding_idx |
|
src_lens = kwargs["src_len"] |
|
src_max_len = self.state["src"].shape[1] |
|
src_pad_mask = ~sequence_mask(src_lens, src_max_len) |
|
src_pad_mask = src_pad_mask.unsqueeze(1) |
|
tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) |
|
|
|
with_align = kwargs.pop("with_align", False) |
|
return_attn = with_align or self._copy |
|
attn_aligns = [] |
|
|
|
for layer in self.transformer_layers: |
|
dec_out, attn, attn_align = layer( |
|
dec_out, |
|
enc_out, |
|
src_pad_mask, |
|
tgt_pad_mask, |
|
step=step, |
|
with_align=with_align, |
|
return_attn=return_attn, |
|
) |
|
if attn_align is not None: |
|
attn_aligns.append(attn_align) |
|
|
|
dec_out = self.layer_norm(dec_out) |
|
|
|
attns = {"std": attn} |
|
if self._copy: |
|
attns["copy"] = attn |
|
if with_align: |
|
attns["align"] = attn_aligns[self.alignment_layer] |
|
|
|
|
|
|
|
return dec_out, attns |
|
|
|
def _init_cache(self, enc_out): |
|
batch_size = enc_out.size(0) |
|
depth = enc_out.size(-1) |
|
|
|
for layer in self.transformer_layers: |
|
|
|
|
|
layer.context_attn.layer_cache = ( |
|
True, |
|
{ |
|
"keys": torch.tensor([], device=enc_out.device), |
|
"values": torch.tensor([], device=enc_out.device), |
|
}, |
|
) |
|
if isinstance(layer.self_attn, AverageAttention): |
|
layer.self_attn.layer_cache = True, { |
|
"prev_g": torch.zeros( |
|
(batch_size, 1, depth), device=enc_out.device |
|
).to(enc_out.dtype) |
|
} |
|
else: |
|
layer.self_attn.layer_cache = ( |
|
True, |
|
{ |
|
"keys": torch.tensor([], device=enc_out.device), |
|
"values": torch.tensor([], device=enc_out.device), |
|
}, |
|
) |
|
|
|
|
|
class TransformerLMDecoderLayer(TransformerDecoderLayerBase): |
|
"""Transformer Decoder only layer block in GPT style. |
|
Args: |
|
See TransformerDecoderLayerBase |
|
""" |
|
|
|
def _forward( |
|
self, layer_in, tgt_pad_mask, step=None, future=False, return_attn=False |
|
): |
|
"""A naive forward pass for transformer decoder. |
|
|
|
# T: could be 1 in the case of stepwise decoding or tgt_len |
|
|
|
Args: |
|
layer_in (FloatTensor): ``(batch_size, T, model_dim)`` |
|
tgt_pad_mask (bool): ``(batch_size, 1, T)`` |
|
layer_cache (dict or None): cached layer info when stepwise decode |
|
step (int or None): stepwise decoding counter |
|
future (bool): If set True, do not apply future_mask. |
|
return_attn (bool): If set True return attn |
|
|
|
Returns: |
|
(FloatTensor, FloatTensor): |
|
|
|
* layer_out ``(batch_size, T, model_dim)`` |
|
* attns ``(batch_size, head, T, T)`` |
|
|
|
""" |
|
dec_mask = None |
|
|
|
if layer_in.size(1) > 1: |
|
|
|
dec_mask = self._compute_dec_mask(tgt_pad_mask, future) |
|
dec_mask = dec_mask.unsqueeze(1) |
|
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1) |
|
|
|
|
|
|
|
norm_layer_in = self.layer_norm_1(layer_in) |
|
|
|
attn_output, attns = self._forward_self_attn( |
|
norm_layer_in, dec_mask, step, return_attn=return_attn |
|
) |
|
|
|
if self.parallel_residual: |
|
|
|
if not self.shared_layer_norm: |
|
norm_res_layer_in = self.layer_norm_res(layer_in) |
|
ff_in = norm_res_layer_in |
|
else: |
|
ff_in = norm_layer_in |
|
layer_out = ( |
|
self.feed_forward(ff_in) - ff_in + layer_in + self.dropout(attn_output) |
|
) |
|
else: |
|
layer_out = self.dropout(attn_output) + layer_in |
|
layer_out = self.feed_forward(layer_out) |
|
|
|
return layer_out, attns |
|
|
|
|
|
class TransformerLMDecoder(TransformerDecoderBase): |
|
"""The Transformer decoder from GPT-2 |
|
Args: |
|
num_layers (int): number of decoder layers. |
|
d_model (int): size of the model |
|
heads (int): number of heads |
|
d_ff (int): size of the inner FF layer |
|
copy_attn (bool): if using a separate copy attention |
|
self_attn_type (str): type of self-attention scaled-dot, average |
|
dropout (float): dropout in residual, self-attn(dot) and feed-forward |
|
attention_dropout (float): dropout in context_attn (and self-attn(avg)) |
|
embeddings (onmt.modules.Embeddings): |
|
embeddings to use, should have positional encodings |
|
max_relative_positions (int): |
|
Max distance between inputs in relative positions representations |
|
relative_positions_buckets (int): |
|
Number of buckets when using Relative positions bias |
|
aan_useffn (bool): Turn on the FFN layer in the AAN decoder |
|
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_layers, |
|
d_model, |
|
heads, |
|
d_ff, |
|
copy_attn, |
|
self_attn_type, |
|
dropout, |
|
attention_dropout, |
|
embeddings, |
|
max_relative_positions, |
|
relative_positions_buckets, |
|
aan_useffn, |
|
full_context_alignment=None, |
|
alignment_layer=None, |
|
alignment_heads=None, |
|
pos_ffn_activation_fn=ActivationFunction.relu, |
|
add_qkvbias=False, |
|
num_kv=0, |
|
add_ffnbias=True, |
|
parallel_residual=False, |
|
shared_layer_norm=False, |
|
layer_norm="standard", |
|
norm_eps=1e-6, |
|
use_ckpting=[], |
|
parallel_gpu=1, |
|
): |
|
super(TransformerLMDecoder, self).__init__( |
|
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps |
|
) |
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
TransformerLMDecoderLayer( |
|
d_model, |
|
heads, |
|
d_ff, |
|
dropout, |
|
attention_dropout, |
|
self_attn_type=self_attn_type, |
|
max_relative_positions=max_relative_positions, |
|
relative_positions_buckets=relative_positions_buckets, |
|
aan_useffn=aan_useffn, |
|
full_context_alignment=None, |
|
alignment_heads=None, |
|
pos_ffn_activation_fn=pos_ffn_activation_fn, |
|
add_qkvbias=add_qkvbias, |
|
num_kv=num_kv, |
|
add_ffnbias=add_ffnbias, |
|
parallel_residual=parallel_residual, |
|
shared_layer_norm=shared_layer_norm, |
|
layer_norm=layer_norm, |
|
norm_eps=norm_eps, |
|
use_ckpting=use_ckpting, |
|
parallel_gpu=parallel_gpu, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
def init_state(self, src=None, enc_out=None, enc_final_hs=None): |
|
super(TransformerLMDecoder, self).init_state(None, None, None) |
|
|
|
def detach_state(self): |
|
pass |
|
|
|
def forward(self, tgt, enc_out=None, step=None, **kwargs): |
|
"""Decode, possibly stepwise.""" |
|
if step == 0: |
|
self._init_cache(tgt) |
|
elif step is None: |
|
for layer in self.transformer_layers: |
|
layer.self_attn.layer_cache = ( |
|
False, |
|
{"keys": torch.tensor([]), "values": torch.tensor([])}, |
|
) |
|
|
|
dec_out = self.embeddings(tgt, step=step) |
|
|
|
assert dec_out.dim() == 3 |
|
|
|
pad_idx = self.embeddings.word_padding_idx |
|
tgt_pad_mask = tgt[:, :, 0].eq(pad_idx).unsqueeze(1) |
|
|
|
with_align = kwargs.pop("with_align", False) |
|
return_attn = with_align or self._copy |
|
assert not with_align, "TransformerLMDecoder does not support align" |
|
|
|
for layer in self.transformer_layers: |
|
dec_out, attn, _ = layer( |
|
dec_out, |
|
tgt_pad_mask, |
|
step=step, |
|
with_align=with_align, |
|
return_attn=return_attn, |
|
) |
|
|
|
dec_out = self.layer_norm(dec_out) |
|
|
|
attns = {"std": attn} |
|
if self._copy: |
|
attns["copy"] = attn |
|
|
|
|
|
return dec_out, attns |
|
|
|
def _init_cache(self, tgt=None): |
|
for layer in self.transformer_layers: |
|
if isinstance(layer.self_attn, AverageAttention): |
|
raise NotImplementedError |
|
else: |
|
layer.self_attn.layer_cache = ( |
|
True, |
|
{ |
|
"keys": torch.tensor([], device=tgt.device), |
|
"values": torch.tensor([], device=tgt.device), |
|
}, |
|
) |
|
|