keithhon commited on
Commit
389c7cc
·
1 Parent(s): 62bcf42

Upload synthesizer/models/tacotron.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. synthesizer/models/tacotron.py +519 -0
synthesizer/models/tacotron.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+
10
+ class HighwayNetwork(nn.Module):
11
+ def __init__(self, size):
12
+ super().__init__()
13
+ self.W1 = nn.Linear(size, size)
14
+ self.W2 = nn.Linear(size, size)
15
+ self.W1.bias.data.fill_(0.)
16
+
17
+ def forward(self, x):
18
+ x1 = self.W1(x)
19
+ x2 = self.W2(x)
20
+ g = torch.sigmoid(x2)
21
+ y = g * F.relu(x1) + (1. - g) * x
22
+ return y
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
27
+ super().__init__()
28
+ prenet_dims = (encoder_dims, encoder_dims)
29
+ cbhg_channels = encoder_dims
30
+ self.embedding = nn.Embedding(num_chars, embed_dims)
31
+ self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
32
+ dropout=dropout)
33
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
34
+ proj_channels=[cbhg_channels, cbhg_channels],
35
+ num_highways=num_highways)
36
+
37
+ def forward(self, x, speaker_embedding=None):
38
+ x = self.embedding(x)
39
+ x = self.pre_net(x)
40
+ x.transpose_(1, 2)
41
+ x = self.cbhg(x)
42
+ if speaker_embedding is not None:
43
+ x = self.add_speaker_embedding(x, speaker_embedding)
44
+ return x
45
+
46
+ def add_speaker_embedding(self, x, speaker_embedding):
47
+ # SV2TTS
48
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
49
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
50
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
51
+ # This concats the speaker embedding for each char in the encoder output
52
+
53
+ # Save the dimensions as human-readable names
54
+ batch_size = x.size()[0]
55
+ num_chars = x.size()[1]
56
+
57
+ if speaker_embedding.dim() == 1:
58
+ idx = 0
59
+ else:
60
+ idx = 1
61
+
62
+ # Start by making a copy of each speaker embedding to match the input text length
63
+ # The output of this has size (batch_size, num_chars * tts_embed_dims)
64
+ speaker_embedding_size = speaker_embedding.size()[idx]
65
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
66
+
67
+ # Reshape it and transpose
68
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
69
+ e = e.transpose(1, 2)
70
+
71
+ # Concatenate the tiled speaker embedding with the encoder output
72
+ x = torch.cat((x, e), 2)
73
+ return x
74
+
75
+
76
+ class BatchNormConv(nn.Module):
77
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
78
+ super().__init__()
79
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
80
+ self.bnorm = nn.BatchNorm1d(out_channels)
81
+ self.relu = relu
82
+
83
+ def forward(self, x):
84
+ x = self.conv(x)
85
+ x = F.relu(x) if self.relu is True else x
86
+ return self.bnorm(x)
87
+
88
+
89
+ class CBHG(nn.Module):
90
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
91
+ super().__init__()
92
+
93
+ # List of all rnns to call `flatten_parameters()` on
94
+ self._to_flatten = []
95
+
96
+ self.bank_kernels = [i for i in range(1, K + 1)]
97
+ self.conv1d_bank = nn.ModuleList()
98
+ for k in self.bank_kernels:
99
+ conv = BatchNormConv(in_channels, channels, k)
100
+ self.conv1d_bank.append(conv)
101
+
102
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
103
+
104
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
105
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
106
+
107
+ # Fix the highway input if necessary
108
+ if proj_channels[-1] != channels:
109
+ self.highway_mismatch = True
110
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
111
+ else:
112
+ self.highway_mismatch = False
113
+
114
+ self.highways = nn.ModuleList()
115
+ for i in range(num_highways):
116
+ hn = HighwayNetwork(channels)
117
+ self.highways.append(hn)
118
+
119
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
120
+ self._to_flatten.append(self.rnn)
121
+
122
+ # Avoid fragmentation of RNN parameters and associated warning
123
+ self._flatten_parameters()
124
+
125
+ def forward(self, x):
126
+ # Although we `_flatten_parameters()` on init, when using DataParallel
127
+ # the model gets replicated, making it no longer guaranteed that the
128
+ # weights are contiguous in GPU memory. Hence, we must call it again
129
+ self._flatten_parameters()
130
+
131
+ # Save these for later
132
+ residual = x
133
+ seq_len = x.size(-1)
134
+ conv_bank = []
135
+
136
+ # Convolution Bank
137
+ for conv in self.conv1d_bank:
138
+ c = conv(x) # Convolution
139
+ conv_bank.append(c[:, :, :seq_len])
140
+
141
+ # Stack along the channel axis
142
+ conv_bank = torch.cat(conv_bank, dim=1)
143
+
144
+ # dump the last padding to fit residual
145
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
146
+
147
+ # Conv1d projections
148
+ x = self.conv_project1(x)
149
+ x = self.conv_project2(x)
150
+
151
+ # Residual Connect
152
+ x = x + residual
153
+
154
+ # Through the highways
155
+ x = x.transpose(1, 2)
156
+ if self.highway_mismatch is True:
157
+ x = self.pre_highway(x)
158
+ for h in self.highways: x = h(x)
159
+
160
+ # And then the RNN
161
+ x, _ = self.rnn(x)
162
+ return x
163
+
164
+ def _flatten_parameters(self):
165
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
166
+ to improve efficiency and avoid PyTorch yelling at us."""
167
+ [m.flatten_parameters() for m in self._to_flatten]
168
+
169
+ class PreNet(nn.Module):
170
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
171
+ super().__init__()
172
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
173
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
174
+ self.p = dropout
175
+
176
+ def forward(self, x):
177
+ x = self.fc1(x)
178
+ x = F.relu(x)
179
+ x = F.dropout(x, self.p, training=True)
180
+ x = self.fc2(x)
181
+ x = F.relu(x)
182
+ x = F.dropout(x, self.p, training=True)
183
+ return x
184
+
185
+
186
+ class Attention(nn.Module):
187
+ def __init__(self, attn_dims):
188
+ super().__init__()
189
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
190
+ self.v = nn.Linear(attn_dims, 1, bias=False)
191
+
192
+ def forward(self, encoder_seq_proj, query, t):
193
+
194
+ # print(encoder_seq_proj.shape)
195
+ # Transform the query vector
196
+ query_proj = self.W(query).unsqueeze(1)
197
+
198
+ # Compute the scores
199
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
200
+ scores = F.softmax(u, dim=1)
201
+
202
+ return scores.transpose(1, 2)
203
+
204
+
205
+ class LSA(nn.Module):
206
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
207
+ super().__init__()
208
+ self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
209
+ self.L = nn.Linear(filters, attn_dim, bias=False)
210
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
211
+ self.v = nn.Linear(attn_dim, 1, bias=False)
212
+ self.cumulative = None
213
+ self.attention = None
214
+
215
+ def init_attention(self, encoder_seq_proj):
216
+ device = next(self.parameters()).device # use same device as parameters
217
+ b, t, c = encoder_seq_proj.size()
218
+ self.cumulative = torch.zeros(b, t, device=device)
219
+ self.attention = torch.zeros(b, t, device=device)
220
+
221
+ def forward(self, encoder_seq_proj, query, t, chars):
222
+
223
+ if t == 0: self.init_attention(encoder_seq_proj)
224
+
225
+ processed_query = self.W(query).unsqueeze(1)
226
+
227
+ location = self.cumulative.unsqueeze(1)
228
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
229
+
230
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
231
+ u = u.squeeze(-1)
232
+
233
+ # Mask zero padding chars
234
+ u = u * (chars != 0).float()
235
+
236
+ # Smooth Attention
237
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
238
+ scores = F.softmax(u, dim=1)
239
+ self.attention = scores
240
+ self.cumulative = self.cumulative + self.attention
241
+
242
+ return scores.unsqueeze(-1).transpose(1, 2)
243
+
244
+
245
+ class Decoder(nn.Module):
246
+ # Class variable because its value doesn't change between classes
247
+ # yet ought to be scoped by class because its a property of a Decoder
248
+ max_r = 20
249
+ def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
250
+ dropout, speaker_embedding_size):
251
+ super().__init__()
252
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
253
+ self.n_mels = n_mels
254
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
255
+ self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
256
+ dropout=dropout)
257
+ self.attn_net = LSA(decoder_dims)
258
+ self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
259
+ self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
260
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
261
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
262
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
263
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
264
+
265
+ def zoneout(self, prev, current, p=0.1):
266
+ device = next(self.parameters()).device # Use same device as parameters
267
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
268
+ return prev * mask + current * (1 - mask)
269
+
270
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
271
+ hidden_states, cell_states, context_vec, t, chars):
272
+
273
+ # Need this for reshaping mels
274
+ batch_size = encoder_seq.size(0)
275
+
276
+ # Unpack the hidden and cell states
277
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
278
+ rnn1_cell, rnn2_cell = cell_states
279
+
280
+ # PreNet for the Attention RNN
281
+ prenet_out = self.prenet(prenet_in)
282
+
283
+ # Compute the Attention RNN hidden state
284
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
285
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
286
+
287
+ # Compute the attention scores
288
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
289
+
290
+ # Dot product to create the context vector
291
+ context_vec = scores @ encoder_seq
292
+ context_vec = context_vec.squeeze(1)
293
+
294
+ # Concat Attention RNN output w. Context Vector & project
295
+ x = torch.cat([context_vec, attn_hidden], dim=1)
296
+ x = self.rnn_input(x)
297
+
298
+ # Compute first Residual RNN
299
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
300
+ if self.training:
301
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
302
+ else:
303
+ rnn1_hidden = rnn1_hidden_next
304
+ x = x + rnn1_hidden
305
+
306
+ # Compute second Residual RNN
307
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
308
+ if self.training:
309
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
310
+ else:
311
+ rnn2_hidden = rnn2_hidden_next
312
+ x = x + rnn2_hidden
313
+
314
+ # Project Mels
315
+ mels = self.mel_proj(x)
316
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
317
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
318
+ cell_states = (rnn1_cell, rnn2_cell)
319
+
320
+ # Stop token prediction
321
+ s = torch.cat((x, context_vec), dim=1)
322
+ s = self.stop_proj(s)
323
+ stop_tokens = torch.sigmoid(s)
324
+
325
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
326
+
327
+
328
+ class Tacotron(nn.Module):
329
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
330
+ fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
331
+ dropout, stop_threshold, speaker_embedding_size):
332
+ super().__init__()
333
+ self.n_mels = n_mels
334
+ self.lstm_dims = lstm_dims
335
+ self.encoder_dims = encoder_dims
336
+ self.decoder_dims = decoder_dims
337
+ self.speaker_embedding_size = speaker_embedding_size
338
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
339
+ encoder_K, num_highways, dropout)
340
+ self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
341
+ self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
342
+ dropout, speaker_embedding_size)
343
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
344
+ [postnet_dims, fft_bins], num_highways)
345
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
346
+
347
+ self.init_model()
348
+ self.num_params()
349
+
350
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
351
+ self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
352
+
353
+ @property
354
+ def r(self):
355
+ return self.decoder.r.item()
356
+
357
+ @r.setter
358
+ def r(self, value):
359
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
360
+
361
+ def forward(self, x, m, speaker_embedding):
362
+ device = next(self.parameters()).device # use same device as parameters
363
+
364
+ self.step += 1
365
+ batch_size, _, steps = m.size()
366
+
367
+ # Initialise all hidden states and pack into tuple
368
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
369
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
370
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
371
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
372
+
373
+ # Initialise all lstm cell states and pack into tuple
374
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
375
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
376
+ cell_states = (rnn1_cell, rnn2_cell)
377
+
378
+ # <GO> Frame for start of decoder loop
379
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
380
+
381
+ # Need an initial context vector
382
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
383
+
384
+ # SV2TTS: Run the encoder with the speaker embedding
385
+ # The projection avoids unnecessary matmuls in the decoder loop
386
+ encoder_seq = self.encoder(x, speaker_embedding)
387
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
388
+
389
+ # Need a couple of lists for outputs
390
+ mel_outputs, attn_scores, stop_outputs = [], [], []
391
+
392
+ # Run the decoder loop
393
+ for t in range(0, steps, self.r):
394
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
395
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
396
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
397
+ hidden_states, cell_states, context_vec, t, x)
398
+ mel_outputs.append(mel_frames)
399
+ attn_scores.append(scores)
400
+ stop_outputs.extend([stop_tokens] * self.r)
401
+
402
+ # Concat the mel outputs into sequence
403
+ mel_outputs = torch.cat(mel_outputs, dim=2)
404
+
405
+ # Post-Process for Linear Spectrograms
406
+ postnet_out = self.postnet(mel_outputs)
407
+ linear = self.post_proj(postnet_out)
408
+ linear = linear.transpose(1, 2)
409
+
410
+ # For easy visualisation
411
+ attn_scores = torch.cat(attn_scores, 1)
412
+ # attn_scores = attn_scores.cpu().data.numpy()
413
+ stop_outputs = torch.cat(stop_outputs, 1)
414
+
415
+ return mel_outputs, linear, attn_scores, stop_outputs
416
+
417
+ def generate(self, x, speaker_embedding=None, steps=2000):
418
+ self.eval()
419
+ device = next(self.parameters()).device # use same device as parameters
420
+
421
+ batch_size, _ = x.size()
422
+
423
+ # Need to initialise all hidden states and pack into tuple for tidyness
424
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
425
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
426
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
427
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
428
+
429
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
430
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
431
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
432
+ cell_states = (rnn1_cell, rnn2_cell)
433
+
434
+ # Need a <GO> Frame for start of decoder loop
435
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
436
+
437
+ # Need an initial context vector
438
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
439
+
440
+ # SV2TTS: Run the encoder with the speaker embedding
441
+ # The projection avoids unnecessary matmuls in the decoder loop
442
+ encoder_seq = self.encoder(x, speaker_embedding)
443
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
444
+
445
+ # Need a couple of lists for outputs
446
+ mel_outputs, attn_scores, stop_outputs = [], [], []
447
+
448
+ # Run the decoder loop
449
+ for t in range(0, steps, self.r):
450
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
451
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
452
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
453
+ hidden_states, cell_states, context_vec, t, x)
454
+ mel_outputs.append(mel_frames)
455
+ attn_scores.append(scores)
456
+ stop_outputs.extend([stop_tokens] * self.r)
457
+ # Stop the loop when all stop tokens in batch exceed threshold
458
+ if (stop_tokens > 0.5).all() and t > 10: break
459
+
460
+ # Concat the mel outputs into sequence
461
+ mel_outputs = torch.cat(mel_outputs, dim=2)
462
+
463
+ # Post-Process for Linear Spectrograms
464
+ postnet_out = self.postnet(mel_outputs)
465
+ linear = self.post_proj(postnet_out)
466
+
467
+
468
+ linear = linear.transpose(1, 2)
469
+
470
+ # For easy visualisation
471
+ attn_scores = torch.cat(attn_scores, 1)
472
+ stop_outputs = torch.cat(stop_outputs, 1)
473
+
474
+ self.train()
475
+
476
+ return mel_outputs, linear, attn_scores
477
+
478
+ def init_model(self):
479
+ for p in self.parameters():
480
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
481
+
482
+ def get_step(self):
483
+ return self.step.data.item()
484
+
485
+ def reset_step(self):
486
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
487
+ self.step = self.step.data.new_tensor(1)
488
+
489
+ def log(self, path, msg):
490
+ with open(path, "a") as f:
491
+ print(msg, file=f)
492
+
493
+ def load(self, path, optimizer=None):
494
+ # Use device of model params as location for loaded state
495
+ device = next(self.parameters()).device
496
+ checkpoint = torch.load(str(path), map_location=device)
497
+ self.load_state_dict(checkpoint["model_state"])
498
+
499
+ if "optimizer_state" in checkpoint and optimizer is not None:
500
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
501
+
502
+ def save(self, path, optimizer=None):
503
+ if optimizer is not None:
504
+ torch.save({
505
+ "model_state": self.state_dict(),
506
+ "optimizer_state": optimizer.state_dict(),
507
+ }, str(path))
508
+ else:
509
+ torch.save({
510
+ "model_state": self.state_dict(),
511
+ }, str(path))
512
+
513
+
514
+ def num_params(self, print_out=True):
515
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
516
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
517
+ if print_out:
518
+ print("Trainable Parameters: %.3fM" % parameters)
519
+ return parameters