hugo flores garcia
bugfixes
08c78c6
import math
import logging
from typing import Optional, Tuple, Union, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import loralib as lora
import audiotools as at
from .activations import get_activation
from .layers import CodebookEmbedding
from .layers import FiLM
from .layers import SequentialWithFiLM
from .layers import WNConv1d
from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
from ..mask import _gamma
LORA_R = 8
# def log(t, eps=1e-20):
# return torch.log(t + eps)
def gumbel_noise_like(t):
noise = torch.zeros_like(t).uniform_(1e-20, 1)
return -torch.log(-torch.log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1):
return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.var_eps = eps
def forward(self, x):
"""Returns root mean square normalized version of input `x`
# T5 uses a layer_norm which only scales and doesn't shift, which is also known
# as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
# thus varience is calculated w/o mean and there is no bias
Parameters
----------
x : Tensor[B x T x D]
Returns
-------
Tensor[B x T x D]
"""
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.var_eps)
return self.weight * x
class FeedForward(nn.Module):
def __init__(
self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
):
super().__init__()
factor = 2 if activation == "geglu" else 1
self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R)
self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R)
self.drop = nn.Dropout(dropout)
self.act = get_activation(activation)()
def forward(self, x):
"""Computes position-wise feed-forward layer
Parameters
----------
x : Tensor[B x T x D]
Returns
-------
Tensor[B x T x D]
"""
x = self.w_1(x)
x = self.act(x)
x = self.drop(x)
x = self.w_2(x)
return x
class MultiHeadRelativeAttention(nn.Module):
def __init__(
self,
n_head: int = 8,
d_model: int = 512,
dropout: float = 0.1,
bidirectional: bool = True,
has_relative_attention_bias: bool = True,
attention_num_buckets: int = 32,
attention_max_distance: int = 128,
):
super().__init__()
d_head = d_model // n_head
self.n_head = n_head
self.d_head = d_head
self.bidirectional = bidirectional
self.has_relative_attention_bias = has_relative_attention_bias
self.attention_num_buckets = attention_num_buckets
self.attention_max_distance = attention_max_distance
# Create linear query, key, value projections
self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
self.w_ks = nn.Linear(d_model, d_model, bias=False)
self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
# Create linear final output projection
self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
# Dropout for attention output weights
self.dropout = nn.Dropout(dropout)
# Create relative positional embeddings (if turned on)
if has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
def _relative_position_bucket(self, relative_position):
"""Converts unbounded relative position into bounded set of buckets
with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
buckets
Parameters
----------
relative_position : Tensor[T_q x T_kv]
Relative positions between queries and key_value items
Returns
-------
Tensor[T_q x T_kv]
Input relative positions converted into buckets
"""
relative_buckets = 0
num_buckets = self.attention_num_buckets
max_distance = self.attention_max_distance
# Convert relative position for (-inf, inf) to [0, inf]
# Negative relative positions correspond to past
# Positive relative positions correspond to future
if self.bidirectional:
# use half buckets for each side (past / future)
num_buckets //= 2
# Shift the position positions by `num_buckets` to wrap around
# negative positions
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
# If not bidirectional, ignore positive positions and wrap
# negative positions to positive
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)
# Allocate half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in
# positions up to `max_distance`
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
# Clip the max relative position to `num_buckets - 1`
relative_postion_if_large = torch.min(
relative_postion_if_large,
torch.full_like(relative_postion_if_large, num_buckets - 1),
)
# Choose relative buckets based on small or large positions
relative_buckets += torch.where(
is_small, relative_position, relative_postion_if_large
)
return relative_buckets
def compute_bias(self, query_length, key_length):
"""Computes a position bias scalar for each index in query_length x key_length
Parameters
----------
query_length : int
key_length : int
Returns
-------
Tensor[heads x 1 x T_q x T_kv]
Position bias to be applied on attention logits
"""
query_position = torch.arange(query_length, dtype=torch.long)[:, None]
key_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = key_position - query_position
# Convert relative position to buckets
relative_position_bucket = self._relative_position_bucket(relative_position)
relative_position_bucket = relative_position_bucket.to(
self.relative_attention_bias.weight.device
)
# Index attention bias values
values = self.relative_attention_bias(relative_position_bucket)
values = rearrange(values, "q k h -> h 1 q k")
return values
def forward(self, q, k, v, mask=None, position_bias=None):
"""Computes attention over (keys, values) for every timestep in query
Parameters
----------
q : Tensor[B x T_q x d_model]
Query vectors
k : Tensor[B x T_kv x d_model]
Key vectors to compute attention over
v : Tensor[B x T_kv x d_model]
Value vectors corresponding to the keys
mask : Tensor[B x T_q x T_kv], optional
position_bias: Tensor[head x 1 x T_q x T_kv]
Returns
-------
Tensor[B x T_q x d_model]
Outputs after attending (key, value) using queries
"""
# Compute query, key, value projections
q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
# Compute attention matrix
attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
# Add relative position bias to attention scores
if position_bias is None:
if self.has_relative_attention_bias:
position_bias = self.compute_bias(q.size(-2), k.size(-2))
else:
position_bias = torch.zeros_like(attn)
attn += position_bias
# Apply mask to attention scores to prevent looking up invalid locations
if mask is not None:
attn = attn.masked_fill(mask[None] == 0, -1e9)
# Normalize attention scores and add dropout
attn = torch.softmax(attn, dim=3)
attn = self.dropout(attn)
# Compute attended outputs (product of attention matrix and values)
output = torch.einsum("hblt,hbtv->hblv", [attn, v])
output = rearrange(output, "head b l v -> b l (head v)")
output = self.fc(output)
return output, position_bias
class TransformerLayer(nn.Module):
def __init__(
self,
d_model: int = 512,
d_cond: int = 64,
n_heads: int = 8,
bidirectional: bool = True,
is_decoder: bool = False,
has_relative_attention_bias: bool = False,
flash_attn: bool = False,
dropout: float = 0.1,
):
super().__init__()
# Store args
self.is_decoder = is_decoder
# Create self-attention layer
self.norm_1 = RMSNorm(d_model)
self.film_1 = FiLM(d_cond, d_model)
self.flash_attn = flash_attn
if flash_attn:
from flash_attn.flash_attention import FlashMHA
self.self_attn = FlashMHA(
embed_dim=d_model,
num_heads=n_heads,
attention_dropout=dropout,
causal=False,
)
else:
self.self_attn = MultiHeadRelativeAttention(
n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
)
# (Optional) Create cross-attention layer
if is_decoder:
self.norm_2 = RMSNorm(d_model)
self.film_2 = FiLM(d_cond, d_model)
self.cross_attn = MultiHeadRelativeAttention(
n_heads,
d_model,
dropout,
bidirectional=True,
has_relative_attention_bias=False,
)
# Create last feed-forward layer
self.norm_3 = RMSNorm(d_model)
self.film_3 = FiLM(d_cond, d_model)
self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
# Create dropout
self.dropout = nn.Dropout(dropout)
def forward(
self,
x,
x_mask,
cond,
src=None,
src_mask=None,
position_bias=None,
encoder_decoder_position_bias=None,
):
"""Computes one transformer layer consisting of self attention, (op) cross attention
and feedforward layer
Parameters
----------
x : Tensor[B x T_q x D]
x_mask : Tensor[B x T_q]
src : Tensor[B x T_kv x D], optional
src_mask : Tensor[B x T_kv x D], optional
position_bias : Tensor[heads x B x T_q x T_q], optional
Relative position bias for self attention layer
encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
Relative position bias for cross attention layer
Returns
-------
Tensor[B x T_q x D]
"""
y = self.norm_1(x)
y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
if self.flash_attn:
with torch.autocast(y.device.type, dtype=torch.bfloat16):
y = self.self_attn(y)[0]
else:
y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
x = x + self.dropout(y)
if self.is_decoder:
y = self.norm_2(x)
y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
y, encoder_decoder_position_bias = self.cross_attn(
y, src, src, src_mask, encoder_decoder_position_bias
)
x = x + self.dropout(y)
y = self.norm_3(x)
y = self.film_3(
y.permute(
0,
2,
1,
),
cond,
).permute(0, 2, 1)
y = self.feed_forward(y)
x = x + self.dropout(y)
return x, position_bias, encoder_decoder_position_bias
class TransformerStack(nn.Module):
def __init__(
self,
d_model: int = 512,
d_cond: int = 64,
n_heads: int = 8,
n_layers: int = 8,
last_layer: bool = True,
bidirectional: bool = True,
flash_attn: bool = False,
is_decoder: bool = False,
dropout: float = 0.1,
):
super().__init__()
# Store args
self.bidirectional = bidirectional
self.is_decoder = is_decoder
# Create transformer layers
# In T5, relative attention bias is shared by all layers in the stack
self.layers = nn.ModuleList(
[
TransformerLayer(
d_model,
d_cond,
n_heads,
bidirectional,
is_decoder,
has_relative_attention_bias=True if (i == 0) else False,
flash_attn=flash_attn,
dropout=dropout,
)
for i in range(n_layers)
]
)
# Perform last normalization
self.norm = RMSNorm(d_model) if last_layer else None
def subsequent_mask(self, size):
return torch.ones(1, size, size).tril().bool()
def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
return_activations: bool = False
):
"""Computes a full transformer stack
Parameters
----------
x : Tensor[B x T_q x D]
x_mask : Tensor[B x T_q]
src : Tensor[B x T_kv x D], optional
src_mask : Tensor[B x T_kv], optional
Returns
-------
Tensor[B x T_q x D]
"""
# Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
if self.is_decoder:
src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
# Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
x_mask = x_mask.unsqueeze(-2)
if not self.bidirectional:
x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
# Initialize position biases
position_bias = None
encoder_decoder_position_bias = None
# Compute transformer layers
if return_activations:
activations = []
for layer in self.layers:
x, position_bias, encoder_decoder_position_bias = layer(
x=x,
x_mask=x_mask,
cond=cond,
src=src,
src_mask=src_mask,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
)
if return_activations:
activations.append(x.detach())
out = self.norm(x) if self.norm is not None else x
if return_activations:
return out, torch.stack(activations)
else:
return out
class VampNet(at.ml.BaseModel):
def __init__(
self,
n_heads: int = 20,
n_layers: int = 16,
r_cond_dim: int = 0,
n_codebooks: int = 9,
n_conditioning_codebooks: int = 0,
latent_dim: int = 8,
embedding_dim: int = 1280,
vocab_size: int = 1024,
flash_attn: bool = True,
noise_mode: str = "mask",
dropout: float = 0.1
):
super().__init__()
assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
self.n_heads = n_heads
self.n_layers = n_layers
self.r_cond_dim = r_cond_dim
self.n_codebooks = n_codebooks
self.n_conditioning_codebooks = n_conditioning_codebooks
self.embedding_dim = embedding_dim
self.vocab_size = vocab_size
self.latent_dim = latent_dim
self.flash_attn = flash_attn
self.noise_mode = noise_mode
assert self.noise_mode == "mask", "deprecated"
self.embedding = CodebookEmbedding(
latent_dim=latent_dim,
n_codebooks=n_codebooks,
vocab_size=vocab_size,
emb_dim=embedding_dim,
special_tokens=["MASK"],
)
self.mask_token = self.embedding.special_idxs["MASK"]
self.transformer = TransformerStack(
d_model=embedding_dim,
d_cond=r_cond_dim,
n_heads=n_heads,
n_layers=n_layers,
last_layer=True,
bidirectional=True,
flash_attn=flash_attn,
is_decoder=False,
dropout=dropout,
)
# Add final conv layer
self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
self.classifier = SequentialWithFiLM(
WNConv1d(
embedding_dim,
vocab_size * self.n_predict_codebooks,
kernel_size=1,
padding="same",
# groups=self.n_predict_codebooks,
),
)
def forward(self, x, return_activations: bool = False):
x = self.embedding(x)
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
x = rearrange(x, "b d n -> b n d")
out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
if return_activations:
out, activations = out
out = rearrange(out, "b n d -> b d n")
out = self.classifier(out, None) # no cond here!
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
if return_activations:
return out, activations
else:
return out
def r_embed(self, r, max_positions=10000):
if self.r_cond_dim > 0:
dtype = r.dtype
r = _gamma(r) * max_positions
half_dim = self.r_cond_dim // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.r_cond_dim % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode="constant")
return emb.to(dtype)
else:
return r
@torch.no_grad()
def decode(self, z, codec):
"""
convert a sequence of latents to a signal.
"""
assert z.ndim == 3
# remove mask token
z = z.masked_fill(z == self.mask_token, 0)
signal = at.AudioSignal(
codec.decode(
codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
)["audio"],
codec.sample_rate,
)
# find where the mask token is and replace it with silence in the audio
for tstep in range(z.shape[-1]):
if torch.all(z[:, :, tstep] == self.mask_token):
sample_idx_0 = tstep * codec.hop_length
sample_idx_1 = sample_idx_0 + codec.hop_length
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
return signal
@torch.inference_mode()
def generate(
self,
codec,
time_steps: int = 300,
_sampling_steps: int = 12,
start_tokens: Optional[torch.Tensor] = None,
temperature: float = 1.0,
mask: Optional[torch.Tensor] = None,
mask_temperature: float = 10.5,
typical_filtering=True,
typical_mass=0.15,
typical_min_tokens=64,
top_p=None,
seed: int = None,
sample_cutoff: float = 1.0,
return_signal=True,
debug=False,
causal_weight: float = 0.0,
cfg_guidance: float = None,
):
if seed is not None:
at.util.seed(seed)
sampling_steps = _sampling_steps
logging.debug(f"beginning generation with {sampling_steps} steps")
#####################
# resolve initial z #
#####################
z = start_tokens
nb = z.shape[0]
if z is None:
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
self.device
)
#################
# resolve mask #
#################
if mask is None:
mask = torch.ones_like(z).to(self.device).int()
mask[:, : self.n_conditioning_codebooks, :] = 0.0
if mask.ndim == 2:
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
# init_mask = mask.clone()
###########
# set up #
##########
# apply the mask to z
z_masked = z.masked_fill(mask.bool(), self.mask_token)
# logging.debug(f"z_masked: {z_masked}")
# how many mask tokens to begin with?
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
# how many codebooks are we inferring vs conditioning on?
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
if cfg_guidance is not None:
# we need to repeat our tensors
z_uncond = torch.full_like(z, self.mask_token)
z_masked = torch.cat(
(z_masked, z_uncond), dim=0
)
z = torch.cat(
(z, z_uncond), dim=0
)
mask = torch.cat(
(mask, torch.full_like(mask, 1)), dim=0
)
#################
# begin sampling #
#################
from tqdm import tqdm
for i in range(sampling_steps):
# our current schedule step
r = scalar_to_batch_tensor(
(i + 1) / sampling_steps,
z.shape[0]
).to(z.device)
# get latents
latents = self.embedding.from_codes(z_masked, codec)
# infer from latents
# NOTE: this collapses the codebook dimension into the sequence dimension
logits = self.forward(latents) # b, prob, seq
if cfg_guidance is not None:
logits_cond, logits_uncond = logits[:nb], logits[nb:]
logits_cond = cfg_guidance * logits_cond + cfg_guidance * (1 - logits_uncond)
logits = logits.permute(0, 2, 1) # b, seq, prob
b = logits.shape[0]
sampled_z, selected_probs = sample_from_logits(
logits, sample=(
(i / sampling_steps) <= sample_cutoff
),
temperature=temperature,
typical_filtering=typical_filtering, typical_mass=typical_mass,
typical_min_tokens=typical_min_tokens,
top_k=None, top_p=top_p, return_probs=True,
)
# flatten z_masked and mask, so we can deal with the sampling logic
# we'll unflatten them at the end of the loop for the next forward pass
# remove conditioning codebooks, we'll add them back at the end
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
mask = (z_masked == self.mask_token).int()
# update the mask, remove conditioning codebooks from the mask
# add z back into sampled z where the mask was false
sampled_z = torch.where(
mask.bool(), sampled_z, z_masked
)
# ignore any tokens that weren't masked
selected_probs = torch.where(
mask.bool(), selected_probs, torch.inf
)
# get the num tokens to mask, according to the schedule
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
logging.debug(f"num to mask: {num_to_mask}")
if i != (sampling_steps - 1):
num_to_mask = torch.maximum(
torch.tensor(1),
torch.minimum(
mask.sum(dim=-1, keepdim=True) - 1,
num_to_mask
)
)
# get our new mask
mask = mask_by_random_topk(
num_to_mask, selected_probs, mask_temperature * (1-r)
)
# update the mask
z_masked = torch.where(
mask.bool(), self.mask_token, sampled_z
)
z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
mask = codebook_unflatten(mask, n_infer_codebooks)
# add conditioning codebooks back to z_masked
z_masked = torch.cat(
(z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
)
# add conditioning codebooks back to sampled_z
sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
sampled_z = torch.cat(
(z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
)
if cfg_guidance is not None:
sampled_z = sampled_z[:nb]
if return_signal:
return self.decode(sampled_z, codec)
else:
return sampled_z
def sample_from_logits(
logits,
sample: bool = True,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
typical_filtering: bool = False,
typical_mass: float = 0.2,
typical_min_tokens: int = 1,
return_probs: bool = False
):
"""Convenience function to sample from a categorial distribution with input as
unnormalized logits.
Parameters
----------
logits : Tensor[..., vocab_size]
config: SamplingConfig
The set of hyperparameters to be used for sampling
sample : bool, optional
Whether to perform multinomial sampling, by default True
temperature : float, optional
Scaling parameter when multinomial samping, by default 1.0
top_k : int, optional
Restricts sampling to only `top_k` values acc. to probability,
by default None
top_p : float, optional
Restricts sampling to only those values with cumulative
probability = `top_p`, by default None
Returns
-------
Tensor[...]
Sampled tokens
"""
shp = logits.shape[:-1]
if typical_filtering:
typical_filter(logits,
typical_mass=typical_mass,
typical_min_tokens=typical_min_tokens
)
# Apply top_k sampling
if top_k is not None:
v, _ = logits.topk(top_k)
logits[logits < v[..., [-1]]] = -float("inf")
# Apply top_p (nucleus) sampling
if top_p is not None and top_p < 1.0:
v, sorted_indices = logits.sort(descending=True)
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Right shift indices_to_remove to keep 1st token over threshold
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
..., :-1
]
# Compute indices_to_remove in unsorted array
indices_to_remove = sorted_indices_to_remove.scatter(
-1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = -float("inf")
# Perform multinomial sampling after normalizing logits
probs = (
F.softmax(logits / temperature, dim=-1)
if temperature > 0
else logits.softmax(dim=-1)
)
token = (
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
if sample
else logits.argmax(-1)
)
if return_probs:
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
return token, token_probs
else:
return token
def mask_by_random_topk(
num_to_mask: int,
probs: torch.Tensor,
temperature: float = 1.0,
):
"""
Args:
num_to_mask (int): number of tokens to mask
probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
temperature (float, optional): temperature. Defaults to 1.0.
"""
logging.debug(f"masking by random topk")
logging.debug(f"num to mask: {num_to_mask}")
logging.debug(f"probs shape: {probs.shape}")
logging.debug(f"temperature: {temperature}")
logging.debug("")
noise = gumbel_noise_like(probs)
temperature = temperature.unsqueeze(-1)
confidence = torch.log(probs) + temperature * noise
logging.debug(f"confidence shape: {confidence.shape}")
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
logging.debug(f"sorted confidence shape: {sorted_confidence.shape}")
logging.debug(f"sorted idx shape: {sorted_idx.shape}")
# get the cut off threshold, given the mask length
cut_off = torch.take_along_dim(
sorted_confidence, num_to_mask, axis=-1
)
logging.debug(f"cut off shape: {cut_off.shape}")
# mask out the tokens
mask = confidence < cut_off
logging.debug(f"mask shape: {mask.shape}")
return mask
def typical_filter(
logits,
typical_mass: float = 0.95,
typical_min_tokens: int = 1,):
nb, nt, _ = logits.shape
x_flat = rearrange(logits, "b t l -> (b t ) l")
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
x_flat_norm_p = torch.exp(x_flat_norm)
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
x_flat_cumsum = (
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
)
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
1, last_ind.view(-1, 1)
)
if typical_min_tokens > 1:
sorted_indices_to_remove[..., :typical_min_tokens] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, x_flat_indices, sorted_indices_to_remove
)
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
return logits
if __name__ == "__main__":
# import argbind
from .layers import num_params
VampNet = argbind.bind(VampNet)
@argbind.bind(without_prefix=True)
def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
seq_len = int(32000 / 512 * seq_len_s)
model = VampNet().to(device)
z = torch.randint(
0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
).to(device)
r = torch.zeros(batch_size).to(device)
z_mask_latent = torch.rand(
batch_size, model.latent_dim * model.n_codebooks, seq_len
).to(device)
z_hat = model(z_mask_latent)
pred = z_hat.argmax(dim=1)
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
logging.debug(f"model has {num_params(model)/1e6:<.3f}M parameters")
logging.debug(f"prediction has shape {pred.shape}")
args = argbind.parse_args()
with argbind.scope(args):
try_model()