chenjgtea
commited on
Commit
·
0bc3ceb
1
Parent(s):
c23cee5
新增gpu模式下chattts代码
Browse files- Chat2TTS/__init__.py +1 -0
- Chat2TTS/core.py +171 -0
- Chat2TTS/experimental/llm.py +40 -0
- Chat2TTS/infer/api.py +126 -0
- Chat2TTS/model/dvae.py +155 -0
- Chat2TTS/model/gpt.py +265 -0
- Chat2TTS/utils/gpu_utils.py +23 -0
- Chat2TTS/utils/infer_utils.py +45 -0
- Chat2TTS/utils/io_utils.py +14 -0
- tool/gpu.py +4 -5
- web/app.py +5 -5
- web/app_sc.py +267 -0
Chat2TTS/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core import Chat
|
Chat2TTS/core.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from vocos import Vocos
|
8 |
+
from .model.dvae import DVAE
|
9 |
+
from .model.gpt import GPT_warpper
|
10 |
+
from .utils.gpu_utils import select_device
|
11 |
+
from .utils.io_utils import get_latest_modified_file
|
12 |
+
from .infer.api import refine_text, infer_code
|
13 |
+
from dataclasses import dataclass
|
14 |
+
from typing import Literal, Optional, List, Tuple, Dict
|
15 |
+
|
16 |
+
from huggingface_hub import snapshot_download
|
17 |
+
|
18 |
+
logging.basicConfig(level = logging.INFO)
|
19 |
+
|
20 |
+
|
21 |
+
class Chat:
|
22 |
+
def __init__(self, ):
|
23 |
+
self.pretrain_models = {}
|
24 |
+
self.logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def check_model(self, level = logging.INFO, use_decoder = False):
|
27 |
+
not_finish = False
|
28 |
+
check_list = ['vocos', 'gpt', 'tokenizer']
|
29 |
+
|
30 |
+
if use_decoder:
|
31 |
+
check_list.append('decoder')
|
32 |
+
else:
|
33 |
+
check_list.append('dvae')
|
34 |
+
|
35 |
+
for module in check_list:
|
36 |
+
if module not in self.pretrain_models:
|
37 |
+
self.logger.log(logging.WARNING, f'{module} not initialized.')
|
38 |
+
not_finish = True
|
39 |
+
|
40 |
+
if not not_finish:
|
41 |
+
self.logger.log(level, f'All initialized.')
|
42 |
+
|
43 |
+
return not not_finish
|
44 |
+
|
45 |
+
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
|
46 |
+
if source == 'huggingface':
|
47 |
+
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
|
48 |
+
try:
|
49 |
+
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
|
50 |
+
except:
|
51 |
+
download_path = None
|
52 |
+
if download_path is None or force_redownload:
|
53 |
+
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
|
54 |
+
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
|
55 |
+
else:
|
56 |
+
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
|
57 |
+
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
|
58 |
+
elif source == 'local':
|
59 |
+
self.logger.log(logging.INFO, f'Load from local: {local_path}')
|
60 |
+
self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
|
61 |
+
|
62 |
+
def _load(
|
63 |
+
self,
|
64 |
+
vocos_config_path: str = None,
|
65 |
+
vocos_ckpt_path: str = None,
|
66 |
+
dvae_config_path: str = None,
|
67 |
+
dvae_ckpt_path: str = None,
|
68 |
+
gpt_config_path: str = None,
|
69 |
+
gpt_ckpt_path: str = None,
|
70 |
+
decoder_config_path: str = None,
|
71 |
+
decoder_ckpt_path: str = None,
|
72 |
+
tokenizer_path: str = None,
|
73 |
+
device: str = None
|
74 |
+
):
|
75 |
+
if not device:
|
76 |
+
device = select_device(4096)
|
77 |
+
self.logger.log(logging.INFO, f'use {device}')
|
78 |
+
|
79 |
+
if vocos_config_path:
|
80 |
+
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
|
81 |
+
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
|
82 |
+
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
83 |
+
self.pretrain_models['vocos'] = vocos
|
84 |
+
self.logger.log(logging.INFO, 'vocos loaded.')
|
85 |
+
|
86 |
+
if dvae_config_path:
|
87 |
+
cfg = OmegaConf.load(dvae_config_path)
|
88 |
+
dvae = DVAE(**cfg).to(device).eval()
|
89 |
+
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
|
90 |
+
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
|
91 |
+
self.pretrain_models['dvae'] = dvae
|
92 |
+
self.logger.log(logging.INFO, 'dvae loaded.')
|
93 |
+
|
94 |
+
if gpt_config_path:
|
95 |
+
cfg = OmegaConf.load(gpt_config_path)
|
96 |
+
gpt = GPT_warpper(**cfg).to(device).eval()
|
97 |
+
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
|
98 |
+
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
|
99 |
+
self.pretrain_models['gpt'] = gpt
|
100 |
+
self.logger.log(logging.INFO, 'gpt loaded.')
|
101 |
+
|
102 |
+
if decoder_config_path:
|
103 |
+
cfg = OmegaConf.load(decoder_config_path)
|
104 |
+
decoder = DVAE(**cfg).to(device).eval()
|
105 |
+
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
|
106 |
+
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
|
107 |
+
self.pretrain_models['decoder'] = decoder
|
108 |
+
self.logger.log(logging.INFO, 'decoder loaded.')
|
109 |
+
|
110 |
+
if tokenizer_path:
|
111 |
+
tokenizer = torch.load(tokenizer_path, map_location='cpu')
|
112 |
+
tokenizer.padding_side = 'left'
|
113 |
+
self.pretrain_models['tokenizer'] = tokenizer
|
114 |
+
self.logger.log(logging.INFO, 'tokenizer loaded.')
|
115 |
+
|
116 |
+
self.check_model()
|
117 |
+
|
118 |
+
@dataclass(repr=False, eq=False)
|
119 |
+
class RefineTextParams:
|
120 |
+
prompt: str = ""
|
121 |
+
top_P: float = 0.7
|
122 |
+
top_K: int = 20
|
123 |
+
temperature: float = 0.7
|
124 |
+
repetition_penalty: float = 1.0
|
125 |
+
max_new_token: int = 384
|
126 |
+
min_new_token: int = 0
|
127 |
+
show_tqdm: bool = True
|
128 |
+
ensure_non_empty: bool = True
|
129 |
+
|
130 |
+
@dataclass(repr=False, eq=False)
|
131 |
+
class InferCodeParams(RefineTextParams):
|
132 |
+
prompt: str = "[speed_5]"
|
133 |
+
spk_emb: Optional[str] = None
|
134 |
+
temperature: float = 0.3
|
135 |
+
repetition_penalty: float = 1.05
|
136 |
+
max_new_token: int = 2048
|
137 |
+
|
138 |
+
def infer(
|
139 |
+
self,
|
140 |
+
text,
|
141 |
+
skip_refine_text=False,
|
142 |
+
refine_text_only=False,
|
143 |
+
params_refine_text={},
|
144 |
+
params_infer_code={},
|
145 |
+
use_decoder=False
|
146 |
+
):
|
147 |
+
|
148 |
+
assert self.check_model(use_decoder=use_decoder)
|
149 |
+
|
150 |
+
if not skip_refine_text:
|
151 |
+
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
152 |
+
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
|
153 |
+
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
154 |
+
if refine_text_only:
|
155 |
+
return text
|
156 |
+
|
157 |
+
text = [params_infer_code.get('prompt', '') + i for i in text]
|
158 |
+
params_infer_code.pop('prompt', '')
|
159 |
+
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
|
160 |
+
|
161 |
+
if use_decoder:
|
162 |
+
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
|
163 |
+
else:
|
164 |
+
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
|
165 |
+
|
166 |
+
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
167 |
+
|
168 |
+
return wav
|
169 |
+
|
170 |
+
|
171 |
+
|
Chat2TTS/experimental/llm.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from openai import OpenAI
|
3 |
+
|
4 |
+
prompt_dict = {
|
5 |
+
'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
|
6 |
+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
|
7 |
+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
|
8 |
+
'deepseek': [
|
9 |
+
{"role": "system", "content": "You are a helpful assistant"},
|
10 |
+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
|
11 |
+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
|
12 |
+
'deepseek_TN': [
|
13 |
+
{"role": "system", "content": "You are a helpful assistant"},
|
14 |
+
{"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
|
15 |
+
{"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
|
16 |
+
{"role": "user", "content": "We paid $123 for this desk."},
|
17 |
+
{"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
|
18 |
+
{"role": "user", "content": "详询请拨打010-724654"},
|
19 |
+
{"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
|
20 |
+
{"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
|
21 |
+
{"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
|
22 |
+
],
|
23 |
+
}
|
24 |
+
|
25 |
+
class llm_api:
|
26 |
+
def __init__(self, api_key, base_url, model):
|
27 |
+
self.client = OpenAI(
|
28 |
+
api_key = api_key,
|
29 |
+
base_url = base_url,
|
30 |
+
)
|
31 |
+
self.model = model
|
32 |
+
def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
|
33 |
+
|
34 |
+
completion = self.client.chat.completions.create(
|
35 |
+
model = self.model,
|
36 |
+
messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
|
37 |
+
temperature = temperature,
|
38 |
+
**kwargs
|
39 |
+
)
|
40 |
+
return completion.choices[0].message.content
|
Chat2TTS/infer/api.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
5 |
+
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
|
6 |
+
|
7 |
+
def infer_code(
|
8 |
+
models,
|
9 |
+
text,
|
10 |
+
spk_emb = None,
|
11 |
+
top_P = 0.7,
|
12 |
+
top_K = 20,
|
13 |
+
temperature = 0.3,
|
14 |
+
repetition_penalty = 1.05,
|
15 |
+
max_new_token = 2048,
|
16 |
+
**kwargs
|
17 |
+
):
|
18 |
+
|
19 |
+
device = next(models['gpt'].parameters()).device
|
20 |
+
|
21 |
+
if not isinstance(text, list):
|
22 |
+
text = [text]
|
23 |
+
|
24 |
+
if not isinstance(temperature, list):
|
25 |
+
temperature = [temperature] * models['gpt'].num_vq
|
26 |
+
|
27 |
+
if spk_emb is not None:
|
28 |
+
text = [f'[Stts][spk_emb]{i}[uv_break][Ptts]' for i in text]
|
29 |
+
else:
|
30 |
+
text = [f'[Stts][empty_spk]{i}[uv_break][Ptts]' for i in text]
|
31 |
+
|
32 |
+
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
|
33 |
+
input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
|
34 |
+
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
|
35 |
+
|
36 |
+
inputs = {
|
37 |
+
'input_ids': input_ids,
|
38 |
+
'text_mask': text_mask,
|
39 |
+
'attention_mask': text_token['attention_mask'],
|
40 |
+
}
|
41 |
+
|
42 |
+
emb = models['gpt'].get_emb(**inputs)
|
43 |
+
if spk_emb is not None:
|
44 |
+
emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
|
45 |
+
F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
|
46 |
+
|
47 |
+
num_code = models['gpt'].emb_code[0].num_embeddings - 1
|
48 |
+
|
49 |
+
LogitsWarpers = []
|
50 |
+
if top_P is not None:
|
51 |
+
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
52 |
+
if top_K is not None:
|
53 |
+
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
54 |
+
|
55 |
+
LogitsProcessors = []
|
56 |
+
if repetition_penalty is not None and repetition_penalty != 1:
|
57 |
+
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
|
58 |
+
repetition_penalty, num_code, 16))
|
59 |
+
|
60 |
+
result = models['gpt'].generate(
|
61 |
+
emb, inputs['input_ids'],
|
62 |
+
temperature = torch.tensor(temperature, device=device),
|
63 |
+
attention_mask = inputs['attention_mask'],
|
64 |
+
LogitsWarpers = LogitsWarpers,
|
65 |
+
LogitsProcessors = LogitsProcessors,
|
66 |
+
eos_token = num_code,
|
67 |
+
max_new_token = max_new_token,
|
68 |
+
infer_text = False,
|
69 |
+
**kwargs
|
70 |
+
)
|
71 |
+
|
72 |
+
return result
|
73 |
+
|
74 |
+
|
75 |
+
def refine_text(
|
76 |
+
models,
|
77 |
+
text,
|
78 |
+
top_P = 0.7,
|
79 |
+
top_K = 20,
|
80 |
+
temperature = 0.7,
|
81 |
+
repetition_penalty = 1.0,
|
82 |
+
max_new_token = 384,
|
83 |
+
prompt = '',
|
84 |
+
**kwargs
|
85 |
+
):
|
86 |
+
|
87 |
+
device = next(models['gpt'].parameters()).device
|
88 |
+
|
89 |
+
if not isinstance(text, list):
|
90 |
+
text = [text]
|
91 |
+
|
92 |
+
assert len(text), 'text should not be empty'
|
93 |
+
|
94 |
+
text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
|
95 |
+
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
|
96 |
+
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
|
97 |
+
|
98 |
+
inputs = {
|
99 |
+
'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
|
100 |
+
'text_mask': text_mask,
|
101 |
+
'attention_mask': text_token['attention_mask'],
|
102 |
+
}
|
103 |
+
|
104 |
+
LogitsWarpers = []
|
105 |
+
if top_P is not None:
|
106 |
+
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
107 |
+
if top_K is not None:
|
108 |
+
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
109 |
+
|
110 |
+
LogitsProcessors = []
|
111 |
+
if repetition_penalty is not None and repetition_penalty != 1:
|
112 |
+
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
|
113 |
+
|
114 |
+
result = models['gpt'].generate(
|
115 |
+
models['gpt'].get_emb(**inputs),
|
116 |
+
inputs['input_ids'],
|
117 |
+
temperature = torch.tensor([temperature,], device=device),
|
118 |
+
attention_mask = inputs['attention_mask'],
|
119 |
+
LogitsWarpers = LogitsWarpers,
|
120 |
+
LogitsProcessors = LogitsProcessors,
|
121 |
+
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
|
122 |
+
max_new_token = max_new_token,
|
123 |
+
infer_text = True,
|
124 |
+
**kwargs
|
125 |
+
)
|
126 |
+
return result
|
Chat2TTS/model/dvae.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from einops import rearrange
|
3 |
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
class ConvNeXtBlock(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
dim: int,
|
13 |
+
intermediate_dim: int,
|
14 |
+
kernel, dilation,
|
15 |
+
layer_scale_init_value: float = 1e-6,
|
16 |
+
):
|
17 |
+
# ConvNeXt Block copied from Vocos.
|
18 |
+
super().__init__()
|
19 |
+
self.dwconv = nn.Conv1d(dim, dim,
|
20 |
+
kernel_size=kernel, padding=dilation*(kernel//2),
|
21 |
+
dilation=dilation, groups=dim
|
22 |
+
) # depthwise conv
|
23 |
+
|
24 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
25 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
26 |
+
self.act = nn.GELU()
|
27 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
28 |
+
self.gamma = (
|
29 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
30 |
+
if layer_scale_init_value > 0
|
31 |
+
else None
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
|
35 |
+
residual = x
|
36 |
+
x = self.dwconv(x)
|
37 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
38 |
+
x = self.norm(x)
|
39 |
+
x = self.pwconv1(x)
|
40 |
+
x = self.act(x)
|
41 |
+
x = self.pwconv2(x)
|
42 |
+
if self.gamma is not None:
|
43 |
+
x = self.gamma * x
|
44 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
45 |
+
|
46 |
+
x = residual + x
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
class GFSQ(nn.Module):
|
52 |
+
|
53 |
+
def __init__(self,
|
54 |
+
dim, levels, G, R, eps=1e-5, transpose = True
|
55 |
+
):
|
56 |
+
super(GFSQ, self).__init__()
|
57 |
+
self.quantizer = GroupedResidualFSQ(
|
58 |
+
dim=dim,
|
59 |
+
levels=levels,
|
60 |
+
num_quantizers=R,
|
61 |
+
groups=G,
|
62 |
+
)
|
63 |
+
self.n_ind = math.prod(levels)
|
64 |
+
self.eps = eps
|
65 |
+
self.transpose = transpose
|
66 |
+
self.G = G
|
67 |
+
self.R = R
|
68 |
+
|
69 |
+
def _embed(self, x):
|
70 |
+
if self.transpose:
|
71 |
+
x = x.transpose(1,2)
|
72 |
+
x = rearrange(
|
73 |
+
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
|
74 |
+
)
|
75 |
+
feat = self.quantizer.get_output_from_indices(x)
|
76 |
+
return feat.transpose(1,2) if self.transpose else feat
|
77 |
+
|
78 |
+
def forward(self, x,):
|
79 |
+
if self.transpose:
|
80 |
+
x = x.transpose(1,2)
|
81 |
+
feat, ind = self.quantizer(x)
|
82 |
+
ind = rearrange(
|
83 |
+
ind, "g b t r ->b t (g r)",
|
84 |
+
)
|
85 |
+
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
|
86 |
+
e_mean = torch.mean(embed_onehot, dim=[0,1])
|
87 |
+
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
88 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
89 |
+
|
90 |
+
return (
|
91 |
+
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
|
92 |
+
feat.transpose(1,2) if self.transpose else feat,
|
93 |
+
perplexity,
|
94 |
+
None,
|
95 |
+
ind.transpose(1,2) if self.transpose else ind,
|
96 |
+
)
|
97 |
+
|
98 |
+
class DVAEDecoder(nn.Module):
|
99 |
+
def __init__(self, idim, odim,
|
100 |
+
n_layer = 12, bn_dim = 64, hidden = 256,
|
101 |
+
kernel = 7, dilation = 2, up = False
|
102 |
+
):
|
103 |
+
super().__init__()
|
104 |
+
self.up = up
|
105 |
+
self.conv_in = nn.Sequential(
|
106 |
+
nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
|
107 |
+
nn.Conv1d(bn_dim, hidden, 3, 1, 1)
|
108 |
+
)
|
109 |
+
self.decoder_block = nn.ModuleList([
|
110 |
+
ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
|
111 |
+
for _ in range(n_layer)])
|
112 |
+
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
113 |
+
|
114 |
+
def forward(self, input, conditioning=None):
|
115 |
+
# B, T, C
|
116 |
+
x = input.transpose(1, 2)
|
117 |
+
x = self.conv_in(x)
|
118 |
+
for f in self.decoder_block:
|
119 |
+
x = f(x, conditioning)
|
120 |
+
|
121 |
+
x = self.conv_out(x)
|
122 |
+
return x.transpose(1, 2)
|
123 |
+
|
124 |
+
|
125 |
+
class DVAE(nn.Module):
|
126 |
+
def __init__(
|
127 |
+
self, decoder_config, vq_config, dim=512
|
128 |
+
):
|
129 |
+
super().__init__()
|
130 |
+
self.register_buffer('coef', torch.randn(1, 100, 1))
|
131 |
+
|
132 |
+
self.decoder = DVAEDecoder(**decoder_config)
|
133 |
+
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
|
134 |
+
if vq_config is not None:
|
135 |
+
self.vq_layer = GFSQ(**vq_config)
|
136 |
+
else:
|
137 |
+
self.vq_layer = None
|
138 |
+
|
139 |
+
def forward(self, inp):
|
140 |
+
|
141 |
+
if self.vq_layer is not None:
|
142 |
+
vq_feats = self.vq_layer._embed(inp)
|
143 |
+
else:
|
144 |
+
vq_feats = inp.detach().clone()
|
145 |
+
|
146 |
+
temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
|
147 |
+
temp = torch.stack(temp, -1)
|
148 |
+
vq_feats = temp.reshape(*temp.shape[:2], -1)
|
149 |
+
|
150 |
+
vq_feats = vq_feats.transpose(1, 2)
|
151 |
+
dec_out = self.decoder(input=vq_feats)
|
152 |
+
dec_out = self.out_conv(dec_out.transpose(1, 2))
|
153 |
+
mel = dec_out * self.coef
|
154 |
+
|
155 |
+
return mel
|
Chat2TTS/model/gpt.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
3 |
+
|
4 |
+
import logging
|
5 |
+
from tqdm import tqdm
|
6 |
+
from einops import rearrange
|
7 |
+
from transformers.cache_utils import Cache
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.nn.utils.parametrize as P
|
13 |
+
from torch.nn.utils.parametrizations import weight_norm
|
14 |
+
from transformers import LlamaModel, LlamaConfig
|
15 |
+
|
16 |
+
|
17 |
+
class LlamaMLP(nn.Module):
|
18 |
+
def __init__(self, hidden_size, intermediate_size):
|
19 |
+
super().__init__()
|
20 |
+
self.hidden_size = hidden_size
|
21 |
+
self.intermediate_size = intermediate_size
|
22 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
23 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
24 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
25 |
+
self.act_fn = F.silu
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
29 |
+
return down_proj
|
30 |
+
|
31 |
+
|
32 |
+
class GPT_warpper(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
gpt_config,
|
36 |
+
num_audio_tokens,
|
37 |
+
num_text_tokens,
|
38 |
+
num_vq=4,
|
39 |
+
**kwargs,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.logger = logging.getLogger(__name__)
|
44 |
+
self.gpt = self.build_model(gpt_config)
|
45 |
+
self.model_dim = self.gpt.config.hidden_size
|
46 |
+
|
47 |
+
self.num_vq = num_vq
|
48 |
+
self.emb_code = nn.ModuleList([nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)])
|
49 |
+
self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
|
50 |
+
self.head_text = weight_norm(nn.Linear(self.model_dim, num_text_tokens, bias=False), name='weight')
|
51 |
+
self.head_code = nn.ModuleList([weight_norm(nn.Linear(self.model_dim, num_audio_tokens, bias=False), name='weight') for i in range(self.num_vq)])
|
52 |
+
|
53 |
+
def build_model(self, config):
|
54 |
+
|
55 |
+
configuration = LlamaConfig(**config)
|
56 |
+
model = LlamaModel(configuration)
|
57 |
+
del model.embed_tokens
|
58 |
+
|
59 |
+
return model
|
60 |
+
|
61 |
+
def get_emb(self, input_ids, text_mask, **kwargs):
|
62 |
+
|
63 |
+
emb_text = self.emb_text(input_ids[text_mask][:, 0])
|
64 |
+
|
65 |
+
emb_code = [self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)]
|
66 |
+
emb_code = torch.stack(emb_code, 2).sum(2)
|
67 |
+
|
68 |
+
emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
|
69 |
+
emb[text_mask] = emb_text
|
70 |
+
emb[~text_mask] = emb_code.to(emb.dtype)
|
71 |
+
|
72 |
+
return emb
|
73 |
+
|
74 |
+
def prepare_inputs_for_generation(
|
75 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
|
76 |
+
):
|
77 |
+
# With static cache, the `past_key_values` is None
|
78 |
+
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
79 |
+
has_static_cache = False
|
80 |
+
if past_key_values is None:
|
81 |
+
past_key_values = getattr(self.gpt.layers[0].self_attn, "past_key_value", None)
|
82 |
+
has_static_cache = past_key_values is not None
|
83 |
+
|
84 |
+
past_length = 0
|
85 |
+
if past_key_values is not None:
|
86 |
+
if isinstance(past_key_values, Cache):
|
87 |
+
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
88 |
+
max_cache_length = (
|
89 |
+
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
90 |
+
if past_key_values.get_max_length() is not None
|
91 |
+
else None
|
92 |
+
)
|
93 |
+
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
94 |
+
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
95 |
+
else:
|
96 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
97 |
+
max_cache_length = None
|
98 |
+
|
99 |
+
# Keep only the unprocessed tokens:
|
100 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
101 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
102 |
+
# input)
|
103 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
104 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
105 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
106 |
+
# input_ids based on the past_length.
|
107 |
+
elif past_length < input_ids.shape[1]:
|
108 |
+
input_ids = input_ids[:, past_length:]
|
109 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
110 |
+
|
111 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
112 |
+
if (
|
113 |
+
max_cache_length is not None
|
114 |
+
and attention_mask is not None
|
115 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
116 |
+
):
|
117 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
118 |
+
|
119 |
+
position_ids = kwargs.get("position_ids", None)
|
120 |
+
if attention_mask is not None and position_ids is None:
|
121 |
+
# create position_ids on the fly for batch generation
|
122 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
123 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
124 |
+
if past_key_values:
|
125 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
126 |
+
|
127 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
128 |
+
if inputs_embeds is not None and past_key_values is None:
|
129 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
130 |
+
else:
|
131 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
132 |
+
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
133 |
+
# TODO: use `next_tokens` directly instead.
|
134 |
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
135 |
+
|
136 |
+
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
137 |
+
if cache_position is None:
|
138 |
+
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
139 |
+
else:
|
140 |
+
cache_position = cache_position[-input_length:]
|
141 |
+
|
142 |
+
if has_static_cache:
|
143 |
+
past_key_values = None
|
144 |
+
|
145 |
+
model_inputs.update(
|
146 |
+
{
|
147 |
+
"position_ids": position_ids,
|
148 |
+
"cache_position": cache_position,
|
149 |
+
"past_key_values": past_key_values,
|
150 |
+
"use_cache": kwargs.get("use_cache"),
|
151 |
+
"attention_mask": attention_mask,
|
152 |
+
}
|
153 |
+
)
|
154 |
+
return model_inputs
|
155 |
+
|
156 |
+
def generate(
|
157 |
+
self,
|
158 |
+
emb,
|
159 |
+
inputs_ids,
|
160 |
+
temperature,
|
161 |
+
eos_token,
|
162 |
+
attention_mask = None,
|
163 |
+
max_new_token = 2048,
|
164 |
+
min_new_token = 0,
|
165 |
+
LogitsWarpers = [],
|
166 |
+
LogitsProcessors = [],
|
167 |
+
infer_text=False,
|
168 |
+
return_attn=False,
|
169 |
+
return_hidden=False,
|
170 |
+
):
|
171 |
+
|
172 |
+
with torch.no_grad():
|
173 |
+
|
174 |
+
attentions = []
|
175 |
+
hiddens = []
|
176 |
+
|
177 |
+
start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long)
|
178 |
+
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
|
179 |
+
|
180 |
+
temperature = temperature[None].expand(inputs_ids.shape[0], -1)
|
181 |
+
temperature = rearrange(temperature, "b n -> (b n) 1")
|
182 |
+
|
183 |
+
attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device)
|
184 |
+
if attention_mask is not None:
|
185 |
+
attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
|
186 |
+
|
187 |
+
for i in tqdm(range(max_new_token)):
|
188 |
+
|
189 |
+
model_input = self.prepare_inputs_for_generation(inputs_ids,
|
190 |
+
outputs.past_key_values if i!=0 else None,
|
191 |
+
attention_mask_cache[:, :inputs_ids.shape[1]], use_cache=True)
|
192 |
+
|
193 |
+
if i == 0:
|
194 |
+
model_input['inputs_embeds'] = emb
|
195 |
+
else:
|
196 |
+
if infer_text:
|
197 |
+
model_input['inputs_embeds'] = self.emb_text(model_input['input_ids'][:,:,0])
|
198 |
+
else:
|
199 |
+
code_emb = [self.emb_code[i](model_input['input_ids'][:,:,i]) for i in range(self.num_vq)]
|
200 |
+
model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3)
|
201 |
+
|
202 |
+
model_input['input_ids'] = None
|
203 |
+
outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
|
204 |
+
attentions.append(outputs.attentions)
|
205 |
+
hidden_states = outputs[0] # 🐻
|
206 |
+
if return_hidden:
|
207 |
+
hiddens.append(hidden_states[:, -1])
|
208 |
+
|
209 |
+
with P.cached():
|
210 |
+
if infer_text:
|
211 |
+
logits = self.head_text(hidden_states)
|
212 |
+
else:
|
213 |
+
logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
|
214 |
+
|
215 |
+
logits = logits[:, -1].float()
|
216 |
+
|
217 |
+
if not infer_text:
|
218 |
+
logits = rearrange(logits, "b c n -> (b n) c")
|
219 |
+
logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
|
220 |
+
else:
|
221 |
+
logits_token = inputs_ids[:, start_idx:, 0]
|
222 |
+
|
223 |
+
logits = logits / temperature
|
224 |
+
|
225 |
+
for logitsProcessors in LogitsProcessors:
|
226 |
+
logits = logitsProcessors(logits_token, logits)
|
227 |
+
|
228 |
+
for logitsWarpers in LogitsWarpers:
|
229 |
+
logits = logitsWarpers(logits_token, logits)
|
230 |
+
|
231 |
+
if i < min_new_token:
|
232 |
+
logits[:, eos_token] = -torch.inf
|
233 |
+
|
234 |
+
scores = F.softmax(logits, dim=-1)
|
235 |
+
|
236 |
+
idx_next = torch.multinomial(scores, num_samples=1)
|
237 |
+
|
238 |
+
if not infer_text:
|
239 |
+
idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
240 |
+
finish = finish | (idx_next == eos_token).any(1)
|
241 |
+
inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
|
242 |
+
else:
|
243 |
+
finish = finish | (idx_next == eos_token).any(1)
|
244 |
+
inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq)], 1)
|
245 |
+
|
246 |
+
end_idx = end_idx + (~finish).int()
|
247 |
+
|
248 |
+
if finish.all():
|
249 |
+
break
|
250 |
+
|
251 |
+
inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
|
252 |
+
inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
|
253 |
+
|
254 |
+
if return_hidden:
|
255 |
+
hiddens = torch.stack(hiddens, 1)
|
256 |
+
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
|
257 |
+
|
258 |
+
if not finish.all():
|
259 |
+
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
|
260 |
+
|
261 |
+
return {
|
262 |
+
'ids': inputs_ids,
|
263 |
+
'attentions': attentions,
|
264 |
+
'hiddens':hiddens,
|
265 |
+
}
|
Chat2TTS/utils/gpu_utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
|
5 |
+
def select_device(min_memory = 2048):
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
if torch.cuda.is_available():
|
8 |
+
available_gpus = []
|
9 |
+
for i in range(torch.cuda.device_count()):
|
10 |
+
props = torch.cuda.get_device_properties(i)
|
11 |
+
free_memory = props.total_memory - torch.cuda.memory_reserved(i)
|
12 |
+
available_gpus.append((i, free_memory))
|
13 |
+
selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1])
|
14 |
+
device = torch.device(f'cuda:{selected_gpu}')
|
15 |
+
free_memory_mb = max_free_memory / (1024 * 1024)
|
16 |
+
if free_memory_mb < min_memory:
|
17 |
+
logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
|
18 |
+
device = torch.device('cpu')
|
19 |
+
else:
|
20 |
+
logger.log(logging.WARNING, f'No GPU found, use CPU instead')
|
21 |
+
device = torch.device('cpu')
|
22 |
+
|
23 |
+
return device
|
Chat2TTS/utils/infer_utils.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class CustomRepetitionPenaltyLogitsProcessorRepeat():
|
7 |
+
|
8 |
+
def __init__(self, penalty: float, max_input_ids, past_window):
|
9 |
+
if not isinstance(penalty, float) or not (penalty > 0):
|
10 |
+
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
11 |
+
|
12 |
+
self.penalty = penalty
|
13 |
+
self.max_input_ids = max_input_ids
|
14 |
+
self.past_window = past_window
|
15 |
+
|
16 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
17 |
+
|
18 |
+
input_ids = input_ids[:, -self.past_window:]
|
19 |
+
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
|
20 |
+
freq[self.max_input_ids:] = 0
|
21 |
+
alpha = self.penalty**freq
|
22 |
+
scores = torch.where(scores < 0, scores*alpha, scores/alpha)
|
23 |
+
|
24 |
+
return scores
|
25 |
+
|
26 |
+
class CustomRepetitionPenaltyLogitsProcessor():
|
27 |
+
|
28 |
+
def __init__(self, penalty: float, max_input_ids, past_window):
|
29 |
+
if not isinstance(penalty, float) or not (penalty > 0):
|
30 |
+
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
31 |
+
|
32 |
+
self.penalty = penalty
|
33 |
+
self.max_input_ids = max_input_ids
|
34 |
+
self.past_window = past_window
|
35 |
+
|
36 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
37 |
+
|
38 |
+
input_ids = input_ids[:, -self.past_window:]
|
39 |
+
score = torch.gather(scores, 1, input_ids)
|
40 |
+
_score = score.detach().clone()
|
41 |
+
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
42 |
+
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
|
43 |
+
scores.scatter_(1, input_ids, score)
|
44 |
+
|
45 |
+
return scores
|
Chat2TTS/utils/io_utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
|
5 |
+
def get_latest_modified_file(directory):
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
files = [os.path.join(directory, f) for f in os.listdir(directory)]
|
9 |
+
if not files:
|
10 |
+
logger.log(logging.WARNING, f'No files found in the directory: {directory}')
|
11 |
+
return None
|
12 |
+
latest_file = max(files, key=os.path.getmtime)
|
13 |
+
|
14 |
+
return latest_file
|
tool/gpu.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import torch
|
2 |
import os, sys
|
3 |
-
import spaces
|
4 |
|
5 |
if sys.platform == "darwin":
|
6 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
7 |
now_dir = os.getcwd()
|
8 |
sys.path.append(now_dir)
|
9 |
-
from
|
10 |
|
11 |
logger = get_logger("gpu")
|
12 |
|
@@ -36,13 +35,13 @@ def select_device(min_memory=2047, experimental=False):
|
|
36 |
"""
|
37 |
if experimental:
|
38 |
# For Apple M1/M2 chips with Metal Performance Shaders
|
39 |
-
logger.
|
40 |
device = torch.device("mps")
|
41 |
else:
|
42 |
-
logger.
|
43 |
device = torch.device("cpu")
|
44 |
else:
|
45 |
-
logger.
|
46 |
device = torch.device("cpu")
|
47 |
|
48 |
return device
|
|
|
1 |
import torch
|
2 |
import os, sys
|
|
|
3 |
|
4 |
if sys.platform == "darwin":
|
5 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
6 |
now_dir = os.getcwd()
|
7 |
sys.path.append(now_dir)
|
8 |
+
from .logger.log import get_logger
|
9 |
|
10 |
logger = get_logger("gpu")
|
11 |
|
|
|
35 |
"""
|
36 |
if experimental:
|
37 |
# For Apple M1/M2 chips with Metal Performance Shaders
|
38 |
+
logger.warn("experimantal: found apple GPU, using MPS.")
|
39 |
device = torch.device("mps")
|
40 |
else:
|
41 |
+
logger.info("found Apple GPU, but use CPU.")
|
42 |
device = torch.device("cpu")
|
43 |
else:
|
44 |
+
logger.warning("no GPU found, use CPU instead")
|
45 |
device = torch.device("cpu")
|
46 |
|
47 |
return device
|
web/app.py
CHANGED
@@ -10,7 +10,7 @@ from tool.logger import get_logger
|
|
10 |
from tool.func import *
|
11 |
from tool.np import *
|
12 |
from tool.gpu import select_device
|
13 |
-
import
|
14 |
import argparse
|
15 |
import torch._dynamo
|
16 |
|
@@ -21,7 +21,7 @@ torch._dynamo.config.suppress_errors = True
|
|
21 |
logger = get_logger("app")
|
22 |
|
23 |
# Initialize and load the model:
|
24 |
-
chat =
|
25 |
|
26 |
|
27 |
def init_chat(args):
|
@@ -38,7 +38,7 @@ def init_chat(args):
|
|
38 |
|
39 |
logger.info("loading ChatTTS device :" + str(device))
|
40 |
|
41 |
-
if chat.
|
42 |
print("Models loaded successfully.")
|
43 |
logger.info("Models loaded successfully.")
|
44 |
else:
|
@@ -203,7 +203,7 @@ def get_chat_infer_audio(chat_txt,
|
|
203 |
spk_emb_text):
|
204 |
logger.info("========开始生成音频文件=====")
|
205 |
#音频参数设置
|
206 |
-
params_infer_code =
|
207 |
spk_emb=spk_emb_text, # add sampled speaker
|
208 |
temperature=temperature_slider, # using custom temperature
|
209 |
top_P=top_p_slider, # top P decode
|
@@ -227,7 +227,7 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
|
|
227 |
logger.info("========文本内容无需优化=====")
|
228 |
return text
|
229 |
|
230 |
-
params_refine_text =
|
231 |
prompt='[oral_2][laugh_0][break_6]',
|
232 |
)
|
233 |
|
|
|
10 |
from tool.func import *
|
11 |
from tool.np import *
|
12 |
from tool.gpu import select_device
|
13 |
+
import Chat2TTS
|
14 |
import argparse
|
15 |
import torch._dynamo
|
16 |
|
|
|
21 |
logger = get_logger("app")
|
22 |
|
23 |
# Initialize and load the model:
|
24 |
+
chat = Chat2TTS.Chat()
|
25 |
|
26 |
|
27 |
def init_chat(args):
|
|
|
38 |
|
39 |
logger.info("loading ChatTTS device :" + str(device))
|
40 |
|
41 |
+
if chat.load_models(source=source, local_path="D:\\chenjgspace\\ai-model\\chattts"):
|
42 |
print("Models loaded successfully.")
|
43 |
logger.info("Models loaded successfully.")
|
44 |
else:
|
|
|
203 |
spk_emb_text):
|
204 |
logger.info("========开始生成音频文件=====")
|
205 |
#音频参数设置
|
206 |
+
params_infer_code = Chat2TTS.Chat.InferCodeParams(
|
207 |
spk_emb=spk_emb_text, # add sampled speaker
|
208 |
temperature=temperature_slider, # using custom temperature
|
209 |
top_P=top_p_slider, # top P decode
|
|
|
227 |
logger.info("========文本内容无需优化=====")
|
228 |
return text
|
229 |
|
230 |
+
params_refine_text = Chat2TTS.Chat.RefineTextParams(
|
231 |
prompt='[oral_2][laugh_0][break_6]',
|
232 |
)
|
233 |
|
web/app_sc.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import spaces
|
3 |
+
|
4 |
+
if sys.platform == "darwin":
|
5 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
6 |
+
now_dir = os.getcwd()
|
7 |
+
sys.path.append(now_dir)
|
8 |
+
|
9 |
+
from tool.logger import get_logger
|
10 |
+
from tool.func import *
|
11 |
+
from tool.np import *
|
12 |
+
from tool.gpu import select_device
|
13 |
+
from tool.ctx import TorchSeedContext
|
14 |
+
import ChatTTS
|
15 |
+
import argparse
|
16 |
+
import torch._dynamo
|
17 |
+
|
18 |
+
torch._dynamo.config.suppress_errors = True
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
logger = get_logger("app")
|
23 |
+
|
24 |
+
# Initialize and load the model:
|
25 |
+
chat = ChatTTS.Chat()
|
26 |
+
|
27 |
+
|
28 |
+
def init_chat(args):
|
29 |
+
global chat
|
30 |
+
source = "custom"
|
31 |
+
# 获取启动模式
|
32 |
+
MODEL = os.getenv('MODEL')
|
33 |
+
logger.info("loading ChatTTS model..., start MODEL:" + str(MODEL))
|
34 |
+
# huggingface 部署模式下,模型则直接使用hf的模型数据
|
35 |
+
if MODEL == "HF":
|
36 |
+
source = "huggingface"
|
37 |
+
|
38 |
+
device=select_device()
|
39 |
+
|
40 |
+
logger.info("loading ChatTTS device :" + str(device))
|
41 |
+
|
42 |
+
if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts",device=device):
|
43 |
+
print("Models loaded successfully.")
|
44 |
+
logger.info("Models loaded successfully.")
|
45 |
+
else:
|
46 |
+
logger.error("=========Models load failed.")
|
47 |
+
sys.exit(1)
|
48 |
+
|
49 |
+
|
50 |
+
def main(args):
|
51 |
+
with gr.Blocks() as demo:
|
52 |
+
gr.Markdown("# ChatTTS demo")
|
53 |
+
with gr.Row():
|
54 |
+
with gr.Column(scale=1):
|
55 |
+
text_input = gr.Textbox(
|
56 |
+
label="转换内容",
|
57 |
+
lines=4,
|
58 |
+
max_lines=4,
|
59 |
+
placeholder="Please Input Text...",
|
60 |
+
value="柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。",
|
61 |
+
interactive=True,
|
62 |
+
)
|
63 |
+
with gr.Row():
|
64 |
+
refine_text_checkBox = gr.Checkbox(
|
65 |
+
label="是否优化文本,如是则先对文本内容做优化分词",
|
66 |
+
interactive=True,
|
67 |
+
value=True
|
68 |
+
)
|
69 |
+
temperature_slider = gr.Slider(
|
70 |
+
minimum=0.00001,
|
71 |
+
maximum=1.0,
|
72 |
+
step=0.00001,
|
73 |
+
value=0.3,
|
74 |
+
interactive=True,
|
75 |
+
label="模型 Temperature 参数设置"
|
76 |
+
)
|
77 |
+
top_p_slider = gr.Slider(
|
78 |
+
minimum=0.1,
|
79 |
+
maximum=0.9,
|
80 |
+
step=0.05,
|
81 |
+
value=0.7,
|
82 |
+
label="模型 top_P 参数设置",
|
83 |
+
interactive=True,
|
84 |
+
)
|
85 |
+
top_k_slider = gr.Slider(
|
86 |
+
minimum=1,
|
87 |
+
maximum=20,
|
88 |
+
step=1,
|
89 |
+
value=20,
|
90 |
+
label="模型 top_K 参数设置",
|
91 |
+
interactive=True,
|
92 |
+
)
|
93 |
+
with gr.Row():
|
94 |
+
voice_selection = gr.Dropdown(
|
95 |
+
label="Timbre",
|
96 |
+
choices=voices.keys(),
|
97 |
+
value="旁白",
|
98 |
+
interactive=True,
|
99 |
+
show_label=True
|
100 |
+
)
|
101 |
+
audio_seed_input = gr.Number(
|
102 |
+
value=2,
|
103 |
+
label="音色种子",
|
104 |
+
interactive=True,
|
105 |
+
minimum=seed_min,
|
106 |
+
maximum=seed_max,
|
107 |
+
)
|
108 |
+
generate_audio_seed = gr.Button("随机生成音色种子", interactive=True)
|
109 |
+
text_seed_input = gr.Number(
|
110 |
+
value=42,
|
111 |
+
label="文本种子",
|
112 |
+
interactive=True,
|
113 |
+
minimum=seed_min,
|
114 |
+
maximum=seed_max,
|
115 |
+
)
|
116 |
+
generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
spk_emb_text = gr.Textbox(
|
120 |
+
label="Speaker Embedding",
|
121 |
+
max_lines=3,
|
122 |
+
show_copy_button=True,
|
123 |
+
interactive=False,
|
124 |
+
scale=2,
|
125 |
+
|
126 |
+
)
|
127 |
+
reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
|
128 |
+
|
129 |
+
with gr.Row():
|
130 |
+
generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
|
131 |
+
|
132 |
+
with gr.Row():
|
133 |
+
text_output = gr.Textbox(
|
134 |
+
label="输出文本",
|
135 |
+
interactive=False,
|
136 |
+
show_copy_button=True,
|
137 |
+
)
|
138 |
+
|
139 |
+
audio_output = gr.Audio(
|
140 |
+
label="输出音频",
|
141 |
+
value=None,
|
142 |
+
format="wav",
|
143 |
+
autoplay=False,
|
144 |
+
streaming=False,
|
145 |
+
interactive=False,
|
146 |
+
show_label=True,
|
147 |
+
waveform_options=gr.WaveformOptions(
|
148 |
+
sample_rate=24000,
|
149 |
+
),
|
150 |
+
)
|
151 |
+
# 针对页面元素新增 监听事件
|
152 |
+
voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
|
153 |
+
|
154 |
+
audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
|
155 |
+
|
156 |
+
generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
|
157 |
+
|
158 |
+
generate_text_seed.click(fn=generate_seed,outputs=text_seed_input)
|
159 |
+
|
160 |
+
# reload_chat_button.click()
|
161 |
+
|
162 |
+
generate_button.click(fn=get_chat_infer_text,
|
163 |
+
inputs=[text_input,
|
164 |
+
text_seed_input,
|
165 |
+
refine_text_checkBox
|
166 |
+
],
|
167 |
+
outputs=[text_output]
|
168 |
+
).then(fn=get_chat_infer_audio,
|
169 |
+
inputs=[text_output,
|
170 |
+
temperature_slider,
|
171 |
+
top_p_slider,
|
172 |
+
top_k_slider,
|
173 |
+
audio_seed_input,
|
174 |
+
spk_emb_text
|
175 |
+
],
|
176 |
+
outputs=[audio_output])
|
177 |
+
# 初始化 spk_emb_text 数值
|
178 |
+
spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
|
179 |
+
logger.info("元素初始化完成,启动gradio服务=======")
|
180 |
+
|
181 |
+
# 运行gradio服务
|
182 |
+
demo.launch()
|
183 |
+
|
184 |
+
|
185 |
+
'''
|
186 |
+
top_K: "top_K"(K个最高得分)是指在所有可能的生成结果中,模型会选取前K个得分最高的结果。这个设置常常用于基于概率的生成任务,例如语言模型中的词或句子生成。当你设置top_K为K时,你要求模型选择得分最高的K个选项,这样输出通常会有一定的多样性,但仍然是基于模型的前K个预测。
|
187 |
+
|
188 |
+
top_P: "top_P"(概率阈值)则是一个连续的值,而不是离散的整数。它代表的是一个概率阈值,模型会生成所有得分高于该阈值的概率的项目。换句话说,top_P会生成那些概率大于等于给定值的所有生成结果。这个设置更为灵活,可以根据实际需求调整生成内容的不确定性,高频选项被生成的概率较高,而低频可能性则可能根据阈值随机出现。
|
189 |
+
|
190 |
+
在实际应用中,选择top_K还是top_P取决于具体任务需求,如是否希望生成内容有一定程度的多样化(top_K),还是希望生成的内容更接近于最可能发生的选项(top_P)。较高的top_P可能会引入更多的随机性和创新,而较低的top_K则更倾向于保守的选择。
|
191 |
+
|
192 |
+
spk_embedding(Speaker Embedding): 这个术语一般用于语音识别或者多说话者模型中。"spk_embedding"指的是每个说话人的身份或特征向量,或者说是用户标识的嵌入表示。在对话系统中,它能帮助模型区分不同的说话者,比如在多轮对话中区分是同一个用户的不同回复,或者是不同用户的交互。这个嵌入可能包含了说话人的个性、语调、口音等信息,有助于提高对话的连贯性和自然性。
|
193 |
+
|
194 |
+
temperature: 通常在语言模型的生成(如基于概率的softmax)中使用。"temperature"是一个正数,用于控制生成内容的随机性和多样性。当温度较低(如0.1)时,模型倾向于生成最可能的结果,文字更保守,少有创新;当温度较高(如1或更高)时,模型将更倾向于产生多样化的内容,但可能性较大的选项将被稀释。因此,temperature调整是一个常用的平衡方法,使得生成更具创造性或是更符合预期。
|
195 |
+
|
196 |
+
简而言之,"spk_embedding"关注的是对话参与者的身份特征,而"temperature"是用于调整生成文本不确定性的一个超参数。
|
197 |
+
'''
|
198 |
+
#@spaces.GPU 使用chattts三方包 gpu模式没有跑通
|
199 |
+
def get_chat_infer_audio(chat_txt,
|
200 |
+
temperature_slider,
|
201 |
+
top_p_slider,
|
202 |
+
top_k_slider,
|
203 |
+
audio_seed_input,
|
204 |
+
spk_emb_text):
|
205 |
+
logger.info("========开始生成音频文件=====")
|
206 |
+
#音频参数设置
|
207 |
+
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
208 |
+
spk_emb=spk_emb_text, # add sampled speaker
|
209 |
+
temperature=temperature_slider, # using custom temperature
|
210 |
+
top_P=top_p_slider, # top P decode
|
211 |
+
top_K=top_k_slider, # top K decode
|
212 |
+
)
|
213 |
+
|
214 |
+
with TorchSeedContext(audio_seed_input):
|
215 |
+
wav = chat.infer(
|
216 |
+
text=chat_txt,
|
217 |
+
skip_refine_text=True, #跳过文本优化
|
218 |
+
params_infer_code=params_infer_code,
|
219 |
+
)
|
220 |
+
yield 24000, float_to_int16(wav[0]).T
|
221 |
+
|
222 |
+
#@spaces.GPU 使用chattts三方包 gpu模式没有跑通
|
223 |
+
def get_chat_infer_text(text,seed,refine_text_checkBox):
|
224 |
+
|
225 |
+
logger.info("========开始优化文本内容2=====")
|
226 |
+
global chat
|
227 |
+
if not refine_text_checkBox:
|
228 |
+
logger.info("========文本内容无需优化=====")
|
229 |
+
return text
|
230 |
+
|
231 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
232 |
+
prompt='[oral_2][laugh_0][break_6]',
|
233 |
+
)
|
234 |
+
|
235 |
+
with TorchSeedContext(seed):
|
236 |
+
chat_text = chat.infer(
|
237 |
+
text=text,
|
238 |
+
skip_refine_text=False,
|
239 |
+
refine_text_only=True, #仅返回优化后文本内容
|
240 |
+
params_refine_text=params_refine_text,
|
241 |
+
)
|
242 |
+
|
243 |
+
return chat_text[0] if isinstance(chat_text, list) else chat_text
|
244 |
+
|
245 |
+
@spaces.GPU
|
246 |
+
def on_audio_seed_change(audio_seed_input):
|
247 |
+
global chat
|
248 |
+
with TorchSeedContext(audio_seed_input):
|
249 |
+
rand_spk = chat.sample_random_speaker()
|
250 |
+
return rand_spk
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|
255 |
+
parser.add_argument(
|
256 |
+
"--server_name", type=str, default="0.0.0.0", help="server name"
|
257 |
+
)
|
258 |
+
parser.add_argument("--server_port", type=int, default=7860, help="server port")
|
259 |
+
parser.add_argument(
|
260 |
+
"--custom_path", type=str, default="D:\\chenjgspace\\ai-model\\chattts", help="custom model path"
|
261 |
+
)
|
262 |
+
parser.add_argument(
|
263 |
+
"--coef", type=str, default=None, help="custom dvae coefficient"
|
264 |
+
)
|
265 |
+
args = parser.parse_args()
|
266 |
+
init_chat(args)
|
267 |
+
main(args)
|