File size: 8,347 Bytes
589d80c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
# Copyright (c) 2023, Dan Fu and Simran Arora.
# Adapted from https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import opt_einsum as oe
contract = oe.contract
""" Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """
class OptimModule(nn.Module):
""" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """
def register(self, name, tensor, lr=None, wd=0.0):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {}
if lr is not None: optim["lr"] = lr
if wd is not None: optim["weight_decay"] = wd
setattr(getattr(self, name), "_optim", optim)
def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
# u.shape: B H L
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
if k_rev is not None:
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
k_f = k_f + k_rev_f.conj()
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3:
k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
out = y + u * D
if gelu:
out = F.gelu(out)
if dropout_mask is not None:
return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
else:
return out.to(dtype=u.dtype)
@torch.jit.script
def mul_sum(q, y):
return (q * y).sum(dim=1)
class Sin(nn.Module):
def __init__(self, dim, w=10, w_mod=1, train_freq=True):
super().__init__()
init_tensor = torch.ones(1, dim)
self.freq = (
nn.Parameter(w * init_tensor)
if train_freq
else w * torch.ones(1, dim)
)
self.w_mod = w_mod
def forward(self, x):
return torch.sin(self.w_mod * self.freq * x)
class PositionalEmbedding(OptimModule):
def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs):
"""Complex exponential positional embeddings for Hyena filters."""
super().__init__()
self.seq_len = seq_len
# The time embedding fed to the filteres is normalized so that t_f = 1
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
if emb_dim > 1:
bands = (emb_dim - 1) // 2
# To compute the right embeddings we use the "proper" linspace
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
z = torch.exp(-1j * f * w)
z = torch.cat([t, z.real, z.imag], dim=-1)
self.register("z", z, lr=lr_pos_emb)
self.register("t", t, lr=0.0)
def forward(self, L):
return self.z[:, :L], self.t[:, :L]
class ExponentialModulation(OptimModule):
def __init__(
self,
d_model,
fast_decay_pct=0.3,
slow_decay_pct=1.5,
target=1e-2,
modulation_lr=0.0,
shift: float = 0.0,
**kwargs,
):
super().__init__()
self.shift = shift
max_decay = math.log(target) / fast_decay_pct
min_decay = math.log(target) / slow_decay_pct
deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
self.register("deltas", deltas, lr=modulation_lr)
def forward(self, t, x):
decay = torch.exp(-t * self.deltas.abs())
x = x * (decay + self.shift)
return x
class HyenaFilter(OptimModule):
def __init__(
self,
d_model,
emb_dim=3, # dim of input to MLP, augments with positional encoding
order=16, # width of the implicit MLP
seq_len=1024,
lr=1e-3,
lr_pos_emb=1e-5,
dropout=0.0,
w=1, # frequency of periodic activations
w_mod=1, # non-learnable modification of w
wd=0, # weight decay of kernel parameters
bias=True,
num_inner_mlps=2,
linear_mixer=False,
modulate: bool = True,
normalized=False,
bidirectional=False,
**kwargs,
):
"""
Implicit long filter with modulation.
Args:
d_model: number of channels in the input
emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
order: width of the FFN
num_inner_mlps: number of inner linear layers inside filter MLP
Note:
filter_dropout is not implemented
"""
super().__init__()
self.d_model=d_model
self.emb_dim=emb_dim
self.seq_len=seq_len
self.modulate=modulate
self.use_bias = bias
self.bidirectional = bidirectional
self.bias = nn.Parameter(torch.randn(self.d_model))
self.dropout = nn.Dropout(dropout)
act = Sin(dim=order, w=w, w_mod=w_mod)
assert (
emb_dim % 2 != 0 and emb_dim >= 3
), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)
# uses a variable number of inner linear layers
if linear_mixer is False:
self.implicit_filter = nn.Sequential(
nn.Linear(emb_dim, order),
act,
)
for i in range(num_inner_mlps):
self.implicit_filter.append(nn.Linear(order, order))
self.implicit_filter.append(act)
self.implicit_filter.append(nn.Linear(order, d_model, bias=False))
else:
self.implicit_filter = nn.Sequential(
nn.Linear(emb_dim, d_model, bias=False),
)
if self.bidirectional:
self.implicit_filter_rev = nn.Sequential(
nn.Linear(emb_dim, order),
act,
)
for i in range(num_inner_mlps):
self.implicit_filter_rev.append(nn.Linear(order, order))
self.implicit_filter_rev.append(act)
self.implicit_filter_rev.append(nn.Linear(order, d_model, bias=False))
self.modulation = ExponentialModulation(d_model, **kwargs)
self.normalized = normalized
for c in self.implicit_filter.children():
for name, v in c.state_dict().items():
optim = {"weight_decay": wd, "lr": lr}
setattr(getattr(c, name), "_optim", optim)
def filter(self, L, *args, **kwargs):
z, t = self.pos_emb(L)
h = self.implicit_filter(z)
if self.modulate:
h = self.modulation(t, h)
if self.normalized:
h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
return h
def filter_rev(self, L, *args, **kwargs):
z, t = self.pos_emb(L)
h = self.implicit_filter_rev(z)
if self.modulate:
h = self.modulation(t, h)
if self.normalized:
h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
return h
def forward(self, x, L, k_fwd=None, k_rev=None, bias=None, *args, **kwargs):
if k_fwd is None:
k_fwd = self.filter(L)
if self.bidirectional and k_rev is None:
k_rev = self.filter_rev(L)
# Ensure compatibility with filters that return a tuple
k_fwd = k_fwd[0] if type(k_fwd) is tuple else k_fwd
if bias is None:
bias = self.bias
bias = bias if self.use_bias else 0 * bias
if self.bidirectional:
k_rev = k_rev[0] if type(k_rev) is tuple else k_rev
k = F.pad(k_fwd, (0, L)) \
+ F.pad(k_rev.flip(-1), (L, 0))
else:
k = k_fwd
y = fftconv_ref(
x,
k,
bias,
dropout_mask=None,
gelu=False,
)
return y.to(dtype=x.dtype) |