riteshkr commited on
Commit
a9e9df1
·
verified ·
1 Parent(s): 1b1a505

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
5
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
6
+ from datasets import load_dataset
7
+
8
+ # Check if a GPU is available and set the device
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load the Whisper ASR model
12
+ whisper_model_id = "riteshkr/quantized-whisper-large-v3"
13
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_id)
14
+ whisper_processor = WhisperProcessor.from_pretrained(whisper_model_id)
15
+
16
+ # Set the language to English using forced_decoder_ids
17
+ forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")
18
+
19
+ whisper_pipe = pipeline(
20
+ "automatic-speech-recognition",
21
+ model=whisper_model,
22
+ tokenizer=whisper_processor.tokenizer,
23
+ feature_extractor=whisper_processor.feature_extractor,
24
+ device=0 if torch.cuda.is_available() else -1
25
+ )
26
+
27
+ # Load the SpeechT5 TTS model
28
+ tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
29
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
30
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
31
+
32
+ tts_model.to(device)
33
+ vocoder.to(device)
34
+
35
+ # Load speaker embeddings for TTS
36
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
37
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
38
+
39
+ # Set target data type and max range for speech
40
+ target_dtype = np.int16
41
+ max_range = np.iinfo(target_dtype).max
42
+
43
+ # Define the transcription function (Whisper ASR)
44
+ def transcribe_speech(filepath):
45
+ batch_size = 16 if torch.cuda.is_available() else 4
46
+ output = whisper_pipe(
47
+ filepath,
48
+ max_new_tokens=256,
49
+ generate_kwargs={"forced_decoder_ids": forced_decoder_ids},
50
+ chunk_length_s=30,
51
+ batch_size=batch_size,
52
+ )
53
+ return output["text"]
54
+
55
+ # Define the synthesis function (SpeechT5 TTS)
56
+ def synthesise(text):
57
+ inputs = tts_processor(text=text, return_tensors="pt")
58
+ speech = tts_model.generate_speech(
59
+ inputs["input_ids"].to(device), speaker_embeddings, vocoder=vocoder
60
+ )
61
+ return speech.cpu()
62
+
63
+ # Define the speech-to-speech translation function
64
+ def speech_to_speech_translation(audio):
65
+ # Transcribe speech
66
+ translated_text = transcribe_speech(audio)
67
+ # Synthesize speech
68
+ synthesised_speech = synthesise(translated_text)
69
+ # Convert speech to desired format
70
+ synthesised_speech = (synthesised_speech.numpy() * max_range).astype(np.int16)
71
+ return 16000, synthesised_speech
72
+
73
+ # Define the Gradio interfaces for microphone input and file upload
74
+ mic_translate = gr.Interface(
75
+ fn=speech_to_speech_translation,
76
+ inputs=gr.Audio(source="microphone", type="filepath"),
77
+ outputs=gr.Audio(label="Generated Speech", type="numpy"),
78
+ )
79
+
80
+ file_translate = gr.Interface(
81
+ fn=speech_to_speech_translation,
82
+ inputs=gr.Audio(source="upload", type="filepath"),
83
+ outputs=gr.Audio(label="Generated Speech", type="numpy"),
84
+ )
85
+
86
+ # Define the Gradio interfaces for transcription
87
+ mic_transcribe = gr.Interface(
88
+ fn=transcribe_speech,
89
+ inputs=gr.Audio(source="microphone", type="filepath"),
90
+ outputs=gr.Textbox(),
91
+ )
92
+
93
+ file_transcribe = gr.Interface(
94
+ fn=transcribe_speech,
95
+ inputs=gr.Audio(source="upload", type="filepath"),
96
+ outputs=gr.Textbox(),
97
+ )
98
+
99
+ # Create the app using Gradio Blocks with tabbed interfaces
100
+ demo = gr.Blocks()
101
+
102
+ with demo:
103
+ gr.TabbedInterface(
104
+ [
105
+ mic_transcribe, file_transcribe, # For transcription
106
+ mic_translate, file_translate # For speech-to-speech translation
107
+ ],
108
+ [
109
+ "Transcribe Microphone", "Transcribe Audio File",
110
+ "Translate Microphone", "Translate Audio File"
111
+ ]
112
+ )
113
+
114
+ # Launch the app with debugging enabled
115
+ if __name__ == "__main__":
116
+ demo.launch(debug=True, share=True)