chenjgtea
commited on
Commit
·
a536c15
1
Parent(s):
0699795
拆分gpu、cpu模式运行模式
Browse files- Chat2TTS/core.py +78 -61
- test/audio_test.py +48 -0
- test/common_test.py +1 -1
- tool/__init__.py +1 -2
- tool/func.py +29 -2
- tool/np.py +19 -2
- tool/pcm.py +0 -21
- web/app_cpu.py +1 -1
- web/app_gpu.py +31 -20
Chat2TTS/core.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import logging
|
4 |
from omegaconf import OmegaConf
|
@@ -11,9 +10,11 @@ from .utils.gpu_utils import select_device
|
|
11 |
from .utils.io_utils import get_latest_modified_file
|
12 |
from .infer.api import refine_text, infer_code
|
13 |
from dataclasses import dataclass
|
14 |
-
from typing import Literal, Optional, List, Tuple, Dict
|
|
|
15 |
from tool.logger import get_logger
|
16 |
-
from tool.normalizer import normalizer_en_nemo_text,normalizer_cn_tn
|
|
|
17 |
|
18 |
from ChatTTS.norm import Normalizer
|
19 |
|
@@ -23,31 +24,31 @@ from huggingface_hub import snapshot_download
|
|
23 |
class Chat:
|
24 |
def __init__(self, ):
|
25 |
self.pretrain_models = {}
|
26 |
-
self.logger = get_logger(__name__,lv=logging.INFO)
|
27 |
self.normalizer = Normalizer(
|
28 |
os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
|
29 |
self.logger,
|
30 |
)
|
31 |
-
|
32 |
-
def check_model(self, level
|
33 |
not_finish = False
|
34 |
check_list = ['vocos', 'gpt', 'tokenizer']
|
35 |
-
|
36 |
if use_decoder:
|
37 |
check_list.append('decoder')
|
38 |
else:
|
39 |
check_list.append('dvae')
|
40 |
-
|
41 |
for module in check_list:
|
42 |
if module not in self.pretrain_models:
|
43 |
self.logger.log(logging.WARNING, f'{module} not initialized.')
|
44 |
not_finish = True
|
45 |
-
|
46 |
if not not_finish:
|
47 |
self.logger.log(level, f'All initialized.')
|
48 |
-
|
49 |
return not not_finish
|
50 |
-
|
51 |
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
|
52 |
if source == 'huggingface':
|
53 |
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
|
@@ -55,25 +56,27 @@ class Chat:
|
|
55 |
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
|
56 |
except:
|
57 |
download_path = None
|
58 |
-
if download_path is None or force_redownload:
|
59 |
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
|
60 |
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
|
61 |
else:
|
62 |
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
|
63 |
-
self._load(**{k: os.path.join(download_path, v) for k, v in
|
|
|
64 |
self._regist_normalizer()
|
65 |
elif source == 'local':
|
66 |
self.logger.log(logging.INFO, f'Load from local: {local_path}')
|
67 |
-
self._load(**{k: os.path.join(local_path, v) for k, v in
|
|
|
68 |
|
69 |
def _regist_normalizer(self):
|
70 |
|
71 |
self.logger.info("==========开始注册 normalizer===========")
|
72 |
|
73 |
try:
|
74 |
-
|
75 |
except ValueError as e:
|
76 |
-
self.logger.error('normalizer_en_nemo_text register fail'
|
77 |
except:
|
78 |
self.logger.error("Package nemo_text_processing not found!")
|
79 |
self.logger.error(
|
@@ -81,40 +84,40 @@ class Chat:
|
|
81 |
)
|
82 |
|
83 |
try:
|
84 |
-
self.normalizer.register("zh",normalizer_cn_tn())
|
85 |
except ValueError as e:
|
86 |
-
self.logger.error('normalizer_cn_tn register fail'
|
87 |
except:
|
88 |
self.logger.error("Package WeTextProcessing not found!")
|
89 |
self.logger.error(
|
90 |
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
|
91 |
)
|
92 |
|
93 |
-
|
94 |
def _load(
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
):
|
107 |
if not device:
|
108 |
device = select_device(4096)
|
109 |
self.logger.log(logging.INFO, f'use {device}')
|
110 |
-
|
|
|
111 |
if vocos_config_path:
|
112 |
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
|
113 |
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
|
114 |
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
115 |
self.pretrain_models['vocos'] = vocos
|
116 |
self.logger.log(logging.INFO, 'vocos loaded.')
|
117 |
-
|
118 |
if dvae_config_path:
|
119 |
cfg = OmegaConf.load(dvae_config_path)
|
120 |
dvae = DVAE(**cfg).to(device).eval()
|
@@ -122,7 +125,7 @@ class Chat:
|
|
122 |
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
|
123 |
self.pretrain_models['dvae'] = dvae
|
124 |
self.logger.log(logging.INFO, 'dvae loaded.')
|
125 |
-
|
126 |
if gpt_config_path:
|
127 |
cfg = OmegaConf.load(gpt_config_path)
|
128 |
gpt = GPT_warpper(**cfg).to(device).eval()
|
@@ -139,7 +142,6 @@ class Chat:
|
|
139 |
spk_stat_path, weights_only=True, mmap=True, map_location='cpu'
|
140 |
).to(device)
|
141 |
|
142 |
-
|
143 |
if decoder_config_path:
|
144 |
cfg = OmegaConf.load(decoder_config_path)
|
145 |
decoder = DVAE(**cfg).to(device).eval()
|
@@ -147,13 +149,13 @@ class Chat:
|
|
147 |
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
|
148 |
self.pretrain_models['decoder'] = decoder
|
149 |
self.logger.log(logging.INFO, 'decoder loaded.')
|
150 |
-
|
151 |
if tokenizer_path:
|
152 |
tokenizer = torch.load(tokenizer_path, map_location='cpu')
|
153 |
tokenizer.padding_side = 'left'
|
154 |
self.pretrain_models['tokenizer'] = tokenizer
|
155 |
self.logger.log(logging.INFO, 'tokenizer loaded.')
|
156 |
-
|
157 |
self.check_model()
|
158 |
|
159 |
@dataclass(repr=False, eq=False)
|
@@ -177,16 +179,19 @@ class Chat:
|
|
177 |
max_new_token: int = 2048
|
178 |
|
179 |
def infer(
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
):
|
189 |
-
|
|
|
|
|
|
|
190 |
assert self.check_model(use_decoder=use_decoder)
|
191 |
|
192 |
if not isinstance(text, list):
|
@@ -203,36 +208,48 @@ class Chat:
|
|
203 |
]
|
204 |
|
205 |
if skip_refine_text:
|
206 |
-
self.logger.info(f"
|
207 |
else:
|
208 |
self.logger.info(f"========针对文本内容做模型优化处理,lang:{lang}======")
|
209 |
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
210 |
-
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in
|
|
|
211 |
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
212 |
if refine_text_only:
|
213 |
return text
|
214 |
-
|
215 |
text = [params_infer_code.get('prompt', '') + i for i in text]
|
216 |
params_infer_code.pop('prompt', '')
|
217 |
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
|
218 |
-
|
219 |
if use_decoder:
|
220 |
-
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
|
221 |
else:
|
222 |
-
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
|
223 |
-
|
224 |
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
225 |
-
|
226 |
return wav
|
227 |
|
|
|
228 |
def emptpy_audio(self):
|
229 |
-
return
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
# def sample_random_speaker(self) -> str:
|
238 |
# return self._encode_spk_emb(self.sample_random_speaker_tensor())
|
@@ -266,4 +283,4 @@ class Chat:
|
|
266 |
.add_(mean)
|
267 |
)
|
268 |
del out, std, mean
|
269 |
-
return spk
|
|
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
from omegaconf import OmegaConf
|
|
|
10 |
from .utils.io_utils import get_latest_modified_file
|
11 |
from .infer.api import refine_text, infer_code
|
12 |
from dataclasses import dataclass
|
13 |
+
from typing import Literal, Optional, List, Tuple, Dict, Union
|
14 |
+
import numpy as np
|
15 |
from tool.logger import get_logger
|
16 |
+
from tool.normalizer import normalizer_en_nemo_text, normalizer_cn_tn
|
17 |
+
from tool.func import encode_prompt
|
18 |
|
19 |
from ChatTTS.norm import Normalizer
|
20 |
|
|
|
24 |
class Chat:
|
25 |
def __init__(self, ):
|
26 |
self.pretrain_models = {}
|
27 |
+
self.logger = get_logger(__name__, lv=logging.INFO)
|
28 |
self.normalizer = Normalizer(
|
29 |
os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
|
30 |
self.logger,
|
31 |
)
|
32 |
+
|
33 |
+
def check_model(self, level=logging.INFO, use_decoder=False):
|
34 |
not_finish = False
|
35 |
check_list = ['vocos', 'gpt', 'tokenizer']
|
36 |
+
|
37 |
if use_decoder:
|
38 |
check_list.append('decoder')
|
39 |
else:
|
40 |
check_list.append('dvae')
|
41 |
+
|
42 |
for module in check_list:
|
43 |
if module not in self.pretrain_models:
|
44 |
self.logger.log(logging.WARNING, f'{module} not initialized.')
|
45 |
not_finish = True
|
46 |
+
|
47 |
if not not_finish:
|
48 |
self.logger.log(level, f'All initialized.')
|
49 |
+
|
50 |
return not not_finish
|
51 |
+
|
52 |
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
|
53 |
if source == 'huggingface':
|
54 |
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
|
|
|
56 |
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
|
57 |
except:
|
58 |
download_path = None
|
59 |
+
if download_path is None or force_redownload:
|
60 |
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
|
61 |
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
|
62 |
else:
|
63 |
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
|
64 |
+
self._load(**{k: os.path.join(download_path, v) for k, v in
|
65 |
+
OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
|
66 |
self._regist_normalizer()
|
67 |
elif source == 'local':
|
68 |
self.logger.log(logging.INFO, f'Load from local: {local_path}')
|
69 |
+
self._load(**{k: os.path.join(local_path, v) for k, v in
|
70 |
+
OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
|
71 |
|
72 |
def _regist_normalizer(self):
|
73 |
|
74 |
self.logger.info("==========开始注册 normalizer===========")
|
75 |
|
76 |
try:
|
77 |
+
self.normalizer.register("en", normalizer_en_nemo_text())
|
78 |
except ValueError as e:
|
79 |
+
self.logger.error('normalizer_en_nemo_text register fail', e)
|
80 |
except:
|
81 |
self.logger.error("Package nemo_text_processing not found!")
|
82 |
self.logger.error(
|
|
|
84 |
)
|
85 |
|
86 |
try:
|
87 |
+
self.normalizer.register("zh", normalizer_cn_tn())
|
88 |
except ValueError as e:
|
89 |
+
self.logger.error('normalizer_cn_tn register fail', e)
|
90 |
except:
|
91 |
self.logger.error("Package WeTextProcessing not found!")
|
92 |
self.logger.error(
|
93 |
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
|
94 |
)
|
95 |
|
|
|
96 |
def _load(
|
97 |
+
self,
|
98 |
+
vocos_config_path: str = None,
|
99 |
+
vocos_ckpt_path: str = None,
|
100 |
+
dvae_config_path: str = None,
|
101 |
+
dvae_ckpt_path: str = None,
|
102 |
+
gpt_config_path: str = None,
|
103 |
+
gpt_ckpt_path: str = None,
|
104 |
+
decoder_config_path: str = None,
|
105 |
+
decoder_ckpt_path: str = None,
|
106 |
+
tokenizer_path: str = None,
|
107 |
+
device: str = None
|
108 |
):
|
109 |
if not device:
|
110 |
device = select_device(4096)
|
111 |
self.logger.log(logging.INFO, f'use {device}')
|
112 |
+
|
113 |
+
self.device = device
|
114 |
if vocos_config_path:
|
115 |
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
|
116 |
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
|
117 |
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
118 |
self.pretrain_models['vocos'] = vocos
|
119 |
self.logger.log(logging.INFO, 'vocos loaded.')
|
120 |
+
|
121 |
if dvae_config_path:
|
122 |
cfg = OmegaConf.load(dvae_config_path)
|
123 |
dvae = DVAE(**cfg).to(device).eval()
|
|
|
125 |
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
|
126 |
self.pretrain_models['dvae'] = dvae
|
127 |
self.logger.log(logging.INFO, 'dvae loaded.')
|
128 |
+
|
129 |
if gpt_config_path:
|
130 |
cfg = OmegaConf.load(gpt_config_path)
|
131 |
gpt = GPT_warpper(**cfg).to(device).eval()
|
|
|
142 |
spk_stat_path, weights_only=True, mmap=True, map_location='cpu'
|
143 |
).to(device)
|
144 |
|
|
|
145 |
if decoder_config_path:
|
146 |
cfg = OmegaConf.load(decoder_config_path)
|
147 |
decoder = DVAE(**cfg).to(device).eval()
|
|
|
149 |
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
|
150 |
self.pretrain_models['decoder'] = decoder
|
151 |
self.logger.log(logging.INFO, 'decoder loaded.')
|
152 |
+
|
153 |
if tokenizer_path:
|
154 |
tokenizer = torch.load(tokenizer_path, map_location='cpu')
|
155 |
tokenizer.padding_side = 'left'
|
156 |
self.pretrain_models['tokenizer'] = tokenizer
|
157 |
self.logger.log(logging.INFO, 'tokenizer loaded.')
|
158 |
+
|
159 |
self.check_model()
|
160 |
|
161 |
@dataclass(repr=False, eq=False)
|
|
|
179 |
max_new_token: int = 2048
|
180 |
|
181 |
def infer(
|
182 |
+
self,
|
183 |
+
text,
|
184 |
+
skip_refine_text=False,
|
185 |
+
refine_text_only=False,
|
186 |
+
params_refine_text={},
|
187 |
+
params_infer_code={},
|
188 |
+
use_decoder=False,
|
189 |
+
lang=None
|
190 |
):
|
191 |
+
|
192 |
+
self.logger.info(
|
193 |
+
f"========开始infer模型,use_decoder:{use_decoder},lang:{lang},"
|
194 |
+
f"mskip_refine_text:{skip_refine_text},refine_text_only:{refine_text_only}======")
|
195 |
assert self.check_model(use_decoder=use_decoder)
|
196 |
|
197 |
if not isinstance(text, list):
|
|
|
208 |
]
|
209 |
|
210 |
if skip_refine_text:
|
211 |
+
self.logger.info(f"========对文本内容不做优化处理,仅做规则处理======")
|
212 |
else:
|
213 |
self.logger.info(f"========针对文本内容做模型优化处理,lang:{lang}======")
|
214 |
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
215 |
+
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in
|
216 |
+
text_tokens]
|
217 |
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
218 |
if refine_text_only:
|
219 |
return text
|
220 |
+
|
221 |
text = [params_infer_code.get('prompt', '') + i for i in text]
|
222 |
params_infer_code.pop('prompt', '')
|
223 |
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
|
224 |
+
|
225 |
if use_decoder:
|
226 |
+
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0, 2, 1)) for i in result['hiddens']]
|
227 |
else:
|
228 |
+
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0, 2, 1)) for i in result['ids']]
|
229 |
+
|
230 |
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
231 |
+
|
232 |
return wav
|
233 |
|
234 |
+
# 返回一个空的wav 音频文件
|
235 |
def emptpy_audio(self):
|
236 |
+
return self.infer(" ",
|
237 |
+
skip_refine_text=True,
|
238 |
+
refine_text_only=False,
|
239 |
+
params_refine_text={},
|
240 |
+
params_infer_code={},
|
241 |
+
use_decoder=False)
|
242 |
+
|
243 |
+
'''
|
244 |
+
将音频张量 做转码处理
|
245 |
+
'''
|
246 |
+
|
247 |
+
@torch.inference_mode()
|
248 |
+
def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:
|
249 |
+
if isinstance(wav, np.ndarray):
|
250 |
+
wav = torch.from_numpy(wav).to(self.device)
|
251 |
+
squeeze = self.pretrain_models['dvae'](wav, "encode").squeeze_(0)
|
252 |
+
return encode_prompt(squeeze)
|
253 |
|
254 |
# def sample_random_speaker(self) -> str:
|
255 |
# return self._encode_spk_emb(self.sample_random_speaker_tensor())
|
|
|
283 |
.add_(mean)
|
284 |
)
|
285 |
del out, std, mean
|
286 |
+
return spk
|
test/audio_test.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
if sys.platform == "darwin":
|
4 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
+
now_dir = os.getcwd()
|
6 |
+
sys.path.append(now_dir)
|
7 |
+
|
8 |
+
import Chat2TTS
|
9 |
+
from tool.av import load_audio
|
10 |
+
from tool.logger import get_logger
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
logger = get_logger("audio_test")
|
15 |
+
# Initialize and load the model:
|
16 |
+
chat = Chat2TTS.Chat()
|
17 |
+
|
18 |
+
def init_chat():
|
19 |
+
global chat
|
20 |
+
source = "local"
|
21 |
+
# 获取启动模式
|
22 |
+
MODEL = os.getenv('MODEL')
|
23 |
+
# huggingface 部署模式下,模型则直接使用hf的模型数据
|
24 |
+
if MODEL == "HF":
|
25 |
+
source = "huggingface"
|
26 |
+
|
27 |
+
logger.info("loading Chat2TTS model..., start source:" + source)
|
28 |
+
|
29 |
+
|
30 |
+
if chat.load_models(source=source, local_path="D:\\chenjgspace\\ai-model\\chattts"):
|
31 |
+
print("Models loaded successfully.")
|
32 |
+
logger.info("Models loaded end.")
|
33 |
+
# else:
|
34 |
+
# logger.error("=========Models load failed.")
|
35 |
+
# sys.exit(1)
|
36 |
+
|
37 |
+
def audo_encode():
|
38 |
+
sample_audio = load_audio("D:\\Download\\audio_test.wav",24000)
|
39 |
+
logger.info("================sample_audio:"+str(sample_audio))
|
40 |
+
spk_smp=chat.sample_audio_speaker(sample_audio)
|
41 |
+
logger.info("================spk_smp:"+str(spk_smp))
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
|
46 |
+
init_chat()
|
47 |
+
# 还需要继续调试
|
48 |
+
audo_encode()
|
test/common_test.py
CHANGED
@@ -8,7 +8,7 @@ from tool.logger import get_logger
|
|
8 |
|
9 |
logger=get_logger("common-test")
|
10 |
def save_mp3_file(wav, index, prefix_name):
|
11 |
-
from tool.
|
12 |
data = pcm_arr_to_mp3_view(wav)
|
13 |
mp3_filename = prefix_name + "_" + str(index) + ".mp3"
|
14 |
with open(mp3_filename, "wb") as f:
|
|
|
8 |
|
9 |
logger=get_logger("common-test")
|
10 |
def save_mp3_file(wav, index, prefix_name):
|
11 |
+
from tool.np import pcm_arr_to_mp3_view
|
12 |
data = pcm_arr_to_mp3_view(wav)
|
13 |
mp3_filename = prefix_name + "_" + str(index) + ".mp3"
|
14 |
with open(mp3_filename, "wb") as f:
|
tool/__init__.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
from .av import load_audio
|
2 |
-
from .
|
3 |
-
from .np import float_to_int16
|
4 |
from .ctx import TorchSeedContext
|
5 |
from .gpu import select_device
|
|
|
1 |
from .av import load_audio
|
2 |
+
from .np import float_to_int16,pcm_arr_to_mp3_view
|
|
|
3 |
from .ctx import TorchSeedContext
|
4 |
from .gpu import select_device
|
tool/func.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
|
2 |
import gradio as gr
|
3 |
import random
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
seed_min = 1
|
6 |
seed_max = 4294967295
|
@@ -30,6 +35,28 @@ voices = {
|
|
30 |
def on_voice_change(vocie_selection):
|
31 |
return voices.get(vocie_selection)["seed"]
|
32 |
|
33 |
-
|
|
|
|
|
34 |
def generate_seed():
|
35 |
-
return gr.update(value=random.randint(seed_min, seed_max))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
import gradio as gr
|
3 |
import random
|
4 |
+
import torch
|
5 |
+
import lzma
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pybase16384 as b14
|
9 |
|
10 |
seed_min = 1
|
11 |
seed_max = 4294967295
|
|
|
35 |
def on_voice_change(vocie_selection):
|
36 |
return voices.get(vocie_selection)["seed"]
|
37 |
|
38 |
+
'''
|
39 |
+
随机生成种子
|
40 |
+
'''
|
41 |
def generate_seed():
|
42 |
+
return gr.update(value=random.randint(seed_min, seed_max))
|
43 |
+
|
44 |
+
'''
|
45 |
+
音频文件张量 编码
|
46 |
+
'''
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def encode_prompt(prompt: torch.Tensor) -> str:
|
50 |
+
arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy()
|
51 |
+
shp = arr.shape
|
52 |
+
assert len(shp) == 2, "prompt must be a 2D tensor"
|
53 |
+
s = b14.encode_to_string(
|
54 |
+
np.array(shp, dtype="<u2").tobytes()
|
55 |
+
+ lzma.compress(
|
56 |
+
arr.astype("<u2").tobytes(),
|
57 |
+
format=lzma.FORMAT_RAW,
|
58 |
+
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
|
59 |
+
),
|
60 |
+
)
|
61 |
+
del arr
|
62 |
+
return s
|
tool/np.py
CHANGED
@@ -1,11 +1,28 @@
|
|
1 |
import math
|
2 |
|
3 |
-
import numpy as np
|
4 |
from numba import jit
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
-
@jit
|
8 |
def float_to_int16(audio: np.ndarray) -> np.ndarray:
|
9 |
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
|
10 |
am = 32767 * 32768 // am
|
11 |
return np.multiply(audio, am).astype(np.int16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import math
|
2 |
|
|
|
3 |
from numba import jit
|
4 |
+
import wave
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from .av import wav2
|
9 |
+
|
10 |
|
11 |
|
|
|
12 |
def float_to_int16(audio: np.ndarray) -> np.ndarray:
|
13 |
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
|
14 |
am = 32767 * 32768 // am
|
15 |
return np.multiply(audio, am).astype(np.int16)
|
16 |
+
|
17 |
+
def pcm_arr_to_mp3_view(wav: np.ndarray):
|
18 |
+
buf = BytesIO()
|
19 |
+
with wave.open(buf, "wb") as wf:
|
20 |
+
wf.setnchannels(1) # Mono channel
|
21 |
+
wf.setsampwidth(2) # Sample width in bytes
|
22 |
+
wf.setframerate(24000) # Sample rate in Hz
|
23 |
+
wf.writeframes(float_to_int16(wav))
|
24 |
+
buf.seek(0, 0)
|
25 |
+
buf2 = BytesIO()
|
26 |
+
wav2(buf, buf2, "mp3")
|
27 |
+
buf.seek(0, 0)
|
28 |
+
return buf2.getbuffer()
|
tool/pcm.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import wave
|
2 |
-
from io import BytesIO
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
from .np import float_to_int16
|
7 |
-
from .av import wav2
|
8 |
-
|
9 |
-
|
10 |
-
def pcm_arr_to_mp3_view(wav: np.ndarray):
|
11 |
-
buf = BytesIO()
|
12 |
-
with wave.open(buf, "wb") as wf:
|
13 |
-
wf.setnchannels(1) # Mono channel
|
14 |
-
wf.setsampwidth(2) # Sample width in bytes
|
15 |
-
wf.setframerate(24000) # Sample rate in Hz
|
16 |
-
wf.writeframes(float_to_int16(wav))
|
17 |
-
buf.seek(0, 0)
|
18 |
-
buf2 = BytesIO()
|
19 |
-
wav2(buf, buf2, "mp3")
|
20 |
-
buf.seek(0, 0)
|
21 |
-
return buf2.getbuffer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
web/app_cpu.py
CHANGED
@@ -45,7 +45,7 @@ def init_chat(args):
|
|
45 |
|
46 |
def main(args):
|
47 |
with gr.Blocks() as demo:
|
48 |
-
gr.Markdown("# ChatTTS demo")
|
49 |
with gr.Row():
|
50 |
with gr.Column(scale=1):
|
51 |
text_input = gr.Textbox(
|
|
|
45 |
|
46 |
def main(args):
|
47 |
with gr.Blocks() as demo:
|
48 |
+
gr.Markdown("# ChatTTS demo CPU模式下运行")
|
49 |
with gr.Row():
|
50 |
with gr.Column(scale=1):
|
51 |
text_input = gr.Textbox(
|
web/app_gpu.py
CHANGED
@@ -48,7 +48,7 @@ def init_chat(args):
|
|
48 |
|
49 |
def main(args):
|
50 |
with gr.Blocks() as demo:
|
51 |
-
gr.Markdown("# ChatTTS demo")
|
52 |
with gr.Row():
|
53 |
with gr.Column(scale=1):
|
54 |
text_input = gr.Textbox(
|
@@ -71,6 +71,12 @@ def main(args):
|
|
71 |
interactive=True,
|
72 |
value=True
|
73 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
temperature_slider = gr.Slider(
|
75 |
minimum=0.00001,
|
76 |
maximum=1.0,
|
@@ -79,22 +85,23 @@ def main(args):
|
|
79 |
interactive=True,
|
80 |
label="模型 Temperature 参数设置"
|
81 |
)
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
98 |
with gr.Row():
|
99 |
lang_selection = gr.Dropdown(
|
100 |
label="语种",
|
@@ -139,7 +146,7 @@ def main(args):
|
|
139 |
# )
|
140 |
|
141 |
with gr.Row():
|
142 |
-
|
143 |
generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
|
144 |
|
145 |
with gr.Row():
|
@@ -177,11 +184,12 @@ def main(args):
|
|
177 |
text_seed_input,
|
178 |
refine_text_checkBox,
|
179 |
refine_audio_checkBox,
|
|
|
180 |
temperature_slider,
|
181 |
top_p_slider,
|
182 |
top_k_slider,
|
183 |
audio_seed_input,
|
184 |
-
|
185 |
],
|
186 |
outputs=[text_output,audio_output])
|
187 |
# 初始化 spk_emb_text 数值
|
@@ -212,6 +220,7 @@ def general_chat_infer_audio(text,
|
|
212 |
text_seed_input,
|
213 |
refine_text_checkBox,
|
214 |
refine_audio_checkBox,
|
|
|
215 |
temperature_slider,
|
216 |
top_p_slider,
|
217 |
top_k_slider,
|
@@ -239,7 +248,8 @@ def general_chat_infer_audio(text,
|
|
239 |
skip_refine_text=False,
|
240 |
refine_text_only=True, #仅返回优化后文本内容
|
241 |
params_refine_text=params_refine_text,
|
242 |
-
lang=lang
|
|
|
243 |
)
|
244 |
|
245 |
|
@@ -265,6 +275,7 @@ def general_chat_infer_audio(text,
|
|
265 |
skip_refine_text=True, #跳过文本优化
|
266 |
params_refine_text=params_refine_text,
|
267 |
params_infer_code=params_infer_code,
|
|
|
268 |
)
|
269 |
|
270 |
#yield 24000, float_to_int16(wav[0]).T
|
|
|
48 |
|
49 |
def main(args):
|
50 |
with gr.Blocks() as demo:
|
51 |
+
gr.Markdown("# ChatTTS demo GPU模式下运行")
|
52 |
with gr.Row():
|
53 |
with gr.Column(scale=1):
|
54 |
text_input = gr.Textbox(
|
|
|
71 |
interactive=True,
|
72 |
value=True
|
73 |
)
|
74 |
+
|
75 |
+
use_decoder_checkBox = gr.Checkbox(
|
76 |
+
label="是否使用decoder模型,如否则使用dvae模型",
|
77 |
+
interactive=True,
|
78 |
+
value=True
|
79 |
+
)
|
80 |
temperature_slider = gr.Slider(
|
81 |
minimum=0.00001,
|
82 |
maximum=1.0,
|
|
|
85 |
interactive=True,
|
86 |
label="模型 Temperature 参数设置"
|
87 |
)
|
88 |
+
with gr.Column():
|
89 |
+
top_p_slider = gr.Slider(
|
90 |
+
minimum=0.1,
|
91 |
+
maximum=0.9,
|
92 |
+
step=0.05,
|
93 |
+
value=0.7,
|
94 |
+
label="模型 top_P 参数设置",
|
95 |
+
interactive=True,
|
96 |
+
)
|
97 |
+
top_k_slider = gr.Slider(
|
98 |
+
minimum=1,
|
99 |
+
maximum=20,
|
100 |
+
step=1,
|
101 |
+
value=20,
|
102 |
+
label="模型 top_K 参数设置",
|
103 |
+
interactive=True,
|
104 |
+
)
|
105 |
with gr.Row():
|
106 |
lang_selection = gr.Dropdown(
|
107 |
label="语种",
|
|
|
146 |
# )
|
147 |
|
148 |
with gr.Row():
|
149 |
+
# reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
|
150 |
generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
|
151 |
|
152 |
with gr.Row():
|
|
|
184 |
text_seed_input,
|
185 |
refine_text_checkBox,
|
186 |
refine_audio_checkBox,
|
187 |
+
use_decoder_checkBox,
|
188 |
temperature_slider,
|
189 |
top_p_slider,
|
190 |
top_k_slider,
|
191 |
audio_seed_input,
|
192 |
+
lang_selection
|
193 |
],
|
194 |
outputs=[text_output,audio_output])
|
195 |
# 初始化 spk_emb_text 数值
|
|
|
220 |
text_seed_input,
|
221 |
refine_text_checkBox,
|
222 |
refine_audio_checkBox,
|
223 |
+
use_decoder_checkBox,
|
224 |
temperature_slider,
|
225 |
top_p_slider,
|
226 |
top_k_slider,
|
|
|
248 |
skip_refine_text=False,
|
249 |
refine_text_only=True, #仅返回优化后文本内容
|
250 |
params_refine_text=params_refine_text,
|
251 |
+
lang=lang,
|
252 |
+
use_decoder=use_decoder_checkBox
|
253 |
)
|
254 |
|
255 |
|
|
|
275 |
skip_refine_text=True, #跳过文本优化
|
276 |
params_refine_text=params_refine_text,
|
277 |
params_infer_code=params_infer_code,
|
278 |
+
use_decoder=use_decoder_checkBox
|
279 |
)
|
280 |
|
281 |
#yield 24000, float_to_int16(wav[0]).T
|