chenjgtea commited on
Commit
f097912
·
1 Parent(s): 52b0147

gpu模式下新增中英文的normal花处理

Browse files
Chat2TTS/core.py CHANGED
@@ -13,6 +13,7 @@ from .infer.api import refine_text, infer_code
13
  from dataclasses import dataclass
14
  from typing import Literal, Optional, List, Tuple, Dict
15
  from tool.logger import get_logger
 
16
 
17
  from ChatTTS.norm import Normalizer
18
 
@@ -60,10 +61,36 @@ class Chat:
60
  else:
61
  self.logger.log(logging.INFO, f'Load from cache: {download_path}')
62
  self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
 
63
  elif source == 'local':
64
  self.logger.log(logging.INFO, f'Load from local: {local_path}')
65
  self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def _load(
68
  self,
69
  vocos_config_path: str = None,
@@ -156,7 +183,8 @@ class Chat:
156
  refine_text_only=False,
157
  params_refine_text={},
158
  params_infer_code={},
159
- use_decoder=False
 
160
  ):
161
 
162
  assert self.check_model(use_decoder=use_decoder)
@@ -169,7 +197,7 @@ class Chat:
169
  text=t,
170
  do_text_normalization=True,
171
  do_homophone_replacement=True,
172
- lang=None,
173
  )
174
  for t in text
175
  ]
 
13
  from dataclasses import dataclass
14
  from typing import Literal, Optional, List, Tuple, Dict
15
  from tool.logger import get_logger
16
+ from tool.normalizer import normalizer_en_nemo_text,normalizer_cn_tn
17
 
18
  from ChatTTS.norm import Normalizer
19
 
 
61
  else:
62
  self.logger.log(logging.INFO, f'Load from cache: {download_path}')
63
  self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
64
+ self._regist_normalizer()
65
  elif source == 'local':
66
  self.logger.log(logging.INFO, f'Load from local: {local_path}')
67
  self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
68
+
69
+ def _regist_normalizer(self):
70
+
71
+ self.logger.info("==========开始注册 normalizer===========")
72
+
73
+ try:
74
+ self.normalizer.register("en",normalizer_en_nemo_text())
75
+ except ValueError as e:
76
+ self.logger.error('normalizer_en_nemo_text register fail' , e)
77
+ except:
78
+ self.logger.error("Package nemo_text_processing not found!")
79
+ self.logger.error(
80
+ "Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
81
+ )
82
+
83
+ try:
84
+ self.normalizer.register("zh",normalizer_cn_tn())
85
+ except ValueError as e:
86
+ self.logger.error('normalizer_cn_tn register fail' , e)
87
+ except:
88
+ self.logger.error("Package WeTextProcessing not found!")
89
+ self.logger.error(
90
+ "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
91
+ )
92
+
93
+
94
  def _load(
95
  self,
96
  vocos_config_path: str = None,
 
183
  refine_text_only=False,
184
  params_refine_text={},
185
  params_infer_code={},
186
+ use_decoder=False,
187
+ lang=None
188
  ):
189
 
190
  assert self.check_model(use_decoder=use_decoder)
 
197
  text=t,
198
  do_text_normalization=True,
199
  do_homophone_replacement=True,
200
+ lang=lang,
201
  )
202
  for t in text
203
  ]
requirements.txt CHANGED
@@ -20,6 +20,10 @@ vector_quantize_pytorch
20
  # Hugging Face Hub client
21
  huggingface_hub
22
 
 
 
 
 
23
  vocos
24
 
25
  spaces
 
20
  # Hugging Face Hub client
21
  huggingface_hub
22
 
23
+ pynini==2.1.5; sys_platform == 'linux'
24
+ WeTextProcessing; sys_platform == 'linux'
25
+ nemo_text_processing; sys_platform == 'linux'
26
+
27
  vocos
28
 
29
  spaces
tool/normalizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .normali_util import normalizer_en_nemo_text,normalizer_cn_tn
tool/normalizer/normali_util.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ from functools import partial
3
+
4
+
5
+ def normalizer_en_nemo_text() -> Callable[[str], str]:
6
+ from nemo_text_processing.text_normalization.normalize import Normalizer
7
+
8
+ return partial(
9
+ Normalizer(input_case="cased", lang="en").normalize,
10
+ verbose=False,
11
+ punct_post_process=True,
12
+ )
13
+
14
+ def normalizer_cn_tn() -> Callable[[str], str]:
15
+ from tn.chinese.normalizer import Normalizer
16
+
17
+ return Normalizer().normalize
web/app_gpu.py CHANGED
@@ -96,6 +96,13 @@ def main(args):
96
  interactive=True,
97
  )
98
  with gr.Row():
 
 
 
 
 
 
 
99
  voice_selection = gr.Dropdown(
100
  label="Timbre",
101
  choices=voices.keys(),
@@ -110,7 +117,6 @@ def main(args):
110
  minimum=seed_min,
111
  maximum=seed_max,
112
  )
113
- generate_audio_seed = gr.Button("随机生成音色种子", interactive=True)
114
  text_seed_input = gr.Number(
115
  value=42,
116
  label="文本种子",
@@ -118,7 +124,9 @@ def main(args):
118
  minimum=seed_min,
119
  maximum=seed_max,
120
  )
121
- generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
 
 
122
 
123
  # with gr.Row():
124
  # spk_emb_text = gr.Textbox(
@@ -172,7 +180,8 @@ def main(args):
172
  temperature_slider,
173
  top_p_slider,
174
  top_k_slider,
175
- audio_seed_input
 
176
  ],
177
  outputs=[text_output,audio_output])
178
  # 初始化 spk_emb_text 数值
@@ -206,7 +215,8 @@ def general_chat_infer_audio(text,
206
  temperature_slider,
207
  top_p_slider,
208
  top_k_slider,
209
- audio_seed_input):
 
210
 
211
  logger.info("========开始处理TTS模型=====")
212
  #音频参数设置
@@ -229,6 +239,7 @@ def general_chat_infer_audio(text,
229
  skip_refine_text=False,
230
  refine_text_only=True, #仅返回优化后文本内容
231
  params_refine_text=params_refine_text,
 
232
  )
233
 
234
 
 
96
  interactive=True,
97
  )
98
  with gr.Row():
99
+ lang_selection = gr.Dropdown(
100
+ label="语种",
101
+ choices=["zh" , "en"],
102
+ value="zh",
103
+ interactive=True,
104
+ show_label=True
105
+ )
106
  voice_selection = gr.Dropdown(
107
  label="Timbre",
108
  choices=voices.keys(),
 
117
  minimum=seed_min,
118
  maximum=seed_max,
119
  )
 
120
  text_seed_input = gr.Number(
121
  value=42,
122
  label="文本种子",
 
124
  minimum=seed_min,
125
  maximum=seed_max,
126
  )
127
+ with gr.Column():
128
+ generate_audio_seed = gr.Button("随机生成音色种子", interactive=True)
129
+ generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
130
 
131
  # with gr.Row():
132
  # spk_emb_text = gr.Textbox(
 
180
  temperature_slider,
181
  top_p_slider,
182
  top_k_slider,
183
+ audio_seed_input,
184
+ lang_selection
185
  ],
186
  outputs=[text_output,audio_output])
187
  # 初始化 spk_emb_text 数值
 
215
  temperature_slider,
216
  top_p_slider,
217
  top_k_slider,
218
+ audio_seed_input,
219
+ lang):
220
 
221
  logger.info("========开始处理TTS模型=====")
222
  #音频参数设置
 
239
  skip_refine_text=False,
240
  refine_text_only=True, #仅返回优化后文本内容
241
  params_refine_text=params_refine_text,
242
+ lang=lang
243
  )
244
 
245