""" Multi-Head Attention module """ import math import torch from torch import Tensor from typing import Optional, Tuple from torch.nn import functional as F import torch.nn as nn from torch.utils.checkpoint import checkpoint from torch.nn.utils import skip_init from .alibi_position_bias import AlibiPositionalBias import torch.distributed as dist # Help functions for Rotary Embeddings # https://arxiv.org/pdf/2104.09864.pdf # too convoluted to make maxseqlen a parameter. # we suppose src_seq_len at training and max_length at inference # are both < 2048 tokens. def rotaryembeddings(dim: int, maxseqlen=8192, base=10000): inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) tmax = torch.arange(maxseqlen, device=inv_freq.device) rope = torch.outer(tmax, inv_freq).float() # rope is now matrix [maxseqlen, dim/2] rope = torch.polar(torch.ones_like(rope), rope) return rope def apply_rotary_emb(query, key, rope): query = query.transpose(1, 2) key = key.transpose(1, 2) query_ = query.float().reshape(*query.shape[:-1], -1, 2) query_ = torch.view_as_complex(query_) key_ = key.float().reshape(*key.shape[:-1], -1, 2) key_ = torch.view_as_complex(key_) rope = rope.view(1, query_.size(1), 1, query_.size(3)) query_out = torch.view_as_real(query_ * rope).flatten(3) key_out = torch.view_as_real(key_ * rope).flatten(3) return query_out.transpose(1, 2).type_as(query), key_out.transpose(1, 2).type_as( key ) # Help functions for max_relative positions # https://arxiv.org/abs/1803.02155 def relative_matmul(x: Tensor, z: Tensor, transpose: bool) -> Tensor: """ Helper function for relative positions attention. https://arxiv.org/pdf/1803.02155.pdf x shape [batch_size x heads x q_len x k_len] """ batch_size = x.size(0) heads = x.size(1) length = x.size(2) x_t = x.permute(2, 0, 1, 3) x_t_r = x_t.contiguous().view(length, heads * batch_size, -1) if transpose: z = z.transpose(1, 2) x_tz_matmul = torch.matmul(x_t_r, z) x_tz_matmul_r = x_tz_matmul.view(length, batch_size, heads, -1) x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3) return x_tz_matmul_r_t def gen_relative_positions( length: int, max_relative_positions: int, cache: bool = False, device: Optional[torch.device] = None, ) -> Tensor: """Generate the clipped relative positions matrix for a given length and maximum relative positions""" if cache: distance_mat = torch.arange(-length + 1, 1, 1, device=device).unsqueeze(0) else: range_vec = torch.arange(length, device=device) range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1) distance_mat = range_mat - range_mat.transpose(0, 1) distance_mat_clipped = torch.clamp( distance_mat, min=-max_relative_positions, max=max_relative_positions ) # Shift values to be >= 0 final_mat = distance_mat_clipped + max_relative_positions return final_mat def _relative_position_bucket( relative_position, bidirectional=True, num_buckets=32, max_distance=128 ): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/ mesh_tensorflow/transformer/transformer_layers.py#L593 Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ relative_buckets = 0 if bidirectional: num_buckets //= 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: relative_position = -torch.min( relative_position, torch.zeros_like(relative_position) ) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions # up to max_distance relative_position_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1), ) relative_buckets += torch.where( is_small, relative_position, relative_position_if_large ) return relative_buckets def compute_bias( query_length, key_length, is_decoder, max_relative_positions, relative_positions_buckets, device=None, ): """Compute binned relative position bias""" context_position = torch.arange(query_length, dtype=torch.long, device=device)[ :, None ] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = ( memory_position - context_position ) # shape (query_length, key_length) relative_position_bucket = _relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not is_decoder), num_buckets=relative_positions_buckets, max_distance=max_relative_positions, ) return relative_position_bucket # Help functions to split model dim per head def shape(x: Tensor, dim_per_head: int) -> Tensor: """ Projection. [batchsize x length x modeldim] -> [batchsize x heads x length x dimperhead] """ return x.view(x.size(0), x.size(1), -1, dim_per_head).transpose(1, 2) def unshape(x: Tensor) -> Tensor: """ Compute context. [batchsize x heads x length x dimperhead] -> [batchsize x length x modeldim] """ return x.transpose(1, 2).contiguous().view(x.size(0), -1, x.size(1) * x.size(3)) class MultiHeadedAttention(nn.Module): # class MultiHeadedAttention(torch.jit.ScriptModule): """Multi-Head Attention module from "Attention is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. Similar to standard `dot` attention but uses multiple attention distributions simulataneously to select relevant items. .. mermaid:: graph BT A[key] B[value] C[query] O[output] subgraph Attn D[Attn 1] E[Attn 2] F[Attn N] end A --> D C --> D A --> E C --> E A --> F C --> F D --> O E --> O F --> O B --> O Also includes several additional tricks. Args: head_count (int): number of parallel heads model_dim (int): the dimension of keys/values/queries, must be divisible by head_count dropout (float): dropout parameter max_relative_positions (int): max relative positions attn_type: "self" or "context" """ def __init__( self, head_count: int, model_dim: int, dropout: float = 0.1, is_decoder: bool = True, max_relative_positions: int = 0, relative_positions_buckets: int = 0, attn_type: str = None, add_qkvbias=False, num_kv=0, use_ckpting=[], parallel_gpu=1, ) -> None: assert ( model_dim % head_count == 0 ), "Model dimension must be divisible by the number of heads" self.dim_per_head = model_dim // head_count super(MultiHeadedAttention, self).__init__() self.head_count = head_count self.num_kv = num_kv self.parallel_gpu = parallel_gpu if num_kv == 0: assert ( model_dim % parallel_gpu == 0 ), "Model dimension must be divisible by the number of partitions" self.linear_keys = skip_init( nn.Linear, in_features=model_dim, out_features=model_dim // parallel_gpu, bias=add_qkvbias, ) self.linear_values = skip_init( nn.Linear, in_features=model_dim, out_features=model_dim // parallel_gpu, bias=add_qkvbias, ) else: assert ( self.dim_per_head * self.num_kv ) % parallel_gpu == 0, ( "Model dimension must be divisible by the number of partitions" ) self.linear_keys = skip_init( nn.Linear, in_features=model_dim, out_features=self.dim_per_head * self.num_kv // parallel_gpu, bias=add_qkvbias, ) self.linear_values = skip_init( nn.Linear, in_features=model_dim, out_features=self.dim_per_head * self.num_kv // parallel_gpu, bias=add_qkvbias, ) self.linear_query = skip_init( nn.Linear, in_features=model_dim, out_features=model_dim // parallel_gpu, bias=add_qkvbias, ) self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.final_linear = skip_init( nn.Linear, in_features=model_dim // parallel_gpu, out_features=model_dim, bias=add_qkvbias, ) self.is_decoder = is_decoder self.max_relative_positions = max_relative_positions self.relative_positions_buckets = relative_positions_buckets self.attn_type = attn_type self.layer_cache = ( False, {"keys": torch.tensor([]), "values": torch.tensor([])}, ) if relative_positions_buckets > 0: self.relative_attention_bias = nn.Embedding( relative_positions_buckets, head_count ) self.relative_positions_embeddings = None elif max_relative_positions > 0: # https://arxiv.org/pdf/1803.02155.pdf # in the paper they suggest either two embeds # relative_key / relative_value or only # relative_key. We implemented the same embed # for both. vocab_size = max_relative_positions * 2 + 1 self.relative_positions_embeddings = nn.Embedding( vocab_size, self.dim_per_head ) self.relative_attention_bias = None else: self.relative_positions_embeddings = None self.relative_attention_bias = None if max_relative_positions == -1: # rotary embeddings self.rope = rotaryembeddings(self.dim_per_head) if max_relative_positions == -2: # alibi positional bias self.alibi = AlibiPositionalBias(head_count) self.maybe_ckpt = checkpoint if "mha" in use_ckpting else lambda f, x: f(x) def update_dropout(self, dropout: float) -> None: self.dropout.p = dropout # @torch.jit.script_method def forward( self, key: Tensor, value: Tensor, query: Tensor, mask: Optional[Tensor] = None, step: Optional[int] = 0, return_attn: Optional[bool] = False, ) -> Tuple[Tensor, Tensor]: """ Compute the context vector and the attention vectors. Args: key (Tensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (Tensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (Tensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask 1/0 indicating which keys have zero / non-zero attention ``(batch, query_len, key_len)`` step (int): decoding step (used for Rotary embedding) Returns: (Tensor, Tensor): * output context vectors ``(batch, query_len, dim)`` * Attention vector in heads ``(batch, head, query_len, key_len)``. """ # 1) Project key, value, and query. # as a reminder at training layer_cache[0] remains False if self.layer_cache[0]: if self.attn_type == "self": query, key, value = ( self.linear_query(query), self.linear_keys(query), self.linear_values(query), ) query = shape(query, self.dim_per_head) key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) if self.max_relative_positions == -1: # Rotary Embeddings start_pos = step seqlen = query.size(2) rope = self.rope[start_pos : start_pos + seqlen].to(query.device) query, key = apply_rotary_emb(query, key, rope=rope) if self.layer_cache[1]["keys"].numel() != 0: key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) if self.layer_cache[1]["values"].numel() != 0: value = torch.cat((self.layer_cache[1]["values"], value), dim=2) self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value elif self.attn_type == "context": query = self.linear_query(query) query = shape(query, self.dim_per_head) if self.layer_cache[1]["keys"].numel() == 0: key, value = self.linear_keys(key), self.linear_values(value) key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) else: key, value = ( self.layer_cache[1]["keys"], self.layer_cache[1]["values"], ) self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value else: key = self.maybe_ckpt(self.linear_keys, key) value = self.maybe_ckpt(self.linear_values, value) query = self.maybe_ckpt(self.linear_query, query) key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) query = shape(query, self.dim_per_head) if self.max_relative_positions == -1: # Rotary Embeddings start_pos = 0 seqlen = query.size(2) rope = self.rope[start_pos : start_pos + seqlen].to(query.device) query, key = apply_rotary_emb(query, key, rope=rope) # expand key on heads dimension when it's less than query heads (multi-query variant) key = key.view(key.size(0), -1, 1, key.size(2), key.size(3)).repeat( 1, 1, query.size(1) // key.size(1), 1, 1 ) key = key.view(key.size(0), query.size(1), key.size(3), key.size(4)) # expand value on heads dimension when it's less than query heads (multi-query variant) value = value.view(value.size(0), -1, 1, value.size(2), value.size(3)).repeat( 1, 1, query.size(1) // value.size(1), 1, 1 ) value = value.view(value.size(0), query.size(1), value.size(3), value.size(4)) # 2) When standard pos. enc. or rotary, use flash attention if self.max_relative_positions in [-1, 0] and not return_attn: if self.is_decoder and self.attn_type == "self": attn_output = F.scaled_dot_product_attention( query, key, value, None, 0.0, is_causal=mask is not None ) else: attn_output = F.scaled_dot_product_attention( query, key, value, ~mask, 0.0, is_causal=False ) x = unshape(attn_output) attn_output = self.maybe_ckpt(self.final_linear, x) if self.parallel_gpu > 1: dist.all_reduce(attn_output) return attn_output, None else: query /= math.sqrt(self.dim_per_head) # batch x num_heads x query_len x key_len scores = torch.matmul(query, key.transpose(2, 3)) if self.relative_attention_bias is not None: q_len = key.size(2) if self.layer_cache[0] else query.size(2) relative_position_bucket = compute_bias( q_len, key.size(2), self.is_decoder, self.max_relative_positions, self.relative_positions_buckets, device=key.device, ) values = self.relative_attention_bias( relative_position_bucket ) # shape (query_length, key_length, num_heads) position_bias = values.permute([2, 0, 1]).unsqueeze( 0 ) # shape (1, num_heads, query_length, key_length) if self.layer_cache[0]: position_bias = position_bias[:, :, -query.size(2) :, :] scores.add_(position_bias) elif self.relative_positions_embeddings is not None: key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = gen_relative_positions( key_len, self.max_relative_positions, cache=self.layer_cache[0], device=key.device, ) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix ) scores.add_(relative_matmul(query, relations_keys, True)) elif self.max_relative_positions == -2: # Alibi scores = self.alibi(scores) scores = scores.float() if mask is not None: # not 100% necessary but expand to nb of heads mask = mask.expand(-1, self.head_count // self.parallel_gpu, -1, -1) # now mask and scores have the same shape scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.relative_positions_embeddings is not None: # We use the same embeddings for key and value relations_values = relations_keys context_original.add_( relative_matmul(drop_attn, relations_values, False) ) context = unshape(context_original) attn_output = self.maybe_ckpt(self.final_linear, context) if self.parallel_gpu > 1: dist.all_reduce(attn_output) return attn_output, attn