File size: 2,542 Bytes
0323180 d347764 0323180 d347764 0323180 d347764 0323180 21f0bae 0323180 d347764 0323180 d347764 0323180 d347764 0323180 d347764 0323180 d347764 0323180 d347764 0323180 d347764 0323180 d347764 f805e49 50e4622 f805e49 c737803 d347764 226ec3a d347764 f805e49 d347764 c737803 0323180 c737803 0323180 c737803 3946ba6 c737803 d347764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
import gradio as gr
import numpy as np
from transformers import (
VitsModel,
VitsTokenizer,
pipeline,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Using {device} with fp {torch_dtype}")
# load speech translation checkpoint
asr_pipe = pipeline( # noqa: F821
"automatic-speech-recognition",
model="openai/whisper-medium",
device=device,
torch_dtype=torch_dtype,
)
# load text-to-speech checkpoint
model = VitsModel.from_pretrained("facebook/mms-tts-zlm")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-zlm")
def synthesise(text):
inputs = tokenizer(text=text, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
outputs = model(input_ids)
speech = outputs["waveform"]
return speech
def translate(audio):
outputs = asr_pipe(
audio,
max_new_tokens=256,
generate_kwargs={"task": "transcribe", "language": "ms"},
)
return outputs["text"]
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
return 16000, synthesised_speech.T
title = "Cascaded STST"
description = """
Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in **Malay**. Demo uses OpenAI's [Whisper Base](https://huggingface.co/openai/whisper-base) model for speech translation, and Facebooks's
[MMS-TTS-ZLM](https://huggingface.co/facebook/mms-tts-zlm) model for text-to-speech:
![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
"""
demo = gr.Blocks()
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
title=title,
description=description,
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(source="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
examples="./examples",
title=title,
description=description,
live=True,
)
with demo:
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
demo.launch()
|