File size: 3,415 Bytes
902e3e0
b04db1b
4a6298a
 
b04db1b
535f29e
2191fe9
b04db1b
902e3e0
4a6298a
d03b85c
902e3e0
1e79bc7
902e3e0
1e79bc7
d03b85c
 
 
902e3e0
 
0a32df4
902e3e0
 
945e29c
 
902e3e0
 
 
945e29c
b519c08
902e3e0
4a6298a
902e3e0
535f29e
902e3e0
b28e9ae
 
 
902e3e0
 
8a71a48
902e3e0
8a71a48
4a6298a
 
902e3e0
 
 
 
 
 
945e29c
 
902e3e0
 
4a6298a
 
8a71a48
f327118
38be9f3
4a6298a
f327118
535f29e
c6a7213
535f29e
902e3e0
535f29e
f327118
38be9f3
c6a7213
1c8e9da
 
 
 
 
 
 
 
 
f327118
38be9f3
ed2c51e
b04db1b
902e3e0
f327118
535f29e
 
f327118
b04db1b
f327118
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import spaces  # Required for ZeroGPU compliance
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import os
import soundfile as sf
from scipy.signal import resample

# Model ID and Hugging Face Token
MODEL_ID = "WMRNORDIC/whisper-swedish-telephonic"
HF_API_TOKEN = os.getenv("HF_API_TOKEN")

if not HF_API_TOKEN:
    raise ValueError("HF_API_TOKEN not found. Set it in the environment variables.")

# Sample file path
SAMPLE_FILE_PATH = "trimmed_resampled_audio.wav"  # Update this path if necessary


@spaces.GPU
def initialize_model():
    """Lazy initialization of model and processor with GPU allocation."""
    print("Initializing model and processor...")
    processor = WhisperProcessor.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
    model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    print(f"Model loaded on device: {device}")
    return processor, model

@spaces.GPU
def transcribe_audio(audio):
    """Transcription logic with ZeroGPU compliance."""
    try:
        # Lazy-load model and processor
        global processor, model
        if 'processor' not in globals() or 'model' not in globals():
            processor, model = initialize_model()

        # Handle audio input
        if isinstance(audio, tuple):  # Microphone input
            audio_data, sample_rate = audio[1], audio[0]
        else:  # Uploaded file
            audio_data, sample_rate = sf.read(audio)

        # Resample to 16kHz
        if sample_rate != 16000:
            num_samples = int(len(audio_data) * 16000 / sample_rate)
            audio_data = resample(audio_data, num_samples)

        # Prepare inputs for the model
        device = "cuda" if torch.cuda.is_available() else "cpu"
        input_features = processor(audio_data, return_tensors="pt", sampling_rate=16000).input_features.to(device)

        # Generate transcription
        with torch.no_grad():
            predicted_ids = model.generate(input_features)

        # Decode and return transcription
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        return transcription

    except Exception as e:
        return f"Error during transcription: {str(e)}"

# Gradio Interface
def create_demo():
    """Set up the Gradio app."""
    with gr.Blocks() as demo:
        gr.Markdown("## Swedish Telephonic Speech-to-Text")
        with gr.Row():
            audio_input = gr.Audio(type="filepath", label="Record via Microphone or Upload a File")
            transcription_output = gr.Textbox(label="Transcription")
        with gr.Row():
            if SAMPLE_FILE_PATH:  # Ensure the sample file path is set
                gr.Markdown("### Sample Audio File")
                sample_audio = gr.Audio(value=SAMPLE_FILE_PATH, label="Sample Audio File", type="filepath")
                transcribe_button = gr.Button("Transcribe Sample")
                transcribe_button.click(transcribe_audio, inputs=sample_audio, outputs=transcription_output)
        # Connect audio input to transcription logic
        audio_input.change(transcribe_audio, inputs=audio_input, outputs=transcription_output)
    return demo


# Initialize Gradio app
demo = create_demo()

# Launch the app
if __name__ == "__main__":
    demo.launch()