techysanoj commited on
Commit
6ab3f9b
·
1 Parent(s): 9412793

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -29
app.py CHANGED
@@ -1,31 +1,37 @@
1
- import gradio as gr
2
- import soundfile as sf
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Load the pre-trained model and tokenizer
6
- model_name = "facebook/wav2vec2-large-960h-lv60-self"
7
- tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
8
- model = SpeechRecognitionModel("jonatasgrosman/wav2vec2-large-xlsr-53-english")
9
-
10
- # Define the speech recognition function
11
- def transcribe_audio(audio):
12
- audio_path = "audio.wav"
13
- sf.write(audio_path, audio, samplerate=16000)
14
- transcriptions = model.transcribe(audio_path)
15
- return transcriptions["transcription"]
16
-
17
- # Set up the Gradio interface
18
- audio_input = gr.inputs.Audio(source="microphone", type="numpy")
19
- text_output = gr.outputs.Textbox()
20
-
21
- interface = gr.Interface(
22
- fn=transcribe_audio,
23
- inputs=audio_input,
24
- outputs=text_output,
25
- title="Speech Recognition",
26
- description="Transcribe speech in real-time.",
27
- server_port=8000,
28
- )
29
-
30
- if __name__ == "__main__":
31
- interface.launch()
 
1
+ import torch
2
+ import torchaudio
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
4
+ import gradio as gr
5
+
6
+ # Load pre-trained model and tokenizer
7
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
8
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h")
9
+
10
+ def transcribe_speech(audio_file):
11
+ # Load and convert audio file to waveform
12
+ waveform, _ = torchaudio.load(audio_file)
13
+
14
+ # Preprocess waveform
15
+ input_values = tokenizer(waveform, return_tensors="pt").input_values
16
+
17
+ # Perform inference
18
+ with torch.no_grad():
19
+ logits = model(input_values).logits
20
+
21
+ # Get predicted transcription
22
+ predicted_ids = torch.argmax(logits, dim=-1)
23
+ transcription = tokenizer.batch_decode(predicted_ids)[0]
24
+
25
+ return transcription
26
+
27
+ # Define Gradio interface
28
+ def speech_recognition(audio_file):
29
+ transcription = transcribe_speech(audio_file)
30
+ return transcription
31
+
32
+ inputs = gr.inputs.Audio(type="file", label="Upload Audio File")
33
+ outputs = gr.outputs.Textbox(label="Transcription")
34
+ interface = gr.Interface(fn=speech_recognition, inputs=inputs, outputs=outputs)
35
 
36
+ # Run the Gradio interface
37
+ interface.launch()