ReactSeq / onmt /modules /alibi_position_bias.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
2.05 kB
""" Alibi position bias """
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def pad_at_dim(t, pad, dim=-1, value=0.0):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), value=value)
class AlibiPositionalBias(nn.Module):
def __init__(self, heads, **kwargs):
super().__init__()
self.heads = heads
slopes = torch.Tensor(self._get_slopes(heads))
slopes = slopes.unsqueeze(1).unsqueeze(1)
self.register_buffer("slopes", slopes, persistent=False)
self.register_buffer("bias", None, persistent=False)
def get_bias(self, i, j, device):
i_arange = torch.arange(j - i, j, device=device)
j_arange = torch.arange(j, device=device)
bias = -torch.abs(
j_arange.unsqueeze(0).unsqueeze(0) - i_arange.unsqueeze(1).unsqueeze(0)
)
return bias
@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
: heads - closest_power_of_2
]
)
def forward(self, qk_dots):
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
if (self.bias is not None) and self.bias.shape[-1] >= j:
return qk_dots + self.bias[..., :i, :j]
bias = self.get_bias(i, j, device)
bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[0]
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=0)
self.register_buffer("bias", bias, persistent=False)
return qk_dots + self.bias