anzorq commited on
Commit
27e87f7
·
1 Parent(s): 1d3c496

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -17
app.py CHANGED
@@ -4,20 +4,27 @@ from TTS.utils.synthesizer import Synthesizer
4
  import gradio as gr
5
  import tempfile
6
 
7
- # Variables
8
  MAX_TXT_LEN = 800
9
- MODEL_DIR = "kbd-vits-tts-male"
10
- MODEL_URL = "https://huggingface.co/anzorq/kbd-vits-tts-male/resolve/main/checkpoint_56000.pth"
11
- CONFIG_URL = "https://huggingface.co/anzorq/kbd-vits-tts-male/resolve/main/config_35000.json"
 
 
12
 
13
- # Downloading model and config
14
- if not os.path.exists(MODEL_DIR):
15
- os.makedirs(MODEL_DIR)
16
- download_url(MODEL_URL, MODEL_DIR, "model.pth")
17
- download_url(CONFIG_URL, MODEL_DIR, "config.json")
 
 
 
 
18
 
 
 
19
 
20
- def tts(text: str):
21
  if len(text) > MAX_TXT_LEN:
22
  text = text[:MAX_TXT_LEN]
23
  print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
@@ -25,8 +32,10 @@ def tts(text: str):
25
 
26
  text = text.replace("I", "ӏ") #replace capital is with "Palochka" symbol
27
 
 
 
28
  # synthesize
29
- synthesizer = Synthesizer(f"{MODEL_DIR}/model.pth", f"{MODEL_DIR}/config.json")
30
  wavs = synthesizer.tts(text)
31
 
32
  # return output
@@ -34,16 +43,22 @@ def tts(text: str):
34
  synthesizer.save_wav(wavs, fp)
35
  return fp.name
36
 
37
- # Gradio interface
38
  iface = gr.Interface(
39
  fn=tts,
40
- inputs=gr.Textbox(
41
- label="Text",
42
- value="Default text here if you need it.",
43
- ),
 
 
 
 
 
 
 
44
  outputs=gr.Audio(label="Output", type='filepath'),
45
  title="KBD TTS",
46
  live=False
47
  )
48
 
49
- iface.launch(share=False)
 
4
  import gradio as gr
5
  import tempfile
6
 
 
7
  MAX_TXT_LEN = 800
8
+ BASE_DIR = "kbd-vits-tts-{}"
9
+ MALE_MODEL_URL = "https://huggingface.co/anzorq/kbd-vits-tts-male/resolve/main/checkpoint_56000.pth"
10
+ MALE_CONFIG_URL = "https://huggingface.co/anzorq/kbd-vits-tts-male/resolve/main/config_35000.json"
11
+ FEMALE_MODEL_URL = "https://huggingface.co/anzorq/kbd-vits-tts-female/resolve/main/best_model_56351.pth"
12
+ FEMALE_CONFIG_URL = "https://huggingface.co/anzorq/kbd-vits-tts-female/resolve/main/config.json"
13
 
14
+ def download_model_and_config(gender):
15
+ dir_path = BASE_DIR.format(gender)
16
+ if not os.path.exists(dir_path):
17
+ os.makedirs(dir_path)
18
+ model_url = MALE_MODEL_URL if gender == "male" else FEMALE_MODEL_URL
19
+ config_url = MALE_CONFIG_URL if gender == "male" else FEMALE_CONFIG_URL
20
+ download_url(model_url, dir_path, "model.pth")
21
+ download_url(config_url, dir_path, "config.json")
22
+ return dir_path
23
 
24
+ download_model_and_config("male")
25
+ download_model_and_config("female")
26
 
27
+ def tts(text: str, voice: str="Male"):
28
  if len(text) > MAX_TXT_LEN:
29
  text = text[:MAX_TXT_LEN]
30
  print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
 
32
 
33
  text = text.replace("I", "ӏ") #replace capital is with "Palochka" symbol
34
 
35
+ model_dir = BASE_DIR.format("male" if voice == "Male" else "female")
36
+
37
  # synthesize
38
+ synthesizer = Synthesizer(f"{model_dir}/model.pth", f"{model_dir}/config.json")
39
  wavs = synthesizer.tts(text)
40
 
41
  # return output
 
43
  synthesizer.save_wav(wavs, fp)
44
  return fp.name
45
 
 
46
  iface = gr.Interface(
47
  fn=tts,
48
+ inputs=[
49
+ gr.Textbox(
50
+ label="Text",
51
+ value="Default text here if you need it.",
52
+ ),
53
+ gr.Radio(
54
+ choices=["Male", "Female"],
55
+ value="Male", # Set Male as the default choice
56
+ label="Voice"
57
+ )
58
+ ],
59
  outputs=gr.Audio(label="Output", type='filepath'),
60
  title="KBD TTS",
61
  live=False
62
  )
63
 
64
+ iface.launch(share=False)