chenjgtea commited on
Commit
0bc3ceb
·
1 Parent(s): c23cee5

新增gpu模式下chattts代码

Browse files
Chat2TTS/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import Chat
Chat2TTS/core.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import logging
4
+ from omegaconf import OmegaConf
5
+
6
+ import torch
7
+ from vocos import Vocos
8
+ from .model.dvae import DVAE
9
+ from .model.gpt import GPT_warpper
10
+ from .utils.gpu_utils import select_device
11
+ from .utils.io_utils import get_latest_modified_file
12
+ from .infer.api import refine_text, infer_code
13
+ from dataclasses import dataclass
14
+ from typing import Literal, Optional, List, Tuple, Dict
15
+
16
+ from huggingface_hub import snapshot_download
17
+
18
+ logging.basicConfig(level = logging.INFO)
19
+
20
+
21
+ class Chat:
22
+ def __init__(self, ):
23
+ self.pretrain_models = {}
24
+ self.logger = logging.getLogger(__name__)
25
+
26
+ def check_model(self, level = logging.INFO, use_decoder = False):
27
+ not_finish = False
28
+ check_list = ['vocos', 'gpt', 'tokenizer']
29
+
30
+ if use_decoder:
31
+ check_list.append('decoder')
32
+ else:
33
+ check_list.append('dvae')
34
+
35
+ for module in check_list:
36
+ if module not in self.pretrain_models:
37
+ self.logger.log(logging.WARNING, f'{module} not initialized.')
38
+ not_finish = True
39
+
40
+ if not not_finish:
41
+ self.logger.log(level, f'All initialized.')
42
+
43
+ return not not_finish
44
+
45
+ def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
46
+ if source == 'huggingface':
47
+ hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
48
+ try:
49
+ download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
50
+ except:
51
+ download_path = None
52
+ if download_path is None or force_redownload:
53
+ self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
54
+ download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
55
+ else:
56
+ self.logger.log(logging.INFO, f'Load from cache: {download_path}')
57
+ self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
58
+ elif source == 'local':
59
+ self.logger.log(logging.INFO, f'Load from local: {local_path}')
60
+ self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
61
+
62
+ def _load(
63
+ self,
64
+ vocos_config_path: str = None,
65
+ vocos_ckpt_path: str = None,
66
+ dvae_config_path: str = None,
67
+ dvae_ckpt_path: str = None,
68
+ gpt_config_path: str = None,
69
+ gpt_ckpt_path: str = None,
70
+ decoder_config_path: str = None,
71
+ decoder_ckpt_path: str = None,
72
+ tokenizer_path: str = None,
73
+ device: str = None
74
+ ):
75
+ if not device:
76
+ device = select_device(4096)
77
+ self.logger.log(logging.INFO, f'use {device}')
78
+
79
+ if vocos_config_path:
80
+ vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
81
+ assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
82
+ vocos.load_state_dict(torch.load(vocos_ckpt_path))
83
+ self.pretrain_models['vocos'] = vocos
84
+ self.logger.log(logging.INFO, 'vocos loaded.')
85
+
86
+ if dvae_config_path:
87
+ cfg = OmegaConf.load(dvae_config_path)
88
+ dvae = DVAE(**cfg).to(device).eval()
89
+ assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
90
+ dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
91
+ self.pretrain_models['dvae'] = dvae
92
+ self.logger.log(logging.INFO, 'dvae loaded.')
93
+
94
+ if gpt_config_path:
95
+ cfg = OmegaConf.load(gpt_config_path)
96
+ gpt = GPT_warpper(**cfg).to(device).eval()
97
+ assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
98
+ gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
99
+ self.pretrain_models['gpt'] = gpt
100
+ self.logger.log(logging.INFO, 'gpt loaded.')
101
+
102
+ if decoder_config_path:
103
+ cfg = OmegaConf.load(decoder_config_path)
104
+ decoder = DVAE(**cfg).to(device).eval()
105
+ assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
106
+ decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
107
+ self.pretrain_models['decoder'] = decoder
108
+ self.logger.log(logging.INFO, 'decoder loaded.')
109
+
110
+ if tokenizer_path:
111
+ tokenizer = torch.load(tokenizer_path, map_location='cpu')
112
+ tokenizer.padding_side = 'left'
113
+ self.pretrain_models['tokenizer'] = tokenizer
114
+ self.logger.log(logging.INFO, 'tokenizer loaded.')
115
+
116
+ self.check_model()
117
+
118
+ @dataclass(repr=False, eq=False)
119
+ class RefineTextParams:
120
+ prompt: str = ""
121
+ top_P: float = 0.7
122
+ top_K: int = 20
123
+ temperature: float = 0.7
124
+ repetition_penalty: float = 1.0
125
+ max_new_token: int = 384
126
+ min_new_token: int = 0
127
+ show_tqdm: bool = True
128
+ ensure_non_empty: bool = True
129
+
130
+ @dataclass(repr=False, eq=False)
131
+ class InferCodeParams(RefineTextParams):
132
+ prompt: str = "[speed_5]"
133
+ spk_emb: Optional[str] = None
134
+ temperature: float = 0.3
135
+ repetition_penalty: float = 1.05
136
+ max_new_token: int = 2048
137
+
138
+ def infer(
139
+ self,
140
+ text,
141
+ skip_refine_text=False,
142
+ refine_text_only=False,
143
+ params_refine_text={},
144
+ params_infer_code={},
145
+ use_decoder=False
146
+ ):
147
+
148
+ assert self.check_model(use_decoder=use_decoder)
149
+
150
+ if not skip_refine_text:
151
+ text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
152
+ text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
153
+ text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
154
+ if refine_text_only:
155
+ return text
156
+
157
+ text = [params_infer_code.get('prompt', '') + i for i in text]
158
+ params_infer_code.pop('prompt', '')
159
+ result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
160
+
161
+ if use_decoder:
162
+ mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
163
+ else:
164
+ mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
165
+
166
+ wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
167
+
168
+ return wav
169
+
170
+
171
+
Chat2TTS/experimental/llm.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from openai import OpenAI
3
+
4
+ prompt_dict = {
5
+ 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
6
+ {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
7
+ {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
8
+ 'deepseek': [
9
+ {"role": "system", "content": "You are a helpful assistant"},
10
+ {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
11
+ {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
12
+ 'deepseek_TN': [
13
+ {"role": "system", "content": "You are a helpful assistant"},
14
+ {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
15
+ {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
16
+ {"role": "user", "content": "We paid $123 for this desk."},
17
+ {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
18
+ {"role": "user", "content": "详询请拨打010-724654"},
19
+ {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
20
+ {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
21
+ {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
22
+ ],
23
+ }
24
+
25
+ class llm_api:
26
+ def __init__(self, api_key, base_url, model):
27
+ self.client = OpenAI(
28
+ api_key = api_key,
29
+ base_url = base_url,
30
+ )
31
+ self.model = model
32
+ def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
33
+
34
+ completion = self.client.chat.completions.create(
35
+ model = self.model,
36
+ messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
37
+ temperature = temperature,
38
+ **kwargs
39
+ )
40
+ return completion.choices[0].message.content
Chat2TTS/infer/api.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
5
+ from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
6
+
7
+ def infer_code(
8
+ models,
9
+ text,
10
+ spk_emb = None,
11
+ top_P = 0.7,
12
+ top_K = 20,
13
+ temperature = 0.3,
14
+ repetition_penalty = 1.05,
15
+ max_new_token = 2048,
16
+ **kwargs
17
+ ):
18
+
19
+ device = next(models['gpt'].parameters()).device
20
+
21
+ if not isinstance(text, list):
22
+ text = [text]
23
+
24
+ if not isinstance(temperature, list):
25
+ temperature = [temperature] * models['gpt'].num_vq
26
+
27
+ if spk_emb is not None:
28
+ text = [f'[Stts][spk_emb]{i}[uv_break][Ptts]' for i in text]
29
+ else:
30
+ text = [f'[Stts][empty_spk]{i}[uv_break][Ptts]' for i in text]
31
+
32
+ text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
33
+ input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
34
+ text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
35
+
36
+ inputs = {
37
+ 'input_ids': input_ids,
38
+ 'text_mask': text_mask,
39
+ 'attention_mask': text_token['attention_mask'],
40
+ }
41
+
42
+ emb = models['gpt'].get_emb(**inputs)
43
+ if spk_emb is not None:
44
+ emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
45
+ F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
46
+
47
+ num_code = models['gpt'].emb_code[0].num_embeddings - 1
48
+
49
+ LogitsWarpers = []
50
+ if top_P is not None:
51
+ LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
52
+ if top_K is not None:
53
+ LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
54
+
55
+ LogitsProcessors = []
56
+ if repetition_penalty is not None and repetition_penalty != 1:
57
+ LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
58
+ repetition_penalty, num_code, 16))
59
+
60
+ result = models['gpt'].generate(
61
+ emb, inputs['input_ids'],
62
+ temperature = torch.tensor(temperature, device=device),
63
+ attention_mask = inputs['attention_mask'],
64
+ LogitsWarpers = LogitsWarpers,
65
+ LogitsProcessors = LogitsProcessors,
66
+ eos_token = num_code,
67
+ max_new_token = max_new_token,
68
+ infer_text = False,
69
+ **kwargs
70
+ )
71
+
72
+ return result
73
+
74
+
75
+ def refine_text(
76
+ models,
77
+ text,
78
+ top_P = 0.7,
79
+ top_K = 20,
80
+ temperature = 0.7,
81
+ repetition_penalty = 1.0,
82
+ max_new_token = 384,
83
+ prompt = '',
84
+ **kwargs
85
+ ):
86
+
87
+ device = next(models['gpt'].parameters()).device
88
+
89
+ if not isinstance(text, list):
90
+ text = [text]
91
+
92
+ assert len(text), 'text should not be empty'
93
+
94
+ text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
95
+ text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
96
+ text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
97
+
98
+ inputs = {
99
+ 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
100
+ 'text_mask': text_mask,
101
+ 'attention_mask': text_token['attention_mask'],
102
+ }
103
+
104
+ LogitsWarpers = []
105
+ if top_P is not None:
106
+ LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
107
+ if top_K is not None:
108
+ LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
109
+
110
+ LogitsProcessors = []
111
+ if repetition_penalty is not None and repetition_penalty != 1:
112
+ LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
113
+
114
+ result = models['gpt'].generate(
115
+ models['gpt'].get_emb(**inputs),
116
+ inputs['input_ids'],
117
+ temperature = torch.tensor([temperature,], device=device),
118
+ attention_mask = inputs['attention_mask'],
119
+ LogitsWarpers = LogitsWarpers,
120
+ LogitsProcessors = LogitsProcessors,
121
+ eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
122
+ max_new_token = max_new_token,
123
+ infer_text = True,
124
+ **kwargs
125
+ )
126
+ return result
Chat2TTS/model/dvae.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from einops import rearrange
3
+ from vector_quantize_pytorch import GroupedResidualFSQ
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ class ConvNeXtBlock(nn.Module):
10
+ def __init__(
11
+ self,
12
+ dim: int,
13
+ intermediate_dim: int,
14
+ kernel, dilation,
15
+ layer_scale_init_value: float = 1e-6,
16
+ ):
17
+ # ConvNeXt Block copied from Vocos.
18
+ super().__init__()
19
+ self.dwconv = nn.Conv1d(dim, dim,
20
+ kernel_size=kernel, padding=dilation*(kernel//2),
21
+ dilation=dilation, groups=dim
22
+ ) # depthwise conv
23
+
24
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
25
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
26
+ self.act = nn.GELU()
27
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
28
+ self.gamma = (
29
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
30
+ if layer_scale_init_value > 0
31
+ else None
32
+ )
33
+
34
+ def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
35
+ residual = x
36
+ x = self.dwconv(x)
37
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
38
+ x = self.norm(x)
39
+ x = self.pwconv1(x)
40
+ x = self.act(x)
41
+ x = self.pwconv2(x)
42
+ if self.gamma is not None:
43
+ x = self.gamma * x
44
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
45
+
46
+ x = residual + x
47
+ return x
48
+
49
+
50
+
51
+ class GFSQ(nn.Module):
52
+
53
+ def __init__(self,
54
+ dim, levels, G, R, eps=1e-5, transpose = True
55
+ ):
56
+ super(GFSQ, self).__init__()
57
+ self.quantizer = GroupedResidualFSQ(
58
+ dim=dim,
59
+ levels=levels,
60
+ num_quantizers=R,
61
+ groups=G,
62
+ )
63
+ self.n_ind = math.prod(levels)
64
+ self.eps = eps
65
+ self.transpose = transpose
66
+ self.G = G
67
+ self.R = R
68
+
69
+ def _embed(self, x):
70
+ if self.transpose:
71
+ x = x.transpose(1,2)
72
+ x = rearrange(
73
+ x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74
+ )
75
+ feat = self.quantizer.get_output_from_indices(x)
76
+ return feat.transpose(1,2) if self.transpose else feat
77
+
78
+ def forward(self, x,):
79
+ if self.transpose:
80
+ x = x.transpose(1,2)
81
+ feat, ind = self.quantizer(x)
82
+ ind = rearrange(
83
+ ind, "g b t r ->b t (g r)",
84
+ )
85
+ embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
86
+ e_mean = torch.mean(embed_onehot, dim=[0,1])
87
+ e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
88
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
89
+
90
+ return (
91
+ torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
92
+ feat.transpose(1,2) if self.transpose else feat,
93
+ perplexity,
94
+ None,
95
+ ind.transpose(1,2) if self.transpose else ind,
96
+ )
97
+
98
+ class DVAEDecoder(nn.Module):
99
+ def __init__(self, idim, odim,
100
+ n_layer = 12, bn_dim = 64, hidden = 256,
101
+ kernel = 7, dilation = 2, up = False
102
+ ):
103
+ super().__init__()
104
+ self.up = up
105
+ self.conv_in = nn.Sequential(
106
+ nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
107
+ nn.Conv1d(bn_dim, hidden, 3, 1, 1)
108
+ )
109
+ self.decoder_block = nn.ModuleList([
110
+ ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
111
+ for _ in range(n_layer)])
112
+ self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
113
+
114
+ def forward(self, input, conditioning=None):
115
+ # B, T, C
116
+ x = input.transpose(1, 2)
117
+ x = self.conv_in(x)
118
+ for f in self.decoder_block:
119
+ x = f(x, conditioning)
120
+
121
+ x = self.conv_out(x)
122
+ return x.transpose(1, 2)
123
+
124
+
125
+ class DVAE(nn.Module):
126
+ def __init__(
127
+ self, decoder_config, vq_config, dim=512
128
+ ):
129
+ super().__init__()
130
+ self.register_buffer('coef', torch.randn(1, 100, 1))
131
+
132
+ self.decoder = DVAEDecoder(**decoder_config)
133
+ self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
134
+ if vq_config is not None:
135
+ self.vq_layer = GFSQ(**vq_config)
136
+ else:
137
+ self.vq_layer = None
138
+
139
+ def forward(self, inp):
140
+
141
+ if self.vq_layer is not None:
142
+ vq_feats = self.vq_layer._embed(inp)
143
+ else:
144
+ vq_feats = inp.detach().clone()
145
+
146
+ temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
147
+ temp = torch.stack(temp, -1)
148
+ vq_feats = temp.reshape(*temp.shape[:2], -1)
149
+
150
+ vq_feats = vq_feats.transpose(1, 2)
151
+ dec_out = self.decoder(input=vq_feats)
152
+ dec_out = self.out_conv(dec_out.transpose(1, 2))
153
+ mel = dec_out * self.coef
154
+
155
+ return mel
Chat2TTS/model/gpt.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+
4
+ import logging
5
+ from tqdm import tqdm
6
+ from einops import rearrange
7
+ from transformers.cache_utils import Cache
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.nn.utils.parametrize as P
13
+ from torch.nn.utils.parametrizations import weight_norm
14
+ from transformers import LlamaModel, LlamaConfig
15
+
16
+
17
+ class LlamaMLP(nn.Module):
18
+ def __init__(self, hidden_size, intermediate_size):
19
+ super().__init__()
20
+ self.hidden_size = hidden_size
21
+ self.intermediate_size = intermediate_size
22
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
23
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
24
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
25
+ self.act_fn = F.silu
26
+
27
+ def forward(self, x):
28
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
29
+ return down_proj
30
+
31
+
32
+ class GPT_warpper(nn.Module):
33
+ def __init__(
34
+ self,
35
+ gpt_config,
36
+ num_audio_tokens,
37
+ num_text_tokens,
38
+ num_vq=4,
39
+ **kwargs,
40
+ ):
41
+ super().__init__()
42
+
43
+ self.logger = logging.getLogger(__name__)
44
+ self.gpt = self.build_model(gpt_config)
45
+ self.model_dim = self.gpt.config.hidden_size
46
+
47
+ self.num_vq = num_vq
48
+ self.emb_code = nn.ModuleList([nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)])
49
+ self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
50
+ self.head_text = weight_norm(nn.Linear(self.model_dim, num_text_tokens, bias=False), name='weight')
51
+ self.head_code = nn.ModuleList([weight_norm(nn.Linear(self.model_dim, num_audio_tokens, bias=False), name='weight') for i in range(self.num_vq)])
52
+
53
+ def build_model(self, config):
54
+
55
+ configuration = LlamaConfig(**config)
56
+ model = LlamaModel(configuration)
57
+ del model.embed_tokens
58
+
59
+ return model
60
+
61
+ def get_emb(self, input_ids, text_mask, **kwargs):
62
+
63
+ emb_text = self.emb_text(input_ids[text_mask][:, 0])
64
+
65
+ emb_code = [self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)]
66
+ emb_code = torch.stack(emb_code, 2).sum(2)
67
+
68
+ emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
69
+ emb[text_mask] = emb_text
70
+ emb[~text_mask] = emb_code.to(emb.dtype)
71
+
72
+ return emb
73
+
74
+ def prepare_inputs_for_generation(
75
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
76
+ ):
77
+ # With static cache, the `past_key_values` is None
78
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
79
+ has_static_cache = False
80
+ if past_key_values is None:
81
+ past_key_values = getattr(self.gpt.layers[0].self_attn, "past_key_value", None)
82
+ has_static_cache = past_key_values is not None
83
+
84
+ past_length = 0
85
+ if past_key_values is not None:
86
+ if isinstance(past_key_values, Cache):
87
+ past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
88
+ max_cache_length = (
89
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
90
+ if past_key_values.get_max_length() is not None
91
+ else None
92
+ )
93
+ cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
94
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
95
+ else:
96
+ cache_length = past_length = past_key_values[0][0].shape[2]
97
+ max_cache_length = None
98
+
99
+ # Keep only the unprocessed tokens:
100
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
101
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
102
+ # input)
103
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
104
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
105
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
106
+ # input_ids based on the past_length.
107
+ elif past_length < input_ids.shape[1]:
108
+ input_ids = input_ids[:, past_length:]
109
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
110
+
111
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
112
+ if (
113
+ max_cache_length is not None
114
+ and attention_mask is not None
115
+ and cache_length + input_ids.shape[1] > max_cache_length
116
+ ):
117
+ attention_mask = attention_mask[:, -max_cache_length:]
118
+
119
+ position_ids = kwargs.get("position_ids", None)
120
+ if attention_mask is not None and position_ids is None:
121
+ # create position_ids on the fly for batch generation
122
+ position_ids = attention_mask.long().cumsum(-1) - 1
123
+ position_ids.masked_fill_(attention_mask == 0, 1)
124
+ if past_key_values:
125
+ position_ids = position_ids[:, -input_ids.shape[1] :]
126
+
127
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
128
+ if inputs_embeds is not None and past_key_values is None:
129
+ model_inputs = {"inputs_embeds": inputs_embeds}
130
+ else:
131
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
132
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
133
+ # TODO: use `next_tokens` directly instead.
134
+ model_inputs = {"input_ids": input_ids.contiguous()}
135
+
136
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
137
+ if cache_position is None:
138
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
139
+ else:
140
+ cache_position = cache_position[-input_length:]
141
+
142
+ if has_static_cache:
143
+ past_key_values = None
144
+
145
+ model_inputs.update(
146
+ {
147
+ "position_ids": position_ids,
148
+ "cache_position": cache_position,
149
+ "past_key_values": past_key_values,
150
+ "use_cache": kwargs.get("use_cache"),
151
+ "attention_mask": attention_mask,
152
+ }
153
+ )
154
+ return model_inputs
155
+
156
+ def generate(
157
+ self,
158
+ emb,
159
+ inputs_ids,
160
+ temperature,
161
+ eos_token,
162
+ attention_mask = None,
163
+ max_new_token = 2048,
164
+ min_new_token = 0,
165
+ LogitsWarpers = [],
166
+ LogitsProcessors = [],
167
+ infer_text=False,
168
+ return_attn=False,
169
+ return_hidden=False,
170
+ ):
171
+
172
+ with torch.no_grad():
173
+
174
+ attentions = []
175
+ hiddens = []
176
+
177
+ start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long)
178
+ finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
179
+
180
+ temperature = temperature[None].expand(inputs_ids.shape[0], -1)
181
+ temperature = rearrange(temperature, "b n -> (b n) 1")
182
+
183
+ attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device)
184
+ if attention_mask is not None:
185
+ attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
186
+
187
+ for i in tqdm(range(max_new_token)):
188
+
189
+ model_input = self.prepare_inputs_for_generation(inputs_ids,
190
+ outputs.past_key_values if i!=0 else None,
191
+ attention_mask_cache[:, :inputs_ids.shape[1]], use_cache=True)
192
+
193
+ if i == 0:
194
+ model_input['inputs_embeds'] = emb
195
+ else:
196
+ if infer_text:
197
+ model_input['inputs_embeds'] = self.emb_text(model_input['input_ids'][:,:,0])
198
+ else:
199
+ code_emb = [self.emb_code[i](model_input['input_ids'][:,:,i]) for i in range(self.num_vq)]
200
+ model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3)
201
+
202
+ model_input['input_ids'] = None
203
+ outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
204
+ attentions.append(outputs.attentions)
205
+ hidden_states = outputs[0] # 🐻
206
+ if return_hidden:
207
+ hiddens.append(hidden_states[:, -1])
208
+
209
+ with P.cached():
210
+ if infer_text:
211
+ logits = self.head_text(hidden_states)
212
+ else:
213
+ logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
214
+
215
+ logits = logits[:, -1].float()
216
+
217
+ if not infer_text:
218
+ logits = rearrange(logits, "b c n -> (b n) c")
219
+ logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
220
+ else:
221
+ logits_token = inputs_ids[:, start_idx:, 0]
222
+
223
+ logits = logits / temperature
224
+
225
+ for logitsProcessors in LogitsProcessors:
226
+ logits = logitsProcessors(logits_token, logits)
227
+
228
+ for logitsWarpers in LogitsWarpers:
229
+ logits = logitsWarpers(logits_token, logits)
230
+
231
+ if i < min_new_token:
232
+ logits[:, eos_token] = -torch.inf
233
+
234
+ scores = F.softmax(logits, dim=-1)
235
+
236
+ idx_next = torch.multinomial(scores, num_samples=1)
237
+
238
+ if not infer_text:
239
+ idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
240
+ finish = finish | (idx_next == eos_token).any(1)
241
+ inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
242
+ else:
243
+ finish = finish | (idx_next == eos_token).any(1)
244
+ inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq)], 1)
245
+
246
+ end_idx = end_idx + (~finish).int()
247
+
248
+ if finish.all():
249
+ break
250
+
251
+ inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
252
+ inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
253
+
254
+ if return_hidden:
255
+ hiddens = torch.stack(hiddens, 1)
256
+ hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
257
+
258
+ if not finish.all():
259
+ self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
260
+
261
+ return {
262
+ 'ids': inputs_ids,
263
+ 'attentions': attentions,
264
+ 'hiddens':hiddens,
265
+ }
Chat2TTS/utils/gpu_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import logging
4
+
5
+ def select_device(min_memory = 2048):
6
+ logger = logging.getLogger(__name__)
7
+ if torch.cuda.is_available():
8
+ available_gpus = []
9
+ for i in range(torch.cuda.device_count()):
10
+ props = torch.cuda.get_device_properties(i)
11
+ free_memory = props.total_memory - torch.cuda.memory_reserved(i)
12
+ available_gpus.append((i, free_memory))
13
+ selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1])
14
+ device = torch.device(f'cuda:{selected_gpu}')
15
+ free_memory_mb = max_free_memory / (1024 * 1024)
16
+ if free_memory_mb < min_memory:
17
+ logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
18
+ device = torch.device('cpu')
19
+ else:
20
+ logger.log(logging.WARNING, f'No GPU found, use CPU instead')
21
+ device = torch.device('cpu')
22
+
23
+ return device
Chat2TTS/utils/infer_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class CustomRepetitionPenaltyLogitsProcessorRepeat():
7
+
8
+ def __init__(self, penalty: float, max_input_ids, past_window):
9
+ if not isinstance(penalty, float) or not (penalty > 0):
10
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
11
+
12
+ self.penalty = penalty
13
+ self.max_input_ids = max_input_ids
14
+ self.past_window = past_window
15
+
16
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
17
+
18
+ input_ids = input_ids[:, -self.past_window:]
19
+ freq = F.one_hot(input_ids, scores.size(1)).sum(1)
20
+ freq[self.max_input_ids:] = 0
21
+ alpha = self.penalty**freq
22
+ scores = torch.where(scores < 0, scores*alpha, scores/alpha)
23
+
24
+ return scores
25
+
26
+ class CustomRepetitionPenaltyLogitsProcessor():
27
+
28
+ def __init__(self, penalty: float, max_input_ids, past_window):
29
+ if not isinstance(penalty, float) or not (penalty > 0):
30
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
31
+
32
+ self.penalty = penalty
33
+ self.max_input_ids = max_input_ids
34
+ self.past_window = past_window
35
+
36
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
37
+
38
+ input_ids = input_ids[:, -self.past_window:]
39
+ score = torch.gather(scores, 1, input_ids)
40
+ _score = score.detach().clone()
41
+ score = torch.where(score < 0, score * self.penalty, score / self.penalty)
42
+ score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
43
+ scores.scatter_(1, input_ids, score)
44
+
45
+ return scores
Chat2TTS/utils/io_utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import logging
4
+
5
+ def get_latest_modified_file(directory):
6
+ logger = logging.getLogger(__name__)
7
+
8
+ files = [os.path.join(directory, f) for f in os.listdir(directory)]
9
+ if not files:
10
+ logger.log(logging.WARNING, f'No files found in the directory: {directory}')
11
+ return None
12
+ latest_file = max(files, key=os.path.getmtime)
13
+
14
+ return latest_file
tool/gpu.py CHANGED
@@ -1,12 +1,11 @@
1
  import torch
2
  import os, sys
3
- import spaces
4
 
5
  if sys.platform == "darwin":
6
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
7
  now_dir = os.getcwd()
8
  sys.path.append(now_dir)
9
- from tool.logger import get_logger
10
 
11
  logger = get_logger("gpu")
12
 
@@ -36,13 +35,13 @@ def select_device(min_memory=2047, experimental=False):
36
  """
37
  if experimental:
38
  # For Apple M1/M2 chips with Metal Performance Shaders
39
- logger.get_logger().warn("experimantal: found apple GPU, using MPS.")
40
  device = torch.device("mps")
41
  else:
42
- logger.get_logger().info("found Apple GPU, but use CPU.")
43
  device = torch.device("cpu")
44
  else:
45
- logger.get_logger().warning("no GPU found, use CPU instead")
46
  device = torch.device("cpu")
47
 
48
  return device
 
1
  import torch
2
  import os, sys
 
3
 
4
  if sys.platform == "darwin":
5
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
6
  now_dir = os.getcwd()
7
  sys.path.append(now_dir)
8
+ from .logger.log import get_logger
9
 
10
  logger = get_logger("gpu")
11
 
 
35
  """
36
  if experimental:
37
  # For Apple M1/M2 chips with Metal Performance Shaders
38
+ logger.warn("experimantal: found apple GPU, using MPS.")
39
  device = torch.device("mps")
40
  else:
41
+ logger.info("found Apple GPU, but use CPU.")
42
  device = torch.device("cpu")
43
  else:
44
+ logger.warning("no GPU found, use CPU instead")
45
  device = torch.device("cpu")
46
 
47
  return device
web/app.py CHANGED
@@ -10,7 +10,7 @@ from tool.logger import get_logger
10
  from tool.func import *
11
  from tool.np import *
12
  from tool.gpu import select_device
13
- import ChatTTS
14
  import argparse
15
  import torch._dynamo
16
 
@@ -21,7 +21,7 @@ torch._dynamo.config.suppress_errors = True
21
  logger = get_logger("app")
22
 
23
  # Initialize and load the model:
24
- chat = ChatTTS.Chat()
25
 
26
 
27
  def init_chat(args):
@@ -38,7 +38,7 @@ def init_chat(args):
38
 
39
  logger.info("loading ChatTTS device :" + str(device))
40
 
41
- if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts", device=device):
42
  print("Models loaded successfully.")
43
  logger.info("Models loaded successfully.")
44
  else:
@@ -203,7 +203,7 @@ def get_chat_infer_audio(chat_txt,
203
  spk_emb_text):
204
  logger.info("========开始生成音频文件=====")
205
  #音频参数设置
206
- params_infer_code = ChatTTS.Chat.InferCodeParams(
207
  spk_emb=spk_emb_text, # add sampled speaker
208
  temperature=temperature_slider, # using custom temperature
209
  top_P=top_p_slider, # top P decode
@@ -227,7 +227,7 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
227
  logger.info("========文本内容无需优化=====")
228
  return text
229
 
230
- params_refine_text = ChatTTS.Chat.RefineTextParams(
231
  prompt='[oral_2][laugh_0][break_6]',
232
  )
233
 
 
10
  from tool.func import *
11
  from tool.np import *
12
  from tool.gpu import select_device
13
+ import Chat2TTS
14
  import argparse
15
  import torch._dynamo
16
 
 
21
  logger = get_logger("app")
22
 
23
  # Initialize and load the model:
24
+ chat = Chat2TTS.Chat()
25
 
26
 
27
  def init_chat(args):
 
38
 
39
  logger.info("loading ChatTTS device :" + str(device))
40
 
41
+ if chat.load_models(source=source, local_path="D:\\chenjgspace\\ai-model\\chattts"):
42
  print("Models loaded successfully.")
43
  logger.info("Models loaded successfully.")
44
  else:
 
203
  spk_emb_text):
204
  logger.info("========开始生成音频文件=====")
205
  #音频参数设置
206
+ params_infer_code = Chat2TTS.Chat.InferCodeParams(
207
  spk_emb=spk_emb_text, # add sampled speaker
208
  temperature=temperature_slider, # using custom temperature
209
  top_P=top_p_slider, # top P decode
 
227
  logger.info("========文本内容无需优化=====")
228
  return text
229
 
230
+ params_refine_text = Chat2TTS.Chat.RefineTextParams(
231
  prompt='[oral_2][laugh_0][break_6]',
232
  )
233
 
web/app_sc.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import spaces
3
+
4
+ if sys.platform == "darwin":
5
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
6
+ now_dir = os.getcwd()
7
+ sys.path.append(now_dir)
8
+
9
+ from tool.logger import get_logger
10
+ from tool.func import *
11
+ from tool.np import *
12
+ from tool.gpu import select_device
13
+ from tool.ctx import TorchSeedContext
14
+ import ChatTTS
15
+ import argparse
16
+ import torch._dynamo
17
+
18
+ torch._dynamo.config.suppress_errors = True
19
+
20
+
21
+
22
+ logger = get_logger("app")
23
+
24
+ # Initialize and load the model:
25
+ chat = ChatTTS.Chat()
26
+
27
+
28
+ def init_chat(args):
29
+ global chat
30
+ source = "custom"
31
+ # 获取启动模式
32
+ MODEL = os.getenv('MODEL')
33
+ logger.info("loading ChatTTS model..., start MODEL:" + str(MODEL))
34
+ # huggingface 部署模式下,模型则直接使用hf的模型数据
35
+ if MODEL == "HF":
36
+ source = "huggingface"
37
+
38
+ device=select_device()
39
+
40
+ logger.info("loading ChatTTS device :" + str(device))
41
+
42
+ if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts",device=device):
43
+ print("Models loaded successfully.")
44
+ logger.info("Models loaded successfully.")
45
+ else:
46
+ logger.error("=========Models load failed.")
47
+ sys.exit(1)
48
+
49
+
50
+ def main(args):
51
+ with gr.Blocks() as demo:
52
+ gr.Markdown("# ChatTTS demo")
53
+ with gr.Row():
54
+ with gr.Column(scale=1):
55
+ text_input = gr.Textbox(
56
+ label="转换内容",
57
+ lines=4,
58
+ max_lines=4,
59
+ placeholder="Please Input Text...",
60
+ value="柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。",
61
+ interactive=True,
62
+ )
63
+ with gr.Row():
64
+ refine_text_checkBox = gr.Checkbox(
65
+ label="是否优化文本,如是则先对文本内容做优化分词",
66
+ interactive=True,
67
+ value=True
68
+ )
69
+ temperature_slider = gr.Slider(
70
+ minimum=0.00001,
71
+ maximum=1.0,
72
+ step=0.00001,
73
+ value=0.3,
74
+ interactive=True,
75
+ label="模型 Temperature 参数设置"
76
+ )
77
+ top_p_slider = gr.Slider(
78
+ minimum=0.1,
79
+ maximum=0.9,
80
+ step=0.05,
81
+ value=0.7,
82
+ label="模型 top_P 参数设置",
83
+ interactive=True,
84
+ )
85
+ top_k_slider = gr.Slider(
86
+ minimum=1,
87
+ maximum=20,
88
+ step=1,
89
+ value=20,
90
+ label="模型 top_K 参数设置",
91
+ interactive=True,
92
+ )
93
+ with gr.Row():
94
+ voice_selection = gr.Dropdown(
95
+ label="Timbre",
96
+ choices=voices.keys(),
97
+ value="旁白",
98
+ interactive=True,
99
+ show_label=True
100
+ )
101
+ audio_seed_input = gr.Number(
102
+ value=2,
103
+ label="音色种子",
104
+ interactive=True,
105
+ minimum=seed_min,
106
+ maximum=seed_max,
107
+ )
108
+ generate_audio_seed = gr.Button("随机生成音色种子", interactive=True)
109
+ text_seed_input = gr.Number(
110
+ value=42,
111
+ label="文本种子",
112
+ interactive=True,
113
+ minimum=seed_min,
114
+ maximum=seed_max,
115
+ )
116
+ generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
117
+
118
+ with gr.Row():
119
+ spk_emb_text = gr.Textbox(
120
+ label="Speaker Embedding",
121
+ max_lines=3,
122
+ show_copy_button=True,
123
+ interactive=False,
124
+ scale=2,
125
+
126
+ )
127
+ reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
128
+
129
+ with gr.Row():
130
+ generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
131
+
132
+ with gr.Row():
133
+ text_output = gr.Textbox(
134
+ label="输出文本",
135
+ interactive=False,
136
+ show_copy_button=True,
137
+ )
138
+
139
+ audio_output = gr.Audio(
140
+ label="输出音频",
141
+ value=None,
142
+ format="wav",
143
+ autoplay=False,
144
+ streaming=False,
145
+ interactive=False,
146
+ show_label=True,
147
+ waveform_options=gr.WaveformOptions(
148
+ sample_rate=24000,
149
+ ),
150
+ )
151
+ # 针对页面元素新增 监听事件
152
+ voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
153
+
154
+ audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
155
+
156
+ generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
157
+
158
+ generate_text_seed.click(fn=generate_seed,outputs=text_seed_input)
159
+
160
+ # reload_chat_button.click()
161
+
162
+ generate_button.click(fn=get_chat_infer_text,
163
+ inputs=[text_input,
164
+ text_seed_input,
165
+ refine_text_checkBox
166
+ ],
167
+ outputs=[text_output]
168
+ ).then(fn=get_chat_infer_audio,
169
+ inputs=[text_output,
170
+ temperature_slider,
171
+ top_p_slider,
172
+ top_k_slider,
173
+ audio_seed_input,
174
+ spk_emb_text
175
+ ],
176
+ outputs=[audio_output])
177
+ # 初始化 spk_emb_text 数值
178
+ spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
179
+ logger.info("元素初始化完成,启动gradio服务=======")
180
+
181
+ # 运行gradio服务
182
+ demo.launch()
183
+
184
+
185
+ '''
186
+ top_K: "top_K"(K个最高得分)是指在所有可能的生成结果中,模型会选取前K个得分最高的结果。这个设置常常用于基于概率的生成任务,例如语言模型中的词或句子生成。当你设置top_K为K时,你要求模型选择得分最高的K个选项,这样输出通常会有一定的多样性,但仍然是基于模型的前K个预测。
187
+
188
+ top_P: "top_P"(概率阈值)则是一个连续的值,而不是离散的整数。它代表的是一个概率阈值,模型会生成所有得分高于该阈值的概率的项目。换句话说,top_P会生成那些概率大于等于给定值的所有生成结果。这个设置更为灵活,可以根据实际需求调整生成内容的不确定性,高频选项被生成的概率较高,而低频可能性则可能根据阈值随机出现。
189
+
190
+ 在实际应用中,选择top_K还是top_P取决于具体任务需求,如是否希望生成内容有一定程度的多样化(top_K),还是希望生成的内容更接近于最可能发生的选项(top_P)。较高的top_P可能会引入更多的随机性和创新,而较低的top_K则更倾向于保守的选择。
191
+
192
+ spk_embedding(Speaker Embedding): 这个术语一般用于语音识别或者多说话者模型中。"spk_embedding"指的是每个说话人的身份或特征向量,或者说是用户标识的嵌入表示。在对话系统中,它能帮助模型区分不同的说话者,比如在多轮对话中区分是同一个用户的不同回复,或者是不同用户的交互。这个嵌入可能包含了说话人的个性、语调、口音等信息,有助于提高对话的连贯性和自然性。
193
+
194
+ temperature: 通常在语言模型的生成(如基于概率的softmax)中使用。"temperature"是一个正数,用于控制生成内容的随机性和多样性。当温度较低(如0.1)时,模型倾向于生成最可能的结果,文字更保守,少有创新;当温度较高(如1或更高)时,模型将更倾向于产生多样化的内容,但可能性较大的选项将被稀释。因此,temperature调整是一个常用的平衡方法,使得生成更具创造性或是更符合预期。
195
+
196
+ 简而言之,"spk_embedding"关注的是对话参与者的身份特征,而"temperature"是用于调整生成文本不确定性的一个超参数。
197
+ '''
198
+ #@spaces.GPU 使用chattts三方包 gpu模式没有跑通
199
+ def get_chat_infer_audio(chat_txt,
200
+ temperature_slider,
201
+ top_p_slider,
202
+ top_k_slider,
203
+ audio_seed_input,
204
+ spk_emb_text):
205
+ logger.info("========开始生成音频文件=====")
206
+ #音频参数设置
207
+ params_infer_code = ChatTTS.Chat.InferCodeParams(
208
+ spk_emb=spk_emb_text, # add sampled speaker
209
+ temperature=temperature_slider, # using custom temperature
210
+ top_P=top_p_slider, # top P decode
211
+ top_K=top_k_slider, # top K decode
212
+ )
213
+
214
+ with TorchSeedContext(audio_seed_input):
215
+ wav = chat.infer(
216
+ text=chat_txt,
217
+ skip_refine_text=True, #跳过文本优化
218
+ params_infer_code=params_infer_code,
219
+ )
220
+ yield 24000, float_to_int16(wav[0]).T
221
+
222
+ #@spaces.GPU 使用chattts三方包 gpu模式没有跑通
223
+ def get_chat_infer_text(text,seed,refine_text_checkBox):
224
+
225
+ logger.info("========开始优化文本内容2=====")
226
+ global chat
227
+ if not refine_text_checkBox:
228
+ logger.info("========文本内容无需优化=====")
229
+ return text
230
+
231
+ params_refine_text = ChatTTS.Chat.RefineTextParams(
232
+ prompt='[oral_2][laugh_0][break_6]',
233
+ )
234
+
235
+ with TorchSeedContext(seed):
236
+ chat_text = chat.infer(
237
+ text=text,
238
+ skip_refine_text=False,
239
+ refine_text_only=True, #仅返回优化后文本内容
240
+ params_refine_text=params_refine_text,
241
+ )
242
+
243
+ return chat_text[0] if isinstance(chat_text, list) else chat_text
244
+
245
+ @spaces.GPU
246
+ def on_audio_seed_change(audio_seed_input):
247
+ global chat
248
+ with TorchSeedContext(audio_seed_input):
249
+ rand_spk = chat.sample_random_speaker()
250
+ return rand_spk
251
+
252
+
253
+ if __name__ == "__main__":
254
+ parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
255
+ parser.add_argument(
256
+ "--server_name", type=str, default="0.0.0.0", help="server name"
257
+ )
258
+ parser.add_argument("--server_port", type=int, default=7860, help="server port")
259
+ parser.add_argument(
260
+ "--custom_path", type=str, default="D:\\chenjgspace\\ai-model\\chattts", help="custom model path"
261
+ )
262
+ parser.add_argument(
263
+ "--coef", type=str, default=None, help="custom dvae coefficient"
264
+ )
265
+ args = parser.parse_args()
266
+ init_chat(args)
267
+ main(args)