|
""" Alibi position bias """ |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def pad_at_dim(t, pad, dim=-1, value=0.0): |
|
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) |
|
zeros = (0, 0) * dims_from_right |
|
return F.pad(t, (*zeros, *pad), value=value) |
|
|
|
|
|
class AlibiPositionalBias(nn.Module): |
|
def __init__(self, heads, **kwargs): |
|
super().__init__() |
|
self.heads = heads |
|
slopes = torch.Tensor(self._get_slopes(heads)) |
|
slopes = slopes.unsqueeze(1).unsqueeze(1) |
|
self.register_buffer("slopes", slopes, persistent=False) |
|
self.register_buffer("bias", None, persistent=False) |
|
|
|
def get_bias(self, i, j, device): |
|
i_arange = torch.arange(j - i, j, device=device) |
|
j_arange = torch.arange(j, device=device) |
|
bias = -torch.abs( |
|
j_arange.unsqueeze(0).unsqueeze(0) - i_arange.unsqueeze(1).unsqueeze(0) |
|
) |
|
return bias |
|
|
|
@staticmethod |
|
def _get_slopes(heads): |
|
def get_slopes_power_of_2(n): |
|
start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
|
ratio = start |
|
return [start * ratio**i for i in range(n)] |
|
|
|
if math.log2(heads).is_integer(): |
|
return get_slopes_power_of_2(heads) |
|
closest_power_of_2 = 2 ** math.floor(math.log2(heads)) |
|
|
|
return ( |
|
get_slopes_power_of_2(closest_power_of_2) |
|
+ get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ |
|
: heads - closest_power_of_2 |
|
] |
|
) |
|
|
|
def forward(self, qk_dots): |
|
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device |
|
|
|
if (self.bias is not None) and self.bias.shape[-1] >= j: |
|
return qk_dots + self.bias[..., :i, :j] |
|
|
|
bias = self.get_bias(i, j, device) |
|
bias = bias * self.slopes |
|
|
|
num_heads_unalibied = h - bias.shape[0] |
|
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=0) |
|
self.register_buffer("bias", bias, persistent=False) |
|
|
|
return qk_dots + self.bias |
|
|