chat-tts / tool /func.py
chenjgtea
拆分gpu、cpu模式运行模式
a536c15
import gradio as gr
import random
import torch
import lzma
import numpy as np
import pybase16384 as b14
seed_min = 1
seed_max = 4294967295
seeds = {
"旁白": {"seed": 4444},
"中年女性": {"seed": 7869},
"年轻女性": {"seed": 6615},
"中年男性": {"seed": 4099},
"年轻男性": {"seed": 6653},
}
# 音色选项:用于预置合适的音色
voices = {
"旁白": {"seed": 2},
"Timbre1": {"seed": 1111},
"Timbre2": {"seed": 2222},
"Timbre3": {"seed": 3333},
"Timbre4": {"seed": 4444},
"Timbre5": {"seed": 5555},
"Timbre6": {"seed": 6666},
"Timbre7": {"seed": 7777},
"Timbre8": {"seed": 8888},
"Timbre9": {"seed": 9999},
}
def on_voice_change(vocie_selection):
return voices.get(vocie_selection)["seed"]
'''
随机生成种子
'''
def generate_seed():
return gr.update(value=random.randint(seed_min, seed_max))
'''
音频文件张量 编码
'''
@torch.no_grad()
def encode_prompt(prompt: torch.Tensor) -> str:
arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy()
shp = arr.shape
assert len(shp) == 2, "prompt must be a 2D tensor"
s = b14.encode_to_string(
np.array(shp, dtype="<u2").tobytes()
+ lzma.compress(
arr.astype("<u2").tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr
return s