Dionyssos's picture
soudscape discard last 1s from AudioGen - avoids splash sound
a84b206
raw
history blame
15.2 kB
import typing as tp
from einops import rearrange
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from xformers import ops
_efficient_attention_backend: str = 'torch'
def _get_attention_time_dimension(memory_efficient: bool) -> int:
if _efficient_attention_backend == 'torch' and memory_efficient:
return 2
else:
return 1
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
Args:
positions (torch.Tensor): LongTensor of positions.
dim (int): Dimension of the embedding.
max_period (float): Maximum period of the cosine/sine functions.
dtype (torch.dtype or str): dtype to use to generate the embedding.
Returns:
torch.Tensor: Sinusoidal positional embedding.
"""
# We aim for BTC format
assert dim % 2 == 0
half_dim = dim // 2
positions = positions.to(dtype)
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
if n_rep == 1:
return x
if _efficient_attention_backend == 'torch' and memory_efficient:
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
else:
bs, slen, n_kv_heads, head_dim = x.shape
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class StreamingMultiheadAttention(nn.Module):
def __init__(self,
embed_dim,
num_heads, dropout: float = 0.0, bias: bool = True,
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
memory_efficient: bool = False, attention_as_float32: bool = False,
cross_attention: bool = False,
kv_repeat: int = 1,
device=None, dtype=None):
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
if past_context is not None:
assert causal
self.embed_dim = embed_dim
self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
self.memory_efficient = memory_efficient
self.cross_attention = cross_attention
self.num_heads = num_heads
self.dropout = dropout
self.kv_repeat = kv_repeat
self.custom = True #_is_custom(custom, memory_efficient)
if not self.custom:
print(f'{self.custom}')
if self.custom:
out_dim = embed_dim
assert num_heads % kv_repeat == 0
assert not cross_attention or kv_repeat == 1
num_kv = num_heads // kv_repeat
kv_dim = (embed_dim // num_heads) * num_kv
out_dim += 2 * kv_dim
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
# We try to follow the default PyTorch MHA convention, to easily compare results.
self.in_proj_weight = in_proj.weight
self.in_proj_bias = in_proj.bias
if bias:
self.in_proj_bias.data.zero_() # Following Pytorch convention
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if bias:
self.out_proj.bias.data.zero_()
else:
assert kv_repeat == 1
self.mha = nn.MultiheadAttention(
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
**factory_kwargs)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
if not self.custom:
# Support compat with regular MHA
keys = [n for n, _ in self.mha.named_parameters()]
for key in keys:
if prefix + key in state_dict:
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def forward(self,
query,
key=None, # ignores those 2 args if not self.cross_attn
value=None):
# time_dim = _get_attention_time_dimension(self.memory_efficient)
# if time_dim == 2:
layout = "b h t d"
# else:
# layout = "b t h d"
# dtype = query.dtype
if self.custom:
if self.cross_attention:
# Different queries, keys, values, we have to spit manually the weights
# before applying the linear.
dim = self.in_proj_weight.shape[0] // 3
if self.in_proj_bias is None:
bias_q, bias_k, bias_v = None, None, None
else:
bias_q = self.in_proj_bias[:dim]
bias_k = self.in_proj_bias[dim: 2 * dim]
bias_v = self.in_proj_bias[2 * dim:]
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
# todo: when streaming, we could actually save k, v and check the shape actually match.
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
# print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
else:
# 1st projected makes k,v (instantaneous)
# 2nd cat
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
if self.kv_repeat == 1:
# if time_dim == 2:
bound_layout = "b h p t d"
# else:
# bound_layout = "b t p h d"
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
q, k, v = ops.unbind(packed, dim=2)
if self.k_history is not None:
#
# pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
# has to be 4D with batch 1 due to single condition 3=seqlen
# 24 heads 64 dimofh
self.k_history = torch.cat([self.k_history, k], 2)
self.v_history = torch.cat([self.v_history, v], 2)
else:
# init on 1st token (for all 47 transf layers)
print(f'else skip')
self.k_history = k
self.v_history = v
k = self.k_history
v = self.v_history
# KV COMPLETION ONLY ON SELF ATTENTION
# print('KV5', self.k_history.sum(), self.v_history.sum(), self.k_history.shape, self.v_history.shape)
if self.memory_efficient:
# print('EVER IN MEMORY EFFICIENT A')
p = self.dropout if self.training else 0
if _efficient_attention_backend == 'torch':
# print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(), 'CROSSopen')
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=False, dropout_p=p
)
x = x.to(q.dtype)
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
x = self.out_proj(x)
return x
class StreamingTransformerLayer(nn.Module): #nn.TransformerEncoderLayer):
# INHERITS MHA !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
def __init__(self,
d_model: int,
num_heads: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
bias_ff: bool = True,
bias_attn: bool = True,
custom: bool = False,
memory_efficient: bool = False,
attention_as_float32: bool = False,
cross_attention: bool = False,
attention_dropout: tp.Optional[float] = None,
kv_repeat: int = 1,
norm: str = 'layer_norm',
device=None,
dtype=None,
**kwargs):
super().__init__() #d_model, num_heads, dim_feedforward, dropout,
#device=device, dtype=dtype, batch_first=True, **kwargs)
# print(kwargs['activation'], 'ACTIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII\n\n\n\n')
# -- EN Layer
# DOES NOT INHERIT NO VARIABLE FROM nn.TransformerEncoderLayer only the _sa_block function
# -- EN layer
factory_kwargs = {'device': device, 'dtype': dtype}
# Redefine self_attn to our streaming multi-head attention
attn_kwargs: tp.Dict[str, tp.Any] = {
'embed_dim': d_model,
'num_heads': num_heads,
'dropout': dropout if attention_dropout is None else attention_dropout,
'bias': bias_attn,
'custom': custom,
'memory_efficient': memory_efficient,
'attention_as_float32': attention_as_float32,
}
self.self_attn = StreamingMultiheadAttention(
kv_repeat=kv_repeat,
**attn_kwargs,
**factory_kwargs) # type: ignore
# Redefine feedforward layers to expose bias parameter
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
# print('LAYER scale', layer_scale, '\n\n\n\n\n\n\n\n\n') # always
self.cross_attention= None
if cross_attention:
self.cross_attention = StreamingMultiheadAttention(
cross_attention=True,
**attn_kwargs,
**factory_kwargs)
self.dropout_cross = nn.Dropout(dropout)
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
def forward(self,
src,
cross_attention_src=None): # txtcond
'''T is saved float16 weights - should we cast src to float16'''
x = src
x = x + self.self_attn(self.norm1(x))
if cross_attention_src is not None:
x = x + self.cross_attention(
query = self.norm_cross(x),
key = cross_attention_src,
value = cross_attention_src) # txtcondition
x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
return x
class StreamingTransformer(nn.Module):
def __init__(self, d_model: int,
num_heads: int,
num_layers: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
bias_ff: bool = True,
bias_attn: bool = True,
custom: bool = False,
memory_efficient: bool = False,
attention_as_float32: bool = False,
cross_attention: bool = False,
positional_embedding: str = 'sin',
max_period: float = 10_000,
layer_class=StreamingTransformerLayer,
checkpointing: str = 'none',
device=None,
dtype=None,
**kwargs):
super().__init__()
assert d_model % num_heads == 0
self.positional_embedding = positional_embedding
self.max_period = max_period
# self._stream_off = 0 # the llm should reinitialize this at ery generate()
self.checkpointing = checkpointing
self.layers = nn.ModuleList()
for idx in range(num_layers):
self.layers.append(
layer_class(
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
custom=custom,
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
cross_attention=cross_attention,
device=device, dtype=dtype, **kwargs))
if self.checkpointing != 'none':
for layer in self.layers:
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
# backward hook inside of FSDP...
layer._magma_checkpointed = True # type: ignore
def forward(self, x: torch.Tensor, *args, **kwargs):
B, T, C = x.shape
if self.positional_embedding in ['sin', 'sin_rope']:
positions = torch.arange(T, device=x.device).view(1, -1, 1)
positions = positions + kwargs['token_count'] #offsets.view(-1, 1, 1)
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
x = x + pos_emb
for j, lay in enumerate(self.layers):
# print(f'Transf Layer{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________')
x = lay(x, cross_attention_src=kwargs["cross_attention_src"]) # cross_attention_src = txt-cond
# each layer (mha) keeps history of its own k,v for all tokens
return x