L0SG commited on
Commit
78885f4
·
1 Parent(s): 0d9e821

Improve the Gradio UI demo (thanks @blaisewf on GitHub)

Browse files
Files changed (1) hide show
  1. app.py +138 -164
app.py CHANGED
@@ -1,11 +1,12 @@
 
 
 
1
  import spaces
2
  import gradio as gr
3
- from huggingface_hub import hf_hub_download
4
-
5
- import json
6
  import torch
7
  import os
8
- from env import AttrDict
9
  from meldataset import get_mel_spectrogram, MAX_WAV_VALUE
10
  from bigvgan import BigVGAN
11
  import librosa
@@ -14,22 +15,14 @@ from utils import plot_spectrogram
14
  import PIL
15
 
16
  if torch.cuda.is_available():
17
- device = torch.device('cuda')
18
  torch.backends.cudnn.benchmark = False
19
  print(f"using GPU")
20
  else:
21
- device = torch.device('cpu')
22
  print(f"using CPU")
23
 
24
 
25
- def load_checkpoint(filepath):
26
- assert os.path.isfile(filepath)
27
- print("Loading '{}'".format(filepath))
28
- checkpoint_dict = torch.load(filepath, map_location='cpu')
29
- print("Complete.")
30
- return checkpoint_dict
31
-
32
-
33
  def inference_gradio(input, model_choice): # input is audio waveform in [T, channel]
34
  sr, audio = input # unpack input to sampling rate and audio itself
35
  audio = np.transpose(audio) # transpose to [channel, T] for librosa
@@ -49,17 +42,11 @@ def inference_gradio(input, model_choice): # input is audio waveform in [T, cha
49
 
50
  spec_plot_gen = plot_spectrogram(spec_gen)
51
 
52
- output_audio = (model.h.sampling_rate, output) # tuple for gr.Audio output
53
 
54
  buffer = spec_plot_gen.canvas.buffer_rgba()
55
  output_image = PIL.Image.frombuffer(
56
- "RGBA",
57
- spec_plot_gen.canvas.get_width_height(),
58
- buffer,
59
- "raw",
60
- "RGBA",
61
- 0,
62
- 1
63
  )
64
 
65
  return output_audio, output_image
@@ -228,7 +215,7 @@ css = """
228
  }
229
  """
230
 
231
- ######################## script for loading the models ########################
232
 
233
  LIST_MODEL_ID = [
234
  "bigvgan_24khz_100band",
@@ -239,7 +226,7 @@ LIST_MODEL_ID = [
239
  "bigvgan_v2_22khz_80band_fmax8k_256x",
240
  "bigvgan_v2_24khz_100band_256x",
241
  "bigvgan_v2_44khz_128band_256x",
242
- "bigvgan_v2_44khz_128band_512x"
243
  ]
244
 
245
  dict_model = {}
@@ -247,16 +234,16 @@ dict_config = {}
247
 
248
  for model_name in LIST_MODEL_ID:
249
 
250
- generator = BigVGAN.from_pretrained('nvidia/'+model_name, token=os.environ['TOKEN'])
251
- generator.eval()
252
  generator.remove_weight_norm()
 
253
 
254
  dict_model[model_name] = generator
255
  dict_config[model_name] = generator.h
256
 
257
- ######################## script for gradio UI ########################
258
 
259
- iface = gr.Blocks(css=css)
260
 
261
  with iface:
262
  gr.HTML(
@@ -267,10 +254,10 @@ with iface:
267
  display: inline-flex;
268
  align-items: center;
269
  gap: 0.8rem;
270
- font-size: 1.75rem;
271
  "
272
  >
273
- <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
274
  BigVGAN: A Universal Neural Vocoder with Large-Scale Training
275
  </h1>
276
  </div>
@@ -299,14 +286,15 @@ with iface:
299
  <div>
300
  <h3>Model Overview</h3>
301
  BigVGAN is a universal neural vocoder model that generates audio waveforms using mel spectrogram as inputs.
302
- <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800" style="margin-top: 20px;"></center>
303
  </div>
304
  """
305
  )
 
306
 
307
- with gr.Group():
308
  model_choice = gr.Dropdown(
309
- label="Select the model. Default: bigvgan_v2_24khz_100band_256x",
 
310
  value="bigvgan_v2_24khz_100band_256x",
311
  choices=[m for m in LIST_MODEL_ID],
312
  interactive=True,
@@ -316,143 +304,129 @@ with iface:
316
  label="Input Audio", elem_id="input-audio", interactive=True
317
  )
318
 
319
- button = gr.Button("Submit")
320
 
321
- output_audio = gr.Audio(label="Output Audio", elem_id="output-audio")
322
- output_image = gr.Image(label="Output Mel Spectrogram", elem_id="output-image-gen")
 
 
 
 
323
 
324
- button.click(
325
- inference_gradio,
326
- inputs=[audio_input, model_choice],
327
- outputs=[output_audio, output_image],
328
- concurrency_limit=10,
329
- )
330
 
331
- gr.Examples(
 
332
  [
333
- [os.path.join(os.path.dirname(__file__), "examples/jensen_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
334
- [os.path.join(os.path.dirname(__file__), "examples/libritts_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
335
- [os.path.join(os.path.dirname(__file__), "examples/queen_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
336
- [os.path.join(os.path.dirname(__file__), "examples/dance_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
337
- [os.path.join(os.path.dirname(__file__), "examples/megalovania_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
338
- [os.path.join(os.path.dirname(__file__), "examples/hifitts_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
339
- [os.path.join(os.path.dirname(__file__), "examples/musdbhq_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
340
- [os.path.join(os.path.dirname(__file__), "examples/musiccaps1_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
341
- [os.path.join(os.path.dirname(__file__), "examples/musiccaps2_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
342
  ],
343
- fn=inference_gradio,
344
- inputs=[audio_input, model_choice],
345
- outputs=[output_audio, output_image]
346
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
- gr.HTML(
349
- """
350
- <table border="1" cellspacing="0" cellpadding="5">
351
- <thead>
352
- <tr>
353
- <th>Model Name</th>
354
- <th>Sampling Rate</th>
355
- <th>Mel band</th>
356
- <th>fmax</th>
357
- <th>Upsampling Ratio</th>
358
- <th>Parameters</th>
359
- <th>Dataset</th>
360
- <th>Fine-Tuned</th>
361
- </tr>
362
- </thead>
363
- <tbody>
364
- <tr>
365
- <td><a href="https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x">bigvgan_v2_44khz_128band_512x</a></td>
366
- <td>44 kHz</td>
367
- <td>128</td>
368
- <td>22050</td>
369
- <td>512</td>
370
- <td>122M</td>
371
- <td>Large-scale Compilation</td>
372
- <td>No</td>
373
- </tr>
374
- <tr>
375
- <td><a href="https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x">bigvgan_v2_44khz_128band_256x</a></td>
376
- <td>44 kHz</td>
377
- <td>128</td>
378
- <td>22050</td>
379
- <td>256</td>
380
- <td>112M</td>
381
- <td>Large-scale Compilation</td>
382
- <td>No</td>
383
- </tr>
384
- <tr>
385
- <td><a href="https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x">bigvgan_v2_24khz_100band_256x</a></td>
386
- <td>24 kHz</td>
387
- <td>100</td>
388
- <td>12000</td>
389
- <td>256</td>
390
- <td>112M</td>
391
- <td>Large-scale Compilation</td>
392
- <td>No</td>
393
- </tr>
394
- <tr>
395
- <td><a href="https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x">bigvgan_v2_22khz_80band_256x</a></td>
396
- <td>22 kHz</td>
397
- <td>80</td>
398
- <td>11025</td>
399
- <td>256</td>
400
- <td>112M</td>
401
- <td>Large-scale Compilation</td>
402
- <td>No</td>
403
- </tr>
404
- <tr>
405
- <td><a href="https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x">bigvgan_v2_22khz_80band_fmax8k_256x</a></td>
406
- <td>22 kHz</td>
407
- <td>80</td>
408
- <td>8000</td>
409
- <td>256</td>
410
- <td>112M</td>
411
- <td>Large-scale Compilation</td>
412
- <td>No</td>
413
- </tr>
414
- <tr>
415
- <td><a href="https://huggingface.co/nvidia/bigvgan_24khz_100band">bigvgan_24khz_100band</a></td>
416
- <td>24 kHz</td>
417
- <td>100</td>
418
- <td>12000</td>
419
- <td>256</td>
420
- <td>112M</td>
421
- <td>LibriTTS</td>
422
- <td>No</td>
423
- </tr>
424
- <tr>
425
- <td><a href="https://huggingface.co/nvidia/bigvgan_base_24khz_100band">bigvgan_base_24khz_100band</a></td>
426
- <td>24 kHz</td>
427
- <td>100</td>
428
- <td>12000</td>
429
- <td>256</td>
430
- <td>14M</td>
431
- <td>LibriTTS</td>
432
- <td>No</td>
433
- </tr>
434
- <tr>
435
- <td><a href="https://huggingface.co/nvidia/bigvgan_22khz_80band">bigvgan_22khz_80band</a></td>
436
- <td>22 kHz</td>
437
- <td>80</td>
438
- <td>8000</td>
439
- <td>256</td>
440
- <td>112M</td>
441
- <td>LibriTTS + VCTK + LJSpeech</td>
442
- <td>No</td>
443
- </tr>
444
- <tr>
445
- <td><a href="https://huggingface.co/nvidia/bigvgan_base_22khz_80band">bigvgan_base_22khz_80band</a></td>
446
- <td>22 kHz</td>
447
- <td>80</td>
448
- <td>8000</td>
449
- <td>256</td>
450
- <td>14M</td>
451
- <td>LibriTTS + VCTK + LJSpeech</td>
452
- <td>No</td>
453
- </tr>
454
- </tbody>
455
- </table>
456
  <p><b>NOTE: The v1 models are trained using speech audio datasets ONLY! (24kHz models: LibriTTS, 22kHz models: LibriTTS + VCTK + LJSpeech).</b></p>
457
  </div>
458
  """
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
  import spaces
5
  import gradio as gr
6
+ import pandas as pd
 
 
7
  import torch
8
  import os
9
+
10
  from meldataset import get_mel_spectrogram, MAX_WAV_VALUE
11
  from bigvgan import BigVGAN
12
  import librosa
 
15
  import PIL
16
 
17
  if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
  torch.backends.cudnn.benchmark = False
20
  print(f"using GPU")
21
  else:
22
+ device = torch.device("cpu")
23
  print(f"using CPU")
24
 
25
 
 
 
 
 
 
 
 
 
26
  def inference_gradio(input, model_choice): # input is audio waveform in [T, channel]
27
  sr, audio = input # unpack input to sampling rate and audio itself
28
  audio = np.transpose(audio) # transpose to [channel, T] for librosa
 
42
 
43
  spec_plot_gen = plot_spectrogram(spec_gen)
44
 
45
+ output_audio = (model.h.sampling_rate, output) # tuple for gr.Audio output
46
 
47
  buffer = spec_plot_gen.canvas.buffer_rgba()
48
  output_image = PIL.Image.frombuffer(
49
+ "RGBA", spec_plot_gen.canvas.get_width_height(), buffer, "raw", "RGBA", 0, 1
 
 
 
 
 
 
50
  )
51
 
52
  return output_audio, output_image
 
215
  }
216
  """
217
 
218
+ # Script for loading the models
219
 
220
  LIST_MODEL_ID = [
221
  "bigvgan_24khz_100band",
 
226
  "bigvgan_v2_22khz_80band_fmax8k_256x",
227
  "bigvgan_v2_24khz_100band_256x",
228
  "bigvgan_v2_44khz_128band_256x",
229
+ "bigvgan_v2_44khz_128band_512x",
230
  ]
231
 
232
  dict_model = {}
 
234
 
235
  for model_name in LIST_MODEL_ID:
236
 
237
+ generator = BigVGAN.from_pretrained("nvidia/" + model_name)
 
238
  generator.remove_weight_norm()
239
+ generator.eval()
240
 
241
  dict_model[model_name] = generator
242
  dict_config[model_name] = generator.h
243
 
244
+ # Script for Gradio UI
245
 
246
+ iface = gr.Blocks(css=css, title="BigVGAN - Demo")
247
 
248
  with iface:
249
  gr.HTML(
 
254
  display: inline-flex;
255
  align-items: center;
256
  gap: 0.8rem;
257
+ font-size: 1.5rem;
258
  "
259
  >
260
+ <h1 style="font-weight: 700; margin-bottom: 7px; line-height: normal;">
261
  BigVGAN: A Universal Neural Vocoder with Large-Scale Training
262
  </h1>
263
  </div>
 
286
  <div>
287
  <h3>Model Overview</h3>
288
  BigVGAN is a universal neural vocoder model that generates audio waveforms using mel spectrogram as inputs.
289
+ <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800" style="margin-top: 20px; border-radius: 15px;"></center>
290
  </div>
291
  """
292
  )
293
+ with gr.Accordion("Input"):
294
 
 
295
  model_choice = gr.Dropdown(
296
+ label="Select the model to use",
297
+ info="The default model is bigvgan_v2_24khz_100band_256x",
298
  value="bigvgan_v2_24khz_100band_256x",
299
  choices=[m for m in LIST_MODEL_ID],
300
  interactive=True,
 
304
  label="Input Audio", elem_id="input-audio", interactive=True
305
  )
306
 
307
+ button = gr.Button("Submit")
308
 
309
+ with gr.Accordion("Output"):
310
+ with gr.Column():
311
+ output_audio = gr.Audio(label="Output Audio", elem_id="output-audio")
312
+ output_image = gr.Image(
313
+ label="Output Mel Spectrogram", elem_id="output-image-gen"
314
+ )
315
 
316
+ button.click(
317
+ inference_gradio,
318
+ inputs=[audio_input, model_choice],
319
+ outputs=[output_audio, output_image],
320
+ concurrency_limit=10,
321
+ )
322
 
323
+ gr.Examples(
324
+ [
325
  [
326
+ os.path.join(os.path.dirname(__file__), "examples/jensen_24k.wav"),
327
+ "bigvgan_v2_24khz_100band_256x",
 
 
 
 
 
 
 
328
  ],
329
+ [
330
+ os.path.join(os.path.dirname(__file__), "examples/libritts_24k.wav"),
331
+ "bigvgan_v2_24khz_100band_256x",
332
+ ],
333
+ [
334
+ os.path.join(os.path.dirname(__file__), "examples/queen_24k.wav"),
335
+ "bigvgan_v2_24khz_100band_256x",
336
+ ],
337
+ [
338
+ os.path.join(os.path.dirname(__file__), "examples/dance_24k.wav"),
339
+ "bigvgan_v2_24khz_100band_256x",
340
+ ],
341
+ [
342
+ os.path.join(os.path.dirname(__file__), "examples/megalovania_24k.wav"),
343
+ "bigvgan_v2_24khz_100band_256x",
344
+ ],
345
+ [
346
+ os.path.join(os.path.dirname(__file__), "examples/hifitts_44k.wav"),
347
+ "bigvgan_v2_44khz_128band_256x",
348
+ ],
349
+ [
350
+ os.path.join(os.path.dirname(__file__), "examples/musdbhq_44k.wav"),
351
+ "bigvgan_v2_44khz_128band_256x",
352
+ ],
353
+ [
354
+ os.path.join(os.path.dirname(__file__), "examples/musiccaps1_44k.wav"),
355
+ "bigvgan_v2_44khz_128band_256x",
356
+ ],
357
+ [
358
+ os.path.join(os.path.dirname(__file__), "examples/musiccaps2_44k.wav"),
359
+ "bigvgan_v2_44khz_128band_256x",
360
+ ],
361
+ ],
362
+ fn=inference_gradio,
363
+ inputs=[audio_input, model_choice],
364
+ outputs=[output_audio, output_image],
365
+ )
366
 
367
+ # Define the data for the table
368
+ data = {
369
+ "Model Name": [
370
+ "bigvgan_v2_44khz_128band_512x",
371
+ "bigvgan_v2_44khz_128band_256x",
372
+ "bigvgan_v2_24khz_100band_256x",
373
+ "bigvgan_v2_22khz_80band_256x",
374
+ "bigvgan_v2_22khz_80band_fmax8k_256x",
375
+ "bigvgan_24khz_100band",
376
+ "bigvgan_base_24khz_100band",
377
+ "bigvgan_22khz_80band",
378
+ "bigvgan_base_22khz_80band",
379
+ ],
380
+ "Sampling Rate": [
381
+ "44 kHz",
382
+ "44 kHz",
383
+ "24 kHz",
384
+ "22 kHz",
385
+ "22 kHz",
386
+ "24 kHz",
387
+ "24 kHz",
388
+ "22 kHz",
389
+ "22 kHz",
390
+ ],
391
+ "Mel band": [128, 128, 100, 80, 80, 100, 100, 80, 80],
392
+ "fmax": [22050, 22050, 12000, 11025, 8000, 12000, 12000, 8000, 8000],
393
+ "Upsampling Ratio": [512, 256, 256, 256, 256, 256, 256, 256, 256],
394
+ "Parameters": [
395
+ "122M",
396
+ "112M",
397
+ "112M",
398
+ "112M",
399
+ "112M",
400
+ "112M",
401
+ "14M",
402
+ "112M",
403
+ "14M",
404
+ ],
405
+ "Dataset": [
406
+ "Large-scale Compilation",
407
+ "Large-scale Compilation",
408
+ "Large-scale Compilation",
409
+ "Large-scale Compilation",
410
+ "Large-scale Compilation",
411
+ "LibriTTS",
412
+ "LibriTTS",
413
+ "LibriTTS + VCTK + LJSpeech",
414
+ "LibriTTS + VCTK + LJSpeech",
415
+ ],
416
+ "Fine-Tuned": ["No", "No", "No", "No", "No", "No", "No", "No", "No"],
417
+ }
418
+
419
+ base_url = "https://huggingface.co/nvidia/"
420
+
421
+ df = pd.DataFrame(data)
422
+ df["Model Name"] = df["Model Name"].apply(
423
+ lambda x: f'<a href="{base_url}{x}">{x}</a>'
424
+ )
425
+
426
+ html_table = gr.HTML(
427
+ f"""
428
+ <div style="text-align: center;">
429
+ {df.to_html(index=False, escape=False, classes='border="1" cellspacing="0" cellpadding="5" style="margin-left: auto; margin-right: auto;')}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  <p><b>NOTE: The v1 models are trained using speech audio datasets ONLY! (24kHz models: LibriTTS, 22kHz models: LibriTTS + VCTK + LJSpeech).</b></p>
431
  </div>
432
  """