Spaces:
Runtime error
Runtime error
Upload synthesizer/models/tacotron.py with huggingface_hub
Browse files- synthesizer/models/tacotron.py +519 -0
synthesizer/models/tacotron.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
|
10 |
+
class HighwayNetwork(nn.Module):
|
11 |
+
def __init__(self, size):
|
12 |
+
super().__init__()
|
13 |
+
self.W1 = nn.Linear(size, size)
|
14 |
+
self.W2 = nn.Linear(size, size)
|
15 |
+
self.W1.bias.data.fill_(0.)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x1 = self.W1(x)
|
19 |
+
x2 = self.W2(x)
|
20 |
+
g = torch.sigmoid(x2)
|
21 |
+
y = g * F.relu(x1) + (1. - g) * x
|
22 |
+
return y
|
23 |
+
|
24 |
+
|
25 |
+
class Encoder(nn.Module):
|
26 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
|
27 |
+
super().__init__()
|
28 |
+
prenet_dims = (encoder_dims, encoder_dims)
|
29 |
+
cbhg_channels = encoder_dims
|
30 |
+
self.embedding = nn.Embedding(num_chars, embed_dims)
|
31 |
+
self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
32 |
+
dropout=dropout)
|
33 |
+
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
34 |
+
proj_channels=[cbhg_channels, cbhg_channels],
|
35 |
+
num_highways=num_highways)
|
36 |
+
|
37 |
+
def forward(self, x, speaker_embedding=None):
|
38 |
+
x = self.embedding(x)
|
39 |
+
x = self.pre_net(x)
|
40 |
+
x.transpose_(1, 2)
|
41 |
+
x = self.cbhg(x)
|
42 |
+
if speaker_embedding is not None:
|
43 |
+
x = self.add_speaker_embedding(x, speaker_embedding)
|
44 |
+
return x
|
45 |
+
|
46 |
+
def add_speaker_embedding(self, x, speaker_embedding):
|
47 |
+
# SV2TTS
|
48 |
+
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
49 |
+
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
50 |
+
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
51 |
+
# This concats the speaker embedding for each char in the encoder output
|
52 |
+
|
53 |
+
# Save the dimensions as human-readable names
|
54 |
+
batch_size = x.size()[0]
|
55 |
+
num_chars = x.size()[1]
|
56 |
+
|
57 |
+
if speaker_embedding.dim() == 1:
|
58 |
+
idx = 0
|
59 |
+
else:
|
60 |
+
idx = 1
|
61 |
+
|
62 |
+
# Start by making a copy of each speaker embedding to match the input text length
|
63 |
+
# The output of this has size (batch_size, num_chars * tts_embed_dims)
|
64 |
+
speaker_embedding_size = speaker_embedding.size()[idx]
|
65 |
+
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
66 |
+
|
67 |
+
# Reshape it and transpose
|
68 |
+
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
69 |
+
e = e.transpose(1, 2)
|
70 |
+
|
71 |
+
# Concatenate the tiled speaker embedding with the encoder output
|
72 |
+
x = torch.cat((x, e), 2)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class BatchNormConv(nn.Module):
|
77 |
+
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
78 |
+
super().__init__()
|
79 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
80 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
81 |
+
self.relu = relu
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x = self.conv(x)
|
85 |
+
x = F.relu(x) if self.relu is True else x
|
86 |
+
return self.bnorm(x)
|
87 |
+
|
88 |
+
|
89 |
+
class CBHG(nn.Module):
|
90 |
+
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
# List of all rnns to call `flatten_parameters()` on
|
94 |
+
self._to_flatten = []
|
95 |
+
|
96 |
+
self.bank_kernels = [i for i in range(1, K + 1)]
|
97 |
+
self.conv1d_bank = nn.ModuleList()
|
98 |
+
for k in self.bank_kernels:
|
99 |
+
conv = BatchNormConv(in_channels, channels, k)
|
100 |
+
self.conv1d_bank.append(conv)
|
101 |
+
|
102 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
103 |
+
|
104 |
+
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
105 |
+
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
106 |
+
|
107 |
+
# Fix the highway input if necessary
|
108 |
+
if proj_channels[-1] != channels:
|
109 |
+
self.highway_mismatch = True
|
110 |
+
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
111 |
+
else:
|
112 |
+
self.highway_mismatch = False
|
113 |
+
|
114 |
+
self.highways = nn.ModuleList()
|
115 |
+
for i in range(num_highways):
|
116 |
+
hn = HighwayNetwork(channels)
|
117 |
+
self.highways.append(hn)
|
118 |
+
|
119 |
+
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
120 |
+
self._to_flatten.append(self.rnn)
|
121 |
+
|
122 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
123 |
+
self._flatten_parameters()
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
127 |
+
# the model gets replicated, making it no longer guaranteed that the
|
128 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
129 |
+
self._flatten_parameters()
|
130 |
+
|
131 |
+
# Save these for later
|
132 |
+
residual = x
|
133 |
+
seq_len = x.size(-1)
|
134 |
+
conv_bank = []
|
135 |
+
|
136 |
+
# Convolution Bank
|
137 |
+
for conv in self.conv1d_bank:
|
138 |
+
c = conv(x) # Convolution
|
139 |
+
conv_bank.append(c[:, :, :seq_len])
|
140 |
+
|
141 |
+
# Stack along the channel axis
|
142 |
+
conv_bank = torch.cat(conv_bank, dim=1)
|
143 |
+
|
144 |
+
# dump the last padding to fit residual
|
145 |
+
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
146 |
+
|
147 |
+
# Conv1d projections
|
148 |
+
x = self.conv_project1(x)
|
149 |
+
x = self.conv_project2(x)
|
150 |
+
|
151 |
+
# Residual Connect
|
152 |
+
x = x + residual
|
153 |
+
|
154 |
+
# Through the highways
|
155 |
+
x = x.transpose(1, 2)
|
156 |
+
if self.highway_mismatch is True:
|
157 |
+
x = self.pre_highway(x)
|
158 |
+
for h in self.highways: x = h(x)
|
159 |
+
|
160 |
+
# And then the RNN
|
161 |
+
x, _ = self.rnn(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
def _flatten_parameters(self):
|
165 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
166 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
167 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
168 |
+
|
169 |
+
class PreNet(nn.Module):
|
170 |
+
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
171 |
+
super().__init__()
|
172 |
+
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
173 |
+
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
174 |
+
self.p = dropout
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
x = self.fc1(x)
|
178 |
+
x = F.relu(x)
|
179 |
+
x = F.dropout(x, self.p, training=True)
|
180 |
+
x = self.fc2(x)
|
181 |
+
x = F.relu(x)
|
182 |
+
x = F.dropout(x, self.p, training=True)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class Attention(nn.Module):
|
187 |
+
def __init__(self, attn_dims):
|
188 |
+
super().__init__()
|
189 |
+
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
|
190 |
+
self.v = nn.Linear(attn_dims, 1, bias=False)
|
191 |
+
|
192 |
+
def forward(self, encoder_seq_proj, query, t):
|
193 |
+
|
194 |
+
# print(encoder_seq_proj.shape)
|
195 |
+
# Transform the query vector
|
196 |
+
query_proj = self.W(query).unsqueeze(1)
|
197 |
+
|
198 |
+
# Compute the scores
|
199 |
+
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
|
200 |
+
scores = F.softmax(u, dim=1)
|
201 |
+
|
202 |
+
return scores.transpose(1, 2)
|
203 |
+
|
204 |
+
|
205 |
+
class LSA(nn.Module):
|
206 |
+
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
207 |
+
super().__init__()
|
208 |
+
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
209 |
+
self.L = nn.Linear(filters, attn_dim, bias=False)
|
210 |
+
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
211 |
+
self.v = nn.Linear(attn_dim, 1, bias=False)
|
212 |
+
self.cumulative = None
|
213 |
+
self.attention = None
|
214 |
+
|
215 |
+
def init_attention(self, encoder_seq_proj):
|
216 |
+
device = next(self.parameters()).device # use same device as parameters
|
217 |
+
b, t, c = encoder_seq_proj.size()
|
218 |
+
self.cumulative = torch.zeros(b, t, device=device)
|
219 |
+
self.attention = torch.zeros(b, t, device=device)
|
220 |
+
|
221 |
+
def forward(self, encoder_seq_proj, query, t, chars):
|
222 |
+
|
223 |
+
if t == 0: self.init_attention(encoder_seq_proj)
|
224 |
+
|
225 |
+
processed_query = self.W(query).unsqueeze(1)
|
226 |
+
|
227 |
+
location = self.cumulative.unsqueeze(1)
|
228 |
+
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
229 |
+
|
230 |
+
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
231 |
+
u = u.squeeze(-1)
|
232 |
+
|
233 |
+
# Mask zero padding chars
|
234 |
+
u = u * (chars != 0).float()
|
235 |
+
|
236 |
+
# Smooth Attention
|
237 |
+
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
238 |
+
scores = F.softmax(u, dim=1)
|
239 |
+
self.attention = scores
|
240 |
+
self.cumulative = self.cumulative + self.attention
|
241 |
+
|
242 |
+
return scores.unsqueeze(-1).transpose(1, 2)
|
243 |
+
|
244 |
+
|
245 |
+
class Decoder(nn.Module):
|
246 |
+
# Class variable because its value doesn't change between classes
|
247 |
+
# yet ought to be scoped by class because its a property of a Decoder
|
248 |
+
max_r = 20
|
249 |
+
def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
|
250 |
+
dropout, speaker_embedding_size):
|
251 |
+
super().__init__()
|
252 |
+
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
|
253 |
+
self.n_mels = n_mels
|
254 |
+
prenet_dims = (decoder_dims * 2, decoder_dims * 2)
|
255 |
+
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
256 |
+
dropout=dropout)
|
257 |
+
self.attn_net = LSA(decoder_dims)
|
258 |
+
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
259 |
+
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
260 |
+
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
261 |
+
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
262 |
+
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
263 |
+
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
264 |
+
|
265 |
+
def zoneout(self, prev, current, p=0.1):
|
266 |
+
device = next(self.parameters()).device # Use same device as parameters
|
267 |
+
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
268 |
+
return prev * mask + current * (1 - mask)
|
269 |
+
|
270 |
+
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
271 |
+
hidden_states, cell_states, context_vec, t, chars):
|
272 |
+
|
273 |
+
# Need this for reshaping mels
|
274 |
+
batch_size = encoder_seq.size(0)
|
275 |
+
|
276 |
+
# Unpack the hidden and cell states
|
277 |
+
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
278 |
+
rnn1_cell, rnn2_cell = cell_states
|
279 |
+
|
280 |
+
# PreNet for the Attention RNN
|
281 |
+
prenet_out = self.prenet(prenet_in)
|
282 |
+
|
283 |
+
# Compute the Attention RNN hidden state
|
284 |
+
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
|
285 |
+
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
|
286 |
+
|
287 |
+
# Compute the attention scores
|
288 |
+
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
|
289 |
+
|
290 |
+
# Dot product to create the context vector
|
291 |
+
context_vec = scores @ encoder_seq
|
292 |
+
context_vec = context_vec.squeeze(1)
|
293 |
+
|
294 |
+
# Concat Attention RNN output w. Context Vector & project
|
295 |
+
x = torch.cat([context_vec, attn_hidden], dim=1)
|
296 |
+
x = self.rnn_input(x)
|
297 |
+
|
298 |
+
# Compute first Residual RNN
|
299 |
+
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
300 |
+
if self.training:
|
301 |
+
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
|
302 |
+
else:
|
303 |
+
rnn1_hidden = rnn1_hidden_next
|
304 |
+
x = x + rnn1_hidden
|
305 |
+
|
306 |
+
# Compute second Residual RNN
|
307 |
+
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
308 |
+
if self.training:
|
309 |
+
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
|
310 |
+
else:
|
311 |
+
rnn2_hidden = rnn2_hidden_next
|
312 |
+
x = x + rnn2_hidden
|
313 |
+
|
314 |
+
# Project Mels
|
315 |
+
mels = self.mel_proj(x)
|
316 |
+
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
|
317 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
318 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
319 |
+
|
320 |
+
# Stop token prediction
|
321 |
+
s = torch.cat((x, context_vec), dim=1)
|
322 |
+
s = self.stop_proj(s)
|
323 |
+
stop_tokens = torch.sigmoid(s)
|
324 |
+
|
325 |
+
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
326 |
+
|
327 |
+
|
328 |
+
class Tacotron(nn.Module):
|
329 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
|
330 |
+
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
|
331 |
+
dropout, stop_threshold, speaker_embedding_size):
|
332 |
+
super().__init__()
|
333 |
+
self.n_mels = n_mels
|
334 |
+
self.lstm_dims = lstm_dims
|
335 |
+
self.encoder_dims = encoder_dims
|
336 |
+
self.decoder_dims = decoder_dims
|
337 |
+
self.speaker_embedding_size = speaker_embedding_size
|
338 |
+
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
339 |
+
encoder_K, num_highways, dropout)
|
340 |
+
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
341 |
+
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
342 |
+
dropout, speaker_embedding_size)
|
343 |
+
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
344 |
+
[postnet_dims, fft_bins], num_highways)
|
345 |
+
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
346 |
+
|
347 |
+
self.init_model()
|
348 |
+
self.num_params()
|
349 |
+
|
350 |
+
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
351 |
+
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
352 |
+
|
353 |
+
@property
|
354 |
+
def r(self):
|
355 |
+
return self.decoder.r.item()
|
356 |
+
|
357 |
+
@r.setter
|
358 |
+
def r(self, value):
|
359 |
+
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
360 |
+
|
361 |
+
def forward(self, x, m, speaker_embedding):
|
362 |
+
device = next(self.parameters()).device # use same device as parameters
|
363 |
+
|
364 |
+
self.step += 1
|
365 |
+
batch_size, _, steps = m.size()
|
366 |
+
|
367 |
+
# Initialise all hidden states and pack into tuple
|
368 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
369 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
370 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
371 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
372 |
+
|
373 |
+
# Initialise all lstm cell states and pack into tuple
|
374 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
375 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
376 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
377 |
+
|
378 |
+
# <GO> Frame for start of decoder loop
|
379 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
380 |
+
|
381 |
+
# Need an initial context vector
|
382 |
+
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
383 |
+
|
384 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
385 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
386 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
387 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
388 |
+
|
389 |
+
# Need a couple of lists for outputs
|
390 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
391 |
+
|
392 |
+
# Run the decoder loop
|
393 |
+
for t in range(0, steps, self.r):
|
394 |
+
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
|
395 |
+
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
396 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
397 |
+
hidden_states, cell_states, context_vec, t, x)
|
398 |
+
mel_outputs.append(mel_frames)
|
399 |
+
attn_scores.append(scores)
|
400 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
401 |
+
|
402 |
+
# Concat the mel outputs into sequence
|
403 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
404 |
+
|
405 |
+
# Post-Process for Linear Spectrograms
|
406 |
+
postnet_out = self.postnet(mel_outputs)
|
407 |
+
linear = self.post_proj(postnet_out)
|
408 |
+
linear = linear.transpose(1, 2)
|
409 |
+
|
410 |
+
# For easy visualisation
|
411 |
+
attn_scores = torch.cat(attn_scores, 1)
|
412 |
+
# attn_scores = attn_scores.cpu().data.numpy()
|
413 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
414 |
+
|
415 |
+
return mel_outputs, linear, attn_scores, stop_outputs
|
416 |
+
|
417 |
+
def generate(self, x, speaker_embedding=None, steps=2000):
|
418 |
+
self.eval()
|
419 |
+
device = next(self.parameters()).device # use same device as parameters
|
420 |
+
|
421 |
+
batch_size, _ = x.size()
|
422 |
+
|
423 |
+
# Need to initialise all hidden states and pack into tuple for tidyness
|
424 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
425 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
426 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
427 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
428 |
+
|
429 |
+
# Need to initialise all lstm cell states and pack into tuple for tidyness
|
430 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
431 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
432 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
433 |
+
|
434 |
+
# Need a <GO> Frame for start of decoder loop
|
435 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
436 |
+
|
437 |
+
# Need an initial context vector
|
438 |
+
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
439 |
+
|
440 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
441 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
442 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
443 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
444 |
+
|
445 |
+
# Need a couple of lists for outputs
|
446 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
447 |
+
|
448 |
+
# Run the decoder loop
|
449 |
+
for t in range(0, steps, self.r):
|
450 |
+
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
451 |
+
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
452 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
453 |
+
hidden_states, cell_states, context_vec, t, x)
|
454 |
+
mel_outputs.append(mel_frames)
|
455 |
+
attn_scores.append(scores)
|
456 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
457 |
+
# Stop the loop when all stop tokens in batch exceed threshold
|
458 |
+
if (stop_tokens > 0.5).all() and t > 10: break
|
459 |
+
|
460 |
+
# Concat the mel outputs into sequence
|
461 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
462 |
+
|
463 |
+
# Post-Process for Linear Spectrograms
|
464 |
+
postnet_out = self.postnet(mel_outputs)
|
465 |
+
linear = self.post_proj(postnet_out)
|
466 |
+
|
467 |
+
|
468 |
+
linear = linear.transpose(1, 2)
|
469 |
+
|
470 |
+
# For easy visualisation
|
471 |
+
attn_scores = torch.cat(attn_scores, 1)
|
472 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
473 |
+
|
474 |
+
self.train()
|
475 |
+
|
476 |
+
return mel_outputs, linear, attn_scores
|
477 |
+
|
478 |
+
def init_model(self):
|
479 |
+
for p in self.parameters():
|
480 |
+
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
481 |
+
|
482 |
+
def get_step(self):
|
483 |
+
return self.step.data.item()
|
484 |
+
|
485 |
+
def reset_step(self):
|
486 |
+
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
487 |
+
self.step = self.step.data.new_tensor(1)
|
488 |
+
|
489 |
+
def log(self, path, msg):
|
490 |
+
with open(path, "a") as f:
|
491 |
+
print(msg, file=f)
|
492 |
+
|
493 |
+
def load(self, path, optimizer=None):
|
494 |
+
# Use device of model params as location for loaded state
|
495 |
+
device = next(self.parameters()).device
|
496 |
+
checkpoint = torch.load(str(path), map_location=device)
|
497 |
+
self.load_state_dict(checkpoint["model_state"])
|
498 |
+
|
499 |
+
if "optimizer_state" in checkpoint and optimizer is not None:
|
500 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
501 |
+
|
502 |
+
def save(self, path, optimizer=None):
|
503 |
+
if optimizer is not None:
|
504 |
+
torch.save({
|
505 |
+
"model_state": self.state_dict(),
|
506 |
+
"optimizer_state": optimizer.state_dict(),
|
507 |
+
}, str(path))
|
508 |
+
else:
|
509 |
+
torch.save({
|
510 |
+
"model_state": self.state_dict(),
|
511 |
+
}, str(path))
|
512 |
+
|
513 |
+
|
514 |
+
def num_params(self, print_out=True):
|
515 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
516 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
517 |
+
if print_out:
|
518 |
+
print("Trainable Parameters: %.3fM" % parameters)
|
519 |
+
return parameters
|