sayashi commited on
Commit
426811c
·
1 Parent(s): 89b7aca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -2,6 +2,7 @@
2
  import time
3
  import gradio as gr
4
  import utils
 
5
  import commons
6
  from models import SynthesizerTrn
7
  from text import text_to_sequence
@@ -14,7 +15,7 @@ net_g_ms = SynthesizerTrn(
14
  hps_ms.train.segment_size // hps_ms.data.hop_length,
15
  n_speakers=hps_ms.data.n_speakers,
16
  **hps_ms.model)
17
- _ = net_g_ms.eval()
18
  speakers = hps_ms.speakers
19
  model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)
20
 
@@ -40,9 +41,9 @@ def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale):
40
  text = f"{text}"
41
  stn_tst, clean_text = get_text(text, hps_ms)
42
  with no_grad():
43
- x_tst = stn_tst.unsqueeze(0)
44
- x_tst_lengths = LongTensor([stn_tst.size(0)])
45
- speaker_id = LongTensor([speaker_id])
46
  audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
47
  length_scale=length_scale)[0][0, 0].data.float().numpy()
48
 
@@ -85,6 +86,12 @@ download_audio_js = """
85
  """
86
 
87
  if __name__ == '__main__':
 
 
 
 
 
 
88
  with gr.Blocks() as app:
89
  gr.Markdown(
90
  "# <center> VITS语音在线合成demo\n"
@@ -121,4 +128,4 @@ if __name__ == '__main__':
121
  lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
122
  with gr.TabItem("可用人物一览"):
123
  gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index")
124
- app.queue(concurrency_count=1).launch()
 
2
  import time
3
  import gradio as gr
4
  import utils
5
+ import argparse
6
  import commons
7
  from models import SynthesizerTrn
8
  from text import text_to_sequence
 
15
  hps_ms.train.segment_size // hps_ms.data.hop_length,
16
  n_speakers=hps_ms.data.n_speakers,
17
  **hps_ms.model)
18
+ _ = net_g_ms.eval().to(device)
19
  speakers = hps_ms.speakers
20
  model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)
21
 
 
41
  text = f"{text}"
42
  stn_tst, clean_text = get_text(text, hps_ms)
43
  with no_grad():
44
+ x_tst = stn_tst.unsqueeze(0).to(device)
45
+ x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
46
+ speaker_id = LongTensor([speaker_id]).to(device)
47
  audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
48
  length_scale=length_scale)[0][0, 0].data.float().numpy()
49
 
 
86
  """
87
 
88
  if __name__ == '__main__':
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument('--device', type=str, default='cpu')
91
+ parser.add_argument('--api', action="store_true", default=False)
92
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
93
+ args = parser.parse_args()
94
+ device = torch.device(args.device)
95
  with gr.Blocks() as app:
96
  gr.Markdown(
97
  "# <center> VITS语音在线合成demo\n"
 
128
  lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
129
  with gr.TabItem("可用人物一览"):
130
  gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index")
131
+ app.queue(concurrency_count=1, api_open=args.api).launch(share=args.share)