Antoniskaraolis commited on
Commit
d7aa11b
·
1 Parent(s): 6c1a3af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -1,22 +1,23 @@
1
- import torch
2
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
3
- import torchaudio
4
 
5
- def speech_recognition(audio_file_path):
6
- tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
7
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
 
 
8
 
9
- waveform, sample_rate = torchaudio.load(audio_file_path)
 
 
10
 
11
- if sample_rate != 16000:
12
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
13
- waveform = resampler(waveform)
14
 
15
- input_values = tokenizer(waveform.squeeze().numpy(), return_tensors="pt", padding="longest").input_values
16
- with torch.no_grad():
17
- logits = model(input_values).logits
 
 
18
 
19
- predicted_ids = torch.argmax(logits, dim=-1)
20
- transcription = tokenizer.batch_decode(predicted_ids)
21
-
22
- return transcription[0]
 
1
+ import whisper
2
+ import gradio as gr
 
3
 
4
+ def transcribe_audio(file_info):
5
+ model = whisper.load_model("base") # Choose the appropriate model size
6
+ audio = whisper.load_audio(file_info.name)
7
+ audio = whisper.pad_or_trim(audio)
8
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
9
 
10
+ _, probs = model.detect_language(mel)
11
+ language = max(probs, key=probs.get)
12
+ print(f"Detected language: {language}")
13
 
14
+ result = model.transcribe(mel)
15
+ return result["text"]
 
16
 
17
+ iface = gr.Interface(
18
+ fn=transcribe_audio,
19
+ inputs=gr.inputs.Audio(source="microphone", type="file"),
20
+ outputs="text"
21
+ )
22
 
23
+ iface.launch()