anzorq commited on
Commit
800e3a8
·
verified ·
1 Parent(s): 7cdf3f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -51,32 +51,22 @@ def wiener_filter(audio_tensor):
51
  return torch.tensor(filtered_audio, dtype=audio_tensor.dtype)
52
 
53
  @spaces.GPU
54
- def transcribe_speech(audio, apply_wiener_filter=False, apply_normalization=False, apply_spectral_gating=False, progress=gr.Progress()):
55
- if audio is None:
56
  return "No audio received.", None
57
 
58
  progress(0.1, desc="Preprocessing audio...")
59
- audio_tensor, original_sample_rate = torchaudio.load(audio)
60
  audio_tensor = preprocess_audio(audio_tensor, original_sample_rate, apply_normalization)
61
 
62
- if apply_wiener_filter:
63
- progress(0.3, desc="Applying Wiener filter...")
64
- audio_tensor = wiener_filter(audio_tensor)
65
-
66
- if apply_spectral_gating:
67
- progress(0.5, desc="Applying Spectral Gating filter...")
68
- audio_tensor = spectral_gating(audio_tensor)
69
-
70
  progress(0.7, desc="Transcribing audio...")
71
  audio_np = audio_tensor.numpy().squeeze()
72
  transcription = pipe(audio_np, chunk_length_s=10)['text']
73
  transcription = replace_symbols_back(transcription)
74
 
75
- audio_np = audio_tensor.numpy().squeeze()
76
- sf.write("temp_audio.wav", audio_np, 16000, subtype='PCM_16')
77
-
78
- return transcription, "temp_audio.wav"
79
 
 
80
  def transcribe_from_youtube(url, apply_wiener_filter, apply_normalization, apply_spectral_gating, progress=gr.Progress()):
81
  progress(0, "Downloading YouTube audio...")
82
 
@@ -95,10 +85,13 @@ def transcribe_from_youtube(url, apply_wiener_filter, apply_normalization, apply
95
  audio_tensor = wiener_filter(audio_tensor)
96
 
97
  if apply_spectral_gating:
98
- progress(0.4, "Applying Spectral Gating filter...")
99
  audio_tensor = spectral_gating(audio_tensor)
100
 
101
- transcription, _ = transcribe_speech(audio_tensor)
 
 
 
102
 
103
  audio_np = audio_tensor.numpy().squeeze()
104
  sf.write("temp_audio.wav", audio_np, 16000, subtype='PCM_16')
@@ -106,7 +99,7 @@ def transcribe_from_youtube(url, apply_wiener_filter, apply_normalization, apply
106
  except Exception as e:
107
  return str(e), None
108
 
109
- return transcription, "temp_audio.wav"
110
 
111
  def populate_metadata(url):
112
  yt = YouTube(url)
@@ -131,9 +124,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
131
  mic_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label="Record or upload an audio")
132
  transcribe_button = gr.Button("Transcribe")
133
  transcription_output = gr.Textbox(label="Transcription")
134
- audio_output = gr.Audio(label="Processed Audio")
135
 
136
- transcribe_button.click(fn=transcribe_speech, inputs=[mic_audio], outputs=[transcription_output, audio_output])
137
 
138
  with gr.Tab("YouTube URL"):
139
  gr.Markdown("## Transcribe speech from YouTube video")
 
51
  return torch.tensor(filtered_audio, dtype=audio_tensor.dtype)
52
 
53
  @spaces.GPU
54
+ def transcribe_speech(audio_path, progress=gr.Progress()):
55
+ if audio_path is None:
56
  return "No audio received.", None
57
 
58
  progress(0.1, desc="Preprocessing audio...")
59
+ audio_tensor, original_sample_rate = torchaudio.load(audio_path)
60
  audio_tensor = preprocess_audio(audio_tensor, original_sample_rate, apply_normalization)
61
 
 
 
 
 
 
 
 
 
62
  progress(0.7, desc="Transcribing audio...")
63
  audio_np = audio_tensor.numpy().squeeze()
64
  transcription = pipe(audio_np, chunk_length_s=10)['text']
65
  transcription = replace_symbols_back(transcription)
66
 
67
+ return transcription
 
 
 
68
 
69
+ @spaces.GPU
70
  def transcribe_from_youtube(url, apply_wiener_filter, apply_normalization, apply_spectral_gating, progress=gr.Progress()):
71
  progress(0, "Downloading YouTube audio...")
72
 
 
85
  audio_tensor = wiener_filter(audio_tensor)
86
 
87
  if apply_spectral_gating:
88
+ progress(0.6, "Applying Spectral Gating filter...")
89
  audio_tensor = spectral_gating(audio_tensor)
90
 
91
+ progress(0.8, "Transcribing audio...")
92
+ audio_np = audio_tensor.numpy().squeeze()
93
+ transcription = pipe(audio_np, chunk_length_s=10)['text']
94
+ transcription = replace_symbols_back(transcription)
95
 
96
  audio_np = audio_tensor.numpy().squeeze()
97
  sf.write("temp_audio.wav", audio_np, 16000, subtype='PCM_16')
 
99
  except Exception as e:
100
  return str(e), None
101
 
102
+ return transcription, "temp_audio.wav"
103
 
104
  def populate_metadata(url):
105
  yt = YouTube(url)
 
124
  mic_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label="Record or upload an audio")
125
  transcribe_button = gr.Button("Transcribe")
126
  transcription_output = gr.Textbox(label="Transcription")
 
127
 
128
+ transcribe_button.click(fn=transcribe_speech, inputs=[mic_audio], outputs=[transcription_output])
129
 
130
  with gr.Tab("YouTube URL"):
131
  gr.Markdown("## Transcribe speech from YouTube video")