Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,14 +14,16 @@ from huggingface_hub import HfApi, hf_hub_download, snapshot_download
|
|
14 |
from TTS.tts.configs.xtts_config import XttsConfig
|
15 |
from TTS.tts.models.xtts import Xtts
|
16 |
from vinorm import TTSnorm
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
# download for mecab
|
19 |
os.system("python -m unidic download")
|
20 |
|
21 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
22 |
api = HfApi(token=HF_TOKEN)
|
23 |
|
24 |
-
# This will trigger downloading model
|
25 |
print("Downloading if not downloaded viXTTS")
|
26 |
checkpoint_dir = "model/"
|
27 |
repo_id = "capleaf/viXTTS"
|
@@ -75,7 +77,6 @@ def normalize_vietnamese_text(text):
|
|
75 |
|
76 |
|
77 |
def calculate_keep_len(text, lang):
|
78 |
-
"""Simple hack for short sentences"""
|
79 |
if lang in ["ja", "zh-cn"]:
|
80 |
return -1
|
81 |
|
@@ -90,17 +91,9 @@ def calculate_keep_len(text, lang):
|
|
90 |
|
91 |
|
92 |
@spaces.GPU(queue=False)
|
93 |
-
|
94 |
-
def predict(
|
95 |
-
prompt,
|
96 |
-
language,
|
97 |
-
audio_file_pth,
|
98 |
-
normalize_text=True,
|
99 |
-
):
|
100 |
if language not in supported_languages:
|
101 |
-
metrics_text = gr.Warning(
|
102 |
-
f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
|
103 |
-
)
|
104 |
|
105 |
return (None, metrics_text)
|
106 |
|
@@ -111,12 +104,7 @@ def predict(
|
|
111 |
return (None, metrics_text)
|
112 |
|
113 |
if len(prompt) > 250:
|
114 |
-
metrics_text = gr.Warning(
|
115 |
-
str(len(prompt))
|
116 |
-
+ " characters.\n"
|
117 |
-
+ "Your prompt is too long, please keep it under 250 characters\n"
|
118 |
-
+ "Văn bản quá dài, vui lòng giữ dưới 250 ký tự."
|
119 |
-
)
|
120 |
return (None, metrics_text)
|
121 |
|
122 |
try:
|
@@ -124,21 +112,11 @@ def predict(
|
|
124 |
t_latent = time.time()
|
125 |
|
126 |
try:
|
127 |
-
(
|
128 |
-
gpt_cond_latent,
|
129 |
-
speaker_embedding,
|
130 |
-
) = MODEL.get_conditioning_latents(
|
131 |
-
audio_path=speaker_wav,
|
132 |
-
gpt_cond_len=30,
|
133 |
-
gpt_cond_chunk_len=4,
|
134 |
-
max_ref_length=60,
|
135 |
-
)
|
136 |
|
137 |
except Exception as e:
|
138 |
print("Speaker encoding error", str(e))
|
139 |
-
metrics_text = gr.Warning(
|
140 |
-
"It appears something wrong with reference, did you unmute your microphone?"
|
141 |
-
)
|
142 |
return (None, metrics_text)
|
143 |
|
144 |
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
|
@@ -148,25 +126,14 @@ def predict(
|
|
148 |
|
149 |
print("I: Generating new audio...")
|
150 |
t0 = time.time()
|
151 |
-
out = MODEL.inference(
|
152 |
-
prompt,
|
153 |
-
language,
|
154 |
-
gpt_cond_latent,
|
155 |
-
speaker_embedding,
|
156 |
-
repetition_penalty=5.0,
|
157 |
-
temperature=0.75,
|
158 |
-
enable_text_splitting=True,
|
159 |
-
)
|
160 |
inference_time = time.time() - t0
|
161 |
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
|
162 |
-
metrics_text += (
|
163 |
-
f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
|
164 |
-
)
|
165 |
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
|
166 |
print(f"Real-time factor (RTF): {real_time_factor}")
|
167 |
metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
|
168 |
|
169 |
-
# Temporary hack for short sentences
|
170 |
keep_len = calculate_keep_len(prompt, language)
|
171 |
out["wav"] = out["wav"][:keep_len]
|
172 |
|
@@ -174,21 +141,12 @@ def predict(
|
|
174 |
|
175 |
except RuntimeError as e:
|
176 |
if "device-side assert" in str(e):
|
177 |
-
|
178 |
-
print(
|
179 |
-
f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
|
180 |
-
flush=True,
|
181 |
-
)
|
182 |
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
183 |
print("Cuda device-assert Runtime encountered need restart")
|
184 |
|
185 |
error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
|
186 |
-
error_data = [
|
187 |
-
error_time,
|
188 |
-
prompt,
|
189 |
-
language,
|
190 |
-
audio_file_pth,
|
191 |
-
]
|
192 |
error_data = [str(e) if type(e) != str else e for e in error_data]
|
193 |
print(error_data)
|
194 |
print(speaker_wav)
|
@@ -199,25 +157,13 @@ def predict(
|
|
199 |
filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
|
200 |
print("Writing error csv")
|
201 |
error_api = HfApi()
|
202 |
-
error_api.upload_file(
|
203 |
-
|
204 |
-
path_in_repo=filename,
|
205 |
-
repo_id="coqui/xtts-flagged-dataset",
|
206 |
-
repo_type="dataset",
|
207 |
-
)
|
208 |
-
|
209 |
-
# speaker_wav
|
210 |
print("Writing error reference audio")
|
211 |
speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
|
212 |
error_api = HfApi()
|
213 |
-
error_api.upload_file(
|
214 |
-
|
215 |
-
path_in_repo=speaker_filename,
|
216 |
-
repo_id="coqui/xtts-flagged-dataset",
|
217 |
-
repo_type="dataset",
|
218 |
-
)
|
219 |
-
|
220 |
-
# HF Space specific.. This error is unrecoverable need to restart space
|
221 |
space = api.get_space_runtime(repo_id=repo_id)
|
222 |
if space.stage != "BUILDING":
|
223 |
api.restart_space(repo_id=repo_id)
|
@@ -227,80 +173,41 @@ def predict(
|
|
227 |
else:
|
228 |
if "Failed to decode" in str(e):
|
229 |
print("Speaker encoding error", str(e))
|
230 |
-
metrics_text = gr.Warning(
|
231 |
-
metrics_text="It appears something wrong with reference, did you unmute your microphone?"
|
232 |
-
)
|
233 |
else:
|
234 |
print("RuntimeError: non device-side assert error:", str(e))
|
235 |
-
metrics_text = gr.Warning(
|
236 |
-
"Something unexpected happened please retry again."
|
237 |
-
)
|
238 |
return (None, metrics_text)
|
239 |
return ("output.wav", metrics_text)
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
with gr.Blocks(analytics_enabled=False) as demo:
|
243 |
with gr.Row():
|
244 |
with gr.Column():
|
245 |
-
gr.Markdown(
|
246 |
-
"""
|
247 |
# viXTTS Demo ✨
|
248 |
- Github: https://github.com/thinhlpg/vixtts-demo/
|
249 |
- viVoice: https://github.com/thinhlpg/viVoice
|
250 |
-
"""
|
251 |
-
)
|
252 |
with gr.Column():
|
253 |
-
# placeholder to align the image
|
254 |
pass
|
255 |
|
256 |
with gr.Row():
|
257 |
with gr.Column():
|
258 |
-
input_text_gr = gr.Textbox(
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
)
|
263 |
-
language_gr = gr.Dropdown(
|
264 |
-
label="Language (Ngôn ngữ)",
|
265 |
-
choices=[
|
266 |
-
"vi",
|
267 |
-
"en",
|
268 |
-
"es",
|
269 |
-
"fr",
|
270 |
-
"de",
|
271 |
-
"it",
|
272 |
-
"pt",
|
273 |
-
"pl",
|
274 |
-
"tr",
|
275 |
-
"ru",
|
276 |
-
"nl",
|
277 |
-
"cs",
|
278 |
-
"ar",
|
279 |
-
"zh-cn",
|
280 |
-
"ja",
|
281 |
-
"ko",
|
282 |
-
"hu",
|
283 |
-
"hi",
|
284 |
-
],
|
285 |
-
max_choices=1,
|
286 |
-
value="vi",
|
287 |
-
)
|
288 |
-
normalize_text = gr.Checkbox(
|
289 |
-
label="Chuẩn hóa văn bản tiếng Việt",
|
290 |
-
info="Normalize Vietnamese text",
|
291 |
-
value=True,
|
292 |
-
)
|
293 |
-
ref_gr = gr.Audio(
|
294 |
-
label="Reference Audio (Giọng mẫu)",
|
295 |
-
type="filepath",
|
296 |
-
value="model/samples/nu-luu-loat.wav",
|
297 |
-
)
|
298 |
-
tts_button = gr.Button(
|
299 |
-
"Đọc 🗣️🔥",
|
300 |
-
elem_id="send-btn",
|
301 |
-
visible=True,
|
302 |
-
variant="primary",
|
303 |
-
)
|
304 |
|
305 |
with gr.Column():
|
306 |
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
@@ -308,15 +215,10 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
308 |
|
309 |
tts_button.click(
|
310 |
predict,
|
311 |
-
[
|
312 |
-
input_text_gr,
|
313 |
-
language_gr,
|
314 |
-
ref_gr,
|
315 |
-
normalize_text,
|
316 |
-
],
|
317 |
outputs=[audio_gr, out_text_gr],
|
318 |
api_name="predict",
|
319 |
)
|
320 |
|
321 |
demo.queue()
|
322 |
-
demo.launch(debug=True, show_api=True, share=True)
|
|
|
14 |
from TTS.tts.configs.xtts_config import XttsConfig
|
15 |
from TTS.tts.models.xtts import Xtts
|
16 |
from vinorm import TTSnorm
|
17 |
+
from fastapi import FastAPI, File, UploadFile
|
18 |
+
from fastapi.responses import FileResponse
|
19 |
+
|
20 |
+
app = FastAPI()
|
21 |
|
|
|
22 |
os.system("python -m unidic download")
|
23 |
|
24 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
25 |
api = HfApi(token=HF_TOKEN)
|
26 |
|
|
|
27 |
print("Downloading if not downloaded viXTTS")
|
28 |
checkpoint_dir = "model/"
|
29 |
repo_id = "capleaf/viXTTS"
|
|
|
77 |
|
78 |
|
79 |
def calculate_keep_len(text, lang):
|
|
|
80 |
if lang in ["ja", "zh-cn"]:
|
81 |
return -1
|
82 |
|
|
|
91 |
|
92 |
|
93 |
@spaces.GPU(queue=False)
|
94 |
+
def predict(prompt, language, audio_file_pth, normalize_text=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if language not in supported_languages:
|
96 |
+
metrics_text = gr.Warning(f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown")
|
|
|
|
|
97 |
|
98 |
return (None, metrics_text)
|
99 |
|
|
|
104 |
return (None, metrics_text)
|
105 |
|
106 |
if len(prompt) > 250:
|
107 |
+
metrics_text = gr.Warning(str(len(prompt)) + " characters.\n" + "Your prompt is too long, please keep it under 250 characters\n" + "Văn bản quá dài, vui lòng giữ dưới 250 ký tự.")
|
|
|
|
|
|
|
|
|
|
|
108 |
return (None, metrics_text)
|
109 |
|
110 |
try:
|
|
|
112 |
t_latent = time.time()
|
113 |
|
114 |
try:
|
115 |
+
(gpt_cond_latent, speaker_embedding) = MODEL.get_conditioning_latents(audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
except Exception as e:
|
118 |
print("Speaker encoding error", str(e))
|
119 |
+
metrics_text = gr.Warning("It appears something wrong with reference, did you unmute your microphone?")
|
|
|
|
|
120 |
return (None, metrics_text)
|
121 |
|
122 |
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
|
|
|
126 |
|
127 |
print("I: Generating new audio...")
|
128 |
t0 = time.time()
|
129 |
+
out = MODEL.inference(prompt, language, gpt_cond_latent, speaker_embedding, repetition_penalty=5.0, temperature=0.75, enable_text_splitting=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
inference_time = time.time() - t0
|
131 |
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
|
132 |
+
metrics_text += f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
|
|
|
|
|
133 |
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
|
134 |
print(f"Real-time factor (RTF): {real_time_factor}")
|
135 |
metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
|
136 |
|
|
|
137 |
keep_len = calculate_keep_len(prompt, language)
|
138 |
out["wav"] = out["wav"][:keep_len]
|
139 |
|
|
|
141 |
|
142 |
except RuntimeError as e:
|
143 |
if "device-side assert" in str(e):
|
144 |
+
print(f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}", flush=True)
|
|
|
|
|
|
|
|
|
145 |
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
146 |
print("Cuda device-assert Runtime encountered need restart")
|
147 |
|
148 |
error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
|
149 |
+
error_data = [error_time, prompt, language, audio_file_pth]
|
|
|
|
|
|
|
|
|
|
|
150 |
error_data = [str(e) if type(e) != str else e for e in error_data]
|
151 |
print(error_data)
|
152 |
print(speaker_wav)
|
|
|
157 |
filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
|
158 |
print("Writing error csv")
|
159 |
error_api = HfApi()
|
160 |
+
error_api.upload_file(path_or_fileobj=csv_upload, path_in_repo=filename, repo_id="coqui/xtts-flagged-dataset", repo_type="dataset")
|
161 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
print("Writing error reference audio")
|
163 |
speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
|
164 |
error_api = HfApi()
|
165 |
+
error_api.upload_file(path_or_fileobj=speaker_wav, path_in_repo=speaker_filename, repo_id="coqui/xtts-flagged-dataset", repo_type="dataset")
|
166 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
space = api.get_space_runtime(repo_id=repo_id)
|
168 |
if space.stage != "BUILDING":
|
169 |
api.restart_space(repo_id=repo_id)
|
|
|
173 |
else:
|
174 |
if "Failed to decode" in str(e):
|
175 |
print("Speaker encoding error", str(e))
|
176 |
+
metrics_text = gr.Warning(metrics_text="It appears something wrong with reference, did you unmute your microphone?")
|
|
|
|
|
177 |
else:
|
178 |
print("RuntimeError: non device-side assert error:", str(e))
|
179 |
+
metrics_text = gr.Warning("Something unexpected happened please retry again.")
|
|
|
|
|
180 |
return (None, metrics_text)
|
181 |
return ("output.wav", metrics_text)
|
182 |
|
183 |
+
@app.post("/synthesize")
|
184 |
+
async def api_synthesize(prompt: str, language: str = "vi", audio_file: UploadFile = File(...)):
|
185 |
+
audio_file_path = f"temp_{uuid.uuid4()}.wav"
|
186 |
+
with open(audio_file_path, "wb") as f:
|
187 |
+
f.write(await audio_file.read())
|
188 |
+
|
189 |
+
audio_output_path, metrics_text = predict(prompt, language, audio_file_path)
|
190 |
+
|
191 |
+
return FileResponse(audio_output_path, media_type="audio/wav")
|
192 |
|
193 |
with gr.Blocks(analytics_enabled=False) as demo:
|
194 |
with gr.Row():
|
195 |
with gr.Column():
|
196 |
+
gr.Markdown("""
|
|
|
197 |
# viXTTS Demo ✨
|
198 |
- Github: https://github.com/thinhlpg/vixtts-demo/
|
199 |
- viVoice: https://github.com/thinhlpg/viVoice
|
200 |
+
""")
|
|
|
201 |
with gr.Column():
|
|
|
202 |
pass
|
203 |
|
204 |
with gr.Row():
|
205 |
with gr.Column():
|
206 |
+
input_text_gr = gr.Textbox(label="Text Prompt (Văn bản cần đọc)", info="Mỗi câu nên từ 10 từ trở lên. Tối đa 250 ký tự (khoảng 2 - 3 câu).", value="Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt.")
|
207 |
+
language_gr = gr.Dropdown(label="Language (Ngôn ngữ)", choices=["vi", "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi"], max_choices=1, value="vi")
|
208 |
+
normalize_text = gr.Checkbox(label="Chuẩn hóa văn bản tiếng Việt", info="Normalize Vietnamese text", value=True)
|
209 |
+
ref_gr = gr.Audio(label="Reference Audio (Giọng mẫu)", type="filepath", value="model/samples/nu-luu-loat.wav")
|
210 |
+
tts_button = gr.Button("Đọc 🗣️🔥", elem_id="send-btn", visible=True, variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
with gr.Column():
|
213 |
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
|
|
215 |
|
216 |
tts_button.click(
|
217 |
predict,
|
218 |
+
[input_text_gr, language_gr, ref_gr, normalize_text],
|
|
|
|
|
|
|
|
|
|
|
219 |
outputs=[audio_gr, out_text_gr],
|
220 |
api_name="predict",
|
221 |
)
|
222 |
|
223 |
demo.queue()
|
224 |
+
demo.launch(debug=True, show_api=True, share=True, server_name="0.0.0.0", server_port=7862)
|