jbochi's picture
Fix mistake in parallel layers
06e4cba unverified
raw
history blame
24.2 kB
import copy
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.models.t5 import modeling_t5
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import (
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from decoder_only_t5.config import DecoderOnlyT5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DecoderOnlyT5Config"
class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
def __init__(self, config: DecoderOnlyT5Config):
super(modeling_t5.T5LayerFF, self).__init__()
if config.is_gated_act:
self.DenseReluDense = modeling_t5.T5DenseGatedActDense(config)
else:
self.DenseReluDense = modeling_t5.T5DenseActDense(config)
if not config.parallel_layers:
self.layer_norm = modeling_t5.T5LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
else:
self.layer_norm = nn.Identity()
self.dropout = nn.Dropout(config.dropout_rate)
# https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/llama/modeling_llama.py#L263
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class DecoderOnlyT5Attention(modeling_t5.T5Attention):
"""
Supports both multi-head and multi-query attention.
https://arxiv.org/abs/1911.02150
https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/attention/dense_attention.py#L292
"""
def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False):
super(modeling_t5.T5Attention, self).__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.n_kv_heads = 1 if config.multi_query_attention else self.n_heads
self.n_kv_groups = self.n_heads // self.n_kv_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.kv_inner_dim = self.n_kv_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(
self.relative_attention_num_buckets, self.n_heads
)
self.pruned_heads = set()
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_kv_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
if past_key_value is not None:
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += (
past_key_value[0].shape[2] if query_length is None else query_length
)
key_length = (
real_seq_length if key_value_states is None else key_value_states.shape[1]
)
def shape(states, n_heads):
"""projection"""
return states.view(
batch_size, -1, n_heads, self.key_value_proj_dim
).transpose(1, 2)
def unshape(states):
"""reshape"""
return (
states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_kv_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states), self.n_kv_heads)
elif past_key_value is None:
# cross-attn
# (batch_size, n_kv_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_kv_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_kv_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
# get query states
query_states = shape(
self.q(hidden_states), self.n_heads
) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = repeat_kv(
project(
hidden_states,
self.k,
key_value_states,
past_key_value[0] if past_key_value is not None else None,
),
self.n_kv_groups,
)
value_states = repeat_kv(
project(
hidden_states,
self.v,
key_value_states,
past_key_value[1] if past_key_value is not None else None,
),
self.n_kv_groups,
)
# compute scores
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length),
device=scores.device,
dtype=scores.dtype,
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device
)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None:
position_bias = (
position_bias + mask
) # (batch_size, n_heads, seq_length, key_length)
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_output = unshape(
torch.matmul(attn_weights, value_states)
) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
present_key_value_state = (
(key_states, value_states) if (self.is_decoder and use_cache) else None
)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention):
def __init__(self, config, has_relative_attention_bias=False):
super(modeling_t5.T5LayerSelfAttention, self).__init__()
self.SelfAttention = DecoderOnlyT5Attention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.layer_norm = modeling_t5.T5LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
self.dropout = nn.Dropout(config.dropout_rate)
self.parallel_layers = config.parallel_layers
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
if not self.parallel_layers:
x = self.layer_norm(hidden_states)
else:
x = hidden_states
attention_output = self.SelfAttention(
x,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
if not self.parallel_layers:
# When parallel_layers is True, the residual connection is applied
# in the decoder block instead of here.
hidden_states = hidden_states + self.dropout(attention_output[0])
else:
hidden_states = attention_output[0]
outputs = (hidden_states,) + attention_output[
1:
] # add attentions if we output them
return outputs
class DecoderOnlyT5Block(modeling_t5.T5Block):
def __init__(self, config, has_relative_attention_bias=False):
super(modeling_t5.T5Block, self).__init__()
self.is_decoder = config.is_decoder
self.is_decoder_only = config.is_decoder_only
self.layer = nn.ModuleList()
self.layer.append(
DecoderOnlyT5LayerSelfAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
)
if self.is_decoder:
if config.is_decoder_only:
self.layer.append(nn.Identity())
else:
self.layer.append(modeling_t5.T5LayerCrossAttention(config))
self.parallel_layers = config.parallel_layers
self.layer.append(DecoderOnlyT5LayerFF(config))
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
return_dict=True,
):
if past_key_value is not None:
if not self.is_decoder:
logger.warning(
"`past_key_values` is passed to the encoder. Please make sure this is intended."
)
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
ff_layer = self.layer[-1]
if self.parallel_layers:
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L563-L568
x = self.layer[0].layer_norm(hidden_states)
ff_output = ff_layer(x)
else:
x = hidden_states
self_attention_outputs = self.layer[0](
x,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
x, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[
2:
] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if x.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(x).any(),
torch.finfo(x.dtype).max - 1000,
torch.finfo(x.dtype).max,
)
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
do_cross_attention = (
self.is_decoder
and not self.is_decoder_only
and encoder_hidden_states is not None
)
if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1](
x,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
)
x = cross_attention_outputs[0]
# clamp inf values to enable fp16 training
if x.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(x).any(),
torch.finfo(x.dtype).max - 1000,
torch.finfo(x.dtype).max,
)
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = (
present_key_value_state + cross_attention_outputs[1]
)
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]
if self.parallel_layers:
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
hidden_states = x + ff_output
hidden_states *= 2**-0.5
hidden_states = hidden_states + self.layer[0].dropout(hidden_states)
else:
hidden_states = ff_layer(x)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
outputs = (hidden_states,)
if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs
else:
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class DecoderOnlyT5Stack(modeling_t5.T5Stack):
def __init__(self, config, embed_tokens=None):
super(modeling_t5.T5Stack, self).__init__(config)
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.block = nn.ModuleList(
[
DecoderOnlyT5Block(
config,
has_relative_attention_bias=(
config.has_relative_attention_bias and bool(i == 0)
),
)
for i in range(config.num_layers)
]
)
if not config.parallel_layers:
self.final_layer_norm = modeling_t5.T5LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
else:
self.final_layer_norm = nn.Identity()
self.dropout = nn.Dropout(config.dropout_rate)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
def __init__(self, config: DecoderOnlyT5Config):
super(modeling_t5.T5ForConditionalGeneration, self).__init__(config)
self.model_dim = config.d_model
self.shared = nn.Embedding(config.vocab_size, config.d_model)
assert (
self.config.num_layers == 0
), "Decoder only model cannot have encoder layers"
self.encoder = None
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = DecoderOnlyT5Stack(decoder_config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
labels in `[0, ..., config.vocab_size]`
Returns:
Examples:
```"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
if input_ids is not None:
input_ids = input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
# Decode
outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
self.lm_head = self.lm_head.to(self.decoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)