Spaces:
Sleeping
Sleeping
Hugo Flores
commited on
Commit
·
04c5b94
1
Parent(s):
534a89c
remove wavenet, readability
Browse files- vampnet/modules/layers.py +18 -0
- vampnet/modules/transformer.py +1 -1
- vampnet/modules/wavenet.py +0 -90
vampnet/modules/layers.py
CHANGED
@@ -8,6 +8,24 @@ import torch.nn.functional as F
|
|
8 |
from einops import rearrange
|
9 |
from torch.nn.utils import weight_norm
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def num_params(model):
|
13 |
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
8 |
from einops import rearrange
|
9 |
from torch.nn.utils import weight_norm
|
10 |
|
11 |
+
# Scripting this brings model speed up 1.4x
|
12 |
+
@torch.jit.script
|
13 |
+
def snake(x, alpha):
|
14 |
+
shape = x.shape
|
15 |
+
x = x.reshape(shape[0], shape[1], -1)
|
16 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
17 |
+
x = x.reshape(shape)
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class Snake1d(nn.Module):
|
22 |
+
def __init__(self, channels):
|
23 |
+
super().__init__()
|
24 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return snake(x, self.alpha)
|
28 |
+
|
29 |
|
30 |
def num_params(model):
|
31 |
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
vampnet/modules/transformer.py
CHANGED
@@ -377,7 +377,7 @@ class TransformerStack(nn.Module):
|
|
377 |
n_heads,
|
378 |
bidirectional,
|
379 |
is_decoder,
|
380 |
-
has_relative_attention_bias=(i == 0),
|
381 |
flash_attn=flash_attn,
|
382 |
dropout=dropout,
|
383 |
)
|
|
|
377 |
n_heads,
|
378 |
bidirectional,
|
379 |
is_decoder,
|
380 |
+
has_relative_attention_bias=True if (i == 0) else False,
|
381 |
flash_attn=flash_attn,
|
382 |
dropout=dropout,
|
383 |
)
|
vampnet/modules/wavenet.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
from einops import rearrange
|
3 |
-
|
4 |
-
from voicegpt.nn import WaveNet
|
5 |
-
|
6 |
-
class AutoregMLP(nn.Module):
|
7 |
-
"""Implements an autoregressive ConvNet decoder
|
8 |
-
Refer to SampleRNN (https://arxiv.org/abs/1612.07837) for motivation
|
9 |
-
"""
|
10 |
-
|
11 |
-
def __init__(
|
12 |
-
self,
|
13 |
-
vocab_size: int,
|
14 |
-
d_model: int,
|
15 |
-
n_layers: int,
|
16 |
-
n_fine_tokens: int = 6,
|
17 |
-
n_tokens: int = 9,
|
18 |
-
dropout: float = 0.1,
|
19 |
-
activation: str = "gelu",
|
20 |
-
causal: bool = True,
|
21 |
-
):
|
22 |
-
super().__init__()
|
23 |
-
self.n_fine = n_fine_tokens
|
24 |
-
self.n_layers = n_layers
|
25 |
-
self.upsampler = nn.Linear(d_model, d_model * n_fine_tokens)
|
26 |
-
|
27 |
-
self.wavenet = WaveNet(
|
28 |
-
d_model,
|
29 |
-
d_model,
|
30 |
-
d_model,
|
31 |
-
n_layers,
|
32 |
-
n_fine_tokens,
|
33 |
-
dropout=dropout,
|
34 |
-
activation=activation,
|
35 |
-
causal=causal,
|
36 |
-
)
|
37 |
-
self.ff_output = nn.Linear(d_model, vocab_size * n_tokens, bias=False)
|
38 |
-
|
39 |
-
def time_upsample(self, h_t_coarse):
|
40 |
-
"""Upsamples the conditioning hidden states to match the time resolution
|
41 |
-
of output tokens
|
42 |
-
Parameters
|
43 |
-
----------
|
44 |
-
h_t_coarse : Tensor[B x T_coarse x D]
|
45 |
-
Conditioning hidden states in coarse time-scale
|
46 |
-
Returns
|
47 |
-
-------
|
48 |
-
Tensor[B x T_fine x D]
|
49 |
-
Conditioning hidden states in fine time-scale
|
50 |
-
"""
|
51 |
-
# Upsample the transformer hidden states to fine scale
|
52 |
-
h_t_fine = rearrange(
|
53 |
-
self.upsampler(h_t_coarse), "b t (n d) -> b (t n) d", n=self.n_fine
|
54 |
-
)
|
55 |
-
return h_t_fine
|
56 |
-
|
57 |
-
def decode_logits(self, x_tm1, h_t_fine):
|
58 |
-
"""Decodes output logits conditioned on previous output
|
59 |
-
tokens (upto timestep t-1) and conditioning hidden states
|
60 |
-
using an autoregressive WaveNet
|
61 |
-
Parameters
|
62 |
-
----------
|
63 |
-
x_tm1 : Tensor[B x T x D]
|
64 |
-
h_t_fine : Tensor[B x T x D]
|
65 |
-
Returns
|
66 |
-
-------
|
67 |
-
Tensor[B x T x vocab_size]
|
68 |
-
Predicted logits
|
69 |
-
"""
|
70 |
-
|
71 |
-
# Compute wavenet layers and predict logits
|
72 |
-
o_t = self.wavenet(x_tm1, h_t_fine)
|
73 |
-
return self.ff_output(o_t)
|
74 |
-
|
75 |
-
def forward(self, x_tm1, h_t_coarse):
|
76 |
-
"""Computes autoregressive conditional probability distribution
|
77 |
-
using a WaveNet decoder
|
78 |
-
Parameters
|
79 |
-
----------
|
80 |
-
x_tm1 : Tensor[B x T_fine x D]
|
81 |
-
Embeddings of tokens at fine time-scale
|
82 |
-
h_t_coarse : Tensor[B x T_coarse x D]
|
83 |
-
Hidden states at coarse time scale
|
84 |
-
Returns
|
85 |
-
-------
|
86 |
-
Tensor[B x T_fine x vocab_size]
|
87 |
-
Predicted logits at fine time-scale
|
88 |
-
"""
|
89 |
-
h_t_fine = self.time_upsample(h_t_coarse)
|
90 |
-
return self.decode_logits(x_tm1, h_t_fine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|