Spaces:
Running
on
Zero
Running
on
Zero
import spaces # Required for ZeroGPU compliance | |
import gradio as gr | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
import torch | |
import os | |
import soundfile as sf | |
from scipy.signal import resample | |
# Model ID and Hugging Face Token | |
MODEL_ID = "WMRNORDIC/whisper-swedish-telephonic" | |
HF_API_TOKEN = os.getenv("HF_API_TOKEN") | |
if not HF_API_TOKEN: | |
raise ValueError("HF_API_TOKEN not found. Set it in the environment variables.") | |
# Sample file path | |
SAMPLE_FILE_PATH = "trimmed_resampled_audio.wav" # Update this path if necessary | |
def initialize_model(): | |
"""Lazy initialization of model and processor with GPU allocation.""" | |
print("Initializing model and processor...") | |
processor = WhisperProcessor.from_pretrained(MODEL_ID, token=HF_API_TOKEN) | |
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, token=HF_API_TOKEN) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
print(f"Model loaded on device: {device}") | |
return processor, model | |
def transcribe_audio(audio): | |
"""Transcription logic with ZeroGPU compliance.""" | |
try: | |
# Lazy-load model and processor | |
global processor, model | |
if 'processor' not in globals() or 'model' not in globals(): | |
processor, model = initialize_model() | |
# Handle audio input | |
if isinstance(audio, tuple): # Microphone input | |
audio_data, sample_rate = audio[1], audio[0] | |
else: # Uploaded file | |
audio_data, sample_rate = sf.read(audio) | |
# Resample to 16kHz | |
if sample_rate != 16000: | |
num_samples = int(len(audio_data) * 16000 / sample_rate) | |
audio_data = resample(audio_data, num_samples) | |
# Prepare inputs for the model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
input_features = processor(audio_data, return_tensors="pt", sampling_rate=16000).input_features.to(device) | |
# Generate transcription | |
with torch.no_grad(): | |
predicted_ids = model.generate(input_features) | |
# Decode and return transcription | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
return transcription | |
except Exception as e: | |
return f"Error during transcription: {str(e)}" | |
# Gradio Interface | |
def create_demo(): | |
"""Set up the Gradio app.""" | |
with gr.Blocks() as demo: | |
gr.Markdown("## Swedish Telephonic Speech-to-Text") | |
with gr.Row(): | |
audio_input = gr.Audio(type="filepath", label="Record via Microphone or Upload a File") | |
transcription_output = gr.Textbox(label="Transcription") | |
with gr.Row(): | |
if SAMPLE_FILE_PATH: # Ensure the sample file path is set | |
gr.Markdown("### Sample Audio File") | |
sample_audio = gr.Audio(value=SAMPLE_FILE_PATH, label="Sample Audio File", type="filepath") | |
transcribe_button = gr.Button("Transcribe Sample") | |
transcribe_button.click(transcribe_audio, inputs=sample_audio, outputs=transcription_output) | |
# Connect audio input to transcription logic | |
audio_input.change(transcribe_audio, inputs=audio_input, outputs=transcription_output) | |
return demo | |
# Initialize Gradio app | |
demo = create_demo() | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |