FireRedTTS / fireredtts /fireredtts.py
hhguo's picture
update
37ced70
import os
import json
import torch
from fireredtts.modules.gpt.gpt import GPT
from fireredtts.modules import Token2Wav, MelSpectrogramExtractor
from fireredtts.modules.tokenizer.tokenizer import VoiceBpeTokenizer
from fireredtts.modules.codec.speaker import SpeakerEmbedddingExtractor
from fireredtts.utils.utils import load_audio
import time
class FireRedTTS:
def __init__(self, config_path, pretrained_path, device="cuda"):
self.device = device
self.config = json.load(open(config_path))
self.gpt_path = os.path.join(pretrained_path, "fireredtts_gpt.pt")
self.token2wav_path = os.path.join(pretrained_path, "fireredtts_token2wav.pt")
self.speaker_extractor_path = os.path.join(
pretrained_path, "fireredtts_speaker.bin"
)
assert os.path.exists(self.token2wav_path)
assert os.path.exists(self.gpt_path)
assert os.path.exists(self.speaker_extractor_path)
# tokenizer;
self.text_tokenizer = VoiceBpeTokenizer()
# speaker ectractor
self.speaker_extractor = SpeakerEmbedddingExtractor(
ckpt_path=self.speaker_extractor_path, device=device
)
# load gpt model
self.gpt = GPT(
start_text_token=self.config["gpt"]["gpt_start_text_token"],
stop_text_token=self.config["gpt"]["gpt_stop_text_token"],
layers=self.config["gpt"]["gpt_layers"],
model_dim=self.config["gpt"]["gpt_n_model_channels"],
heads=self.config["gpt"]["gpt_n_heads"],
max_text_tokens=self.config["gpt"]["gpt_max_text_tokens"],
max_mel_tokens=self.config["gpt"]["gpt_max_audio_tokens"],
max_prompt_tokens=self.config["gpt"]["gpt_max_prompt_tokens"],
code_stride_len=self.config["gpt"]["gpt_code_stride_len"],
number_text_tokens=self.config["gpt"]["gpt_number_text_tokens"],
num_audio_tokens=self.config["gpt"]["gpt_num_audio_tokens"],
start_audio_token=self.config["gpt"]["gpt_start_audio_token"],
stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"],
)
sd = torch.load(self.gpt_path, map_location=device)["model"]
self.gpt.load_state_dict(sd, strict=True)
self.gpt = self.gpt.to(device=device)
self.gpt.eval()
self.gpt.init_gpt_for_inference(kv_cache=True)
# mel-spectrogram extractor
self.mel_extractor = MelSpectrogramExtractor()
# load token2wav model
self.token2wav = Token2Wav.init_from_config(self.config)
sd = torch.load(self.token2wav_path, map_location="cpu")
self.token2wav.load_state_dict(sd, strict=True)
self.token2wav.generator.remove_weight_norm()
self.token2wav.eval()
self.token2wav = self.token2wav.to(device)
def extract_spk_embeddings(self, prompt_wav):
_, _, audio_resampled = load_audio(audiopath=prompt_wav, sampling_rate=16000)
audio_len = torch.tensor(
data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False
)
# speaker embeddings [1,512]
spk_embeddings = self.speaker_extractor(
audio_resampled.to(device="cuda")
).unsqueeze(0)
return spk_embeddings
def do_gpt_inference(self, spk_gpt, text_tokens):
"""_summary_
Args:
spk_gpt (_type_): speaker embeddidng in gpt
text_tokens (_type_): text tokens
"""
with torch.no_grad():
gpt_codes = self.gpt.generate(
cond_latents=spk_gpt,
text_inputs=text_tokens,
input_tokens=None,
do_sample=True,
top_p=0.85,
top_k=30,
temperature=0.75,
num_return_sequences=9,
num_beams=1,
length_penalty=1.0,
repetition_penalty=2.0,
output_attentions=False,
)
seqs = []
EOS_TOKEN = self.config["gpt"]["gpt_stop_audio_token"]
for seq in gpt_codes:
index = (seq == EOS_TOKEN).nonzero(as_tuple=True)[0][0]
seq = seq[:index]
seqs.append(seq)
sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False)
gpt_codes = sorted_seqs[2].unsqueeze(0) # [1, len]
# sorted_len = [len(l) for l in sorted_seqs]
# print("---sorted_len:", sorted_len)
return gpt_codes
def synthesize(self, prompt_wav, text, lang="auto"):
"""_summary_
Args:
prompts_wav (_type_): prompts_wav path
text (_type_): text
lang (_type_): language of text
"""
# Currently only supports Chinese and English
assert lang in ["zh", "en", "auto"]
assert os.path.exists(prompt_wav)
# text to tokens
text_tokens = self.text_tokenizer.encode(text=text, lang=lang)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device)
assert text_tokens.shape[-1] < 400
# extract speaker embedding
spk_embeddings = self.extract_spk_embeddings(prompt_wav=prompt_wav).unsqueeze(0)
with torch.no_grad():
spk_gpt = self.gpt.reference_embedding(spk_embeddings)
# gpt inference
gpt_start_time = time.time()
gpt_codes = self.do_gpt_inference(spk_gpt=spk_gpt, text_tokens=text_tokens)
gpt_end_time = time.time()
gpt_dur = gpt_end_time - gpt_start_time
# prompt mel-spectrogram compute
prompt_mel = (
self.mel_extractor(wav_path=prompt_wav).unsqueeze(0).to(self.device)
)
# convert token to waveform (b=1, t)
voc_start_time = time.time()
rec_wavs = self.token2wav.inference(gpt_codes, prompt_mel, n_timesteps=10)
voc_end_time = time.time()
voc_dur = voc_end_time - voc_start_time
all_dur = voc_end_time - gpt_start_time
# rtf compute
# audio_dur = rec_wavs.shape[-1] / 24000
# rtf_gpt = gpt_dur / audio_dur
# rtf_voc = voc_dur / audio_dur
# rtf_all = all_dur / audio_dur
return rec_wavs