Yhhxhfh commited on
Commit
2906d24
·
verified ·
1 Parent(s): 58b0a1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -135
app.py CHANGED
@@ -14,14 +14,16 @@ from huggingface_hub import HfApi, hf_hub_download, snapshot_download
14
  from TTS.tts.configs.xtts_config import XttsConfig
15
  from TTS.tts.models.xtts import Xtts
16
  from vinorm import TTSnorm
 
 
 
 
17
 
18
- # download for mecab
19
  os.system("python -m unidic download")
20
 
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  api = HfApi(token=HF_TOKEN)
23
 
24
- # This will trigger downloading model
25
  print("Downloading if not downloaded viXTTS")
26
  checkpoint_dir = "model/"
27
  repo_id = "capleaf/viXTTS"
@@ -75,7 +77,6 @@ def normalize_vietnamese_text(text):
75
 
76
 
77
  def calculate_keep_len(text, lang):
78
- """Simple hack for short sentences"""
79
  if lang in ["ja", "zh-cn"]:
80
  return -1
81
 
@@ -90,17 +91,9 @@ def calculate_keep_len(text, lang):
90
 
91
 
92
  @spaces.GPU(queue=False)
93
-
94
- def predict(
95
- prompt,
96
- language,
97
- audio_file_pth,
98
- normalize_text=True,
99
- ):
100
  if language not in supported_languages:
101
- metrics_text = gr.Warning(
102
- f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
103
- )
104
 
105
  return (None, metrics_text)
106
 
@@ -111,12 +104,7 @@ def predict(
111
  return (None, metrics_text)
112
 
113
  if len(prompt) > 250:
114
- metrics_text = gr.Warning(
115
- str(len(prompt))
116
- + " characters.\n"
117
- + "Your prompt is too long, please keep it under 250 characters\n"
118
- + "Văn bản quá dài, vui lòng giữ dưới 250 ký tự."
119
- )
120
  return (None, metrics_text)
121
 
122
  try:
@@ -124,21 +112,11 @@ def predict(
124
  t_latent = time.time()
125
 
126
  try:
127
- (
128
- gpt_cond_latent,
129
- speaker_embedding,
130
- ) = MODEL.get_conditioning_latents(
131
- audio_path=speaker_wav,
132
- gpt_cond_len=30,
133
- gpt_cond_chunk_len=4,
134
- max_ref_length=60,
135
- )
136
 
137
  except Exception as e:
138
  print("Speaker encoding error", str(e))
139
- metrics_text = gr.Warning(
140
- "It appears something wrong with reference, did you unmute your microphone?"
141
- )
142
  return (None, metrics_text)
143
 
144
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
@@ -148,25 +126,14 @@ def predict(
148
 
149
  print("I: Generating new audio...")
150
  t0 = time.time()
151
- out = MODEL.inference(
152
- prompt,
153
- language,
154
- gpt_cond_latent,
155
- speaker_embedding,
156
- repetition_penalty=5.0,
157
- temperature=0.75,
158
- enable_text_splitting=True,
159
- )
160
  inference_time = time.time() - t0
161
  print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
162
- metrics_text += (
163
- f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
164
- )
165
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
166
  print(f"Real-time factor (RTF): {real_time_factor}")
167
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
168
 
169
- # Temporary hack for short sentences
170
  keep_len = calculate_keep_len(prompt, language)
171
  out["wav"] = out["wav"][:keep_len]
172
 
@@ -174,21 +141,12 @@ def predict(
174
 
175
  except RuntimeError as e:
176
  if "device-side assert" in str(e):
177
- # cannot do anything on cuda device side error, need tor estart
178
- print(
179
- f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
180
- flush=True,
181
- )
182
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
183
  print("Cuda device-assert Runtime encountered need restart")
184
 
185
  error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
186
- error_data = [
187
- error_time,
188
- prompt,
189
- language,
190
- audio_file_pth,
191
- ]
192
  error_data = [str(e) if type(e) != str else e for e in error_data]
193
  print(error_data)
194
  print(speaker_wav)
@@ -199,25 +157,13 @@ def predict(
199
  filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
200
  print("Writing error csv")
201
  error_api = HfApi()
202
- error_api.upload_file(
203
- path_or_fileobj=csv_upload,
204
- path_in_repo=filename,
205
- repo_id="coqui/xtts-flagged-dataset",
206
- repo_type="dataset",
207
- )
208
-
209
- # speaker_wav
210
  print("Writing error reference audio")
211
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
212
  error_api = HfApi()
213
- error_api.upload_file(
214
- path_or_fileobj=speaker_wav,
215
- path_in_repo=speaker_filename,
216
- repo_id="coqui/xtts-flagged-dataset",
217
- repo_type="dataset",
218
- )
219
-
220
- # HF Space specific.. This error is unrecoverable need to restart space
221
  space = api.get_space_runtime(repo_id=repo_id)
222
  if space.stage != "BUILDING":
223
  api.restart_space(repo_id=repo_id)
@@ -227,80 +173,41 @@ def predict(
227
  else:
228
  if "Failed to decode" in str(e):
229
  print("Speaker encoding error", str(e))
230
- metrics_text = gr.Warning(
231
- metrics_text="It appears something wrong with reference, did you unmute your microphone?"
232
- )
233
  else:
234
  print("RuntimeError: non device-side assert error:", str(e))
235
- metrics_text = gr.Warning(
236
- "Something unexpected happened please retry again."
237
- )
238
  return (None, metrics_text)
239
  return ("output.wav", metrics_text)
240
 
 
 
 
 
 
 
 
 
 
241
 
242
  with gr.Blocks(analytics_enabled=False) as demo:
243
  with gr.Row():
244
  with gr.Column():
245
- gr.Markdown(
246
- """
247
  # viXTTS Demo ✨
248
  - Github: https://github.com/thinhlpg/vixtts-demo/
249
  - viVoice: https://github.com/thinhlpg/viVoice
250
- """
251
- )
252
  with gr.Column():
253
- # placeholder to align the image
254
  pass
255
 
256
  with gr.Row():
257
  with gr.Column():
258
- input_text_gr = gr.Textbox(
259
- label="Text Prompt (Văn bản cần đọc)",
260
- info="Mỗi câu nên từ 10 từ trở lên. Tối đa 250 tự (khoảng 2 - 3 câu).",
261
- value="Xin chào, tôi một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt.",
262
- )
263
- language_gr = gr.Dropdown(
264
- label="Language (Ngôn ngữ)",
265
- choices=[
266
- "vi",
267
- "en",
268
- "es",
269
- "fr",
270
- "de",
271
- "it",
272
- "pt",
273
- "pl",
274
- "tr",
275
- "ru",
276
- "nl",
277
- "cs",
278
- "ar",
279
- "zh-cn",
280
- "ja",
281
- "ko",
282
- "hu",
283
- "hi",
284
- ],
285
- max_choices=1,
286
- value="vi",
287
- )
288
- normalize_text = gr.Checkbox(
289
- label="Chuẩn hóa văn bản tiếng Việt",
290
- info="Normalize Vietnamese text",
291
- value=True,
292
- )
293
- ref_gr = gr.Audio(
294
- label="Reference Audio (Giọng mẫu)",
295
- type="filepath",
296
- value="model/samples/nu-luu-loat.wav",
297
- )
298
- tts_button = gr.Button(
299
- "Đọc 🗣️🔥",
300
- elem_id="send-btn",
301
- visible=True,
302
- variant="primary",
303
- )
304
 
305
  with gr.Column():
306
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
@@ -308,15 +215,10 @@ with gr.Blocks(analytics_enabled=False) as demo:
308
 
309
  tts_button.click(
310
  predict,
311
- [
312
- input_text_gr,
313
- language_gr,
314
- ref_gr,
315
- normalize_text,
316
- ],
317
  outputs=[audio_gr, out_text_gr],
318
  api_name="predict",
319
  )
320
 
321
  demo.queue()
322
- demo.launch(debug=True, show_api=True, share=True)
 
14
  from TTS.tts.configs.xtts_config import XttsConfig
15
  from TTS.tts.models.xtts import Xtts
16
  from vinorm import TTSnorm
17
+ from fastapi import FastAPI, File, UploadFile
18
+ from fastapi.responses import FileResponse
19
+
20
+ app = FastAPI()
21
 
 
22
  os.system("python -m unidic download")
23
 
24
  HF_TOKEN = os.environ.get("HF_TOKEN")
25
  api = HfApi(token=HF_TOKEN)
26
 
 
27
  print("Downloading if not downloaded viXTTS")
28
  checkpoint_dir = "model/"
29
  repo_id = "capleaf/viXTTS"
 
77
 
78
 
79
  def calculate_keep_len(text, lang):
 
80
  if lang in ["ja", "zh-cn"]:
81
  return -1
82
 
 
91
 
92
 
93
  @spaces.GPU(queue=False)
94
+ def predict(prompt, language, audio_file_pth, normalize_text=True):
 
 
 
 
 
 
95
  if language not in supported_languages:
96
+ metrics_text = gr.Warning(f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown")
 
 
97
 
98
  return (None, metrics_text)
99
 
 
104
  return (None, metrics_text)
105
 
106
  if len(prompt) > 250:
107
+ metrics_text = gr.Warning(str(len(prompt)) + " characters.\n" + "Your prompt is too long, please keep it under 250 characters\n" + "Văn bản quá dài, vui lòng giữ dưới 250 ký tự.")
 
 
 
 
 
108
  return (None, metrics_text)
109
 
110
  try:
 
112
  t_latent = time.time()
113
 
114
  try:
115
+ (gpt_cond_latent, speaker_embedding) = MODEL.get_conditioning_latents(audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60)
 
 
 
 
 
 
 
 
116
 
117
  except Exception as e:
118
  print("Speaker encoding error", str(e))
119
+ metrics_text = gr.Warning("It appears something wrong with reference, did you unmute your microphone?")
 
 
120
  return (None, metrics_text)
121
 
122
  prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
 
126
 
127
  print("I: Generating new audio...")
128
  t0 = time.time()
129
+ out = MODEL.inference(prompt, language, gpt_cond_latent, speaker_embedding, repetition_penalty=5.0, temperature=0.75, enable_text_splitting=True)
 
 
 
 
 
 
 
 
130
  inference_time = time.time() - t0
131
  print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
132
+ metrics_text += f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
 
 
133
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
134
  print(f"Real-time factor (RTF): {real_time_factor}")
135
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
136
 
 
137
  keep_len = calculate_keep_len(prompt, language)
138
  out["wav"] = out["wav"][:keep_len]
139
 
 
141
 
142
  except RuntimeError as e:
143
  if "device-side assert" in str(e):
144
+ print(f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}", flush=True)
 
 
 
 
145
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
146
  print("Cuda device-assert Runtime encountered need restart")
147
 
148
  error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
149
+ error_data = [error_time, prompt, language, audio_file_pth]
 
 
 
 
 
150
  error_data = [str(e) if type(e) != str else e for e in error_data]
151
  print(error_data)
152
  print(speaker_wav)
 
157
  filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
158
  print("Writing error csv")
159
  error_api = HfApi()
160
+ error_api.upload_file(path_or_fileobj=csv_upload, path_in_repo=filename, repo_id="coqui/xtts-flagged-dataset", repo_type="dataset")
161
+
 
 
 
 
 
 
162
  print("Writing error reference audio")
163
  speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
164
  error_api = HfApi()
165
+ error_api.upload_file(path_or_fileobj=speaker_wav, path_in_repo=speaker_filename, repo_id="coqui/xtts-flagged-dataset", repo_type="dataset")
166
+
 
 
 
 
 
 
167
  space = api.get_space_runtime(repo_id=repo_id)
168
  if space.stage != "BUILDING":
169
  api.restart_space(repo_id=repo_id)
 
173
  else:
174
  if "Failed to decode" in str(e):
175
  print("Speaker encoding error", str(e))
176
+ metrics_text = gr.Warning(metrics_text="It appears something wrong with reference, did you unmute your microphone?")
 
 
177
  else:
178
  print("RuntimeError: non device-side assert error:", str(e))
179
+ metrics_text = gr.Warning("Something unexpected happened please retry again.")
 
 
180
  return (None, metrics_text)
181
  return ("output.wav", metrics_text)
182
 
183
+ @app.post("/synthesize")
184
+ async def api_synthesize(prompt: str, language: str = "vi", audio_file: UploadFile = File(...)):
185
+ audio_file_path = f"temp_{uuid.uuid4()}.wav"
186
+ with open(audio_file_path, "wb") as f:
187
+ f.write(await audio_file.read())
188
+
189
+ audio_output_path, metrics_text = predict(prompt, language, audio_file_path)
190
+
191
+ return FileResponse(audio_output_path, media_type="audio/wav")
192
 
193
  with gr.Blocks(analytics_enabled=False) as demo:
194
  with gr.Row():
195
  with gr.Column():
196
+ gr.Markdown("""
 
197
  # viXTTS Demo ✨
198
  - Github: https://github.com/thinhlpg/vixtts-demo/
199
  - viVoice: https://github.com/thinhlpg/viVoice
200
+ """)
 
201
  with gr.Column():
 
202
  pass
203
 
204
  with gr.Row():
205
  with gr.Column():
206
+ input_text_gr = gr.Textbox(label="Text Prompt (Văn bản cần đọc)", info="Mỗi câu nên từ 10 từ trở lên. Tối đa 250 ký tự (khoảng 2 - 3 câu).", value="Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt.")
207
+ language_gr = gr.Dropdown(label="Language (Ngôn ngữ)", choices=["vi", "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi"], max_choices=1, value="vi")
208
+ normalize_text = gr.Checkbox(label="Chuẩn hóa văn bản tiếng Việt", info="Normalize Vietnamese text", value=True)
209
+ ref_gr = gr.Audio(label="Reference Audio (Giọng mẫu)", type="filepath", value="model/samples/nu-luu-loat.wav")
210
+ tts_button = gr.Button("Đọc 🗣️🔥", elem_id="send-btn", visible=True, variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  with gr.Column():
213
  audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
 
215
 
216
  tts_button.click(
217
  predict,
218
+ [input_text_gr, language_gr, ref_gr, normalize_text],
 
 
 
 
 
219
  outputs=[audio_gr, out_text_gr],
220
  api_name="predict",
221
  )
222
 
223
  demo.queue()
224
+ demo.launch(debug=True, show_api=True, share=True, server_name="0.0.0.0", server_port=7862)