juliuserictuliao commited on
Commit
28cd99e
·
verified ·
1 Parent(s): 24b2ad1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -43
app.py CHANGED
@@ -1,52 +1,52 @@
1
- #Importing all the necessary packages
2
  import nltk
3
- import librosa
4
  import torch
5
  import gradio as gr
6
- from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
 
 
 
7
  nltk.download("punkt")
8
 
9
- #Loading the pre-trained model and the tokenizer
10
  model_name = "facebook/wav2vec2-base-960h"
11
- tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
12
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
13
 
14
- def load_data(input_file):
15
-
16
- #reading the file
17
- speech, sample_rate = librosa.load(input_file)
18
- #make it 1-D
19
- if len(speech.shape) > 1:
20
- speech = speech[:,0] + speech[:,1]
21
- #Resampling the audio at 16KHz
22
- if sample_rate !=16000:
23
- speech = librosa.resample(speech, sample_rate,16000)
24
- return speech
25
-
26
  def correct_casing(input_sentence):
27
-
28
- sentences = nltk.sent_tokenize(input_sentence)
29
- return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
30
-
31
- def asr_transcript(input_file):
32
-
33
- speech = load_data(input_file)
34
- #Tokenize
35
- input_values = tokenizer(speech, return_tensors="pt").input_values
36
- #Take logits
37
- logits = model(input_values).logits
38
- #Take argmax
39
- predicted_ids = torch.argmax(logits, dim=-1)
40
- #Get the words from predicted word ids
41
- transcription = tokenizer.decode(predicted_ids[0])
42
- #Correcting the letter casing
43
- transcription = correct_casing(transcription.lower())
44
- return transcription
45
-
46
- gr.Interface(asr_transcript,
47
- inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker"),
48
- outputs = gr.outputs.Textbox(label="Output Text"),
49
- title="ASR using Wav2Vec 2.0",
50
- description = "This application displays transcribed text for given audio input",
51
- examples = [["Test_File1.wav"], ["Test_File2.wav"], ["Test_File3.wav"]], theme="grass").launch()
52
-
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing all the necessary packages
2
  import nltk
 
3
  import torch
4
  import gradio as gr
5
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
6
+ import numpy as np
7
+
8
+ # Downloading the necessary NLTK data
9
  nltk.download("punkt")
10
 
11
+ # Loading the pre-trained model and the processor
12
  model_name = "facebook/wav2vec2-base-960h"
13
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
14
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def correct_casing(input_sentence):
17
+ sentences = nltk.sent_tokenize(input_sentence)
18
+ return ' '.join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences])
19
+
20
+ def asr_transcript(audio):
21
+ # Process the audio
22
+ input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
23
+ # Get logits
24
+ logits = model(input_values).logits
25
+ # Get predicted IDs
26
+ predicted_ids = torch.argmax(logits, dim=-1)
27
+ # Decode the IDs to text
28
+ transcription = processor.decode(predicted_ids[0])
29
+ # Correct the casing
30
+ transcription = correct_casing(transcription.lower())
31
+ return transcription
32
+
33
+ def real_time_asr(audio, state=None):
34
+ if state is None:
35
+ state = ""
36
+ audio = np.array(audio)
37
+ transcription = asr_transcript(audio)
38
+ state += " " + transcription
39
+ return state, state
40
+
41
+ # Create the Gradio interface
42
+ iface = gr.Interface(
43
+ fn=real_time_asr,
44
+ inputs=[gr.Audio(source="microphone", streaming=True), "state"],
45
+ outputs="text",
46
+ live=True,
47
+ title="Real-Time ASR using Wav2Vec 2.0",
48
+ description="This application displays transcribed text in real-time for given audio input"
49
+ )
50
+
51
+ # Launch the interface
52
+ iface.launch()