""" Alibi position bias """ import math import torch import torch.nn as nn import torch.nn.functional as F def pad_at_dim(t, pad, dim=-1, value=0.0): dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = (0, 0) * dims_from_right return F.pad(t, (*zeros, *pad), value=value) class AlibiPositionalBias(nn.Module): def __init__(self, heads, **kwargs): super().__init__() self.heads = heads slopes = torch.Tensor(self._get_slopes(heads)) slopes = slopes.unsqueeze(1).unsqueeze(1) self.register_buffer("slopes", slopes, persistent=False) self.register_buffer("bias", None, persistent=False) def get_bias(self, i, j, device): i_arange = torch.arange(j - i, j, device=device) j_arange = torch.arange(j, device=device) bias = -torch.abs( j_arange.unsqueeze(0).unsqueeze(0) - i_arange.unsqueeze(1).unsqueeze(0) ) return bias @staticmethod def _get_slopes(heads): def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(heads).is_integer(): return get_slopes_power_of_2(heads) closest_power_of_2 = 2 ** math.floor(math.log2(heads)) return ( get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ : heads - closest_power_of_2 ] ) def forward(self, qk_dots): h, i, j, device = *qk_dots.shape[-3:], qk_dots.device if (self.bias is not None) and self.bias.shape[-1] >= j: return qk_dots + self.bias[..., :i, :j] bias = self.get_bias(i, j, device) bias = bias * self.slopes num_heads_unalibied = h - bias.shape[0] bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=0) self.register_buffer("bias", bias, persistent=False) return qk_dots + self.bias