|
import gradio as gr |
|
import librosa |
|
import soundfile as sf |
|
import torch |
|
import warnings |
|
import os |
|
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
from fastapi import FastAPI, HTTPException, File |
|
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
|
pipe_95m = pipeline(model="Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned",chunk_length_s=20, stride_length_s=(3, 3)) |
|
pipe_300m = pipeline(model="Finnish-NLP/wav2vec2-large-uralic-voxpopuli-v2-finnish",chunk_length_s=20, stride_length_s=(3, 3)) |
|
pipe_1b = pipeline(model="Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2",chunk_length_s=20, stride_length_s=(3, 3)) |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model_checkpoint = 'Finnish-NLP/t5-small-nl24-casing-punctuation-correction' |
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_flax=False, torch_dtype=torch.float32, use_auth_token=os.environ.get('hf_token')).to(device) |
|
|
|
|
|
def asr_transcript(audio, audio_microphone, model_params): |
|
|
|
|
|
audio = audio_microphone if audio_microphone else audio |
|
|
|
if audio == None and audio_microphone == None: |
|
return "Please provide audio (wav or mp3) by uploading a file or by recording audio using microphone by pressing Record (And allow usage of microphone)", "Please provide audio by uploading a file or by recording audio using microphone by pressing Record (And allow usage of microphone)" |
|
text = "" |
|
|
|
if audio: |
|
if model_params == "1 billion": |
|
text = pipe_1b(audio.name) |
|
elif model_params == "300 million": |
|
text = pipe_300m(audio.name) |
|
elif model_params == "95 million": |
|
text = pipe_95m(audio.name) |
|
|
|
input_ids = tokenizer(text['text'], return_tensors="pt").input_ids.to(device) |
|
outputs = model.generate(input_ids, max_length=128) |
|
case_corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return text['text'], case_corrected_text |
|
else: |
|
return "File not valid" |
|
|
|
gradio_ui = gr.Interface( |
|
fn=asr_transcript, |
|
title="Finnish Automatic Speech Recognition", |
|
description="Upload an audio clip or record from browser using microphone, and let AI do the hard work of transcribing.", |
|
article = """ |
|
This demo includes 2 kinds of models that are run together. First selected ASR model does speech recognition which produces lowercase text without punctuation. |
|
After that we run a sequence-to-sequence model which tries to correct casing and punctuation which produces the final output. |
|
You can select one of two speech recognition models listed below |
|
|
|
1. 1 billion, best accuracy but slowest by big margin. Based on multilingual wav2vec2-xlsr model by Meta. More info here https://huggingface.co/Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2 |
|
3. 300 million, at bar in accuracy as 1. but a lot faster. Based on Uralic wav2vec2 model by Meta. More info here https://huggingface.co/Finnish-NLP/wav2vec2-large-uralic-voxpopuli-v2-finnish |
|
3. 95 million, almost as accurate as 1. but really much faster. Based on Finnish wav2vec2 model by Meta. More info here https://huggingface.co/Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned |
|
|
|
More info about the casing+punctuation correction model can be found here https://huggingface.co/Finnish-NLP/t5-small-nl24-casing-punctuation-correction |
|
""", |
|
inputs=[gr.inputs.Audio(label="Upload Audio File", type="file", optional=True), gr.inputs.Audio(source="microphone", type="file", optional=True, label="Record from microphone"), gr.inputs.Dropdown(choices=["95 million","300 million", "1 billion"], type="value", default="300 million", label="Select speech recognition model parameter amount", optional=False)], |
|
outputs=[gr.outputs.Textbox(label="Recognized speech"),gr.outputs.Textbox(label="Recognized speech with case correction and punctuation")] |
|
) |
|
|
|
gradio_ui.launch() |