davidmeikle's picture
Update app.py
04e9646 verified
import spaces
import gradio as gr
import torch
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import platform
import librosa
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
model.to('cuda')
class PhoneticEnhancer:
def __init__(self):
# Vowel length rules
self.long_vowels = {
'i': 'iː',
'u': 'uː',
'a': 'ɑː',
'ɑ': 'ɑː',
'e': 'eː',
'o': 'oː'
}
# Common diphthongs
self.diphthongs = {
'ei': 'eɪ',
'ai': 'aɪ',
'oi': 'ɔɪ',
'ou': 'əʊ',
'au': 'aʊ'
}
# Vowel quality adjustments
self.vowel_quality = {
'ə': 'æ', # In stressed positions
'ɐ': 'æ' # Common substitution
}
# Stress pattern rules
self.stress_patterns = [
# (pattern, position) - position is index from start
(['CV', 'CV'], 1), # For words like "piage"
(['CVV', 'CV'], 0), # For words with long first vowel
]
def _is_vowel(self, phoneme: str) -> bool:
vowels = set('aeiouɑɐəæɛɪʊʌɔ')
return any(char in vowels for char in phoneme)
def _split_into_syllables(self, phonemes: list) -> list:
syllables = []
current_syllable = []
for phoneme in phonemes:
current_syllable.append(phoneme)
if self._is_vowel(phoneme) and len(current_syllable) > 0:
syllables.append(current_syllable)
current_syllable = []
if current_syllable:
if len(syllables) > 0:
syllables[-1].extend(current_syllable)
else:
syllables.append(current_syllable)
return syllables
def enhance_transcription(self, raw_phonemes: str, enhancements: list = None) -> str:
if enhancements is None:
enhancements = ['length', 'quality', 'stress', 'diphthongs']
# Split into individual phonemes
phonemes = raw_phonemes.split()
enhanced_phonemes = phonemes.copy()
if 'length' in enhancements:
# Apply vowel length rules
for i, phoneme in enumerate(enhanced_phonemes):
if phoneme in self.long_vowels:
enhanced_phonemes[i] = self.long_vowels[phoneme]
if 'quality' in enhancements:
# Apply vowel quality adjustments
for i, phoneme in enumerate(enhanced_phonemes):
if phoneme in self.vowel_quality:
enhanced_phonemes[i] = self.vowel_quality[phoneme]
if 'diphthongs' in enhancements:
# Apply diphthong rules
i = 0
while i < len(enhanced_phonemes) - 1:
pair = enhanced_phonemes[i] + enhanced_phonemes[i + 1]
if pair in self.diphthongs:
enhanced_phonemes[i] = self.diphthongs[pair]
enhanced_phonemes.pop(i + 1)
i += 1
if 'stress' in enhancements:
# Add stress marks based on syllable structure
syllables = self._split_into_syllables(enhanced_phonemes)
if len(syllables) > 1:
# Add stress to the syllable containing 'æ' if present
for i, syll in enumerate(syllables):
if any('æ' in p for p in syll):
syllables[i].insert(0, 'ˈ')
break
# If no 'æ', add stress to first syllable by default
else:
syllables[0].insert(0, 'ˈ')
# Flatten syllables back to phonemes
enhanced_phonemes = [p for syll in syllables for p in syll]
return ' '.join(enhanced_phonemes)
def preprocess_audio(audio):
"""Preprocess audio data for model input."""
if isinstance(audio, tuple):
sample_rate, audio_data = audio
else:
return None
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
if audio_data.max() > 1.0 or audio_data.min() < -1.0:
audio_data = audio_data / 32768.0
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
if sample_rate != 16000:
audio_data = librosa.resample(
y=audio_data,
orig_sr=sample_rate,
target_sr=16000
)
return audio_data
@spaces.GPU
def transcribe_to_phonemes(audio, enhancements):
"""Transcribe audio to phonemes with enhancements."""
try:
audio_data = preprocess_audio(audio)
if audio_data is None:
return "Please provide valid audio input"
selected_enhancements = enhancements.split(',') if enhancements else []
inputs = processor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
padding=True
).input_values.to('cuda')
with torch.no_grad():
logits = model(inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
enhancer = PhoneticEnhancer()
enhanced = enhancer.enhance_transcription(
transcription,
selected_enhancements
)
return f"""Raw IPA: {transcription}
Enhanced IPA: {enhanced}
Applied enhancements: {', '.join(selected_enhancements) or 'none'}"""
except Exception as e:
import traceback
return f"Error processing audio: {str(e)}\n{traceback.format_exc()}"
iface = gr.Interface(
fn=transcribe_to_phonemes,
inputs=[
gr.Audio(sources=["microphone", "upload"], type="numpy"),
gr.Textbox(
label="Enhancements (comma-separated)",
value="length,quality,stress,diphthongs",
placeholder="e.g., length,quality,stress,diphthongs"
)
],
outputs="text",
title="Speech to Phoneme Converter - Enhanced IPA",
description="""Convert speech to phonemes with customizable IPA enhancements.
Available enhancements:
- length: Add vowel length markers (ː)
- quality: Adjust vowel quality (e.g., ə → æ)
- stress: Add stress marks (ˈ)
- diphthongs: Combine vowels into diphthongs (e.g., ei → eɪ)
"""
)
iface.launch()