Tony4's picture
Update app.py
902e3e0 verified
import spaces # Required for ZeroGPU compliance
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import os
import soundfile as sf
from scipy.signal import resample
# Model ID and Hugging Face Token
MODEL_ID = "WMRNORDIC/whisper-swedish-telephonic"
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if not HF_API_TOKEN:
raise ValueError("HF_API_TOKEN not found. Set it in the environment variables.")
# Sample file path
SAMPLE_FILE_PATH = "trimmed_resampled_audio.wav" # Update this path if necessary
@spaces.GPU
def initialize_model():
"""Lazy initialization of model and processor with GPU allocation."""
print("Initializing model and processor...")
processor = WhisperProcessor.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Model loaded on device: {device}")
return processor, model
@spaces.GPU
def transcribe_audio(audio):
"""Transcription logic with ZeroGPU compliance."""
try:
# Lazy-load model and processor
global processor, model
if 'processor' not in globals() or 'model' not in globals():
processor, model = initialize_model()
# Handle audio input
if isinstance(audio, tuple): # Microphone input
audio_data, sample_rate = audio[1], audio[0]
else: # Uploaded file
audio_data, sample_rate = sf.read(audio)
# Resample to 16kHz
if sample_rate != 16000:
num_samples = int(len(audio_data) * 16000 / sample_rate)
audio_data = resample(audio_data, num_samples)
# Prepare inputs for the model
device = "cuda" if torch.cuda.is_available() else "cpu"
input_features = processor(audio_data, return_tensors="pt", sampling_rate=16000).input_features.to(device)
# Generate transcription
with torch.no_grad():
predicted_ids = model.generate(input_features)
# Decode and return transcription
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
except Exception as e:
return f"Error during transcription: {str(e)}"
# Gradio Interface
def create_demo():
"""Set up the Gradio app."""
with gr.Blocks() as demo:
gr.Markdown("## Swedish Telephonic Speech-to-Text")
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Record via Microphone or Upload a File")
transcription_output = gr.Textbox(label="Transcription")
with gr.Row():
if SAMPLE_FILE_PATH: # Ensure the sample file path is set
gr.Markdown("### Sample Audio File")
sample_audio = gr.Audio(value=SAMPLE_FILE_PATH, label="Sample Audio File", type="filepath")
transcribe_button = gr.Button("Transcribe Sample")
transcribe_button.click(transcribe_audio, inputs=sample_audio, outputs=transcription_output)
# Connect audio input to transcription logic
audio_input.change(transcribe_audio, inputs=audio_input, outputs=transcription_output)
return demo
# Initialize Gradio app
demo = create_demo()
# Launch the app
if __name__ == "__main__":
demo.launch()