ORI-Muchim commited on
Commit
924e8f7
·
verified ·
1 Parent(s): 0ce22b3

Upload 25 files

Browse files
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: HioriTTS
3
- emoji: 🏃
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
  app_file: app.py
 
1
  ---
2
  title: HioriTTS
3
+ emoji: 📊
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ from torch import no_grad, LongTensor
9
+ import commons
10
+ import utils
11
+ import gradio as gr
12
+ from models import SynthesizerTrn
13
+ from text import text_to_sequence, _clean_text
14
+ from mel_processing import spectrogram_torch
15
+
16
+ from text.symbols import symbols
17
+
18
+ limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
19
+
20
+ device = 'cpu'
21
+
22
+ def get_text(text, hps):
23
+ text_norm = text_to_sequence(text, hps.data.text_cleaners)
24
+ if hps.data.add_blank:
25
+ text_norm = commons.intersperse(text_norm, 0)
26
+ text_norm = LongTensor(text_norm)
27
+ return text_norm
28
+
29
+
30
+ def create_tts_fn(model, hps, speaker_ids):
31
+ def tts_fn(text, speaker, speed):
32
+ print(speaker, text)
33
+ if limitation:
34
+ text_len = len(text)
35
+ max_len = 500
36
+ if len(hps.data.text_cleaners) > 0 and hps.data.text_cleaners[0] == "zh_ja_mixture_cleaners":
37
+ text_len = len(re.sub("(\[ZH\]|\[JA\])", "", text))
38
+ if text_len > max_len:
39
+ return "Error: Text is too long", None
40
+
41
+ speaker_id = speaker_ids[speaker]
42
+ stn_tst = get_text(text, hps)
43
+ with no_grad():
44
+ x_tst = stn_tst.unsqueeze(0)
45
+ x_tst_lengths = LongTensor([stn_tst.size(0)])
46
+ sid = LongTensor([speaker_id])
47
+ audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
48
+ length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
49
+ del stn_tst, x_tst, x_tst_lengths, sid
50
+ return "Success", (hps.data.sampling_rate, audio)
51
+
52
+ return tts_fn
53
+
54
+
55
+ def create_to_phoneme_fn(hps):
56
+ def to_phoneme_fn(text):
57
+ return _clean_text(text, hps.data.text_cleaners) if text != "" else ""
58
+
59
+ return to_phoneme_fn
60
+
61
+
62
+ css = """
63
+ #advanced-btn {
64
+ color: white;
65
+ border-color: black;
66
+ background: black;
67
+ font-size: .7rem !important;
68
+ line-height: 19px;
69
+ margin-top: 24px;
70
+ margin-bottom: 12px;
71
+ padding: 2px 8px;
72
+ border-radius: 14px !important;
73
+ }
74
+ #advanced-options {
75
+ display: none;
76
+ margin-bottom: 20px;
77
+ }
78
+ """
79
+
80
+ if __name__ == '__main__':
81
+ models_tts = []
82
+ name = 'HioriTTS'
83
+ lang = '日本語 (Japanese)'
84
+ example = 'プロデューサー、今日も良い一日を!'
85
+ config_path = f"saved_model/config.json"
86
+ model_path = f"saved_model/model.pth"
87
+ cover_path = f"saved_model/cover.png"
88
+ hps = utils.get_hparams_from_file(config_path)
89
+
90
+ if "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True:
91
+ print("Using mel posterior encoder for VITS2")
92
+ posterior_channels = 80 # vits2
93
+ hps.data.use_mel_posterior_encoder = True
94
+ else:
95
+ print("Using lin posterior encoder for VITS1")
96
+ posterior_channels = hps.data.filter_length // 2 + 1
97
+ hps.data.use_mel_posterior_encoder = False
98
+
99
+ model = SynthesizerTrn(
100
+ len(symbols),
101
+ posterior_channels,
102
+ hps.train.segment_size // hps.data.hop_length,
103
+ n_speakers=hps.data.n_speakers, #- >0 for multi speaker
104
+ **hps.model)
105
+ utils.load_checkpoint(model_path, model, None)
106
+ model.eval()
107
+ speaker_ids = [sid for sid, name in enumerate(hps.speakers) if name != "None"]
108
+ speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"]
109
+
110
+ t = 'vits'
111
+ models_tts.append((name, cover_path, speakers, lang, example,
112
+ symbols, create_tts_fn(model, hps, speaker_ids),
113
+ create_to_phoneme_fn(hps)))
114
+
115
+
116
+ app = gr.Blocks(css=css)
117
+
118
+ with app:
119
+ gr.Markdown("# HioriTTS Using VITS2 Model\n\n"
120
+ "## Model Updated: VITS -> VITS2\n\n"
121
+ "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=ORI-Muchim.HioriTTS)\n\n")
122
+ with gr.Tabs():
123
+ with gr.TabItem("TTS"):
124
+ with gr.Tabs():
125
+ for i, (name, cover_path, speakers, lang, example, symbols, tts_fn,
126
+ to_phoneme_fn) in enumerate(models_tts):
127
+ with gr.TabItem(f"Hiori"):
128
+ with gr.Column():
129
+ gr.Markdown(f"## {name}\n\n"
130
+ f"![cover](file/{cover_path})\n\n"
131
+ f"lang: {lang}")
132
+ tts_input1 = gr.TextArea(label="Text (500 words limitation)", value=example,
133
+ elem_id=f"tts-input{i}")
134
+ tts_input2 = gr.Dropdown(label="Speaker", choices=speakers,
135
+ type="index", value=speakers[0])
136
+ tts_input3 = gr.Slider(label="Speed", value=1.2, minimum=0.1, maximum=2, step=0.1)
137
+ tts_submit = gr.Button("Generate", variant="primary")
138
+ tts_output1 = gr.Textbox(label="Output Message")
139
+ tts_output2 = gr.Audio(label="Output Audio")
140
+ tts_submit.click(tts_fn, [tts_input1, tts_input2, tts_input3],
141
+ [tts_output1, tts_output2])
142
+
143
+ app.queue(concurrency_count=3).launch(show_api=False)
attentions.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils import remove_weight_norm, weight_norm
8
+
9
+ import commons
10
+ import modules
11
+ from modules import LayerNorm
12
+
13
+ class Encoder(nn.Module): #backward compatible vits2 encoder
14
+ def __init__(
15
+ self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs
16
+ ):
17
+ super().__init__()
18
+ self.hidden_channels = hidden_channels
19
+ self.filter_channels = filter_channels
20
+ self.n_heads = n_heads
21
+ self.n_layers = n_layers
22
+ self.kernel_size = kernel_size
23
+ self.p_dropout = p_dropout
24
+ self.window_size = window_size
25
+
26
+ self.drop = nn.Dropout(p_dropout)
27
+ self.attn_layers = nn.ModuleList()
28
+ self.norm_layers_1 = nn.ModuleList()
29
+ self.ffn_layers = nn.ModuleList()
30
+ self.norm_layers_2 = nn.ModuleList()
31
+ # if kwargs has spk_emb_dim, then add a linear layer to project spk_emb_dim to hidden_channels
32
+ self.cond_layer_idx = self.n_layers
33
+ if 'gin_channels' in kwargs:
34
+ self.gin_channels = kwargs['gin_channels']
35
+ if self.gin_channels != 0:
36
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
37
+ # vits2 says 3rd block, so idx is 2 by default
38
+ self.cond_layer_idx = kwargs['cond_layer_idx'] if 'cond_layer_idx' in kwargs else 2
39
+ print(self.gin_channels, self.cond_layer_idx)
40
+ assert self.cond_layer_idx < self.n_layers, 'cond_layer_idx should be less than n_layers'
41
+
42
+ for i in range(self.n_layers):
43
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
44
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
45
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
46
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
47
+
48
+ def forward(self, x, x_mask, g=None):
49
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
50
+ x = x * x_mask
51
+ for i in range(self.n_layers):
52
+ if i == self.cond_layer_idx and g is not None:
53
+ g = self.spk_emb_linear(g.transpose(1, 2))
54
+ g = g.transpose(1, 2)
55
+ x = x + g
56
+ x = x * x_mask
57
+ y = self.attn_layers[i](x, x, attn_mask)
58
+ y = self.drop(y)
59
+ x = self.norm_layers_1[i](x + y)
60
+
61
+ y = self.ffn_layers[i](x, x_mask)
62
+ y = self.drop(y)
63
+ x = self.norm_layers_2[i](x + y)
64
+ x = x * x_mask
65
+ return x
66
+
67
+ class Decoder(nn.Module):
68
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
69
+ super().__init__()
70
+ self.hidden_channels = hidden_channels
71
+ self.filter_channels = filter_channels
72
+ self.n_heads = n_heads
73
+ self.n_layers = n_layers
74
+ self.kernel_size = kernel_size
75
+ self.p_dropout = p_dropout
76
+ self.proximal_bias = proximal_bias
77
+ self.proximal_init = proximal_init
78
+
79
+ self.drop = nn.Dropout(p_dropout)
80
+ self.self_attn_layers = nn.ModuleList()
81
+ self.norm_layers_0 = nn.ModuleList()
82
+ self.encdec_attn_layers = nn.ModuleList()
83
+ self.norm_layers_1 = nn.ModuleList()
84
+ self.ffn_layers = nn.ModuleList()
85
+ self.norm_layers_2 = nn.ModuleList()
86
+ for i in range(self.n_layers):
87
+ self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
88
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
89
+ self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
90
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
91
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
92
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
93
+
94
+ def forward(self, x, x_mask, h, h_mask):
95
+ """
96
+ x: decoder input
97
+ h: encoder output
98
+ """
99
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
100
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
101
+ x = x * x_mask
102
+ for i in range(self.n_layers):
103
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
104
+ y = self.drop(y)
105
+ x = self.norm_layers_0[i](x + y)
106
+
107
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
108
+ y = self.drop(y)
109
+ x = self.norm_layers_1[i](x + y)
110
+
111
+ y = self.ffn_layers[i](x, x_mask)
112
+ y = self.drop(y)
113
+ x = self.norm_layers_2[i](x + y)
114
+ x = x * x_mask
115
+ return x
116
+
117
+
118
+ class MultiHeadAttention(nn.Module):
119
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
120
+ super().__init__()
121
+ assert channels % n_heads == 0
122
+
123
+ self.channels = channels
124
+ self.out_channels = out_channels
125
+ self.n_heads = n_heads
126
+ self.p_dropout = p_dropout
127
+ self.window_size = window_size
128
+ self.heads_share = heads_share
129
+ self.block_length = block_length
130
+ self.proximal_bias = proximal_bias
131
+ self.proximal_init = proximal_init
132
+ self.attn = None
133
+
134
+ self.k_channels = channels // n_heads
135
+ self.conv_q = nn.Conv1d(channels, channels, 1)
136
+ self.conv_k = nn.Conv1d(channels, channels, 1)
137
+ self.conv_v = nn.Conv1d(channels, channels, 1)
138
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
139
+ self.drop = nn.Dropout(p_dropout)
140
+
141
+ if window_size is not None:
142
+ n_heads_rel = 1 if heads_share else n_heads
143
+ rel_stddev = self.k_channels**-0.5
144
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
145
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
146
+
147
+ nn.init.xavier_uniform_(self.conv_q.weight)
148
+ nn.init.xavier_uniform_(self.conv_k.weight)
149
+ nn.init.xavier_uniform_(self.conv_v.weight)
150
+ if proximal_init:
151
+ with torch.no_grad():
152
+ self.conv_k.weight.copy_(self.conv_q.weight)
153
+ self.conv_k.bias.copy_(self.conv_q.bias)
154
+
155
+ def forward(self, x, c, attn_mask=None):
156
+ q = self.conv_q(x)
157
+ k = self.conv_k(c)
158
+ v = self.conv_v(c)
159
+
160
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
161
+
162
+ x = self.conv_o(x)
163
+ return x
164
+
165
+ def attention(self, query, key, value, mask=None):
166
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
167
+ b, d, t_s, t_t = (*key.size(), query.size(2))
168
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
169
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
170
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
171
+
172
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
173
+ if self.window_size is not None:
174
+ assert t_s == t_t, "Relative attention is only available for self-attention."
175
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
176
+ rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
177
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
178
+ scores = scores + scores_local
179
+ if self.proximal_bias:
180
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
181
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
182
+ if mask is not None:
183
+ scores = scores.masked_fill(mask == 0, -1e4)
184
+ if self.block_length is not None:
185
+ assert t_s == t_t, "Local attention is only available for self-attention."
186
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
187
+ scores = scores.masked_fill(block_mask == 0, -1e4)
188
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
189
+ p_attn = self.drop(p_attn)
190
+ output = torch.matmul(p_attn, value)
191
+ if self.window_size is not None:
192
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
193
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
194
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
195
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
196
+ return output, p_attn
197
+
198
+ def _matmul_with_relative_values(self, x, y):
199
+ """
200
+ x: [b, h, l, m]
201
+ y: [h or 1, m, d]
202
+ ret: [b, h, l, d]
203
+ """
204
+ ret = torch.matmul(x, y.unsqueeze(0))
205
+ return ret
206
+
207
+ def _matmul_with_relative_keys(self, x, y):
208
+ """
209
+ x: [b, h, l, d]
210
+ y: [h or 1, m, d]
211
+ ret: [b, h, l, m]
212
+ """
213
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
214
+ return ret
215
+
216
+ def _get_relative_embeddings(self, relative_embeddings, length):
217
+ max_relative_position = 2 * self.window_size + 1
218
+ # Pad first before slice to avoid using cond ops.
219
+ pad_length = max(length - (self.window_size + 1), 0)
220
+ slice_start_position = max((self.window_size + 1) - length, 0)
221
+ slice_end_position = slice_start_position + 2 * length - 1
222
+ if pad_length > 0:
223
+ padded_relative_embeddings = F.pad(
224
+ relative_embeddings,
225
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
226
+ else:
227
+ padded_relative_embeddings = relative_embeddings
228
+ used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
229
+ return used_relative_embeddings
230
+
231
+ def _relative_position_to_absolute_position(self, x):
232
+ """
233
+ x: [b, h, l, 2*l-1]
234
+ ret: [b, h, l, l]
235
+ """
236
+ batch, heads, length, _ = x.size()
237
+ # Concat columns of pad to shift from relative to absolute indexing.
238
+ x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
239
+
240
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
241
+ x_flat = x.view([batch, heads, length * 2 * length])
242
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
243
+
244
+ # Reshape and slice out the padded elements.
245
+ x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
246
+ return x_final
247
+
248
+ def _absolute_position_to_relative_position(self, x):
249
+ """
250
+ x: [b, h, l, l]
251
+ ret: [b, h, l, 2*l-1]
252
+ """
253
+ batch, heads, length, _ = x.size()
254
+ # padd along column
255
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
256
+ x_flat = x.view([batch, heads, length**2 + length*(length -1)])
257
+ # add 0's in the beginning that will skew the elements after reshape
258
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
259
+ x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
260
+ return x_final
261
+
262
+ def _attention_bias_proximal(self, length):
263
+ """Bias for self-attention to encourage attention to close positions.
264
+ Args:
265
+ length: an integer scalar.
266
+ Returns:
267
+ a Tensor with shape [1, 1, length, length]
268
+ """
269
+ r = torch.arange(length, dtype=torch.float32)
270
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
271
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
272
+
273
+
274
+ class FFN(nn.Module):
275
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
276
+ super().__init__()
277
+ self.in_channels = in_channels
278
+ self.out_channels = out_channels
279
+ self.filter_channels = filter_channels
280
+ self.kernel_size = kernel_size
281
+ self.p_dropout = p_dropout
282
+ self.activation = activation
283
+ self.causal = causal
284
+
285
+ if causal:
286
+ self.padding = self._causal_padding
287
+ else:
288
+ self.padding = self._same_padding
289
+
290
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
291
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
292
+ self.drop = nn.Dropout(p_dropout)
293
+
294
+ def forward(self, x, x_mask):
295
+ x = self.conv_1(self.padding(x * x_mask))
296
+ if self.activation == "gelu":
297
+ x = x * torch.sigmoid(1.702 * x)
298
+ else:
299
+ x = torch.relu(x)
300
+ x = self.drop(x)
301
+ x = self.conv_2(self.padding(x * x_mask))
302
+ return x * x_mask
303
+
304
+ def _causal_padding(self, x):
305
+ if self.kernel_size == 1:
306
+ return x
307
+ pad_l = self.kernel_size - 1
308
+ pad_r = 0
309
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
310
+ x = F.pad(x, commons.convert_pad_shape(padding))
311
+ return x
312
+
313
+ def _same_padding(self, x):
314
+ if self.kernel_size == 1:
315
+ return x
316
+ pad_l = (self.kernel_size - 1) // 2
317
+ pad_r = self.kernel_size // 2
318
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
319
+ x = F.pad(x, commons.convert_pad_shape(padding))
320
+ return x
321
+
322
+
323
+ class Depthwise_Separable_Conv1D(nn.Module):
324
+ def __init__(
325
+ self,
326
+ in_channels,
327
+ out_channels,
328
+ kernel_size,
329
+ stride = 1,
330
+ padding = 0,
331
+ dilation = 1,
332
+ bias = True,
333
+ padding_mode = 'zeros', # TODO: refine this type
334
+ device=None,
335
+ dtype=None
336
+ ):
337
+ super().__init__()
338
+ self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
339
+ self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
340
+
341
+ def forward(self, input):
342
+ return self.point_conv(self.depth_conv(input))
343
+
344
+ def weight_norm(self):
345
+ self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
346
+ self.point_conv = weight_norm(self.point_conv, name = 'weight')
347
+
348
+ def remove_weight_norm(self):
349
+ self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
350
+ self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')
351
+
352
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
353
+ def __init__(
354
+ self,
355
+ in_channels,
356
+ out_channels,
357
+ kernel_size,
358
+ stride = 1,
359
+ padding = 0,
360
+ output_padding = 0,
361
+ bias = True,
362
+ dilation = 1,
363
+ padding_mode = 'zeros', # TODO: refine this type
364
+ device=None,
365
+ dtype=None
366
+ ):
367
+ super().__init__()
368
+ self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
369
+ self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
370
+
371
+ def forward(self, input):
372
+ return self.point_conv(self.depth_conv(input))
373
+
374
+ def weight_norm(self):
375
+ self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
376
+ self.point_conv = weight_norm(self.point_conv, name = 'weight')
377
+
378
+ def remove_weight_norm(self):
379
+ remove_weight_norm(self.depth_conv, name = 'weight')
380
+ remove_weight_norm(self.point_conv, name = 'weight')
381
+
382
+
383
+ def weight_norm_modules(module, name = 'weight', dim = 0):
384
+ if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
385
+ module.weight_norm()
386
+ return module
387
+ else:
388
+ return weight_norm(module,name,dim)
389
+
390
+ def remove_weight_norm_modules(module, name = 'weight'):
391
+ if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
392
+ module.remove_weight_norm()
393
+ else:
394
+ remove_weight_norm(module,name)
395
+
396
+ class FFT(nn.Module):
397
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
398
+ proximal_bias=False, proximal_init=True, isflow = False, **kwargs):
399
+ super().__init__()
400
+ self.hidden_channels = hidden_channels
401
+ self.filter_channels = filter_channels
402
+ self.n_heads = n_heads
403
+ self.n_layers = n_layers
404
+ self.kernel_size = kernel_size
405
+ self.p_dropout = p_dropout
406
+ self.proximal_bias = proximal_bias
407
+ self.proximal_init = proximal_init
408
+ if isflow and 'gin_channels' in kwargs and kwargs["gin_channels"] > 0:
409
+ cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1)
410
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
411
+ self.cond_layer = weight_norm_modules(cond_layer, name='weight')
412
+ self.gin_channels = kwargs["gin_channels"]
413
+ self.drop = nn.Dropout(p_dropout)
414
+ self.self_attn_layers = nn.ModuleList()
415
+ self.norm_layers_0 = nn.ModuleList()
416
+ self.ffn_layers = nn.ModuleList()
417
+ self.norm_layers_1 = nn.ModuleList()
418
+ for i in range(self.n_layers):
419
+ self.self_attn_layers.append(
420
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias,
421
+ proximal_init=proximal_init))
422
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
423
+ self.ffn_layers.append(
424
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
425
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
426
+
427
+ def forward(self, x, x_mask, g = None):
428
+ """
429
+ x: decoder input
430
+ h: encoder output
431
+ """
432
+ if g is not None:
433
+ g = self.cond_layer(g)
434
+
435
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
436
+ x = x * x_mask
437
+ for i in range(self.n_layers):
438
+ if g is not None:
439
+ x = self.cond_pre(x)
440
+ cond_offset = i * 2 * self.hidden_channels
441
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
442
+ x = commons.fused_add_tanh_sigmoid_multiply(
443
+ x,
444
+ g_l,
445
+ torch.IntTensor([self.hidden_channels]))
446
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
447
+ y = self.drop(y)
448
+ x = self.norm_layers_0[i](x + y)
449
+
450
+ y = self.ffn_layers[i](x, x_mask)
451
+ y = self.drop(y)
452
+ x = self.norm_layers_1[i](x + y)
453
+ x = x * x_mask
454
+ return x
commons.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size*dilation - dilation)/2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(
68
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
69
+ position = torch.arange(length, dtype=torch.float)
70
+ num_timescales = channels // 2
71
+ log_timescale_increment = (
72
+ math.log(float(max_timescale) / float(min_timescale)) /
73
+ (num_timescales - 1))
74
+ inv_timescales = min_timescale * torch.exp(
75
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ l = pad_shape[::-1]
112
+ pad_shape = [item for sublist in l for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+ device = duration.device
134
+
135
+ b, _, t_y, t_x = mask.shape
136
+ cum_duration = torch.cumsum(duration, -1)
137
+
138
+ cum_duration_flat = cum_duration.view(b * t_x)
139
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140
+ path = path.view(b, t_x, t_y)
141
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142
+ path = path.unsqueeze(1).transpose(2,3) * mask
143
+ return path
144
+
145
+
146
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
147
+ if isinstance(parameters, torch.Tensor):
148
+ parameters = [parameters]
149
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
150
+ norm_type = float(norm_type)
151
+ if clip_value is not None:
152
+ clip_value = float(clip_value)
153
+
154
+ total_norm = 0
155
+ for p in parameters:
156
+ param_norm = p.grad.data.norm(norm_type)
157
+ total_norm += param_norm.item() ** norm_type
158
+ if clip_value is not None:
159
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
+ total_norm = total_norm ** (1. / norm_type)
161
+ return total_norm
export_model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ if __name__ == '__main__':
4
+ model_path = "saved_model/model.pth"
5
+ output_path = "saved_model/model1.pth"
6
+ checkpoint_dict = torch.load(model_path, map_location='cpu')
7
+ checkpoint_dict_new = {}
8
+ for k, v in checkpoint_dict.items():
9
+ if k == "optimizer":
10
+ print("remove optimizer")
11
+ continue
12
+ checkpoint_dict_new[k] = v
13
+ torch.save(checkpoint_dict_new, output_path)
mel_processing.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from packaging import version
4
+ import random
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.data
9
+ import numpy as np
10
+ import librosa
11
+ import librosa.util as librosa_util
12
+ from librosa.util import normalize, pad_center, tiny
13
+ from scipy.signal import get_window
14
+ from scipy.io.wavfile import read
15
+ from librosa.filters import mel as librosa_mel_fn
16
+
17
+ MAX_WAV_VALUE = 32768.0
18
+
19
+
20
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
21
+ """
22
+ PARAMS
23
+ ------
24
+ C: compression factor
25
+ """
26
+ return torch.log(torch.clamp(x, min=clip_val) * C)
27
+
28
+
29
+ def dynamic_range_decompression_torch(x, C=1):
30
+ """
31
+ PARAMS
32
+ ------
33
+ C: compression factor used to compress
34
+ """
35
+ return torch.exp(x) / C
36
+
37
+
38
+ def spectral_normalize_torch(magnitudes):
39
+ output = dynamic_range_compression_torch(magnitudes)
40
+ return output
41
+
42
+
43
+ def spectral_de_normalize_torch(magnitudes):
44
+ output = dynamic_range_decompression_torch(magnitudes)
45
+ return output
46
+
47
+
48
+ mel_basis = {}
49
+ hann_window = {}
50
+
51
+
52
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
53
+ if torch.min(y) < -1.:
54
+ print('min value is ', torch.min(y))
55
+ if torch.max(y) > 1.:
56
+ print('max value is ', torch.max(y))
57
+
58
+ global hann_window
59
+ dtype_device = str(y.dtype) + '_' + str(y.device)
60
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
61
+ if wnsize_dtype_device not in hann_window:
62
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
63
+
64
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
65
+ y = y.squeeze(1)
66
+
67
+ if version.parse(torch.__version__) >= version.parse("2"):
68
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
69
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
70
+ else:
71
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
72
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
73
+
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
75
+ return spec
76
+
77
+
78
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
79
+ global mel_basis
80
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
81
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
82
+ if fmax_dtype_device not in mel_basis:
83
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
84
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
85
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
86
+ spec = spectral_normalize_torch(spec)
87
+ return spec
88
+
89
+
90
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
91
+ if torch.min(y) < -1.:
92
+ print('min value is ', torch.min(y))
93
+ if torch.max(y) > 1.:
94
+ print('max value is ', torch.max(y))
95
+
96
+ global mel_basis, hann_window
97
+ dtype_device = str(y.dtype) + '_' + str(y.device)
98
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
99
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
100
+ if fmax_dtype_device not in mel_basis:
101
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
102
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
103
+ if wnsize_dtype_device not in hann_window:
104
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
105
+
106
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
107
+ y = y.squeeze(1)
108
+
109
+ if version.parse(torch.__version__) >= version.parse("2"):
110
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
111
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
112
+ else:
113
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
114
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
115
+ '''
116
+ #- reserve : from https://github.com/jaywalnut310/vits/issues/15#issuecomment-1084148441
117
+ with autocast(enabled=False):
118
+ y = y.float()
119
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
120
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
121
+ '''
122
+
123
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
124
+
125
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
126
+ spec = spectral_normalize_torch(spec)
127
+
128
+ return spec
models.py ADDED
@@ -0,0 +1,1464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ import modules
9
+ import attentions
10
+ import monotonic_align
11
+
12
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
13
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
+ from commons import init_weights, get_padding
15
+
16
+ from pqmf import PQMF
17
+ from stft import TorchSTFT, OnnxSTFT
18
+
19
+ AVAILABLE_FLOW_TYPES = ["pre_conv", "pre_conv2", "fft", "mono_layer_inter_residual", "mono_layer_post_residual"]
20
+ AVAILABLE_DURATION_DISCRIMINATOR_TYPES = {"dur_disc_1": "DurationDiscriminator", "dur_disc_2": "DurationDiscriminator2"}
21
+
22
+
23
+ class StochasticDurationPredictor(nn.Module):
24
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
25
+ super().__init__()
26
+ filter_channels = in_channels # it needs to be removed from future version.
27
+ self.in_channels = in_channels
28
+ self.filter_channels = filter_channels
29
+ self.kernel_size = kernel_size
30
+ self.p_dropout = p_dropout
31
+ self.n_flows = n_flows
32
+ self.gin_channels = gin_channels
33
+
34
+ self.log_flow = modules.Log()
35
+ self.flows = nn.ModuleList()
36
+ self.flows.append(modules.ElementwiseAffine(2))
37
+ for i in range(n_flows):
38
+ self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
39
+ self.flows.append(modules.Flip())
40
+
41
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
42
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
43
+ self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
44
+ self.post_flows = nn.ModuleList()
45
+ self.post_flows.append(modules.ElementwiseAffine(2))
46
+ for i in range(4):
47
+ self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
48
+ self.post_flows.append(modules.Flip())
49
+
50
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
51
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
52
+ self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
53
+ if gin_channels != 0:
54
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
55
+
56
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
57
+ x = torch.detach(x)
58
+ x = self.pre(x)
59
+ if g is not None:
60
+ g = torch.detach(g)
61
+ x = x + self.cond(g)
62
+ x = self.convs(x, x_mask)
63
+ x = self.proj(x) * x_mask
64
+
65
+ if not reverse:
66
+ flows = self.flows
67
+ assert w is not None
68
+
69
+ logdet_tot_q = 0
70
+ h_w = self.post_pre(w)
71
+ h_w = self.post_convs(h_w, x_mask)
72
+ h_w = self.post_proj(h_w) * x_mask
73
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
74
+ z_q = e_q
75
+ for flow in self.post_flows:
76
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
77
+ logdet_tot_q += logdet_q
78
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
79
+ u = torch.sigmoid(z_u) * x_mask
80
+ z0 = (w - u) * x_mask
81
+ logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
82
+ logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
83
+
84
+ logdet_tot = 0
85
+ z0, logdet = self.log_flow(z0, x_mask)
86
+ logdet_tot += logdet
87
+ z = torch.cat([z0, z1], 1)
88
+ for flow in flows:
89
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
90
+ logdet_tot = logdet_tot + logdet
91
+ nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
92
+ return nll + logq # [b]
93
+ else:
94
+ flows = list(reversed(self.flows))
95
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
96
+ z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
97
+ for flow in flows:
98
+ z = flow(z, x_mask, g=x, reverse=reverse)
99
+ z0, z1 = torch.split(z, [1, 1], 1)
100
+ logw = z0
101
+ return logw
102
+
103
+
104
+ class DurationPredictor(nn.Module):
105
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
106
+ super().__init__()
107
+
108
+ self.in_channels = in_channels
109
+ self.filter_channels = filter_channels
110
+ self.kernel_size = kernel_size
111
+ self.p_dropout = p_dropout
112
+ self.gin_channels = gin_channels
113
+
114
+ self.drop = nn.Dropout(p_dropout)
115
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
116
+ self.norm_1 = modules.LayerNorm(filter_channels)
117
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
118
+ self.norm_2 = modules.LayerNorm(filter_channels)
119
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
120
+
121
+ if gin_channels != 0:
122
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
123
+
124
+ def forward(self, x, x_mask, g=None):
125
+ x = torch.detach(x)
126
+ if g is not None:
127
+ g = torch.detach(g)
128
+ x = x + self.cond(g)
129
+ x = self.conv_1(x * x_mask)
130
+ x = torch.relu(x)
131
+ x = self.norm_1(x)
132
+ x = self.drop(x)
133
+ x = self.conv_2(x * x_mask)
134
+ x = torch.relu(x)
135
+ x = self.norm_2(x)
136
+ x = self.drop(x)
137
+ x = self.proj(x * x_mask)
138
+ return x * x_mask
139
+
140
+
141
+ class DurationDiscriminator(nn.Module): # vits2
142
+ # TODO : not using "spk conditioning" for now according to the paper.
143
+ # Can be a better discriminator if we use it.
144
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
145
+ super().__init__()
146
+
147
+ self.in_channels = in_channels
148
+ self.filter_channels = filter_channels
149
+ self.kernel_size = kernel_size
150
+ self.p_dropout = p_dropout
151
+ self.gin_channels = gin_channels
152
+
153
+ self.drop = nn.Dropout(p_dropout)
154
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
155
+ # self.norm_1 = modules.LayerNorm(filter_channels)
156
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
157
+ # self.norm_2 = modules.LayerNorm(filter_channels)
158
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
159
+
160
+ self.pre_out_conv_1 = nn.Conv1d(2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
161
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
162
+ self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
163
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
164
+
165
+ # if gin_channels != 0:
166
+ # self.cond = nn.Conv1d(gin_channels, in_channels, 1)
167
+
168
+ self.output_layer = nn.Sequential(
169
+ nn.Linear(filter_channels, 1),
170
+ nn.Sigmoid()
171
+ )
172
+
173
+ def forward_probability(self, x, x_mask, dur, g=None):
174
+ dur = self.dur_proj(dur)
175
+ x = torch.cat([x, dur], dim=1)
176
+ x = self.pre_out_conv_1(x * x_mask)
177
+ # x = torch.relu(x)
178
+ # x = self.pre_out_norm_1(x)
179
+ # x = self.drop(x)
180
+ x = self.pre_out_conv_2(x * x_mask)
181
+ # x = torch.relu(x)
182
+ # x = self.pre_out_norm_2(x)
183
+ # x = self.drop(x)
184
+ x = x * x_mask
185
+ x = x.transpose(1, 2)
186
+ output_prob = self.output_layer(x)
187
+ return output_prob
188
+
189
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
190
+ x = torch.detach(x)
191
+ # if g is not None:
192
+ # g = torch.detach(g)
193
+ # x = x + self.cond(g)
194
+ x = self.conv_1(x * x_mask)
195
+ # x = torch.relu(x)
196
+ # x = self.norm_1(x)
197
+ # x = self.drop(x)
198
+ x = self.conv_2(x * x_mask)
199
+ # x = torch.relu(x)
200
+ # x = self.norm_2(x)
201
+ # x = self.drop(x)
202
+
203
+ output_probs = []
204
+ for dur in [dur_r, dur_hat]:
205
+ output_prob = self.forward_probability(x, x_mask, dur, g)
206
+ output_probs.append(output_prob)
207
+
208
+ return output_probs
209
+
210
+
211
+ class DurationDiscriminator2(nn.Module): # vits2 - DurationDiscriminator2
212
+ # TODO : not using "spk conditioning" for now according to the paper.
213
+ # Can be a better discriminator if we use it.
214
+ def __init__(
215
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
216
+ ):
217
+ super().__init__()
218
+
219
+ self.in_channels = in_channels
220
+ self.filter_channels = filter_channels
221
+ self.kernel_size = kernel_size
222
+ self.p_dropout = p_dropout
223
+ self.gin_channels = gin_channels
224
+
225
+ self.conv_1 = nn.Conv1d(
226
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
227
+ )
228
+ self.norm_1 = modules.LayerNorm(filter_channels)
229
+ self.conv_2 = nn.Conv1d(
230
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
231
+ )
232
+ self.norm_2 = modules.LayerNorm(filter_channels)
233
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
234
+
235
+ self.pre_out_conv_1 = nn.Conv1d(
236
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
237
+ )
238
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
239
+ self.pre_out_conv_2 = nn.Conv1d(
240
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
241
+ )
242
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
243
+
244
+ # if gin_channels != 0:
245
+ # self.cond = nn.Conv1d(gin_channels, in_channels, 1)
246
+
247
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
248
+
249
+ def forward_probability(self, x, x_mask, dur, g=None):
250
+ dur = self.dur_proj(dur)
251
+ x = torch.cat([x, dur], dim=1)
252
+ x = self.pre_out_conv_1(x * x_mask)
253
+ x = torch.relu(x)
254
+ x = self.pre_out_norm_1(x)
255
+ x = self.pre_out_conv_2(x * x_mask)
256
+ x = torch.relu(x)
257
+ x = self.pre_out_norm_2(x)
258
+ x = x * x_mask
259
+ x = x.transpose(1, 2)
260
+ output_prob = self.output_layer(x)
261
+ return output_prob
262
+
263
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
264
+ x = torch.detach(x)
265
+ # if g is not None:
266
+ # g = torch.detach(g)
267
+ # x = x + self.cond(g)
268
+ x = self.conv_1(x * x_mask)
269
+ x = torch.relu(x)
270
+ x = self.norm_1(x)
271
+ x = self.conv_2(x * x_mask)
272
+ x = torch.relu(x)
273
+ x = self.norm_2(x)
274
+
275
+ output_probs = []
276
+ for dur in [dur_r, dur_hat]:
277
+ output_prob = self.forward_probability(x, x_mask, dur, g)
278
+ output_probs.append([output_prob])
279
+
280
+ return output_probs
281
+
282
+
283
+ class TextEncoder(nn.Module):
284
+ def __init__(self,
285
+ n_vocab,
286
+ out_channels,
287
+ hidden_channels,
288
+ filter_channels,
289
+ n_heads,
290
+ n_layers,
291
+ kernel_size,
292
+ p_dropout,
293
+ gin_channels=0):
294
+ super().__init__()
295
+ self.n_vocab = n_vocab
296
+ self.out_channels = out_channels
297
+ self.hidden_channels = hidden_channels
298
+ self.filter_channels = filter_channels
299
+ self.n_heads = n_heads
300
+ self.n_layers = n_layers
301
+ self.kernel_size = kernel_size
302
+ self.p_dropout = p_dropout
303
+ self.gin_channels = gin_channels
304
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
305
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
306
+
307
+ self.encoder = attentions.Encoder(
308
+ hidden_channels,
309
+ filter_channels,
310
+ n_heads,
311
+ n_layers,
312
+ kernel_size,
313
+ p_dropout,
314
+ gin_channels=self.gin_channels)
315
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
316
+
317
+ def forward(self, x, x_lengths, g=None):
318
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
319
+ x = torch.transpose(x, 1, -1) # [b, h, t]
320
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
321
+
322
+ x = self.encoder(x * x_mask, x_mask, g=g)
323
+ stats = self.proj(x) * x_mask
324
+
325
+ m, logs = torch.split(stats, self.out_channels, dim=1)
326
+ return x, m, logs, x_mask
327
+
328
+
329
+ class ResidualCouplingTransformersLayer2(nn.Module): # vits2
330
+ def __init__(
331
+ self,
332
+ channels,
333
+ hidden_channels,
334
+ kernel_size,
335
+ dilation_rate,
336
+ n_layers,
337
+ p_dropout=0,
338
+ gin_channels=0,
339
+ mean_only=False,
340
+ ):
341
+ assert channels % 2 == 0, "channels should be divisible by 2"
342
+ super().__init__()
343
+ self.channels = channels
344
+ self.hidden_channels = hidden_channels
345
+ self.kernel_size = kernel_size
346
+ self.dilation_rate = dilation_rate
347
+ self.n_layers = n_layers
348
+ self.half_channels = channels // 2
349
+ self.mean_only = mean_only
350
+
351
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
352
+ self.pre_transformer = attentions.Encoder(
353
+ hidden_channels,
354
+ hidden_channels,
355
+ n_heads=2,
356
+ n_layers=1,
357
+ kernel_size=kernel_size,
358
+ p_dropout=p_dropout,
359
+ # window_size=None,
360
+ )
361
+ self.enc = modules.WN(
362
+ hidden_channels,
363
+ kernel_size,
364
+ dilation_rate,
365
+ n_layers,
366
+ p_dropout=p_dropout,
367
+ gin_channels=gin_channels,
368
+ )
369
+
370
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
371
+ self.post.weight.data.zero_()
372
+ self.post.bias.data.zero_()
373
+
374
+ def forward(self, x, x_mask, g=None, reverse=False):
375
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
376
+ h = self.pre(x0) * x_mask
377
+ h = h + self.pre_transformer(h * x_mask, x_mask) # vits2 residual connection
378
+ h = self.enc(h, x_mask, g=g)
379
+ stats = self.post(h) * x_mask
380
+ if not self.mean_only:
381
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
382
+ else:
383
+ m = stats
384
+ logs = torch.zeros_like(m)
385
+ if not reverse:
386
+ x1 = m + x1 * torch.exp(logs) * x_mask
387
+ x = torch.cat([x0, x1], 1)
388
+ logdet = torch.sum(logs, [1, 2])
389
+ return x, logdet
390
+ else:
391
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
392
+ x = torch.cat([x0, x1], 1)
393
+ return x
394
+
395
+
396
+ class ResidualCouplingTransformersLayer(nn.Module): # vits2
397
+ def __init__(
398
+ self,
399
+ channels,
400
+ hidden_channels,
401
+ kernel_size,
402
+ dilation_rate,
403
+ n_layers,
404
+ p_dropout=0,
405
+ gin_channels=0,
406
+ mean_only=False,
407
+ ):
408
+ assert channels % 2 == 0, "channels should be divisible by 2"
409
+ super().__init__()
410
+ self.channels = channels
411
+ self.hidden_channels = hidden_channels
412
+ self.kernel_size = kernel_size
413
+ self.dilation_rate = dilation_rate
414
+ self.n_layers = n_layers
415
+ self.half_channels = channels // 2
416
+ self.mean_only = mean_only
417
+ # vits2
418
+ self.pre_transformer = attentions.Encoder(
419
+ self.half_channels,
420
+ self.half_channels,
421
+ n_heads=2,
422
+ n_layers=2,
423
+ kernel_size=3,
424
+ p_dropout=0.1,
425
+ window_size=None
426
+ )
427
+
428
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
429
+ self.enc = modules.WN(
430
+ hidden_channels,
431
+ kernel_size,
432
+ dilation_rate,
433
+ n_layers,
434
+ p_dropout=p_dropout,
435
+ gin_channels=gin_channels,
436
+ )
437
+ # vits2
438
+ self.post_transformer = attentions.Encoder(
439
+ self.hidden_channels,
440
+ self.hidden_channels,
441
+ n_heads=2,
442
+ n_layers=2,
443
+ kernel_size=3,
444
+ p_dropout=0.1,
445
+ window_size=None
446
+ )
447
+
448
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
449
+ self.post.weight.data.zero_()
450
+ self.post.bias.data.zero_()
451
+
452
+ def forward(self, x, x_mask, g=None, reverse=False):
453
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
454
+ x0_ = self.pre_transformer(x0 * x_mask, x_mask) # vits2
455
+ x0_ = x0_ + x0 # vits2 residual connection
456
+ h = self.pre(x0_) * x_mask # changed from x0 to x0_ to retain x0 for the flow
457
+ h = self.enc(h, x_mask, g=g)
458
+
459
+ # vits2 - (experimental;uncomment the following 2 line to use)
460
+ # h_ = self.post_transformer(h, x_mask)
461
+ # h = h + h_ #vits2 residual connection
462
+
463
+ stats = self.post(h) * x_mask
464
+ if not self.mean_only:
465
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
466
+ else:
467
+ m = stats
468
+ logs = torch.zeros_like(m)
469
+ if not reverse:
470
+ x1 = m + x1 * torch.exp(logs) * x_mask
471
+ x = torch.cat([x0, x1], 1)
472
+ logdet = torch.sum(logs, [1, 2])
473
+ return x, logdet
474
+ else:
475
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
476
+ x = torch.cat([x0, x1], 1)
477
+ return x
478
+
479
+ def remove_weight_norm(self): # !
480
+ self.enc.remove_weight_norm()
481
+
482
+
483
+ class FFTransformerCouplingLayer(nn.Module): # vits2
484
+ def __init__(self,
485
+ channels,
486
+ hidden_channels,
487
+ kernel_size,
488
+ n_layers,
489
+ n_heads,
490
+ p_dropout=0,
491
+ filter_channels=768,
492
+ mean_only=False,
493
+ gin_channels=0
494
+ ):
495
+ assert channels % 2 == 0, "channels should be divisible by 2"
496
+ super().__init__()
497
+ self.channels = channels
498
+ self.hidden_channels = hidden_channels
499
+ self.kernel_size = kernel_size
500
+ self.n_layers = n_layers
501
+ self.half_channels = channels // 2
502
+ self.mean_only = mean_only
503
+
504
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
505
+ self.enc = attentions.FFT(
506
+ hidden_channels,
507
+ filter_channels,
508
+ n_heads,
509
+ n_layers,
510
+ kernel_size,
511
+ p_dropout,
512
+ isflow=True,
513
+ gin_channels=gin_channels
514
+ )
515
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
516
+ self.post.weight.data.zero_()
517
+ self.post.bias.data.zero_()
518
+
519
+ def forward(self, x, x_mask, g=None, reverse=False):
520
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
521
+ h = self.pre(x0) * x_mask
522
+ h_ = self.enc(h, x_mask, g=g)
523
+ h = h_ + h
524
+ stats = self.post(h) * x_mask
525
+ if not self.mean_only:
526
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
527
+ else:
528
+ m = stats
529
+ logs = torch.zeros_like(m)
530
+
531
+ if not reverse:
532
+ x1 = m + x1 * torch.exp(logs) * x_mask
533
+ x = torch.cat([x0, x1], 1)
534
+ logdet = torch.sum(logs, [1, 2])
535
+ return x, logdet
536
+ else:
537
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
538
+ x = torch.cat([x0, x1], 1)
539
+ return x
540
+
541
+
542
+ class MonoTransformerFlowLayer(nn.Module): # vits2
543
+ def __init__(
544
+ self,
545
+ channels,
546
+ hidden_channels,
547
+ mean_only=False,
548
+ residual_connection=False,
549
+ # according to VITS-2 paper fig 1B set residual_connection=True
550
+ ):
551
+ assert channels % 2 == 0, "channels should be divisible by 2"
552
+ super().__init__()
553
+ self.channels = channels
554
+ self.hidden_channels = hidden_channels
555
+ self.half_channels = channels // 2
556
+ self.mean_only = mean_only
557
+ self.residual_connection = residual_connection
558
+ # vits2
559
+ self.pre_transformer = attentions.Encoder(
560
+ self.half_channels,
561
+ self.half_channels,
562
+ n_heads=2,
563
+ n_layers=2,
564
+ kernel_size=3,
565
+ p_dropout=0.1,
566
+ window_size=None
567
+ )
568
+
569
+ self.post = nn.Conv1d(self.half_channels, self.half_channels * (2 - mean_only), 1)
570
+ self.post.weight.data.zero_()
571
+ self.post.bias.data.zero_()
572
+
573
+ def forward(self, x, x_mask, g=None, reverse=False):
574
+ if self.residual_connection:
575
+ if not reverse:
576
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
577
+ x0_ = x0 * x_mask
578
+ x0_ = self.pre_transformer(x0, x_mask) # vits2
579
+ stats = self.post(x0_) * x_mask
580
+ if not self.mean_only:
581
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
582
+ else:
583
+ m = stats
584
+ logs = torch.zeros_like(m)
585
+ x1 = m + x1 * torch.exp(logs) * x_mask
586
+ x_ = torch.cat([x0, x1], 1)
587
+ x = x + x_
588
+ logdet = torch.sum(torch.log(torch.exp(logs) + 1), [1, 2])
589
+ logdet = logdet + torch.log(torch.tensor(2)) * (x0.shape[1] * x0.shape[2])
590
+ return x, logdet
591
+
592
+ else:
593
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
594
+ x0 = x0 / 2
595
+ x0_ = x0 * x_mask
596
+ x0_ = self.pre_transformer(x0, x_mask) # vits2
597
+ stats = self.post(x0_) * x_mask
598
+ if not self.mean_only:
599
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
600
+ else:
601
+ m = stats
602
+ logs = torch.zeros_like(m)
603
+ x1_ = ((x1 - m) / (1 + torch.exp(-logs))) * x_mask
604
+ x = torch.cat([x0, x1_], 1)
605
+ return x
606
+ else:
607
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
608
+ x0_ = self.pre_transformer(x0 * x_mask, x_mask) # vits2
609
+ h = x0_ + x0 # vits2
610
+ stats = self.post(h) * x_mask
611
+ if not self.mean_only:
612
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
613
+ else:
614
+ m = stats
615
+ logs = torch.zeros_like(m)
616
+ if not reverse:
617
+ x1 = m + x1 * torch.exp(logs) * x_mask
618
+ x = torch.cat([x0, x1], 1)
619
+ logdet = torch.sum(logs, [1, 2])
620
+ return x, logdet
621
+ else:
622
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
623
+ x = torch.cat([x0, x1], 1)
624
+ return x
625
+
626
+
627
+ class ResidualCouplingTransformersBlock(nn.Module): # vits2
628
+ def __init__(self,
629
+ channels,
630
+ hidden_channels,
631
+ kernel_size,
632
+ dilation_rate,
633
+ n_layers,
634
+ n_flows=4,
635
+ gin_channels=0,
636
+ use_transformer_flows=False,
637
+ transformer_flow_type="pre_conv",
638
+ ):
639
+ super().__init__()
640
+ self.channels = channels
641
+ self.hidden_channels = hidden_channels
642
+ self.kernel_size = kernel_size
643
+ self.dilation_rate = dilation_rate
644
+ self.n_layers = n_layers
645
+ self.n_flows = n_flows
646
+ self.gin_channels = gin_channels
647
+
648
+ self.flows = nn.ModuleList()
649
+ # TODO : clean up this mess
650
+ if use_transformer_flows:
651
+ if transformer_flow_type == "pre_conv":
652
+ for i in range(n_flows):
653
+ self.flows.append(
654
+ ResidualCouplingTransformersLayer(
655
+ channels,
656
+ hidden_channels,
657
+ kernel_size,
658
+ dilation_rate,
659
+ n_layers,
660
+ gin_channels=gin_channels,
661
+ mean_only=True
662
+ )
663
+ )
664
+ self.flows.append(modules.Flip())
665
+ elif transformer_flow_type == "pre_conv2":
666
+ for i in range(n_flows):
667
+ self.flows.append(
668
+ ResidualCouplingTransformersLayer2(
669
+ channels,
670
+ hidden_channels,
671
+ kernel_size,
672
+ dilation_rate,
673
+ n_layers,
674
+ gin_channels=gin_channels,
675
+ mean_only=True,
676
+ )
677
+ )
678
+ self.flows.append(modules.Flip())
679
+ elif transformer_flow_type == "fft":
680
+ for i in range(n_flows):
681
+ self.flows.append(
682
+ FFTransformerCouplingLayer(
683
+ channels,
684
+ hidden_channels,
685
+ kernel_size,
686
+ dilation_rate,
687
+ n_layers,
688
+ gin_channels=gin_channels,
689
+ mean_only=True
690
+ )
691
+ )
692
+ self.flows.append(modules.Flip())
693
+ elif transformer_flow_type == "mono_layer_inter_residual":
694
+ for i in range(n_flows):
695
+ self.flows.append(
696
+ modules.ResidualCouplingLayer(
697
+ channels,
698
+ hidden_channels,
699
+ kernel_size,
700
+ dilation_rate,
701
+ n_layers,
702
+ gin_channels=gin_channels,
703
+ mean_only=True
704
+ )
705
+ )
706
+ self.flows.append(modules.Flip())
707
+ self.flows.append(
708
+ MonoTransformerFlowLayer(
709
+ channels, hidden_channels, mean_only=True
710
+ )
711
+ )
712
+ elif transformer_flow_type == "mono_layer_post_residual":
713
+ for i in range(n_flows):
714
+ self.flows.append(
715
+ modules.ResidualCouplingLayer(
716
+ channels,
717
+ hidden_channels,
718
+ kernel_size,
719
+ dilation_rate,
720
+ n_layers,
721
+ gin_channels=gin_channels,
722
+ mean_only=True,
723
+ )
724
+ )
725
+ self.flows.append(modules.Flip())
726
+ self.flows.append(
727
+ MonoTransformerFlowLayer(
728
+ channels, hidden_channels, mean_only=True,
729
+ residual_connection=True
730
+ )
731
+ )
732
+ else:
733
+ for i in range(n_flows):
734
+ self.flows.append(
735
+ modules.ResidualCouplingLayer(
736
+ channels,
737
+ hidden_channels,
738
+ kernel_size,
739
+ dilation_rate,
740
+ n_layers,
741
+ gin_channels=gin_channels,
742
+ mean_only=True
743
+ )
744
+ )
745
+ self.flows.append(modules.Flip())
746
+
747
+ def forward(self, x, x_mask, g=None, reverse=False):
748
+ if not reverse:
749
+ for flow in self.flows:
750
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
751
+ else:
752
+ for flow in reversed(self.flows):
753
+ x = flow(x, x_mask, g=g, reverse=reverse)
754
+ return x
755
+
756
+ def remove_weight_norm(self): # !
757
+ for i, l in enumerate(self.flows):
758
+ if i % 2 == 0:
759
+ l.remove_weight_norm()
760
+
761
+
762
+ class ResidualCouplingBlock(nn.Module):
763
+ def __init__(self,
764
+ channels,
765
+ hidden_channels,
766
+ kernel_size,
767
+ dilation_rate,
768
+ n_layers,
769
+ n_flows=4,
770
+ gin_channels=0):
771
+ super().__init__()
772
+ self.channels = channels
773
+ self.hidden_channels = hidden_channels
774
+ self.kernel_size = kernel_size
775
+ self.dilation_rate = dilation_rate
776
+ self.n_layers = n_layers
777
+ self.n_flows = n_flows
778
+ self.gin_channels = gin_channels
779
+
780
+ self.flows = nn.ModuleList()
781
+ for i in range(n_flows):
782
+ self.flows.append(
783
+ modules.ResidualCouplingLayer(
784
+ channels,
785
+ hidden_channels,
786
+ kernel_size,
787
+ dilation_rate,
788
+ n_layers,
789
+ gin_channels=gin_channels,
790
+ mean_only=True
791
+ )
792
+ )
793
+ self.flows.append(modules.Flip())
794
+
795
+ def forward(self, x, x_mask, g=None, reverse=False):
796
+ if not reverse:
797
+ for flow in self.flows:
798
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
799
+ else:
800
+ for flow in reversed(self.flows):
801
+ x = flow(x, x_mask, g=g, reverse=reverse)
802
+ return x
803
+
804
+ def remove_weight_norm(self): # !
805
+ for i, l in enumerate(self.flows):
806
+ if i % 2 == 0:
807
+ l.remove_weight_norm()
808
+
809
+
810
+ class PosteriorEncoder(nn.Module):
811
+ def __init__(self,
812
+ in_channels,
813
+ out_channels,
814
+ hidden_channels,
815
+ kernel_size,
816
+ dilation_rate,
817
+ n_layers,
818
+ gin_channels=0):
819
+ super().__init__()
820
+ self.in_channels = in_channels
821
+ self.out_channels = out_channels
822
+ self.hidden_channels = hidden_channels
823
+ self.kernel_size = kernel_size
824
+ self.dilation_rate = dilation_rate
825
+ self.n_layers = n_layers
826
+ self.gin_channels = gin_channels
827
+
828
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
829
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
830
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
831
+
832
+ def forward(self, x, x_lengths, g=None):
833
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
834
+ x = self.pre(x) * x_mask
835
+ x = self.enc(x, x_mask, g=g)
836
+ stats = self.proj(x) * x_mask
837
+ m, logs = torch.split(stats, self.out_channels, dim=1)
838
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
839
+ return z, m, logs, x_mask
840
+
841
+
842
+ class Generator(torch.nn.Module):
843
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
844
+ upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
845
+ super(Generator, self).__init__()
846
+ self.num_kernels = len(resblock_kernel_sizes)
847
+ self.num_upsamples = len(upsample_rates)
848
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
849
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
850
+
851
+ self.ups = nn.ModuleList()
852
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
853
+ self.ups.append(weight_norm(
854
+ ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
855
+ k, u, padding=(k - u) // 2)))
856
+
857
+ self.resblocks = nn.ModuleList()
858
+ for i in range(len(self.ups)):
859
+ ch = upsample_initial_channel // (2 ** (i + 1))
860
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
861
+ self.resblocks.append(resblock(ch, k, d))
862
+
863
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
864
+ self.ups.apply(init_weights)
865
+
866
+ if gin_channels != 0:
867
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
868
+
869
+ def forward(self, x, g=None):
870
+ x = self.conv_pre(x)
871
+ if g is not None:
872
+ x = x + self.cond(g)
873
+
874
+ for i in range(self.num_upsamples):
875
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
876
+ x = self.ups[i](x)
877
+ xs = None
878
+ for j in range(self.num_kernels):
879
+ if xs is None:
880
+ xs = self.resblocks[i * self.num_kernels + j](x)
881
+ else:
882
+ xs += self.resblocks[i * self.num_kernels + j](x)
883
+ x = xs / self.num_kernels
884
+ x = F.leaky_relu(x)
885
+ x = self.conv_post(x)
886
+ x = torch.tanh(x)
887
+
888
+ return x
889
+
890
+ def remove_weight_norm(self):
891
+ print('Removing weight norm...')
892
+ for l in self.ups:
893
+ remove_weight_norm(l)
894
+ for l in self.resblocks:
895
+ l.remove_weight_norm()
896
+
897
+
898
+ class iSTFT_Generator(torch.nn.Module):
899
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
900
+ upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size,
901
+ gin_channels=0, is_onnx=False):
902
+ super(iSTFT_Generator, self).__init__()
903
+ # self.h = h
904
+ self.gen_istft_n_fft = gen_istft_n_fft
905
+ self.gen_istft_hop_size = gen_istft_hop_size
906
+
907
+ self.num_kernels = len(resblock_kernel_sizes)
908
+ self.num_upsamples = len(upsample_rates)
909
+ self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3))
910
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
911
+
912
+ self.ups = nn.ModuleList()
913
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
914
+ self.ups.append(weight_norm(
915
+ ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
916
+ k, u, padding=(k - u) // 2)))
917
+
918
+ self.resblocks = nn.ModuleList()
919
+ for i in range(len(self.ups)):
920
+ ch = upsample_initial_channel // (2 ** (i + 1))
921
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
922
+ self.resblocks.append(resblock(ch, k, d))
923
+
924
+ self.post_n_fft = self.gen_istft_n_fft
925
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
926
+ self.ups.apply(init_weights)
927
+ self.conv_post.apply(init_weights)
928
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
929
+ '''
930
+ self.stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size,
931
+ win_length=self.gen_istft_n_fft)
932
+ '''
933
+ # - for onnx
934
+ if is_onnx == True:
935
+ self.stft = OnnxSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size,
936
+ win_length=self.gen_istft_n_fft)
937
+ else:
938
+ self.stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size,
939
+ win_length=self.gen_istft_n_fft)
940
+
941
+ def forward(self, x, g=None):
942
+ x = self.conv_pre(x)
943
+ for i in range(self.num_upsamples):
944
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
945
+ x = self.ups[i](x)
946
+ xs = None
947
+ for j in range(self.num_kernels):
948
+ if xs is None:
949
+ xs = self.resblocks[i * self.num_kernels + j](x)
950
+ else:
951
+ xs += self.resblocks[i * self.num_kernels + j](x)
952
+ x = xs / self.num_kernels
953
+ x = F.leaky_relu(x)
954
+ x = self.reflection_pad(x)
955
+ x = self.conv_post(x)
956
+ spec = torch.exp(x[:, :self.post_n_fft // 2 + 1, :])
957
+ phase = math.pi * torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
958
+ out = self.stft.inverse(spec, phase).to(x.device)
959
+ return out, None
960
+
961
+ def remove_weight_norm(self):
962
+ print('Removing weight norm...')
963
+ for l in self.ups:
964
+ remove_weight_norm(l)
965
+ for l in self.resblocks:
966
+ l.remove_weight_norm()
967
+ remove_weight_norm(self.conv_pre)
968
+ remove_weight_norm(self.conv_post)
969
+
970
+
971
+ class Multiband_iSTFT_Generator(torch.nn.Module): # !
972
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
973
+ upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, subbands,
974
+ gin_channels=0, is_onnx=False):
975
+ super(Multiband_iSTFT_Generator, self).__init__()
976
+ # self.h = h
977
+ self.subbands = subbands
978
+ self.num_kernels = len(resblock_kernel_sizes)
979
+ self.num_upsamples = len(upsample_rates)
980
+ self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3))
981
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
982
+
983
+ self.ups = nn.ModuleList()
984
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
985
+ self.ups.append(weight_norm(
986
+ ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
987
+ k, u, padding=(k - u) // 2)))
988
+
989
+ self.resblocks = nn.ModuleList()
990
+ for i in range(len(self.ups)):
991
+ ch = upsample_initial_channel // (2 ** (i + 1))
992
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
993
+ self.resblocks.append(resblock(ch, k, d))
994
+
995
+ self.post_n_fft = gen_istft_n_fft
996
+ self.ups.apply(init_weights)
997
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
998
+ self.reshape_pixelshuffle = []
999
+
1000
+ self.subband_conv_post = weight_norm(Conv1d(ch, self.subbands * (self.post_n_fft + 2), 7, 1, padding=3))
1001
+
1002
+ self.subband_conv_post.apply(init_weights)
1003
+
1004
+ self.gen_istft_n_fft = gen_istft_n_fft
1005
+ self.gen_istft_hop_size = gen_istft_hop_size
1006
+
1007
+ #- for onnx
1008
+ if is_onnx == True:
1009
+ self.stft = OnnxSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft)
1010
+ else:
1011
+ self.stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft)
1012
+
1013
+ def forward(self, x, g=None):
1014
+ '''
1015
+ stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size,
1016
+ win_length=self.gen_istft_n_fft).to(x.device) # !
1017
+ '''
1018
+ stft = self.stft.to(x.device)
1019
+ pqmf = PQMF(x.device)
1020
+
1021
+ x = self.conv_pre(x) # [B, ch, length]
1022
+
1023
+ for i in range(self.num_upsamples):
1024
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1025
+ x = self.ups[i](x)
1026
+
1027
+ xs = None
1028
+ for j in range(self.num_kernels):
1029
+ if xs is None:
1030
+ xs = self.resblocks[i * self.num_kernels + j](x)
1031
+ else:
1032
+ xs += self.resblocks[i * self.num_kernels + j](x)
1033
+ x = xs / self.num_kernels
1034
+
1035
+ x = F.leaky_relu(x)
1036
+ x = self.reflection_pad(x)
1037
+ x = self.subband_conv_post(x)
1038
+ x = torch.reshape(x, (x.shape[0], self.subbands, x.shape[1] // self.subbands, x.shape[-1]))
1039
+
1040
+ spec = torch.exp(x[:, :, :self.post_n_fft // 2 + 1, :])
1041
+ phase = math.pi * torch.sin(x[:, :, self.post_n_fft // 2 + 1:, :])
1042
+
1043
+ y_mb_hat = stft.inverse(
1044
+ torch.reshape(spec, (spec.shape[0] * self.subbands, self.gen_istft_n_fft // 2 + 1, spec.shape[-1])),
1045
+ torch.reshape(phase, (phase.shape[0] * self.subbands, self.gen_istft_n_fft // 2 + 1, phase.shape[-1])))
1046
+ y_mb_hat = torch.reshape(y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1]))
1047
+ y_mb_hat = y_mb_hat.squeeze(-2)
1048
+
1049
+ y_g_hat = pqmf.synthesis(y_mb_hat)
1050
+
1051
+ return y_g_hat, y_mb_hat
1052
+
1053
+ def remove_weight_norm(self):
1054
+ print('Removing weight norm...')
1055
+ for l in self.ups:
1056
+ remove_weight_norm(l)
1057
+ for l in self.resblocks:
1058
+ l.remove_weight_norm()
1059
+
1060
+
1061
+ class Multistream_iSTFT_Generator(torch.nn.Module):
1062
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
1063
+ upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, subbands,
1064
+ gin_channels=0, is_onnx=False):
1065
+ super(Multistream_iSTFT_Generator, self).__init__()
1066
+ # self.h = h
1067
+ self.subbands = subbands
1068
+ self.num_kernels = len(resblock_kernel_sizes)
1069
+ self.num_upsamples = len(upsample_rates)
1070
+ self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3))
1071
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
1072
+
1073
+ self.ups = nn.ModuleList()
1074
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
1075
+ self.ups.append(weight_norm(
1076
+ ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
1077
+ k, u, padding=(k - u) // 2)))
1078
+
1079
+ self.resblocks = nn.ModuleList()
1080
+ for i in range(len(self.ups)):
1081
+ ch = upsample_initial_channel // (2 ** (i + 1))
1082
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
1083
+ self.resblocks.append(resblock(ch, k, d))
1084
+
1085
+ self.post_n_fft = gen_istft_n_fft
1086
+ self.ups.apply(init_weights)
1087
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
1088
+ self.reshape_pixelshuffle = []
1089
+
1090
+ self.subband_conv_post = weight_norm(Conv1d(ch, self.subbands * (self.post_n_fft + 2), 7, 1, padding=3))
1091
+
1092
+ self.subband_conv_post.apply(init_weights)
1093
+
1094
+ self.gen_istft_n_fft = gen_istft_n_fft
1095
+ self.gen_istft_hop_size = gen_istft_hop_size
1096
+
1097
+ updown_filter = torch.zeros((self.subbands, self.subbands, self.subbands)).float()
1098
+ for k in range(self.subbands):
1099
+ updown_filter[k, k, 0] = 1.0
1100
+ self.register_buffer("updown_filter", updown_filter)
1101
+ #self.multistream_conv_post = weight_norm(Conv1d(4, 1, kernel_size=63, bias=False, padding=get_padding(63, 1)))
1102
+ self.multistream_conv_post = weight_norm(Conv1d(self.subbands, 1, kernel_size=63, bias=False, padding=get_padding(63, 1))) # from MB-iSTFT-VITS-44100-Ja
1103
+ self.multistream_conv_post.apply(init_weights)
1104
+
1105
+ #- for onnx
1106
+ if is_onnx == True:
1107
+ self.stft = OnnxSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft)
1108
+ else:
1109
+ self.stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft)
1110
+
1111
+ def forward(self, x, g=None):
1112
+ '''
1113
+ stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size,
1114
+ win_length=self.gen_istft_n_fft).to(x.device) # !
1115
+ '''
1116
+ stft = self.stft.to(x.device)
1117
+
1118
+ # pqmf = PQMF(x.device)
1119
+
1120
+ x = self.conv_pre(x) # [B, ch, length]
1121
+
1122
+ for i in range(self.num_upsamples):
1123
+
1124
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1125
+ x = self.ups[i](x)
1126
+
1127
+ xs = None
1128
+ for j in range(self.num_kernels):
1129
+ if xs is None:
1130
+ xs = self.resblocks[i * self.num_kernels + j](x)
1131
+ else:
1132
+ xs += self.resblocks[i * self.num_kernels + j](x)
1133
+ x = xs / self.num_kernels
1134
+
1135
+ x = F.leaky_relu(x)
1136
+ x = self.reflection_pad(x)
1137
+ x = self.subband_conv_post(x)
1138
+ x = torch.reshape(x, (x.shape[0], self.subbands, x.shape[1] // self.subbands, x.shape[-1]))
1139
+
1140
+ spec = torch.exp(x[:, :, :self.post_n_fft // 2 + 1, :])
1141
+ phase = math.pi * torch.sin(x[:, :, self.post_n_fft // 2 + 1:, :])
1142
+
1143
+ y_mb_hat = stft.inverse(
1144
+ torch.reshape(spec, (spec.shape[0] * self.subbands, self.gen_istft_n_fft // 2 + 1, spec.shape[-1])),
1145
+ torch.reshape(phase, (phase.shape[0] * self.subbands, self.gen_istft_n_fft // 2 + 1, phase.shape[-1])))
1146
+ y_mb_hat = torch.reshape(y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1]))
1147
+ y_mb_hat = y_mb_hat.squeeze(-2)
1148
+
1149
+ #y_mb_hat = F.conv_transpose1d(y_mb_hat, self.updown_filter.cuda(x.device) * self.subbands, stride=self.subbands)
1150
+ y_mb_hat = F.conv_transpose1d(y_mb_hat, self.updown_filter.to(x.device) * self.subbands, stride=self.subbands)
1151
+
1152
+ y_g_hat = self.multistream_conv_post(y_mb_hat)
1153
+
1154
+ return y_g_hat, y_mb_hat
1155
+
1156
+ def remove_weight_norm(self):
1157
+ print('Removing weight norm...')
1158
+ for l in self.ups:
1159
+ remove_weight_norm(l)
1160
+ for l in self.resblocks:
1161
+ l.remove_weight_norm()
1162
+
1163
+
1164
+ class DiscriminatorP(torch.nn.Module):
1165
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1166
+ super(DiscriminatorP, self).__init__()
1167
+ self.period = period
1168
+ self.use_spectral_norm = use_spectral_norm
1169
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1170
+ self.convs = nn.ModuleList([
1171
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
1172
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
1173
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
1174
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
1175
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
1176
+ ])
1177
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1178
+
1179
+ def forward(self, x):
1180
+ fmap = []
1181
+
1182
+ # 1d to 2d
1183
+ b, c, t = x.shape
1184
+ if t % self.period != 0: # pad first
1185
+ n_pad = self.period - (t % self.period)
1186
+ x = F.pad(x, (0, n_pad), "reflect")
1187
+ t = t + n_pad
1188
+ x = x.view(b, c, t // self.period, self.period)
1189
+
1190
+ for l in self.convs:
1191
+ x = l(x)
1192
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1193
+ fmap.append(x)
1194
+ x = self.conv_post(x)
1195
+ fmap.append(x)
1196
+ x = torch.flatten(x, 1, -1)
1197
+
1198
+ return x, fmap
1199
+
1200
+
1201
+ class DiscriminatorS(torch.nn.Module):
1202
+ def __init__(self, use_spectral_norm=False):
1203
+ super(DiscriminatorS, self).__init__()
1204
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1205
+ self.convs = nn.ModuleList([
1206
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1207
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1208
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1209
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1210
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1211
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1212
+ ])
1213
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1214
+
1215
+ def forward(self, x):
1216
+ fmap = []
1217
+
1218
+ for l in self.convs:
1219
+ x = l(x)
1220
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1221
+ fmap.append(x)
1222
+ x = self.conv_post(x)
1223
+ fmap.append(x)
1224
+ x = torch.flatten(x, 1, -1)
1225
+
1226
+ return x, fmap
1227
+
1228
+
1229
+ class MultiPeriodDiscriminator(torch.nn.Module):
1230
+ def __init__(self, use_spectral_norm=False):
1231
+ super(MultiPeriodDiscriminator, self).__init__()
1232
+ periods = [2, 3, 5, 7, 11]
1233
+
1234
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
1235
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
1236
+ self.discriminators = nn.ModuleList(discs)
1237
+
1238
+ def forward(self, y, y_hat):
1239
+ y_d_rs = []
1240
+ y_d_gs = []
1241
+ fmap_rs = []
1242
+ fmap_gs = []
1243
+ for i, d in enumerate(self.discriminators):
1244
+ y_d_r, fmap_r = d(y)
1245
+ y_d_g, fmap_g = d(y_hat)
1246
+ y_d_rs.append(y_d_r)
1247
+ y_d_gs.append(y_d_g)
1248
+ fmap_rs.append(fmap_r)
1249
+ fmap_gs.append(fmap_g)
1250
+
1251
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1252
+
1253
+
1254
+ class SynthesizerTrn(nn.Module):
1255
+ """
1256
+ Synthesizer for Training
1257
+ """
1258
+
1259
+ def __init__(self,
1260
+ n_vocab,
1261
+ spec_channels,
1262
+ segment_size,
1263
+ inter_channels,
1264
+ hidden_channels,
1265
+ filter_channels,
1266
+ n_heads,
1267
+ n_layers,
1268
+ kernel_size,
1269
+ p_dropout,
1270
+ resblock,
1271
+ resblock_kernel_sizes,
1272
+ resblock_dilation_sizes,
1273
+ upsample_rates,
1274
+ upsample_initial_channel,
1275
+ upsample_kernel_sizes,
1276
+ gen_istft_n_fft,
1277
+ gen_istft_hop_size,
1278
+ n_speakers=0,
1279
+ gin_channels=0,
1280
+ use_sdp=True,
1281
+ ms_istft_vits=False,
1282
+ mb_istft_vits=False,
1283
+ subbands=False,
1284
+ istft_vits=False,
1285
+ is_onnx=False,
1286
+ **kwargs):
1287
+
1288
+ super().__init__()
1289
+ self.n_vocab = n_vocab
1290
+ self.spec_channels = spec_channels
1291
+ self.inter_channels = inter_channels
1292
+ self.hidden_channels = hidden_channels
1293
+ self.filter_channels = filter_channels
1294
+ self.n_heads = n_heads
1295
+ self.n_layers = n_layers
1296
+ self.kernel_size = kernel_size
1297
+ self.p_dropout = p_dropout
1298
+ self.resblock = resblock
1299
+ self.resblock_kernel_sizes = resblock_kernel_sizes
1300
+ self.resblock_dilation_sizes = resblock_dilation_sizes
1301
+ self.upsample_rates = upsample_rates
1302
+ self.upsample_initial_channel = upsample_initial_channel
1303
+ self.upsample_kernel_sizes = upsample_kernel_sizes
1304
+ self.segment_size = segment_size
1305
+ self.n_speakers = n_speakers
1306
+ self.gin_channels = gin_channels
1307
+ self.ms_istft_vits = ms_istft_vits
1308
+ self.mb_istft_vits = mb_istft_vits
1309
+ self.istft_vits = istft_vits
1310
+ self.use_spk_conditioned_encoder = kwargs.get("use_spk_conditioned_encoder", False)
1311
+ self.use_transformer_flows = kwargs.get("use_transformer_flows", False)
1312
+ self.transformer_flow_type = kwargs.get("transformer_flow_type", "mono_layer_post_residual")
1313
+ if self.use_transformer_flows:
1314
+ assert self.transformer_flow_type in AVAILABLE_FLOW_TYPES, f"transformer_flow_type must be one of {AVAILABLE_FLOW_TYPES}"
1315
+ self.use_sdp = use_sdp
1316
+ # self.use_duration_discriminator = kwargs.get("use_duration_discriminator", False)
1317
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
1318
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
1319
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
1320
+
1321
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
1322
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
1323
+ self.enc_gin_channels = gin_channels
1324
+ else:
1325
+ self.enc_gin_channels = 0
1326
+ self.enc_p = TextEncoder(n_vocab,
1327
+ inter_channels,
1328
+ hidden_channels,
1329
+ filter_channels,
1330
+ n_heads,
1331
+ n_layers,
1332
+ kernel_size,
1333
+ p_dropout,
1334
+ gin_channels=self.enc_gin_channels)
1335
+
1336
+ if mb_istft_vits == True:
1337
+ print('Multi-band iSTFT VITS2')
1338
+ self.dec = Multiband_iSTFT_Generator(inter_channels, resblock, resblock_kernel_sizes,
1339
+ resblock_dilation_sizes,
1340
+ upsample_rates, upsample_initial_channel, upsample_kernel_sizes,
1341
+ gen_istft_n_fft, gen_istft_hop_size, subbands,
1342
+ gin_channels=gin_channels, is_onnx=is_onnx)
1343
+ elif ms_istft_vits == True:
1344
+ print('Multi-stream iSTFT VITS2')
1345
+ self.dec = Multistream_iSTFT_Generator(inter_channels, resblock, resblock_kernel_sizes,
1346
+ resblock_dilation_sizes,
1347
+ upsample_rates, upsample_initial_channel, upsample_kernel_sizes,
1348
+ gen_istft_n_fft, gen_istft_hop_size, subbands,
1349
+ gin_channels=gin_channels, is_onnx=is_onnx)
1350
+ elif istft_vits == True:
1351
+ print('iSTFT-VITS2')
1352
+ self.dec = iSTFT_Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes,
1353
+ upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft,
1354
+ gen_istft_hop_size, gin_channels=gin_channels, is_onnx=is_onnx)
1355
+ else:
1356
+ print('No iSTFT arguments found in json file')
1357
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes,
1358
+ upsample_rates,
1359
+ upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) # vits 2
1360
+
1361
+ self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
1362
+ gin_channels=gin_channels)
1363
+ # self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
1364
+ self.flow = ResidualCouplingTransformersBlock(
1365
+ inter_channels,
1366
+ hidden_channels,
1367
+ 5,
1368
+ 1,
1369
+ 4,
1370
+ gin_channels=gin_channels,
1371
+ use_transformer_flows=self.use_transformer_flows,
1372
+ transformer_flow_type=self.transformer_flow_type
1373
+ )
1374
+
1375
+ if use_sdp:
1376
+ self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
1377
+ else:
1378
+ self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
1379
+
1380
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
1381
+
1382
+ def forward(self, x, x_lengths, y, y_lengths, sid=None):
1383
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
1384
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1385
+
1386
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, g=g) # vits2?
1387
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
1388
+ z_p = self.flow(z, y_mask, g=g)
1389
+
1390
+ with torch.no_grad():
1391
+ # negative cross-entropy
1392
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
1393
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
1394
+ neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2),
1395
+ s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1396
+ neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1397
+ neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
1398
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
1399
+
1400
+ if self.use_noise_scaled_mas:
1401
+ epsilon = torch.std(neg_cent) * torch.randn_like(neg_cent) * self.current_mas_noise_scale
1402
+ neg_cent = neg_cent + epsilon
1403
+
1404
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1405
+ attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
1406
+
1407
+ w = attn.sum(2)
1408
+ if self.use_sdp:
1409
+ l_length = self.dp(x, x_mask, w, g=g)
1410
+ l_length = l_length / torch.sum(x_mask)
1411
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=1.)
1412
+ logw_ = torch.log(w + 1e-6) * x_mask
1413
+ else:
1414
+ logw_ = torch.log(w + 1e-6) * x_mask
1415
+ logw = self.dp(x, x_mask, g=g)
1416
+ l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # for averaging
1417
+
1418
+ # expand prior
1419
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1420
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1421
+
1422
+ z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
1423
+ o, o_mb = self.dec(z_slice, g=g)
1424
+ return o, o_mb, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), (x, logw, logw_)
1425
+
1426
+ def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
1427
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1428
+
1429
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, g=g)
1430
+ if self.use_sdp:
1431
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
1432
+ else:
1433
+ logw = self.dp(x, x_mask, g=g)
1434
+ w = torch.exp(logw) * x_mask * length_scale
1435
+ w_ceil = torch.ceil(w)
1436
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1437
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
1438
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1439
+ attn = commons.generate_path(w_ceil, attn_mask)
1440
+
1441
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
1442
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1,
1443
+ 2) # [b, t', t], [b, t, d] -> [b, d, t']
1444
+
1445
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1446
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1447
+
1448
+ o, o_mb = self.dec((z * y_mask)[:, :, :max_len], g=g)
1449
+ return o, o_mb, attn, y_mask, (z, z_p, m_p, logs_p)
1450
+
1451
+
1452
+ #'''
1453
+ ## currently vits-2 is not capable of voice conversion
1454
+ # comment - choihkk : Assuming the use of the ResidualCouplingTransformersLayer2 module, it seems that voice conversion is possible
1455
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
1456
+ assert self.n_speakers > 0, "n_speakers have to be larger than 0."
1457
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
1458
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
1459
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
1460
+ z_p = self.flow(z, y_mask, g=g_src)
1461
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
1462
+ o_hat, o_hat_mb = self.dec(z_hat * y_mask, g=g_tgt)
1463
+ return o_hat, o_hat_mb, y_mask, (z, z_p, z_hat)
1464
+ #'''
modules.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+ from transforms import piecewise_rational_quadratic_transform
15
+
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
37
+ super().__init__()
38
+ self.in_channels = in_channels
39
+ self.hidden_channels = hidden_channels
40
+ self.out_channels = out_channels
41
+ self.kernel_size = kernel_size
42
+ self.n_layers = n_layers
43
+ self.p_dropout = p_dropout
44
+ assert n_layers > 1, "Number of layers should be larger than 0."
45
+
46
+ self.conv_layers = nn.ModuleList()
47
+ self.norm_layers = nn.ModuleList()
48
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
49
+ self.norm_layers.append(LayerNorm(hidden_channels))
50
+ self.relu_drop = nn.Sequential(
51
+ nn.ReLU(),
52
+ nn.Dropout(p_dropout))
53
+ for _ in range(n_layers-1):
54
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
55
+ self.norm_layers.append(LayerNorm(hidden_channels))
56
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
57
+ self.proj.weight.data.zero_()
58
+ self.proj.bias.data.zero_()
59
+
60
+ def forward(self, x, x_mask):
61
+ x_org = x
62
+ for i in range(self.n_layers):
63
+ x = self.conv_layers[i](x * x_mask)
64
+ x = self.norm_layers[i](x)
65
+ x = self.relu_drop(x)
66
+ x = x_org + self.proj(x)
67
+ return x * x_mask
68
+
69
+
70
+ class DDSConv(nn.Module):
71
+ """
72
+ Dialted and Depth-Separable Convolution
73
+ """
74
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
75
+ super().__init__()
76
+ self.channels = channels
77
+ self.kernel_size = kernel_size
78
+ self.n_layers = n_layers
79
+ self.p_dropout = p_dropout
80
+
81
+ self.drop = nn.Dropout(p_dropout)
82
+ self.convs_sep = nn.ModuleList()
83
+ self.convs_1x1 = nn.ModuleList()
84
+ self.norms_1 = nn.ModuleList()
85
+ self.norms_2 = nn.ModuleList()
86
+ for i in range(n_layers):
87
+ dilation = kernel_size ** i
88
+ padding = (kernel_size * dilation - dilation) // 2
89
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
90
+ groups=channels, dilation=dilation, padding=padding
91
+ ))
92
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
93
+ self.norms_1.append(LayerNorm(channels))
94
+ self.norms_2.append(LayerNorm(channels))
95
+
96
+ def forward(self, x, x_mask, g=None):
97
+ if g is not None:
98
+ x = x + g
99
+ for i in range(self.n_layers):
100
+ y = self.convs_sep[i](x * x_mask)
101
+ y = self.norms_1[i](y)
102
+ y = F.gelu(y)
103
+ y = self.convs_1x1[i](y)
104
+ y = self.norms_2[i](y)
105
+ y = F.gelu(y)
106
+ y = self.drop(y)
107
+ x = x + y
108
+ return x * x_mask
109
+
110
+
111
+ class WN(torch.nn.Module):
112
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
113
+ super(WN, self).__init__()
114
+ assert(kernel_size % 2 == 1)
115
+ self.hidden_channels =hidden_channels
116
+ self.kernel_size = kernel_size,
117
+ self.dilation_rate = dilation_rate
118
+ self.n_layers = n_layers
119
+ self.gin_channels = gin_channels
120
+ self.p_dropout = p_dropout
121
+
122
+ self.in_layers = torch.nn.ModuleList()
123
+ self.res_skip_layers = torch.nn.ModuleList()
124
+ self.drop = nn.Dropout(p_dropout)
125
+
126
+ if gin_channels != 0:
127
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
128
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
129
+
130
+ for i in range(n_layers):
131
+ dilation = dilation_rate ** i
132
+ padding = int((kernel_size * dilation - dilation) / 2)
133
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
134
+ dilation=dilation, padding=padding)
135
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
136
+ self.in_layers.append(in_layer)
137
+
138
+ # last one is not necessary
139
+ if i < n_layers - 1:
140
+ res_skip_channels = 2 * hidden_channels
141
+ else:
142
+ res_skip_channels = hidden_channels
143
+
144
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
145
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
146
+ self.res_skip_layers.append(res_skip_layer)
147
+
148
+ def forward(self, x, x_mask, g=None, **kwargs):
149
+ output = torch.zeros_like(x)
150
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
151
+
152
+ if g is not None:
153
+ g = self.cond_layer(g)
154
+
155
+ for i in range(self.n_layers):
156
+ x_in = self.in_layers[i](x)
157
+ if g is not None:
158
+ cond_offset = i * 2 * self.hidden_channels
159
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
160
+ else:
161
+ g_l = torch.zeros_like(x_in)
162
+
163
+ acts = commons.fused_add_tanh_sigmoid_multiply(
164
+ x_in,
165
+ g_l,
166
+ n_channels_tensor)
167
+ acts = self.drop(acts)
168
+
169
+ res_skip_acts = self.res_skip_layers[i](acts)
170
+ if i < self.n_layers - 1:
171
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
172
+ x = (x + res_acts) * x_mask
173
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
174
+ else:
175
+ output = output + res_skip_acts
176
+ return output * x_mask
177
+
178
+ def remove_weight_norm(self):
179
+ if self.gin_channels != 0:
180
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
181
+ for l in self.in_layers:
182
+ torch.nn.utils.remove_weight_norm(l)
183
+ for l in self.res_skip_layers:
184
+ torch.nn.utils.remove_weight_norm(l)
185
+
186
+
187
+ class ResBlock1(torch.nn.Module):
188
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
189
+ super(ResBlock1, self).__init__()
190
+ self.convs1 = nn.ModuleList([
191
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
192
+ padding=get_padding(kernel_size, dilation[0]))),
193
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
194
+ padding=get_padding(kernel_size, dilation[1]))),
195
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
196
+ padding=get_padding(kernel_size, dilation[2])))
197
+ ])
198
+ self.convs1.apply(init_weights)
199
+
200
+ self.convs2 = nn.ModuleList([
201
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
202
+ padding=get_padding(kernel_size, 1))),
203
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
204
+ padding=get_padding(kernel_size, 1))),
205
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
206
+ padding=get_padding(kernel_size, 1)))
207
+ ])
208
+ self.convs2.apply(init_weights)
209
+
210
+ def forward(self, x, x_mask=None):
211
+ for c1, c2 in zip(self.convs1, self.convs2):
212
+ xt = F.leaky_relu(x, LRELU_SLOPE)
213
+ if x_mask is not None:
214
+ xt = xt * x_mask
215
+ xt = c1(xt)
216
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
217
+ if x_mask is not None:
218
+ xt = xt * x_mask
219
+ xt = c2(xt)
220
+ x = xt + x
221
+ if x_mask is not None:
222
+ x = x * x_mask
223
+ return x
224
+
225
+ def remove_weight_norm(self):
226
+ for l in self.convs1:
227
+ remove_weight_norm(l)
228
+ for l in self.convs2:
229
+ remove_weight_norm(l)
230
+
231
+
232
+ class ResBlock2(torch.nn.Module):
233
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
234
+ super(ResBlock2, self).__init__()
235
+ self.convs = nn.ModuleList([
236
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
237
+ padding=get_padding(kernel_size, dilation[0]))),
238
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
239
+ padding=get_padding(kernel_size, dilation[1])))
240
+ ])
241
+ self.convs.apply(init_weights)
242
+
243
+ def forward(self, x, x_mask=None):
244
+ for c in self.convs:
245
+ xt = F.leaky_relu(x, LRELU_SLOPE)
246
+ if x_mask is not None:
247
+ xt = xt * x_mask
248
+ xt = c(xt)
249
+ x = xt + x
250
+ if x_mask is not None:
251
+ x = x * x_mask
252
+ return x
253
+
254
+ def remove_weight_norm(self):
255
+ for l in self.convs:
256
+ remove_weight_norm(l)
257
+
258
+
259
+ class Log(nn.Module):
260
+ def forward(self, x, x_mask, reverse=False, **kwargs):
261
+ if not reverse:
262
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
263
+ logdet = torch.sum(-y, [1, 2])
264
+ return y, logdet
265
+ else:
266
+ x = torch.exp(x) * x_mask
267
+ return x
268
+
269
+
270
+ class Flip(nn.Module):
271
+ def forward(self, x, *args, reverse=False, **kwargs):
272
+ x = torch.flip(x, [1])
273
+ if not reverse:
274
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
275
+ return x, logdet
276
+ else:
277
+ return x
278
+
279
+
280
+ class ElementwiseAffine(nn.Module):
281
+ def __init__(self, channels):
282
+ super().__init__()
283
+ self.channels = channels
284
+ self.m = nn.Parameter(torch.zeros(channels,1))
285
+ self.logs = nn.Parameter(torch.zeros(channels,1))
286
+
287
+ def forward(self, x, x_mask, reverse=False, **kwargs):
288
+ if not reverse:
289
+ y = self.m + torch.exp(self.logs) * x
290
+ y = y * x_mask
291
+ logdet = torch.sum(self.logs * x_mask, [1,2])
292
+ return y, logdet
293
+ else:
294
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
295
+ return x
296
+
297
+
298
+ class ResidualCouplingLayer(nn.Module):
299
+ def __init__(self,
300
+ channels,
301
+ hidden_channels,
302
+ kernel_size,
303
+ dilation_rate,
304
+ n_layers,
305
+ p_dropout=0,
306
+ gin_channels=0,
307
+ mean_only=False):
308
+ assert channels % 2 == 0, "channels should be divisible by 2"
309
+ super().__init__()
310
+ self.channels = channels
311
+ self.hidden_channels = hidden_channels
312
+ self.kernel_size = kernel_size
313
+ self.dilation_rate = dilation_rate
314
+ self.n_layers = n_layers
315
+ self.half_channels = channels // 2
316
+ self.mean_only = mean_only
317
+
318
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
319
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
320
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
321
+ self.post.weight.data.zero_()
322
+ self.post.bias.data.zero_()
323
+
324
+ def forward(self, x, x_mask, g=None, reverse=False):
325
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
326
+ h = self.pre(x0) * x_mask
327
+ h = self.enc(h, x_mask, g=g)
328
+ stats = self.post(h) * x_mask
329
+ if not self.mean_only:
330
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
331
+ else:
332
+ m = stats
333
+ logs = torch.zeros_like(m)
334
+
335
+ if not reverse:
336
+ x1 = m + x1 * torch.exp(logs) * x_mask
337
+ x = torch.cat([x0, x1], 1)
338
+ logdet = torch.sum(logs, [1,2])
339
+ return x, logdet
340
+ else:
341
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
342
+ x = torch.cat([x0, x1], 1)
343
+ return x
344
+
345
+
346
+ class ConvFlow(nn.Module):
347
+ def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
348
+ super().__init__()
349
+ self.in_channels = in_channels
350
+ self.filter_channels = filter_channels
351
+ self.kernel_size = kernel_size
352
+ self.n_layers = n_layers
353
+ self.num_bins = num_bins
354
+ self.tail_bound = tail_bound
355
+ self.half_channels = in_channels // 2
356
+
357
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
358
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
359
+ self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
360
+ self.proj.weight.data.zero_()
361
+ self.proj.bias.data.zero_()
362
+
363
+ def forward(self, x, x_mask, g=None, reverse=False):
364
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
365
+ h = self.pre(x0)
366
+ h = self.convs(h, x_mask, g=g)
367
+ h = self.proj(h) * x_mask
368
+
369
+ b, c, t = x0.shape
370
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
371
+
372
+ unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
373
+ unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
374
+ unnormalized_derivatives = h[..., 2 * self.num_bins:]
375
+
376
+ x1, logabsdet = piecewise_rational_quadratic_transform(x1,
377
+ unnormalized_widths,
378
+ unnormalized_heights,
379
+ unnormalized_derivatives,
380
+ inverse=reverse,
381
+ tails='linear',
382
+ tail_bound=self.tail_bound
383
+ )
384
+
385
+ x = torch.cat([x0, x1], 1) * x_mask
386
+ logdet = torch.sum(logabsdet * x_mask, [1,2])
387
+ if not reverse:
388
+ return x, logdet
389
+ else:
390
+ return x
monotonic_align/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ """ numba optimized version.
9
+ neg_cent: [b, t_t, t_s]
10
+ mask: [b, t_t, t_s]
11
+ """
12
+ device = neg_cent.device
13
+ dtype = neg_cent.dtype
14
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
15
+ path = zeros(neg_cent.shape, dtype=int32)
16
+
17
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
18
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
19
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
20
+ return from_numpy(path).to(device=device, dtype=dtype)
21
+
monotonic_align/core.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]),
5
+ nopython=True, nogil=True)
6
+ def maximum_path_jit(paths, values, t_ys, t_xs):
7
+ b = paths.shape[0]
8
+ max_neg_val = -1e9
9
+ for i in range(int(b)):
10
+ path = paths[i]
11
+ value = values[i]
12
+ t_y = t_ys[i]
13
+ t_x = t_xs[i]
14
+
15
+ v_prev = v_cur = 0.0
16
+ index = t_x - 1
17
+
18
+ for y in range(t_y):
19
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
20
+ if x == y:
21
+ v_cur = max_neg_val
22
+ else:
23
+ v_cur = value[y - 1, x]
24
+ if x == 0:
25
+ if y == 0:
26
+ v_prev = 0.
27
+ else:
28
+ v_prev = max_neg_val
29
+ else:
30
+ v_prev = value[y - 1, x - 1]
31
+ value[y, x] += max(v_prev, v_cur)
32
+
33
+ for y in range(t_y - 1, -1, -1):
34
+ path[y, index] = 1
35
+ if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
36
+ index = index - 1
pqmf.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Pseudo QMF modules."""
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from scipy.signal.windows import kaiser
13
+
14
+
15
+ def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
16
+ """Design prototype filter for PQMF.
17
+ This method is based on `A Kaiser window approach for the design of prototype
18
+ filters of cosine modulated filterbanks`_.
19
+ Args:
20
+ taps (int): The number of filter taps.
21
+ cutoff_ratio (float): Cut-off frequency ratio.
22
+ beta (float): Beta coefficient for kaiser window.
23
+ Returns:
24
+ ndarray: Impluse response of prototype filter (taps + 1,).
25
+ .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
26
+ https://ieeexplore.ieee.org/abstract/document/681427
27
+ """
28
+ # check the arguments are valid
29
+ assert taps % 2 == 0, "The number of taps mush be even number."
30
+ assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
31
+
32
+ # make initial filter
33
+ omega_c = np.pi * cutoff_ratio
34
+ with np.errstate(invalid='ignore'):
35
+ h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
36
+ / (np.pi * (np.arange(taps + 1) - 0.5 * taps))
37
+ h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
38
+
39
+ # apply kaiser window
40
+ w = kaiser(taps + 1, beta)
41
+ h = h_i * w
42
+
43
+ return h
44
+
45
+
46
+ class PQMF(torch.nn.Module):
47
+ """PQMF module.
48
+ This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
49
+ .. _`Near-perfect-reconstruction pseudo-QMF banks`:
50
+ https://ieeexplore.ieee.org/document/258122
51
+ """
52
+
53
+ def __init__(self, device, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
54
+ """Initilize PQMF module.
55
+ Args:
56
+ subbands (int): The number of subbands.
57
+ taps (int): The number of filter taps.
58
+ cutoff_ratio (float): Cut-off frequency ratio.
59
+ beta (float): Beta coefficient for kaiser window.
60
+ """
61
+ super(PQMF, self).__init__()
62
+
63
+ # define filter coefficient
64
+ h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
65
+ h_analysis = np.zeros((subbands, len(h_proto)))
66
+ h_synthesis = np.zeros((subbands, len(h_proto)))
67
+ for k in range(subbands):
68
+ h_analysis[k] = 2 * h_proto * np.cos(
69
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
70
+ (np.arange(taps + 1) - ((taps - 1) / 2)) +
71
+ (-1) ** k * np.pi / 4)
72
+ h_synthesis[k] = 2 * h_proto * np.cos(
73
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
74
+ (np.arange(taps + 1) - ((taps - 1) / 2)) -
75
+ (-1) ** k * np.pi / 4)
76
+
77
+ # convert to tensor
78
+ analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
79
+ synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
80
+
81
+ # register coefficients as beffer
82
+ self.register_buffer("analysis_filter", analysis_filter)
83
+ self.register_buffer("synthesis_filter", synthesis_filter)
84
+
85
+ # filter for downsampling & upsampling
86
+ updown_filter = torch.zeros((subbands, subbands, subbands)).float()
87
+ for k in range(subbands):
88
+ updown_filter[k, k, 0] = 1.0
89
+ self.register_buffer("updown_filter", updown_filter)
90
+ self.subbands = subbands
91
+
92
+ # keep padding info
93
+ self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
94
+
95
+ def analysis(self, x):
96
+ """Analysis with PQMF.
97
+ Args:
98
+ x (Tensor): Input tensor (B, 1, T).
99
+ Returns:
100
+ Tensor: Output tensor (B, subbands, T // subbands).
101
+ """
102
+ x = F.conv1d(self.pad_fn(x), self.analysis_filter)
103
+ return F.conv1d(x, self.updown_filter, stride=self.subbands)
104
+
105
+ def synthesis(self, x):
106
+ """Synthesis with PQMF.
107
+ Args:
108
+ x (Tensor): Input tensor (B, subbands, T // subbands).
109
+ Returns:
110
+ Tensor: Output tensor (B, 1, T).
111
+ """
112
+ # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
113
+ # Not sure this is the correct way, it is better to check again.
114
+ # TODO(kan-bayashi): Understand the reconstruction procedure
115
+ x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
116
+ return F.conv1d(self.pad_fn(x), self.synthesis_filter)
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numba
2
+ librosa
3
+ matplotlib
4
+ numpy
5
+ phonemizer
6
+ scipy
7
+ tensorboard
8
+ torch
9
+ torchvision
10
+ torchaudio
11
+ Unidecode
12
+ pyopenjtalk
13
+ jamo
14
+ pypinyin
15
+ ko_pron
16
+ jieba
17
+ cn2an
18
+ gradio==3.50.2
19
+ monotonic_align
20
+ httpx==0.24.1
saved_model/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8ced49ea7591ac054d578e597386f26827dc35a757bae228123b5ed59d8d3bb
3
+ size 2275
saved_model/cover.png ADDED

Git LFS Details

  • SHA256: 4be366d6f5a71284f97c2ebb484cfb6e5bcfe50ceed72a46b0e43e2d06620415
  • Pointer size: 131 Bytes
  • Size of remote file: 259 kB
saved_model/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c2a297cbe07aa674bf524592799bd56b70735fb510496e8ccd739644c76f9ce
3
+ size 162174033
stft.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+ Copyright (c) 2017, Prem Seetharaman
4
+ All rights reserved.
5
+ * Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+ * Redistributions of source code must retain the above copyright notice,
8
+ this list of conditions and the following disclaimer.
9
+ * Redistributions in binary form must reproduce the above copyright notice, this
10
+ list of conditions and the following disclaimer in the
11
+ documentation and/or other materials provided with the distribution.
12
+ * Neither the name of the copyright holder nor the names of its
13
+ contributors may be used to endorse or promote products derived from this
14
+ software without specific prior written permission.
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ """
26
+
27
+ import torch
28
+ import numpy as np
29
+ import torch.nn.functional as F
30
+ from torch.autograd import Variable
31
+ from scipy.signal import get_window
32
+ from librosa.util import pad_center, tiny
33
+ import librosa.util as librosa_util
34
+
35
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
36
+ n_fft=800, dtype=np.float32, norm=None):
37
+ """
38
+ # from librosa 0.6
39
+ Compute the sum-square envelope of a window function at a given hop length.
40
+ This is used to estimate modulation effects induced by windowing
41
+ observations in short-time fourier transforms.
42
+ Parameters
43
+ ----------
44
+ window : string, tuple, number, callable, or list-like
45
+ Window specification, as in `get_window`
46
+ n_frames : int > 0
47
+ The number of analysis frames
48
+ hop_length : int > 0
49
+ The number of samples to advance between frames
50
+ win_length : [optional]
51
+ The length of the window function. By default, this matches `n_fft`.
52
+ n_fft : int > 0
53
+ The length of each analysis frame.
54
+ dtype : np.dtype
55
+ The data type of the output
56
+ Returns
57
+ -------
58
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
59
+ The sum-squared envelope of the window function
60
+ """
61
+ if win_length is None:
62
+ win_length = n_fft
63
+
64
+ n = n_fft + hop_length * (n_frames - 1)
65
+ x = np.zeros(n, dtype=dtype)
66
+
67
+ # Compute the squared window at the desired length
68
+ win_sq = get_window(window, win_length, fftbins=True)
69
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
70
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
71
+
72
+ # Fill the envelope
73
+ for i in range(n_frames):
74
+ sample = i * hop_length
75
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
76
+ return x
77
+
78
+
79
+ class STFT(torch.nn.Module):
80
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
81
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
82
+ window='hann'):
83
+ super(STFT, self).__init__()
84
+ self.filter_length = filter_length
85
+ self.hop_length = hop_length
86
+ self.win_length = win_length
87
+ self.window = window
88
+ self.forward_transform = None
89
+ scale = self.filter_length / self.hop_length
90
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
91
+
92
+ cutoff = int((self.filter_length / 2 + 1))
93
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
94
+ np.imag(fourier_basis[:cutoff, :])])
95
+
96
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
97
+ inverse_basis = torch.FloatTensor(
98
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
99
+
100
+ if window is not None:
101
+ assert(filter_length >= win_length)
102
+ # get window and zero center pad it to filter_length
103
+ fft_window = get_window(window, win_length, fftbins=True)
104
+ fft_window = pad_center(fft_window, filter_length)
105
+ fft_window = torch.from_numpy(fft_window).float()
106
+
107
+ # window the bases
108
+ forward_basis *= fft_window
109
+ inverse_basis *= fft_window
110
+
111
+ self.register_buffer('forward_basis', forward_basis.float())
112
+ self.register_buffer('inverse_basis', inverse_basis.float())
113
+
114
+ def transform(self, input_data):
115
+ num_batches = input_data.size(0)
116
+ num_samples = input_data.size(1)
117
+
118
+ self.num_samples = num_samples
119
+
120
+ # similar to librosa, reflect-pad the input
121
+ input_data = input_data.view(num_batches, 1, num_samples)
122
+ input_data = F.pad(
123
+ input_data.unsqueeze(1),
124
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
125
+ mode='reflect')
126
+ input_data = input_data.squeeze(1)
127
+
128
+ forward_transform = F.conv1d(
129
+ input_data,
130
+ Variable(self.forward_basis, requires_grad=False),
131
+ stride=self.hop_length,
132
+ padding=0)
133
+
134
+ cutoff = int((self.filter_length / 2) + 1)
135
+ real_part = forward_transform[:, :cutoff, :]
136
+ imag_part = forward_transform[:, cutoff:, :]
137
+
138
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
139
+ phase = torch.autograd.Variable(
140
+ torch.atan2(imag_part.data, real_part.data))
141
+
142
+ return magnitude, phase
143
+
144
+ def inverse(self, magnitude, phase):
145
+ recombine_magnitude_phase = torch.cat(
146
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
147
+
148
+ inverse_transform = F.conv_transpose1d(
149
+ recombine_magnitude_phase,
150
+ Variable(self.inverse_basis, requires_grad=False),
151
+ stride=self.hop_length,
152
+ padding=0)
153
+
154
+ if self.window is not None:
155
+ window_sum = window_sumsquare(
156
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
157
+ win_length=self.win_length, n_fft=self.filter_length,
158
+ dtype=np.float32)
159
+ # remove modulation effects
160
+ approx_nonzero_indices = torch.from_numpy(
161
+ np.where(window_sum > tiny(window_sum))[0])
162
+ window_sum = torch.autograd.Variable(
163
+ torch.from_numpy(window_sum), requires_grad=False)
164
+ window_sum = window_sum.to(inverse_transform.device()) if magnitude.is_cuda else window_sum
165
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
166
+
167
+ # scale by hop ratio
168
+ inverse_transform *= float(self.filter_length) / self.hop_length
169
+
170
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
171
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
172
+
173
+ return inverse_transform
174
+
175
+ def forward(self, input_data):
176
+ self.magnitude, self.phase = self.transform(input_data)
177
+ reconstruction = self.inverse(self.magnitude, self.phase)
178
+ return reconstruction
179
+
180
+
181
+ class OnnxSTFT(torch.nn.Module):
182
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
183
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
184
+ window='hann'):
185
+ super(OnnxSTFT, self).__init__()
186
+ self.filter_length = filter_length
187
+ self.hop_length = hop_length
188
+ self.win_length = win_length
189
+ self.window = window
190
+ self.forward_transform = None
191
+ scale = self.filter_length / self.hop_length
192
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
193
+
194
+ cutoff = int((self.filter_length / 2 + 1))
195
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
196
+ np.imag(fourier_basis[:cutoff, :])])
197
+
198
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
199
+ inverse_basis = torch.FloatTensor(
200
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
201
+
202
+ if window is not None:
203
+ assert(filter_length >= win_length)
204
+ # get window and zero center pad it to filter_length
205
+ fft_window = get_window(window, win_length, fftbins=True)
206
+ fft_window = pad_center(fft_window, filter_length)
207
+ fft_window = torch.from_numpy(fft_window).float()
208
+
209
+ # window the bases
210
+ forward_basis *= fft_window
211
+ inverse_basis *= fft_window
212
+
213
+ self.register_buffer('forward_basis', forward_basis.float())
214
+ self.register_buffer('inverse_basis', inverse_basis.float())
215
+
216
+ def transform(self, input_data):
217
+ num_batches = input_data.size(0)
218
+ num_samples = input_data.size(1)
219
+
220
+ self.num_samples = num_samples
221
+
222
+ # similar to librosa, reflect-pad the input
223
+ input_data = input_data.view(num_batches, 1, num_samples)
224
+ input_data = F.pad(
225
+ input_data.unsqueeze(1),
226
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
227
+ mode='reflect')
228
+ input_data = input_data.squeeze(1)
229
+
230
+ forward_transform = F.conv1d(
231
+ input_data,
232
+ Variable(self.forward_basis, requires_grad=False),
233
+ stride=self.hop_length,
234
+ padding=0)
235
+
236
+ cutoff = int((self.filter_length / 2) + 1)
237
+ real_part = forward_transform[:, :cutoff, :]
238
+ imag_part = forward_transform[:, cutoff:, :]
239
+
240
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
241
+ phase = torch.autograd.Variable(
242
+ torch.atan2(imag_part.data, real_part.data))
243
+
244
+ return magnitude, phase
245
+
246
+ def inverse(self, magnitude, phase):
247
+ recombine_magnitude_phase = torch.cat(
248
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
249
+
250
+ inverse_transform = F.conv_transpose1d(
251
+ recombine_magnitude_phase,
252
+ Variable(self.inverse_basis, requires_grad=False),
253
+ stride=self.hop_length,
254
+ padding=0)
255
+
256
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
257
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
258
+
259
+ return inverse_transform
260
+
261
+ def forward(self, input_data):
262
+ self.magnitude, self.phase = self.transform(input_data)
263
+ reconstruction = self.inverse(self.magnitude, self.phase)
264
+ return reconstruction
265
+
266
+
267
+ class TorchSTFT(torch.nn.Module):
268
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
269
+ super().__init__()
270
+ self.filter_length = filter_length
271
+ self.hop_length = hop_length
272
+ self.win_length = win_length
273
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
274
+
275
+ def transform(self, input_data):
276
+ forward_transform = torch.stft(
277
+ input_data,
278
+ self.filter_length, self.hop_length, self.win_length, window=self.window,
279
+ return_complex=True)
280
+
281
+ return torch.abs(forward_transform), torch.angle(forward_transform)
282
+
283
+ def inverse(self, magnitude, phase):
284
+ inverse_transform = torch.istft(
285
+ magnitude * torch.exp(phase * 1j),
286
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
287
+
288
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
289
+
290
+ def forward(self, input_data):
291
+ self.magnitude, self.phase = self.transform(input_data)
292
+ reconstruction = self.inverse(self.magnitude, self.phase)
293
+ return reconstruction
294
+
295
+
stft_loss.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """STFT-based Loss modules."""
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def stft(x, fft_size, hop_size, win_length, window):
13
+ """Perform STFT and convert to magnitude spectrogram.
14
+ Args:
15
+ x (Tensor): Input signal tensor (B, T).
16
+ fft_size (int): FFT size.
17
+ hop_size (int): Hop size.
18
+ win_length (int): Window length.
19
+ window (str): Window function type.
20
+ Returns:
21
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
22
+ """
23
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window.to(x.device))
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
28
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
29
+
30
+
31
+ class SpectralConvergengeLoss(torch.nn.Module):
32
+ """Spectral convergence loss module."""
33
+
34
+ def __init__(self):
35
+ """Initilize spectral convergence loss module."""
36
+ super(SpectralConvergengeLoss, self).__init__()
37
+
38
+ def forward(self, x_mag, y_mag):
39
+ """Calculate forward propagation.
40
+ Args:
41
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
42
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
43
+ Returns:
44
+ Tensor: Spectral convergence loss value.
45
+ """
46
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
47
+
48
+
49
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
50
+ """Log STFT magnitude loss module."""
51
+
52
+ def __init__(self):
53
+ """Initilize los STFT magnitude loss module."""
54
+ super(LogSTFTMagnitudeLoss, self).__init__()
55
+
56
+ def forward(self, x_mag, y_mag):
57
+ """Calculate forward propagation.
58
+ Args:
59
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
60
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
61
+ Returns:
62
+ Tensor: Log STFT magnitude loss value.
63
+ """
64
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
65
+
66
+
67
+ class STFTLoss(torch.nn.Module):
68
+ """STFT loss module."""
69
+
70
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
71
+ """Initialize STFT loss module."""
72
+ super(STFTLoss, self).__init__()
73
+ self.fft_size = fft_size
74
+ self.shift_size = shift_size
75
+ self.win_length = win_length
76
+ self.window = getattr(torch, window)(win_length)
77
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
78
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
79
+
80
+ def forward(self, x, y):
81
+ """Calculate forward propagation.
82
+ Args:
83
+ x (Tensor): Predicted signal (B, T).
84
+ y (Tensor): Groundtruth signal (B, T).
85
+ Returns:
86
+ Tensor: Spectral convergence loss value.
87
+ Tensor: Log STFT magnitude loss value.
88
+ """
89
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
90
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
91
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
92
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
93
+
94
+ return sc_loss, mag_loss
95
+
96
+
97
+ class MultiResolutionSTFTLoss(torch.nn.Module):
98
+ """Multi resolution STFT loss module."""
99
+
100
+ def __init__(self,
101
+ fft_sizes=[1024, 2048, 512],
102
+ hop_sizes=[120, 240, 50],
103
+ win_lengths=[600, 1200, 240],
104
+ window="hann_window"):
105
+ """Initialize Multi resolution STFT loss module.
106
+ Args:
107
+ fft_sizes (list): List of FFT sizes.
108
+ hop_sizes (list): List of hop sizes.
109
+ win_lengths (list): List of window lengths.
110
+ window (str): Window function type.
111
+ """
112
+ super(MultiResolutionSTFTLoss, self).__init__()
113
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
114
+ self.stft_losses = torch.nn.ModuleList()
115
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
116
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
117
+
118
+ def forward(self, x, y):
119
+ """Calculate forward propagation.
120
+ Args:
121
+ x (Tensor): Predicted signal (B, T).
122
+ y (Tensor): Groundtruth signal (B, T).
123
+ Returns:
124
+ Tensor: Multi resolution spectral convergence loss value.
125
+ Tensor: Multi resolution log STFT magnitude loss value.
126
+ """
127
+ sc_loss = 0.0
128
+ mag_loss = 0.0
129
+ for f in self.stft_losses:
130
+ sc_l, mag_l = f(x, y)
131
+ sc_loss += sc_l
132
+ mag_loss += mag_l
133
+ sc_loss /= len(self.stft_losses)
134
+ mag_loss /= len(self.stft_losses)
135
+
136
+ return sc_loss, mag_loss
text/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
text/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+ from text.symbols import symbols
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+
11
+ def text_to_sequence(text, cleaner_names):
12
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13
+ Args:
14
+ text: string to convert to a sequence
15
+ cleaner_names: names of the cleaner functions to run the text through
16
+ Returns:
17
+ List of integers corresponding to the symbols in the text
18
+ '''
19
+ sequence = []
20
+
21
+ clean_text = _clean_text(text, cleaner_names)
22
+ for symbol in clean_text:
23
+ if symbol not in _symbol_to_id.keys():
24
+ continue
25
+ symbol_id = _symbol_to_id[symbol]
26
+ sequence += [symbol_id]
27
+ return sequence
28
+
29
+
30
+ def cleaned_text_to_sequence(cleaned_text):
31
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
32
+ Args:
33
+ text: string to convert to a sequence
34
+ Returns:
35
+ List of integers corresponding to the symbols in the text
36
+ '''
37
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
38
+ return sequence
39
+
40
+
41
+ def sequence_to_text(sequence):
42
+ '''Converts a sequence of IDs back to a string'''
43
+ result = ''
44
+ for symbol_id in sequence:
45
+ s = _id_to_symbol[symbol_id]
46
+ result += s
47
+ return result
48
+
49
+
50
+ def _clean_text(text, cleaner_names):
51
+ for name in cleaner_names:
52
+ cleaner = getattr(cleaners, name)
53
+ if not cleaner:
54
+ raise Exception('Unknown cleaner: %s' % name)
55
+ text = cleaner(text)
56
+ return text
text/cleaners.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from text.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3
3
+
4
+
5
+ def japanese_cleaners(text):
6
+ text = japanese_to_romaji_with_accent(text)
7
+ text = re.sub(r'([A-Za-z])$', r'\1.', text)
8
+ return text
9
+
10
+
11
+ def japanese_cleaners2(text):
12
+ return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
text/japanese.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from unidecode import unidecode
3
+ import pyopenjtalk
4
+
5
+
6
+ # Regular expression matching Japanese without punctuation marks:
7
+ _japanese_characters = re.compile(
8
+ r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
9
+
10
+ # Regular expression matching non-Japanese characters or punctuation marks:
11
+ _japanese_marks = re.compile(
12
+ r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
13
+
14
+ # List of (symbol, Japanese) pairs for marks:
15
+ _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
16
+ ('%', 'パーセント')
17
+ ]]
18
+
19
+ # List of (romaji, ipa) pairs for marks:
20
+ _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
21
+ ('ts', 'ʦ'),
22
+ ('u', 'ɯ'),
23
+ ('j', 'ʥ'),
24
+ ('y', 'j'),
25
+ ('ni', 'n^i'),
26
+ ('nj', 'n^'),
27
+ ('hi', 'çi'),
28
+ ('hj', 'ç'),
29
+ ('f', 'ɸ'),
30
+ ('I', 'i*'),
31
+ ('U', 'ɯ*'),
32
+ ('r', 'ɾ')
33
+ ]]
34
+
35
+ # List of (romaji, ipa2) pairs for marks:
36
+ _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
37
+ ('u', 'ɯ'),
38
+ ('ʧ', 'tʃ'),
39
+ ('j', 'dʑ'),
40
+ ('y', 'j'),
41
+ ('ni', 'n^i'),
42
+ ('nj', 'n^'),
43
+ ('hi', 'çi'),
44
+ ('hj', 'ç'),
45
+ ('f', 'ɸ'),
46
+ ('I', 'i*'),
47
+ ('U', 'ɯ*'),
48
+ ('r', 'ɾ')
49
+ ]]
50
+
51
+ # List of (consonant, sokuon) pairs:
52
+ _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
53
+ (r'Q([↑↓]*[kg])', r'k#\1'),
54
+ (r'Q([↑↓]*[tdjʧ])', r't#\1'),
55
+ (r'Q([↑↓]*[sʃ])', r's\1'),
56
+ (r'Q([↑↓]*[pb])', r'p#\1')
57
+ ]]
58
+
59
+ # List of (consonant, hatsuon) pairs:
60
+ _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
61
+ (r'N([↑↓]*[pbm])', r'm\1'),
62
+ (r'N([↑↓]*[ʧʥj])', r'n^\1'),
63
+ (r'N([↑↓]*[tdn])', r'n\1'),
64
+ (r'N([↑↓]*[kg])', r'ŋ\1')
65
+ ]]
66
+
67
+
68
+ def symbols_to_japanese(text):
69
+ for regex, replacement in _symbols_to_japanese:
70
+ text = re.sub(regex, replacement, text)
71
+ return text
72
+
73
+
74
+ def japanese_to_romaji_with_accent(text):
75
+ '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
76
+ text = symbols_to_japanese(text)
77
+ sentences = re.split(_japanese_marks, text)
78
+ marks = re.findall(_japanese_marks, text)
79
+ text = ''
80
+ for i, sentence in enumerate(sentences):
81
+ if re.match(_japanese_characters, sentence):
82
+ if text != '':
83
+ text += ' '
84
+ labels = pyopenjtalk.extract_fullcontext(sentence)
85
+ for n, label in enumerate(labels):
86
+ phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
87
+ if phoneme not in ['sil', 'pau']:
88
+ text += phoneme.replace('ch', 'ʧ').replace('sh',
89
+ 'ʃ').replace('cl', 'Q')
90
+ else:
91
+ continue
92
+ # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
93
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
94
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
95
+ a3 = int(re.search(r"\+(\d+)/", label).group(1))
96
+ if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
97
+ a2_next = -1
98
+ else:
99
+ a2_next = int(
100
+ re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
101
+ # Accent phrase boundary
102
+ if a3 == 1 and a2_next == 1:
103
+ text += ' '
104
+ # Falling
105
+ elif a1 == 0 and a2_next == a2 + 1:
106
+ text += '↓'
107
+ # Rising
108
+ elif a2 == 1 and a2_next == 2:
109
+ text += '↑'
110
+ if i < len(marks):
111
+ text += unidecode(marks[i]).replace(' ', '')
112
+ return text
113
+
114
+
115
+ def get_real_sokuon(text):
116
+ for regex, replacement in _real_sokuon:
117
+ text = re.sub(regex, replacement, text)
118
+ return text
119
+
120
+
121
+ def get_real_hatsuon(text):
122
+ for regex, replacement in _real_hatsuon:
123
+ text = re.sub(regex, replacement, text)
124
+ return text
125
+
126
+
127
+ def japanese_to_ipa(text):
128
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
129
+ text = re.sub(
130
+ r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
131
+ text = get_real_sokuon(text)
132
+ text = get_real_hatsuon(text)
133
+ for regex, replacement in _romaji_to_ipa:
134
+ text = re.sub(regex, replacement, text)
135
+ return text
136
+
137
+
138
+ def japanese_to_ipa2(text):
139
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
140
+ text = get_real_sokuon(text)
141
+ text = get_real_hatsuon(text)
142
+ for regex, replacement in _romaji_to_ipa2:
143
+ text = re.sub(regex, replacement, text)
144
+ return text
145
+
146
+
147
+ def japanese_to_ipa3(text):
148
+ text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
149
+ 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
150
+ text = re.sub(
151
+ r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
152
+ text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
153
+ return text
text/symbols.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Defines the set of symbols used in text input to the model.
3
+ '''
4
+
5
+ '''# japanese_cleaners
6
+ _pad = '_'
7
+ _punctuation = ',.!?-'
8
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9
+ '''
10
+
11
+ # japanese_cleaners2
12
+ _pad = '_'
13
+ _punctuation = ',.!?-~…'
14
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15
+
16
+
17
+ '''
18
+ # korean_cleaners
19
+ _pad = '_'
20
+ _punctuation = ',.!?…~'
21
+ _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22
+ '''
23
+
24
+ '''# chinese_cleaners
25
+ _pad = '_'
26
+ _punctuation = ',。!?—…'
27
+ _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28
+ '''
29
+
30
+ '''# zh_ja_mixture_cleaners
31
+ _pad = '_'
32
+ _punctuation = ',.!?-~…'
33
+ _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34
+ '''
35
+
36
+ '''# sanskrit_cleaners
37
+ _pad = '_'
38
+ _punctuation = '।'
39
+ _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40
+ '''
41
+
42
+ '''# cjks_cleaners
43
+ _pad = '_'
44
+ _punctuation = ',.!?-~…'
45
+ _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46
+ '''
47
+
48
+ '''# thai_cleaners
49
+ _pad = '_'
50
+ _punctuation = '.!? '
51
+ _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52
+ '''
53
+
54
+ '''# cjke_cleaners2
55
+ _pad = '_'
56
+ _punctuation = ',.!?-~…'
57
+ _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58
+ '''
59
+
60
+ '''# shanghainese_cleaners
61
+ _pad = '_'
62
+ _punctuation = ',.!?…'
63
+ _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
64
+ '''
65
+
66
+ '''# chinese_dialect_cleaners
67
+ _pad = '_'
68
+ _punctuation = ',.!?~…─'
69
+ _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚αᴀᴇ↑↓∅ⱼ '
70
+ '''
71
+
72
+ # Export all symbols:
73
+ symbols = [_pad] + list(_punctuation) + list(_letters)
74
+
75
+ # Special symbol ids
76
+ SPACE_ID = symbols.index(" ")
transforms.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(inputs,
13
+ unnormalized_widths,
14
+ unnormalized_heights,
15
+ unnormalized_derivatives,
16
+ inverse=False,
17
+ tails=None,
18
+ tail_bound=1.,
19
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
22
+
23
+ if tails is None:
24
+ spline_fn = rational_quadratic_spline
25
+ spline_kwargs = {}
26
+ else:
27
+ spline_fn = unconstrained_rational_quadratic_spline
28
+ spline_kwargs = {
29
+ 'tails': tails,
30
+ 'tail_bound': tail_bound
31
+ }
32
+
33
+ outputs, logabsdet = spline_fn(
34
+ inputs=inputs,
35
+ unnormalized_widths=unnormalized_widths,
36
+ unnormalized_heights=unnormalized_heights,
37
+ unnormalized_derivatives=unnormalized_derivatives,
38
+ inverse=inverse,
39
+ min_bin_width=min_bin_width,
40
+ min_bin_height=min_bin_height,
41
+ min_derivative=min_derivative,
42
+ **spline_kwargs
43
+ )
44
+ return outputs, logabsdet
45
+
46
+
47
+ def searchsorted(bin_locations, inputs, eps=1e-6):
48
+ bin_locations[..., -1] += eps
49
+ return torch.sum(
50
+ inputs[..., None] >= bin_locations,
51
+ dim=-1
52
+ ) - 1
53
+
54
+
55
+ def unconstrained_rational_quadratic_spline(inputs,
56
+ unnormalized_widths,
57
+ unnormalized_heights,
58
+ unnormalized_derivatives,
59
+ inverse=False,
60
+ tails='linear',
61
+ tail_bound=1.,
62
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
65
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66
+ outside_interval_mask = ~inside_interval_mask
67
+
68
+ outputs = torch.zeros_like(inputs)
69
+ logabsdet = torch.zeros_like(inputs)
70
+
71
+ if tails == 'linear':
72
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73
+ constant = np.log(np.exp(1 - min_derivative) - 1)
74
+ unnormalized_derivatives[..., 0] = constant
75
+ unnormalized_derivatives[..., -1] = constant
76
+
77
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
78
+ logabsdet[outside_interval_mask] = 0
79
+ else:
80
+ raise RuntimeError('{} tails are not implemented.'.format(tails))
81
+
82
+ outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89
+ min_bin_width=min_bin_width,
90
+ min_bin_height=min_bin_height,
91
+ min_derivative=min_derivative
92
+ )
93
+
94
+ return outputs, logabsdet
95
+
96
+ def rational_quadratic_spline(inputs,
97
+ unnormalized_widths,
98
+ unnormalized_heights,
99
+ unnormalized_derivatives,
100
+ inverse=False,
101
+ left=0., right=1., bottom=0., top=1.,
102
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
105
+ if torch.min(inputs) < left or torch.max(inputs) > right:
106
+ raise ValueError('Input to a transform is not within its domain')
107
+
108
+ num_bins = unnormalized_widths.shape[-1]
109
+
110
+ if min_bin_width * num_bins > 1.0:
111
+ raise ValueError('Minimal bin width too large for the number of bins')
112
+ if min_bin_height * num_bins > 1.0:
113
+ raise ValueError('Minimal bin height too large for the number of bins')
114
+
115
+ widths = F.softmax(unnormalized_widths, dim=-1)
116
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117
+ cumwidths = torch.cumsum(widths, dim=-1)
118
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119
+ cumwidths = (right - left) * cumwidths + left
120
+ cumwidths[..., 0] = left
121
+ cumwidths[..., -1] = right
122
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123
+
124
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125
+
126
+ heights = F.softmax(unnormalized_heights, dim=-1)
127
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128
+ cumheights = torch.cumsum(heights, dim=-1)
129
+ cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130
+ cumheights = (top - bottom) * cumheights + bottom
131
+ cumheights[..., 0] = bottom
132
+ cumheights[..., -1] = top
133
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
134
+
135
+ if inverse:
136
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
137
+ else:
138
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
139
+
140
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142
+
143
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144
+ delta = heights / widths
145
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
146
+
147
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149
+
150
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
151
+
152
+ if inverse:
153
+ a = (((inputs - input_cumheights) * (input_derivatives
154
+ + input_derivatives_plus_one
155
+ - 2 * input_delta)
156
+ + input_heights * (input_delta - input_derivatives)))
157
+ b = (input_heights * input_derivatives
158
+ - (inputs - input_cumheights) * (input_derivatives
159
+ + input_derivatives_plus_one
160
+ - 2 * input_delta))
161
+ c = - input_delta * (inputs - input_cumheights)
162
+
163
+ discriminant = b.pow(2) - 4 * a * c
164
+ assert (discriminant >= 0).all()
165
+
166
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
167
+ outputs = root * input_bin_widths + input_cumwidths
168
+
169
+ theta_one_minus_theta = root * (1 - root)
170
+ denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171
+ * theta_one_minus_theta)
172
+ derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173
+ + 2 * input_delta * theta_one_minus_theta
174
+ + input_derivatives * (1 - root).pow(2))
175
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176
+
177
+ return outputs, -logabsdet
178
+ else:
179
+ theta = (inputs - input_cumwidths) / input_bin_widths
180
+ theta_one_minus_theta = theta * (1 - theta)
181
+
182
+ numerator = input_heights * (input_delta * theta.pow(2)
183
+ + input_derivatives * theta_one_minus_theta)
184
+ denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185
+ * theta_one_minus_theta)
186
+ outputs = input_cumheights + numerator / denominator
187
+
188
+ derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189
+ + 2 * input_delta * theta_one_minus_theta
190
+ + input_derivatives * (1 - theta).pow(2))
191
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192
+
193
+ return outputs, logabsdet
utils.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21
+ iteration = checkpoint_dict['iteration']
22
+ learning_rate = checkpoint_dict['learning_rate']
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
25
+ saved_state_dict = checkpoint_dict['model']
26
+ if hasattr(model, 'module'):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict= {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, 'module'):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info("Loaded checkpoint '{}' (iteration {})" .format(
42
+ checkpoint_path, iteration))
43
+ return model, optimizer, learning_rate, iteration
44
+
45
+
46
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
47
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
48
+ iteration, checkpoint_path))
49
+ if hasattr(model, 'module'):
50
+ state_dict = model.module.state_dict()
51
+ else:
52
+ state_dict = model.state_dict()
53
+ torch.save({'model': state_dict,
54
+ 'iteration': iteration,
55
+ 'optimizer': optimizer.state_dict(),
56
+ 'learning_rate': learning_rate}, checkpoint_path)
57
+
58
+
59
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
60
+ for k, v in scalars.items():
61
+ writer.add_scalar(k, v, global_step)
62
+ for k, v in histograms.items():
63
+ writer.add_histogram(k, v, global_step)
64
+ for k, v in images.items():
65
+ writer.add_image(k, v, global_step, dataformats='HWC')
66
+ for k, v in audios.items():
67
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
68
+
69
+
70
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
71
+ f_list = glob.glob(os.path.join(dir_path, regex))
72
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
73
+ x = f_list[-1]
74
+ print(x)
75
+ return x
76
+
77
+
78
+ def plot_spectrogram_to_numpy(spectrogram):
79
+ global MATPLOTLIB_FLAG
80
+ if not MATPLOTLIB_FLAG:
81
+ import matplotlib
82
+ matplotlib.use("Agg")
83
+ MATPLOTLIB_FLAG = True
84
+ mpl_logger = logging.getLogger('matplotlib')
85
+ mpl_logger.setLevel(logging.WARNING)
86
+ import matplotlib.pylab as plt
87
+ import numpy as np
88
+
89
+ fig, ax = plt.subplots(figsize=(10,2))
90
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
91
+ interpolation='none')
92
+ plt.colorbar(im, ax=ax)
93
+ plt.xlabel("Frames")
94
+ plt.ylabel("Channels")
95
+ plt.tight_layout()
96
+
97
+ fig.canvas.draw()
98
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
99
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
100
+ plt.close()
101
+ return data
102
+
103
+
104
+ def plot_alignment_to_numpy(alignment, info=None):
105
+ global MATPLOTLIB_FLAG
106
+ if not MATPLOTLIB_FLAG:
107
+ import matplotlib
108
+ matplotlib.use("Agg")
109
+ MATPLOTLIB_FLAG = True
110
+ mpl_logger = logging.getLogger('matplotlib')
111
+ mpl_logger.setLevel(logging.WARNING)
112
+ import matplotlib.pylab as plt
113
+ import numpy as np
114
+
115
+ fig, ax = plt.subplots(figsize=(6, 4))
116
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
117
+ interpolation='none')
118
+ fig.colorbar(im, ax=ax)
119
+ xlabel = 'Decoder timestep'
120
+ if info is not None:
121
+ xlabel += '\n\n' + info
122
+ plt.xlabel(xlabel)
123
+ plt.ylabel('Encoder timestep')
124
+ plt.tight_layout()
125
+
126
+ fig.canvas.draw()
127
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
128
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
129
+ plt.close()
130
+ return data
131
+
132
+
133
+ def load_wav_to_torch(full_path):
134
+ sampling_rate, data = read(full_path)
135
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
136
+
137
+
138
+ def load_filepaths_and_text(filename, split="|"):
139
+ with open(filename, encoding='utf-8') as f:
140
+ filepaths_and_text = [line.strip().split(split) for line in f]
141
+ return filepaths_and_text
142
+
143
+
144
+ def get_hparams(init=True):
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
147
+ help='JSON file for configuration')
148
+ parser.add_argument('-m', '--model', type=str, required=True,
149
+ help='Model name')
150
+
151
+ args = parser.parse_args()
152
+ model_dir = os.path.join("../models", args.model)
153
+
154
+ if not os.path.exists(model_dir):
155
+ os.makedirs(model_dir)
156
+
157
+ config_path = args.config
158
+ config_save_path = os.path.join(model_dir, "config.json")
159
+ if init:
160
+ with open(config_path, "r") as f:
161
+ data = f.read()
162
+ with open(config_save_path, "w") as f:
163
+ f.write(data)
164
+ else:
165
+ with open(config_save_path, "r") as f:
166
+ data = f.read()
167
+ config = json.loads(data)
168
+
169
+ hparams = HParams(**config)
170
+ hparams.model_dir = model_dir
171
+ return hparams
172
+
173
+
174
+ def get_hparams_from_dir(model_dir):
175
+ config_save_path = os.path.join(model_dir, "config.json")
176
+ with open(config_save_path, "r") as f:
177
+ data = f.read()
178
+ config = json.loads(data)
179
+
180
+ hparams = HParams(**config)
181
+ hparams.model_dir = model_dir
182
+ return hparams
183
+
184
+
185
+ def get_hparams_from_file(config_path):
186
+ with open(config_path, "r") as f:
187
+ data = f.read()
188
+ config = json.loads(data)
189
+
190
+ hparams = HParams(**config)
191
+ return hparams
192
+
193
+
194
+ def check_git_hash(model_dir):
195
+ source_dir = os.path.dirname(os.path.realpath(__file__))
196
+ if not os.path.exists(os.path.join(source_dir, ".git")):
197
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
198
+ source_dir
199
+ ))
200
+ return
201
+
202
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
203
+
204
+ path = os.path.join(model_dir, "githash")
205
+ if os.path.exists(path):
206
+ saved_hash = open(path).read()
207
+ if saved_hash != cur_hash:
208
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
209
+ saved_hash[:8], cur_hash[:8]))
210
+ else:
211
+ open(path, "w").write(cur_hash)
212
+
213
+
214
+ def get_logger(model_dir, filename="train.log"):
215
+ global logger
216
+ logger = logging.getLogger(os.path.basename(model_dir))
217
+ logger.setLevel(logging.DEBUG)
218
+
219
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
220
+ if not os.path.exists(model_dir):
221
+ os.makedirs(model_dir)
222
+ h = logging.FileHandler(os.path.join(model_dir, filename))
223
+ h.setLevel(logging.DEBUG)
224
+ h.setFormatter(formatter)
225
+ logger.addHandler(h)
226
+ return logger
227
+
228
+
229
+ class HParams():
230
+ def __init__(self, **kwargs):
231
+ for k, v in kwargs.items():
232
+ if type(v) == dict:
233
+ v = HParams(**v)
234
+ self[k] = v
235
+
236
+ def keys(self):
237
+ return self.__dict__.keys()
238
+
239
+ def items(self):
240
+ return self.__dict__.items()
241
+
242
+ def values(self):
243
+ return self.__dict__.values()
244
+
245
+ def __len__(self):
246
+ return len(self.__dict__)
247
+
248
+ def __getitem__(self, key):
249
+ return getattr(self, key)
250
+
251
+ def __setitem__(self, key, value):
252
+ return setattr(self, key, value)
253
+
254
+ def __contains__(self, key):
255
+ return key in self.__dict__
256
+
257
+ def __repr__(self):
258
+ return self.__dict__.__repr__()