SpeechT5-TTS-BN / app.py
Solo448's picture
Update app.py
e500175 verified
import gradio as gr
import torch
import os
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset, Audio
import numpy as np
from speechbrain.inference import EncoderClassifier
# Load models and processor
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("Solo448/SpeechT5-tuned-bn")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
# Load speaker encoder
device = "cuda" if torch.cuda.is_available() else "cpu"
speaker_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-xvect-voxceleb",
run_opts={"device": device},
savedir=os.path.join("/tmp", "speechbrain/spkrec-xvect-voxceleb")
)
# Load a sample from the dataset for speaker embedding
try:
dataset = load_dataset("Sajjo/bangala_data_v3", split="train", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
sample = dataset[0]
speaker_embedding = create_speaker_embedding(sample['audio']['array'])
except Exception as e:
print(f"Error loading dataset: {e}")
# Use a random speaker embedding as fallback
speaker_embedding = torch.randn(1, 512)
def create_speaker_embedding(waveform):
with torch.no_grad():
speaker_embeddings = speaker_model.encode_batch(torch.tensor(waveform))
speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2)
speaker_embeddings = speaker_embeddings.squeeze().cpu().numpy()
return speaker_embeddings
def text_to_speech(text):
# Clean up text
replacements = [
("অ", "a"),
("আ", "aa"),
("ই", "i"),
("ঈ", "ee"),
("উ", "u"),
("ঊ", "oo"),
("ঋ", "ri"),
("এ", "e"),
("ঐ", "oi"),
("ও", "o"),
("ঔ", "ou"),
("ক", "k"),
("খ", "kh"),
("গ", "g"),
("ঘ", "gh"),
("ঙ", "ng"),
("চ", "ch"),
("ছ", "chh"),
("জ", "j"),
("ঝ", "jh"),
("ঞ", "nj"),
("ট", "t"),
("ঠ", "th"),
("ড", "d"),
("ঢ", "dh"),
("ণ", "nr"),
("ত", "t"),
("থ", "th"),
("দ", "d"),
("ধ", "dh"),
("ন", "n"),
("প", "p"),
("ফ", "ph"),
("ব", "b"),
("ভ", "bh"),
("ম", "m"),
("য", "ya"),
("র", "r"),
("ল", "l"),
("শ", "sha"),
("ষ", "sh"),
("স", "s"),
("হ", "ha"),
("ড়", "rh"),
("ঢ়", "rh"),
("য়", "y"),
("ৎ", "t"),
("ঃ", "h"),
("ঁ", "n"),
("়", ""),
("া", "a"),
("ি", "i"),
("ী", "ii"),
("ু", "u"),
("ূ", "uu"),
("ৃ", "r"),
("ে", "e"),
("ৈ", "oi"),
("ো", "o"),
("ৌ", "ou"),
("্", ""),
("ৎ", "t"),
("ৗ", "ou"),
("ড়", "r"),
("ঢ়", "r"),
("য়", "y"),
("ৰ", "r"),
("৵", "lee"),
("ং", "ng"),
("১", "1"),
("২", "2"),
("৩", "3"),
("৪", "4"),
("৫", "5"),
("৬", "6"),
("৭", "7"),
("৮", "8"),
("৯", "9"),
("০", "0")
]
for src, dst in replacements:
text = text.replace(src, dst)
inputs = processor(text=text, return_tensors="pt")
speech = model.generate_speech(inputs["input_ids"], speaker_embedding, vocoder=vocoder)
return (16000, speech.numpy())
iface = gr.Interface(
fn=text_to_speech,
inputs="text",
outputs="audio",
title="Bengali Text-to-Speech",
description="Enter bengali text to convert to speech"
)
iface.launch()