Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import platform | |
import librosa | |
import multiprocessing | |
from dataclasses import dataclass | |
from typing import Dict, Tuple, List | |
class ModelConfig: | |
name: str | |
processor: Wav2Vec2Processor | |
model: Wav2Vec2ForCTC | |
description: str | |
class PhoneticEnhancer: | |
def __init__(self): | |
# Vowel length rules | |
self.long_vowels = { | |
'i': 'iː', | |
'u': 'uː', | |
'a': 'ɑː', | |
'ɑ': 'ɑː', | |
'e': 'eː', | |
'o': 'oː' | |
} | |
# Common diphthongs | |
self.diphthongs = { | |
'ei': 'eɪ', | |
'ai': 'aɪ', | |
'oi': 'ɔɪ', | |
'ou': 'əʊ', | |
'au': 'aʊ' | |
} | |
# Vowel quality adjustments | |
self.vowel_quality = { | |
'ə': 'æ', # In stressed positions | |
'ɐ': 'æ' # Common substitution | |
} | |
# Stress pattern rules | |
self.stress_patterns = [ | |
# (pattern, position) - position is index from start | |
(['CV', 'CV'], 1), # For words like "piage" | |
(['CVV', 'CV'], 0), # For words with long first vowel | |
] | |
def _is_vowel(self, phoneme: str) -> bool: | |
vowels = set('aeiouɑɐəæɛɪʊʌɔ') | |
return any(char in vowels for char in phoneme) | |
def _split_into_syllables(self, phonemes: List[str]) -> List[List[str]]: | |
syllables = [] | |
current_syllable = [] | |
for phoneme in phonemes: | |
current_syllable.append(phoneme) | |
if self._is_vowel(phoneme) and len(current_syllable) > 0: | |
syllables.append(current_syllable) | |
current_syllable = [] | |
if current_syllable: | |
if len(syllables) > 0: | |
syllables[-1].extend(current_syllable) | |
else: | |
syllables.append(current_syllable) | |
return syllables | |
def enhance_transcription(self, raw_phonemes: str, enhancements: List[str] = None) -> str: | |
if enhancements is None: | |
enhancements = ['length', 'quality', 'stress', 'diphthongs'] | |
# Split into individual phonemes | |
phonemes = raw_phonemes.split() | |
enhanced_phonemes = phonemes.copy() | |
if 'length' in enhancements: | |
# Apply vowel length rules | |
for i, phoneme in enumerate(enhanced_phonemes): | |
if phoneme in self.long_vowels: | |
enhanced_phonemes[i] = self.long_vowels[phoneme] | |
if 'quality' in enhancements: | |
# Apply vowel quality adjustments | |
for i, phoneme in enumerate(enhanced_phonemes): | |
if phoneme in self.vowel_quality: | |
enhanced_phonemes[i] = self.vowel_quality[phoneme] | |
if 'diphthongs' in enhancements: | |
# Apply diphthong rules | |
i = 0 | |
while i < len(enhanced_phonemes) - 1: | |
pair = enhanced_phonemes[i] + enhanced_phonemes[i + 1] | |
if pair in self.diphthongs: | |
enhanced_phonemes[i] = self.diphthongs[pair] | |
enhanced_phonemes.pop(i + 1) | |
i += 1 | |
if 'stress' in enhancements: | |
# Add stress marks based on syllable structure | |
syllables = self._split_into_syllables(enhanced_phonemes) | |
if len(syllables) > 1: | |
# Add stress to the syllable containing 'æ' if present | |
for i, syll in enumerate(syllables): | |
if any('æ' in p for p in syll): | |
syllables[i].insert(0, 'ˈ') | |
break | |
# If no 'æ', add stress to first syllable by default | |
else: | |
syllables[0].insert(0, 'ˈ') | |
# Flatten syllables back to phonemes | |
enhanced_phonemes = [p for syll in syllables for p in syll] | |
return ' '.join(enhanced_phonemes) | |
class PhonemeTranscriber: | |
def __init__(self): | |
self.device = self._get_optimal_device() | |
print(f"Using device: {self.device}") | |
self.model_config = self._initialize_model() | |
self.target_sample_rate = 16_000 | |
self.enhancer = PhoneticEnhancer() | |
def _get_optimal_device(self): | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif torch.backends.mps.is_available() and platform.system() == 'Darwin': | |
return "mps" | |
return "cpu" | |
def _initialize_model(self) -> ModelConfig: | |
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft" | |
processor = Wav2Vec2Processor.from_pretrained(model_name) | |
model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
return ModelConfig( | |
name=model_name, | |
processor=processor, | |
model=model, | |
description="LV-60 + CommonVoice (26 langs) + eSpeak" | |
) | |
def preprocess_audio(self, audio): | |
"""Preprocess audio data for model input.""" | |
if isinstance(audio, tuple): | |
sample_rate, audio_data = audio | |
else: | |
return None | |
if audio_data.dtype != np.float32: | |
audio_data = audio_data.astype(np.float32) | |
if audio_data.max() > 1.0 or audio_data.min() < -1.0: | |
audio_data = audio_data / 32768.0 | |
if len(audio_data.shape) > 1: | |
audio_data = audio_data.mean(axis=1) | |
if sample_rate != self.target_sample_rate: | |
audio_data = librosa.resample( | |
y=audio_data, | |
orig_sr=sample_rate, | |
target_sr=self.target_sample_rate | |
) | |
return audio_data | |
def transcribe_to_phonemes(self, audio, enhancements): | |
"""Transcribe audio to phonemes with enhancements.""" | |
try: | |
audio_data = self.preprocess_audio(audio) | |
if audio_data is None: | |
return "Please provide valid audio input" | |
selected_enhancements = enhancements.split(',') if enhancements else [] | |
inputs = self.model_config.processor( | |
audio_data, | |
sampling_rate=self.target_sample_rate, | |
return_tensors="pt", | |
padding=True | |
).input_values.to(self.device) | |
with torch.no_grad(): | |
logits = self.model_config.model(inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = self.model_config.processor.batch_decode(predicted_ids)[0] | |
enhanced = self.enhancer.enhance_transcription( | |
transcription, | |
selected_enhancements | |
) | |
return f"""Raw IPA: {transcription} | |
Enhanced IPA: {enhanced} | |
Applied enhancements: {', '.join(selected_enhancements) or 'none'}""" | |
except Exception as e: | |
import traceback | |
return f"Error processing audio: {str(e)}\n{traceback.format_exc()}" | |
if __name__ == "__main__": | |
multiprocessing.freeze_support() | |
transcriber = PhonemeTranscriber() | |
iface = gr.Interface( | |
fn=transcriber.transcribe_to_phonemes, | |
inputs=[ | |
gr.Audio(sources=["microphone", "upload"], type="numpy"), | |
gr.Textbox( | |
label="Enhancements (comma-separated)", | |
value="length,quality,stress,diphthongs", | |
placeholder="e.g., length,quality,stress,diphthongs" | |
) | |
], | |
outputs="text", | |
title="Speech to Phoneme Converter - Enhanced IPA", | |
description=f"""Convert speech to phonemes with customizable IPA enhancements. | |
Currently using device: {transcriber.device} | |
Available enhancements: | |
- length: Add vowel length markers (ː) | |
- quality: Adjust vowel quality (e.g., ə → æ) | |
- stress: Add stress marks (ˈ) | |
- diphthongs: Combine vowels into diphthongs (e.g., ei → eɪ) | |
""" | |
) | |
iface.launch() |