Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,415 Bytes
902e3e0 b04db1b 4a6298a b04db1b 535f29e 2191fe9 b04db1b 902e3e0 4a6298a d03b85c 902e3e0 1e79bc7 902e3e0 1e79bc7 d03b85c 902e3e0 0a32df4 902e3e0 945e29c 902e3e0 945e29c b519c08 902e3e0 4a6298a 902e3e0 535f29e 902e3e0 b28e9ae 902e3e0 8a71a48 902e3e0 8a71a48 4a6298a 902e3e0 945e29c 902e3e0 4a6298a 8a71a48 f327118 38be9f3 4a6298a f327118 535f29e c6a7213 535f29e 902e3e0 535f29e f327118 38be9f3 c6a7213 1c8e9da f327118 38be9f3 ed2c51e b04db1b 902e3e0 f327118 535f29e f327118 b04db1b f327118 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import spaces # Required for ZeroGPU compliance
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import os
import soundfile as sf
from scipy.signal import resample
# Model ID and Hugging Face Token
MODEL_ID = "WMRNORDIC/whisper-swedish-telephonic"
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if not HF_API_TOKEN:
raise ValueError("HF_API_TOKEN not found. Set it in the environment variables.")
# Sample file path
SAMPLE_FILE_PATH = "trimmed_resampled_audio.wav" # Update this path if necessary
@spaces.GPU
def initialize_model():
"""Lazy initialization of model and processor with GPU allocation."""
print("Initializing model and processor...")
processor = WhisperProcessor.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Model loaded on device: {device}")
return processor, model
@spaces.GPU
def transcribe_audio(audio):
"""Transcription logic with ZeroGPU compliance."""
try:
# Lazy-load model and processor
global processor, model
if 'processor' not in globals() or 'model' not in globals():
processor, model = initialize_model()
# Handle audio input
if isinstance(audio, tuple): # Microphone input
audio_data, sample_rate = audio[1], audio[0]
else: # Uploaded file
audio_data, sample_rate = sf.read(audio)
# Resample to 16kHz
if sample_rate != 16000:
num_samples = int(len(audio_data) * 16000 / sample_rate)
audio_data = resample(audio_data, num_samples)
# Prepare inputs for the model
device = "cuda" if torch.cuda.is_available() else "cpu"
input_features = processor(audio_data, return_tensors="pt", sampling_rate=16000).input_features.to(device)
# Generate transcription
with torch.no_grad():
predicted_ids = model.generate(input_features)
# Decode and return transcription
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
except Exception as e:
return f"Error during transcription: {str(e)}"
# Gradio Interface
def create_demo():
"""Set up the Gradio app."""
with gr.Blocks() as demo:
gr.Markdown("## Swedish Telephonic Speech-to-Text")
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Record via Microphone or Upload a File")
transcription_output = gr.Textbox(label="Transcription")
with gr.Row():
if SAMPLE_FILE_PATH: # Ensure the sample file path is set
gr.Markdown("### Sample Audio File")
sample_audio = gr.Audio(value=SAMPLE_FILE_PATH, label="Sample Audio File", type="filepath")
transcribe_button = gr.Button("Transcribe Sample")
transcribe_button.click(transcribe_audio, inputs=sample_audio, outputs=transcription_output)
# Connect audio input to transcription logic
audio_input.change(transcribe_audio, inputs=audio_input, outputs=transcription_output)
return demo
# Initialize Gradio app
demo = create_demo()
# Launch the app
if __name__ == "__main__":
demo.launch()
|