import spaces import gradio as gr import torch import numpy as np from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import platform import librosa import multiprocessing from dataclasses import dataclass from typing import Dict, Tuple, List @dataclass class ModelConfig: name: str processor: Wav2Vec2Processor model: Wav2Vec2ForCTC description: str class PhoneticEnhancer: def __init__(self): # Vowel length rules self.long_vowels = { 'i': 'iː', 'u': 'uː', 'a': 'ɑː', 'ɑ': 'ɑː', 'e': 'eː', 'o': 'oː' } # Common diphthongs self.diphthongs = { 'ei': 'eɪ', 'ai': 'aɪ', 'oi': 'ɔɪ', 'ou': 'əʊ', 'au': 'aʊ' } # Vowel quality adjustments self.vowel_quality = { 'ə': 'æ', # In stressed positions 'ɐ': 'æ' # Common substitution } # Stress pattern rules self.stress_patterns = [ # (pattern, position) - position is index from start (['CV', 'CV'], 1), # For words like "piage" (['CVV', 'CV'], 0), # For words with long first vowel ] def _is_vowel(self, phoneme: str) -> bool: vowels = set('aeiouɑɐəæɛɪʊʌɔ') return any(char in vowels for char in phoneme) def _split_into_syllables(self, phonemes: List[str]) -> List[List[str]]: syllables = [] current_syllable = [] for phoneme in phonemes: current_syllable.append(phoneme) if self._is_vowel(phoneme) and len(current_syllable) > 0: syllables.append(current_syllable) current_syllable = [] if current_syllable: if len(syllables) > 0: syllables[-1].extend(current_syllable) else: syllables.append(current_syllable) return syllables def enhance_transcription(self, raw_phonemes: str, enhancements: List[str] = None) -> str: if enhancements is None: enhancements = ['length', 'quality', 'stress', 'diphthongs'] # Split into individual phonemes phonemes = raw_phonemes.split() enhanced_phonemes = phonemes.copy() if 'length' in enhancements: # Apply vowel length rules for i, phoneme in enumerate(enhanced_phonemes): if phoneme in self.long_vowels: enhanced_phonemes[i] = self.long_vowels[phoneme] if 'quality' in enhancements: # Apply vowel quality adjustments for i, phoneme in enumerate(enhanced_phonemes): if phoneme in self.vowel_quality: enhanced_phonemes[i] = self.vowel_quality[phoneme] if 'diphthongs' in enhancements: # Apply diphthong rules i = 0 while i < len(enhanced_phonemes) - 1: pair = enhanced_phonemes[i] + enhanced_phonemes[i + 1] if pair in self.diphthongs: enhanced_phonemes[i] = self.diphthongs[pair] enhanced_phonemes.pop(i + 1) i += 1 if 'stress' in enhancements: # Add stress marks based on syllable structure syllables = self._split_into_syllables(enhanced_phonemes) if len(syllables) > 1: # Add stress to the syllable containing 'æ' if present for i, syll in enumerate(syllables): if any('æ' in p for p in syll): syllables[i].insert(0, 'ˈ') break # If no 'æ', add stress to first syllable by default else: syllables[0].insert(0, 'ˈ') # Flatten syllables back to phonemes enhanced_phonemes = [p for syll in syllables for p in syll] return ' '.join(enhanced_phonemes) class PhonemeTranscriber: def __init__(self): self.device = self._get_optimal_device() print(f"Using device: {self.device}") self.model_config = self._initialize_model() self.target_sample_rate = 16_000 self.enhancer = PhoneticEnhancer() def _get_optimal_device(self): if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available() and platform.system() == 'Darwin': return "mps" return "cpu" def _initialize_model(self) -> ModelConfig: model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft" processor = Wav2Vec2Processor.from_pretrained(model_name) model = Wav2Vec2ForCTC.from_pretrained(model_name) return ModelConfig( name=model_name, processor=processor, model=model, description="LV-60 + CommonVoice (26 langs) + eSpeak" ) def preprocess_audio(self, audio): """Preprocess audio data for model input.""" if isinstance(audio, tuple): sample_rate, audio_data = audio else: return None if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) if audio_data.max() > 1.0 or audio_data.min() < -1.0: audio_data = audio_data / 32768.0 if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1) if sample_rate != self.target_sample_rate: audio_data = librosa.resample( y=audio_data, orig_sr=sample_rate, target_sr=self.target_sample_rate ) return audio_data @spaces.GPU def transcribe_to_phonemes(self, audio, enhancements): """Transcribe audio to phonemes with enhancements.""" try: audio_data = self.preprocess_audio(audio) if audio_data is None: return "Please provide valid audio input" selected_enhancements = enhancements.split(',') if enhancements else [] inputs = self.model_config.processor( audio_data, sampling_rate=self.target_sample_rate, return_tensors="pt", padding=True ).input_values.to(self.device) with torch.no_grad(): logits = self.model_config.model(inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = self.model_config.processor.batch_decode(predicted_ids)[0] enhanced = self.enhancer.enhance_transcription( transcription, selected_enhancements ) return f"""Raw IPA: {transcription} Enhanced IPA: {enhanced} Applied enhancements: {', '.join(selected_enhancements) or 'none'}""" except Exception as e: import traceback return f"Error processing audio: {str(e)}\n{traceback.format_exc()}" if __name__ == "__main__": multiprocessing.freeze_support() transcriber = PhonemeTranscriber() iface = gr.Interface( fn=transcriber.transcribe_to_phonemes, inputs=[ gr.Audio(sources=["microphone", "upload"], type="numpy"), gr.Textbox( label="Enhancements (comma-separated)", value="length,quality,stress,diphthongs", placeholder="e.g., length,quality,stress,diphthongs" ) ], outputs="text", title="Speech to Phoneme Converter - Enhanced IPA", description=f"""Convert speech to phonemes with customizable IPA enhancements. Currently using device: {transcriber.device} Available enhancements: - length: Add vowel length markers (ː) - quality: Adjust vowel quality (e.g., ə → æ) - stress: Add stress marks (ˈ) - diphthongs: Combine vowels into diphthongs (e.g., ei → eɪ) """ ) iface.launch()