Respair commited on
Commit
394f443
·
verified ·
1 Parent(s): 759eb3f

Update styletts2importable.py

Browse files
Files changed (1) hide show
  1. styletts2importable.py +660 -20
styletts2importable.py CHANGED
@@ -39,7 +39,639 @@ from utils import *
39
  from text_utils import TextCleaner
40
  textclenaer = TextCleaner()
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  to_mel = torchaudio.transforms.MelSpectrogram(
44
  n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
45
  mean, std = -4, 4
@@ -94,7 +726,7 @@ pitch_extractor = load_F0_models(F0_path)
94
 
95
  # load BERT model
96
  from Utils.PLBERT.util import load_plbert
97
- BERT_path = config.get('PLBERT_dir', False)
98
  plbert = load_plbert(BERT_path)
99
 
100
  model_params = recursive_munch(config['model_params'])
@@ -103,7 +735,7 @@ _ = [model[key].eval() for key in model]
103
  _ = [model[key].to(device) for key in model]
104
 
105
  # params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
106
- params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
107
  params = params_whole['net']
108
 
109
  for key in model:
@@ -134,11 +766,14 @@ sampler = DiffusionSampler(
134
  )
135
 
136
  def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
137
- text = text.strip()
138
- ps = global_phonemizer.phonemize([text])
139
- ps = word_tokenize(ps[0])
140
- ps = ' '.join(ps)
141
- tokens = textclenaer(ps)
 
 
 
142
  tokens.insert(0, 0)
143
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
144
 
@@ -203,14 +838,16 @@ def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding
203
  return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
204
 
205
  def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
206
- text = text.strip()
207
- ps = global_phonemizer.phonemize([text])
208
- ps = word_tokenize(ps[0])
209
- ps = ' '.join(ps)
210
- ps = ps.replace('``', '"')
211
- ps = ps.replace("''", '"')
212
-
213
- tokens = textclenaer(ps)
 
 
214
  tokens.insert(0, 0)
215
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
216
 
@@ -280,12 +917,15 @@ def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion
280
  return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
281
 
282
  def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
283
- text = text.strip()
284
- ps = global_phonemizer.phonemize([text])
285
- ps = word_tokenize(ps[0])
286
- ps = ' '.join(ps)
287
 
288
- tokens = textclenaer(ps)
 
 
 
 
 
 
 
289
  tokens.insert(0, 0)
290
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
291
 
 
39
  from text_utils import TextCleaner
40
  textclenaer = TextCleaner()
41
 
42
+ from cached_path import cached_path
43
+
44
+
45
+ import torch
46
+ torch.manual_seed(0)
47
+ torch.backends.cudnn.benchmark = False
48
+ torch.backends.cudnn.deterministic = True
49
+
50
+ import random
51
+ random.seed(0)
52
+
53
+ import numpy as np
54
+ np.random.seed(0)
55
+
56
+ import nltk
57
+ nltk.download('punkt')
58
+
59
+ # load packages
60
+ import time
61
+ import random
62
+ import yaml
63
+ from munch import Munch
64
+ import numpy as np
65
+ import torch
66
+ from torch import nn
67
+ import torch.nn.functional as F
68
+ import torchaudio
69
+ import librosa
70
+ from nltk.tokenize import word_tokenize
71
+
72
+ from models import *
73
+ from utils import *
74
+ from text_utils import TextCleaner
75
+ textclenaer = TextCleaner()
76
+
77
+
78
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
79
 
80
+ to_mel = torchaudio.transforms.MelSpectrogram(
81
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
82
+ mean, std = -4, 4
83
+
84
+ def length_to_mask(lengths):
85
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
86
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
87
+ return mask
88
+
89
+ def preprocess(wave):
90
+ wave_tensor = torch.from_numpy(wave).float()
91
+ mel_tensor = to_mel(wave_tensor)
92
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
93
+ return mel_tensor
94
+
95
+ def compute_style(ref_dicts):
96
+ reference_embeddings = {}
97
+ for key, path in ref_dicts.items():
98
+ wave, sr = librosa.load(path, sr=24000)
99
+ audio, index = librosa.effects.trim(wave, top_db=30)
100
+ if sr != 24000:
101
+ audio = librosa.resample(audio, sr, 24000)
102
+ mel_tensor = preprocess(audio).to(device)
103
+
104
+ with torch.no_grad():
105
+ ref = model.style_encoder(mel_tensor.unsqueeze(1))
106
+ reference_embeddings[key] = (ref.squeeze(1), audio)
107
+
108
+ return reference_embeddings
109
+
110
+ # load phonemizer
111
+ # import phonemizer
112
+ # global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')
113
+
114
+ # phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
115
+ import fugashi
116
+ import pykakasi
117
+ from collections import OrderedDict
118
+
119
+
120
+ # MB-iSTFT-VITS2
121
+
122
+ import re
123
+ from unidecode import unidecode
124
+ import pyopenjtalk
125
+
126
+
127
+ # Regular expression matching Japanese without punctuation marks:
128
+ _japanese_characters = re.compile(
129
+ r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
130
+
131
+ # Regular expression matching non-Japanese characters or punctuation marks:
132
+ _japanese_marks = re.compile(
133
+ r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
134
+
135
+ # List of (symbol, Japanese) pairs for marks:
136
+ _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
137
+ ('%', 'パーセント')
138
+ ]]
139
+
140
+ # List of (romaji, ipa) pairs for marks:
141
+ _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
142
+ ('ts', 'ʦ'),
143
+ ('u', 'ɯ'),
144
+ ('j', 'ʥ'),
145
+ ('y', 'j'),
146
+ ('ni', 'n^i'),
147
+ ('nj', 'n^'),
148
+ ('hi', 'çi'),
149
+ ('hj', 'ç'),
150
+ ('f', 'ɸ'),
151
+ ('I', 'i*'),
152
+ ('U', 'ɯ*'),
153
+ ('r', 'ɾ')
154
+ ]]
155
+
156
+ # List of (romaji, ipa2) pairs for marks:
157
+ _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
158
+ ('u', 'ɯ'),
159
+ ('ʧ', 'tʃ'),
160
+ ('j', 'dʑ'),
161
+ ('y', 'j'),
162
+ ('ni', 'n^i'),
163
+ ('nj', 'n^'),
164
+ ('hi', 'çi'),
165
+ ('hj', 'ç'),
166
+ ('f', 'ɸ'),
167
+ ('I', 'i*'),
168
+ ('U', 'ɯ*'),
169
+ ('r', 'ɾ')
170
+ ]]
171
+
172
+ # List of (consonant, sokuon) pairs:
173
+ _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
174
+ (r'Q([↑↓]*[kg])', r'k#\1'),
175
+ (r'Q([↑↓]*[tdjʧ])', r't#\1'),
176
+ (r'Q([↑↓]*[sʃ])', r's\1'),
177
+ (r'Q([↑↓]*[pb])', r'p#\1')
178
+ ]]
179
+
180
+ # List of (consonant, hatsuon) pairs:
181
+ _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
182
+ (r'N([↑↓]*[pbm])', r'm\1'),
183
+ (r'N([↑↓]*[ʧʥj])', r'n^\1'),
184
+ (r'N([↑↓]*[tdn])', r'n\1'),
185
+ (r'N([↑↓]*[kg])', r'ŋ\1')
186
+ ]]
187
+
188
+
189
+ def symbols_to_japanese(text):
190
+ for regex, replacement in _symbols_to_japanese:
191
+ text = re.sub(regex, replacement, text)
192
+ return text
193
+
194
+
195
+ def japanese_to_romaji_with_accent(text):
196
+ '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
197
+ text = symbols_to_japanese(text)
198
+ sentences = re.split(_japanese_marks, text)
199
+ marks = re.findall(_japanese_marks, text)
200
+ text = ''
201
+ for i, sentence in enumerate(sentences):
202
+ if re.match(_japanese_characters, sentence):
203
+ if text != '':
204
+ text += ' '
205
+ labels = pyopenjtalk.extract_fullcontext(sentence)
206
+ for n, label in enumerate(labels):
207
+ phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
208
+ if phoneme not in ['sil', 'pau']:
209
+ text += phoneme.replace('ch', 'ʧ').replace('sh',
210
+ 'ʃ').replace('cl', 'Q')
211
+ else:
212
+ continue
213
+ # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
214
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
215
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
216
+ a3 = int(re.search(r"\+(\d+)/", label).group(1))
217
+ if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
218
+ a2_next = -1
219
+ else:
220
+ a2_next = int(
221
+ re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
222
+ # Accent phrase boundary
223
+ if a3 == 1 and a2_next == 1:
224
+ text += ' '
225
+ # Falling
226
+ elif a1 == 0 and a2_next == a2 + 1:
227
+ text += '↓'
228
+ # Rising
229
+ elif a2 == 1 and a2_next == 2:
230
+ text += '↑'
231
+ if i < len(marks):
232
+ text += unidecode(marks[i]).replace(' ', '')
233
+ return text
234
+
235
+
236
+ def get_real_sokuon(text):
237
+ for regex, replacement in _real_sokuon:
238
+ text = re.sub(regex, replacement, text)
239
+ return text
240
+
241
+
242
+ def get_real_hatsuon(text):
243
+ for regex, replacement in _real_hatsuon:
244
+ text = re.sub(regex, replacement, text)
245
+ return text
246
+
247
+
248
+ def japanese_to_ipa(text):
249
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
250
+ text = re.sub(
251
+ r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
252
+ text = get_real_sokuon(text)
253
+ text = get_real_hatsuon(text)
254
+ for regex, replacement in _romaji_to_ipa:
255
+ text = re.sub(regex, replacement, text)
256
+ return text
257
+
258
+
259
+ def japanese_to_ipa2(text):
260
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
261
+ text = get_real_sokuon(text)
262
+ text = get_real_hatsuon(text)
263
+ for regex, replacement in _romaji_to_ipa2:
264
+ text = re.sub(regex, replacement, text)
265
+ return text
266
+
267
+
268
+ def japanese_to_ipa3(text):
269
+ text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
270
+ 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
271
+ text = re.sub(
272
+ r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
273
+ text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
274
+ return text
275
+
276
+
277
+ """ from https://github.com/keithito/tacotron """
278
+
279
+ '''
280
+ Cleaners are transformations that run over the input text at both training and eval time.
281
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
282
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
283
+ 1. "english_cleaners" for English text
284
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
285
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
286
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
287
+ the symbols in symbols.py to match your data).
288
+ '''
289
+
290
+
291
+ # Regular expression matching whitespace:
292
+
293
+
294
+ import re
295
+ import inflect
296
+ from unidecode import unidecode
297
+
298
+ _inflect = inflect.engine()
299
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
300
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
301
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
302
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
303
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
304
+ _number_re = re.compile(r'[0-9]+')
305
+
306
+ # List of (regular expression, replacement) pairs for abbreviations:
307
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
308
+ ('mrs', 'misess'),
309
+ ('mr', 'mister'),
310
+ ('dr', 'doctor'),
311
+ ('st', 'saint'),
312
+ ('co', 'company'),
313
+ ('jr', 'junior'),
314
+ ('maj', 'major'),
315
+ ('gen', 'general'),
316
+ ('drs', 'doctors'),
317
+ ('rev', 'reverend'),
318
+ ('lt', 'lieutenant'),
319
+ ('hon', 'honorable'),
320
+ ('sgt', 'sergeant'),
321
+ ('capt', 'captain'),
322
+ ('esq', 'esquire'),
323
+ ('ltd', 'limited'),
324
+ ('col', 'colonel'),
325
+ ('ft', 'fort'),
326
+ ]]
327
+
328
+
329
+ # List of (ipa, lazy ipa) pairs:
330
+ _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
331
+ ('r', 'ɹ'),
332
+ ('æ', 'e'),
333
+ ('ɑ', 'a'),
334
+ ('ɔ', 'o'),
335
+ ('ð', 'z'),
336
+ ('θ', 's'),
337
+ ('ɛ', 'e'),
338
+ ('ɪ', 'i'),
339
+ ('ʊ', 'u'),
340
+ ('ʒ', 'ʥ'),
341
+ ('ʤ', 'ʥ'),
342
+ ('', '↓'),
343
+ ]]
344
+
345
+ # List of (ipa, lazy ipa2) pairs:
346
+ _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
347
+ ('r', 'ɹ'),
348
+ ('ð', 'z'),
349
+ ('θ', 's'),
350
+ ('ʒ', 'ʑ'),
351
+ ('ʤ', 'dʑ'),
352
+ ('', '↓'),
353
+ ]]
354
+
355
+ # List of (ipa, ipa2) pairs
356
+ _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
357
+ ('r', 'ɹ'),
358
+ ('ʤ', 'dʒ'),
359
+ ('ʧ', 'tʃ')
360
+ ]]
361
+
362
+
363
+ def expand_abbreviations(text):
364
+ for regex, replacement in _abbreviations:
365
+ text = re.sub(regex, replacement, text)
366
+ return text
367
+
368
+
369
+ def collapse_whitespace(text):
370
+ return re.sub(r'\s+', ' ', text)
371
+
372
+
373
+ def _remove_commas(m):
374
+ return m.group(1).replace(',', '')
375
+
376
+
377
+ def _expand_decimal_point(m):
378
+ return m.group(1).replace('.', ' point ')
379
+
380
+
381
+ def _expand_dollars(m):
382
+ match = m.group(1)
383
+ parts = match.split('.')
384
+ if len(parts) > 2:
385
+ return match + ' dollars' # Unexpected format
386
+ dollars = int(parts[0]) if parts[0] else 0
387
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
388
+ if dollars and cents:
389
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
390
+ cent_unit = 'cent' if cents == 1 else 'cents'
391
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
392
+ elif dollars:
393
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
394
+ return '%s %s' % (dollars, dollar_unit)
395
+ elif cents:
396
+ cent_unit = 'cent' if cents == 1 else 'cents'
397
+ return '%s %s' % (cents, cent_unit)
398
+ else:
399
+ return 'zero dollars'
400
+
401
+
402
+ def _expand_ordinal(m):
403
+ return _inflect.number_to_words(m.group(0))
404
+
405
+
406
+ def _expand_number(m):
407
+ num = int(m.group(0))
408
+ if num > 1000 and num < 3000:
409
+ if num == 2000:
410
+ return 'two thousand'
411
+ elif num > 2000 and num < 2010:
412
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
413
+ elif num % 100 == 0:
414
+ return _inflect.number_to_words(num // 100) + ' hundred'
415
+ else:
416
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
417
+ else:
418
+ return _inflect.number_to_words(num, andword='')
419
+
420
+
421
+ def normalize_numbers(text):
422
+ text = re.sub(_comma_number_re, _remove_commas, text)
423
+ text = re.sub(_pounds_re, r'\1 pounds', text)
424
+ text = re.sub(_dollars_re, _expand_dollars, text)
425
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
426
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
427
+ text = re.sub(_number_re, _expand_number, text)
428
+ return text
429
+
430
+
431
+ def mark_dark_l(text):
432
+ return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
433
+
434
+
435
+ import re
436
+ #from text.thai import num_to_thai, latin_to_thai
437
+ #from text.shanghainese import shanghainese_to_ipa
438
+ #from text.cantonese import cantonese_to_ipa
439
+ #from text.ngu_dialect import ngu_dialect_to_ipa
440
+ from unidecode import unidecode
441
+
442
+
443
+ _whitespace_re = re.compile(r'\s+')
444
+
445
+ # Regular expression matching Japanese without punctuation marks:
446
+ _japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
447
+
448
+ # Regular expression matching non-Japanese characters or punctuation marks:
449
+ _japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
450
+
451
+ # List of (regular expression, replacement) pairs for abbreviations:
452
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
453
+ ('mrs', 'misess'),
454
+ ('mr', 'mister'),
455
+ ('dr', 'doctor'),
456
+ ('st', 'saint'),
457
+ ('co', 'company'),
458
+ ('jr', 'junior'),
459
+ ('maj', 'major'),
460
+ ('gen', 'general'),
461
+ ('drs', 'doctors'),
462
+ ('rev', 'reverend'),
463
+ ('lt', 'lieutenant'),
464
+ ('hon', 'honorable'),
465
+ ('sgt', 'sergeant'),
466
+ ('capt', 'captain'),
467
+ ('esq', 'esquire'),
468
+ ('ltd', 'limited'),
469
+ ('col', 'colonel'),
470
+ ('ft', 'fort'),
471
+ ]]
472
+
473
+
474
+ def expand_abbreviations(text):
475
+ for regex, replacement in _abbreviations:
476
+ text = re.sub(regex, replacement, text)
477
+ return text
478
+
479
+ def collapse_whitespace(text):
480
+ return re.sub(_whitespace_re, ' ', text)
481
+
482
+
483
+ def convert_to_ascii(text):
484
+ return unidecode(text)
485
+
486
+
487
+ def basic_cleaners(text):
488
+ # - For replication of https://github.com/FENRlR/MB-iSTFT-VITS2/issues/2
489
+ # you may need to replace the symbol to Russian one
490
+ '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
491
+ text = text.lower()
492
+ text = collapse_whitespace(text)
493
+ return text
494
+
495
+ '''
496
+ def fix_g2pk2_error(text):
497
+ new_text = ""
498
+ i = 0
499
+ while i < len(text) - 4:
500
+ if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == 'ㄹ':
501
+ new_text += text[i:i+3] + ' ' + 'ㄴ'
502
+ i += 5
503
+ else:
504
+ new_text += text[i]
505
+ i += 1
506
+ new_text += text[i:]
507
+ return new_text
508
+ '''
509
+
510
+
511
+
512
+ def japanese_cleaners(text):
513
+ text = japanese_to_romaji_with_accent(text)
514
+ text = re.sub(r'([A-Za-z])$', r'\1.', text)
515
+ return text
516
+
517
+
518
+ def japanese_cleaners2(text):
519
+ return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
520
+
521
+ def japanese_cleaners3(text):
522
+ text = japanese_to_ipa3(text)
523
+ if "<<" in text or ">>" in text or "¡" in text or "¿" in text:
524
+ text = text.replace("<<","«")
525
+ text = text.replace(">>","»")
526
+ text = text.replace("!","¡")
527
+ text = text.replace("?","¿")
528
+
529
+ if'"'in text:
530
+ text = text.replace('"','”')
531
+
532
+ if'--'in text:
533
+ text = text.replace('--','—')
534
+ if ' ' in text:
535
+ text = text.replace(' ','')
536
+ return text
537
+
538
+
539
+
540
+ # ------------------------------
541
+ ''' cjke type cleaners below '''
542
+ #- text for these cleaners must be labeled first
543
+ # ex1 (single) : some.wav|[EN]put some text here[EN]
544
+ # ex2 (multi) : some.wav|0|[EN]put some text here[EN]
545
+ # ------------------------------
546
+
547
+
548
+ def kej_cleaners(text):
549
+ text = re.sub(r'\[KO\](.*?)\[KO\]',
550
+ lambda x: korean_to_ipa(x.group(1))+' ', text)
551
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
552
+ lambda x: english_to_ipa2(x.group(1)) + ' ', text)
553
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
554
+ lambda x: japanese_to_ipa2(x.group(1)) + ' ', text)
555
+ text = re.sub(r'\s+$', '', text)
556
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
557
+ return text
558
+
559
+
560
+ def cjks_cleaners(text):
561
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
562
+ lambda x: japanese_to_ipa(x.group(1))+' ', text)
563
+ #text = re.sub(r'\[SA\](.*?)\[SA\]',
564
+ # lambda x: devanagari_to_ipa(x.group(1))+' ', text)
565
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
566
+ lambda x: english_to_lazy_ipa(x.group(1))+' ', text)
567
+ text = re.sub(r'\s+$', '', text)
568
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
569
+ return text
570
+
571
+ '''
572
+ #- reserves
573
+ def thai_cleaners(text):
574
+ text = num_to_thai(text)
575
+ text = latin_to_thai(text)
576
+ return text
577
+ def shanghainese_cleaners(text):
578
+ text = shanghainese_to_ipa(text)
579
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
580
+ return text
581
+ def chinese_dialect_cleaners(text):
582
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]',
583
+ lambda x: chinese_to_ipa2(x.group(1))+' ', text)
584
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
585
+ lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text)
586
+ text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5',
587
+ '˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text)
588
+ text = re.sub(r'\[GD\](.*?)\[GD\]',
589
+ lambda x: cantonese_to_ipa(x.group(1))+' ', text)
590
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
591
+ lambda x: english_to_lazy_ipa2(x.group(1))+' ', text)
592
+ text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group(
593
+ 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text)
594
+ text = re.sub(r'\s+$', '', text)
595
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
596
+ return text
597
+ '''
598
+ def japanese_cleaners3(text):
599
+
600
+ global orig
601
+
602
+ orig = text # saving the original unmodifed text for future use
603
+
604
+ text = japanese_to_ipa2(text)
605
+
606
+ if '' in text:
607
+ text = text.replace('','')
608
+ if "<<" in text or ">>" in text or "¡" in text or "¿" in text:
609
+ text = text.replace("<<","«")
610
+ text = text.replace(">>","»")
611
+ text = text.replace("!","¡")
612
+ text = text.replace("?","¿")
613
+
614
+ if'"'in text:
615
+ text = text.replace('"','”')
616
+
617
+ if'--'in text:
618
+ text = text.replace('--','—')
619
+
620
+ text = text.replace("#","ʔ")
621
+ text = text.replace("^","")
622
+
623
+ text = text.replace("kj","kʲ")
624
+ text = text.replace("kj","kʲ")
625
+ text = text.replace("ɾj","ɾʲ")
626
+
627
+ text = text.replace("mj","mʲ")
628
+ text = text.replace("ʃ","ɕ")
629
+ text = text.replace("*","")
630
+ text = text.replace("bj","bʲ")
631
+ text = text.replace("h","ç")
632
+ text = text.replace("gj","gʲ")
633
+
634
+
635
+ return text
636
+
637
+ def japanese_cleaners4(text):
638
+
639
+ text = japanese_cleaners3(text)
640
+
641
+ if "にゃ" in orig:
642
+ text = text.replace("na","nʲa")
643
+
644
+ elif "にゅ" in orig:
645
+ text = text.replace("n","nʲ")
646
+
647
+ elif "にょ" in orig:
648
+ text = text.replace("n","nʲ")
649
+ elif "にぃ" in orig:
650
+ text = text.replace("ni i","niː")
651
+
652
+ elif "いゃ" in orig:
653
+ text = text.replace("i↑ja","ja")
654
+
655
+ elif "いゃ" in orig:
656
+ text = text.replace("i↑ja","ja")
657
+
658
+ elif "ひょ" in orig:
659
+ text = text.replace("ço","çʲo")
660
+
661
+ elif "しょ" in orig:
662
+ text = text.replace("ɕo","ɕʲo")
663
+
664
+
665
+ text = text.replace("Q","ʔ")
666
+ text = text.replace("N","ɴ")
667
+
668
+ text = re.sub(r'.ʔ', 'ʔ', text)
669
+ text = text.replace('" ', '"')
670
+ text = text.replace('” ', '”')
671
+
672
+ return text
673
+
674
+
675
  to_mel = torchaudio.transforms.MelSpectrogram(
676
  n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
677
  mean, std = -4, 4
 
726
 
727
  # load BERT model
728
  from Utils.PLBERT.util import load_plbert
729
+ BERT_path = "Utils/PLBERT/step_1040000.t7"
730
  plbert = load_plbert(BERT_path)
731
 
732
  model_params = recursive_munch(config['model_params'])
 
735
  _ = [model[key].to(device) for key in model]
736
 
737
  # params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
738
+ params_whole = torch.load("Models/Kaede.pth", map_location='cpu')
739
  params = params_whole['net']
740
 
741
  for key in model:
 
766
  )
767
 
768
  def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
769
+ # text = text.strip()
770
+ # ps = global_phonemizer.phonemize([text])
771
+ # ps = word_tokenize(ps[0])
772
+ # ps = ' '.join(ps)
773
+
774
+ text = japanese_cleaners4(text)
775
+ print(text)
776
+ tokens = textclenaer(text)
777
  tokens.insert(0, 0)
778
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
779
 
 
838
  return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
839
 
840
  def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
841
+ # text = text.strip()
842
+ # ps = global_phonemizer.phonemize([text])
843
+ # ps = word_tokenize(ps[0])
844
+ # ps = ' '.join(ps)
845
+ # ps = ps.replace('``', '"')
846
+ # ps = ps.replace("''", '"')
847
+
848
+ text = japanese_cleaners4(text)
849
+ print(text)
850
+ tokens = textclenaer(text)
851
  tokens.insert(0, 0)
852
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
853
 
 
917
  return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
918
 
919
  def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
 
 
 
 
920
 
921
+ print("don't use")
922
+
923
+ # text = text.strip()
924
+ # ps = global_phonemizer.phonemize([text])
925
+ # ps = word_tokenize(ps[0])
926
+ # ps = ' '.join(ps)
927
+ text = japanese_cleaners4(text)
928
+ tokens = textclenaer(text)
929
  tokens.insert(0, 0)
930
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
931