import gradio as gr
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import transformers
import torch
import time
from nemo.collections.asr.models import ASRModel
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
SAMPLE_RATE = 16000 # Hz
MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
DESCRIPTION = '''
MyAlexa: Voice Chat Assistant
MyAlexa is a demo of a voice chat assistant with chat logs that accepts audio input and outputs an AI response.
This space uses NVIDIA Canary 1B for Automatic Speech-to-text Recognition (ASR), Meta Llama 3 8B Insruct for the large language model (LLM) and VITS for text to speech (TTS).
This demo accepts audio inputs not more than 40 seconds long.
Transcription and responses are limited to the English language.
'''
PLACEHOLDER = """
What's on your mind?
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### ASR model
canary_model = ASRModel.from_pretrained("nvidia/canary-1b").to(device)
canary_model.eval()
# make sure beam size always 1 for consistency
canary_model.change_decoding_strategy(None)
decoding_cfg = canary_model.cfg.decoding
decoding_cfg.beam.beam_size = 1
canary_model.change_decoding_strategy(decoding_cfg)
### LLM model
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
def convert_audio(audio_filepath, tmpdir, utt_id):
"""
Convert all files to monochannel 16 kHz wav files.
Do not convert and raise error if audio is too long.
Returns output filename and duration.
"""
data, sr = librosa.load(audio_filepath, sr=None, mono=True)
duration = librosa.get_duration(y=data, sr=sr)
if duration > MAX_AUDIO_SECONDS:
raise gr.Error(
f"This demo can transcribe up to {MAX_AUDIO_SECONDS} seconds of audio. "
"If you wish, you may trim the audio using the Audio viewer in Step 1 "
"(click on the scissors icon to start trimming audio)."
)
if sr != SAMPLE_RATE:
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
out_filename = os.path.join(tmpdir, utt_id + '.wav')
# save output audio
sf.write(out_filename, data, SAMPLE_RATE)
return out_filename, duration
def transcribe(audio_filepath):
"""
Transcribes a converted audio file.
Set to english language with punctuations.
Returns the output text.
"""
if audio_filepath is None:
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
utt_id = uuid.uuid4()
with tempfile.TemporaryDirectory() as tmpdir:
converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
# make manifest file and save
manifest_data = {
"audio_filepath": converted_audio_filepath,
"source_lang": "en",
"target_lang": "en",
"taskname": "asr",
"pnc": "yes",
"answer": "predict",
"duration": str(duration),
}
manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
with open(manifest_filepath, 'w') as fout:
line = json.dumps(manifest_data)
fout.write(line + '\n')
# call transcribe, passing in manifest filepath
output_text = canary_model.transcribe(manifest_filepath)[0]
return output_text
def add_message(history, message):
"""
Adds the input message in the chatbot.
Returns the updated chatbot with an empty input textbox.
"""
history.append((message, None))
return history
def bot(history,message):
"""
Prints the LLM's response in the chatbot
"""
#response = bot_response(message)
response = "bot_response(message)"
history[-1][1] = ""
for character in response:
history[-1][1] += character
time.sleep(0.05)
yield history
def bot_response(message: str,
history: list,
temperature: float,
max_new_tokens: int
) -> str: # type: ignore
"""
Generate a streaming response using the llama3-8b model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated response.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(llama3_model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids= input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=llama3_model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
#print(outputs)
yield "".join(outputs)
with gr.Blocks(
title="MyAlexa",
css="""
textarea { font-size: 18px;}
""",
theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
) as demo:
gr.HTML(DESCRIPTION)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
placeholder=PLACEHOLDER,
label='MyAlexa'
)
with gr.Row():
with gr.Column():
gr.HTML(
"Step 1: Upload an audio file or record with your microphone.
"
)
audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
with gr.Column():
gr.HTML("Step 2: Enter audio as input and wait for MyAlexa's response.
")
submit_button = gr.Button(
value="Submit audio",
variant="primary"
)
chat_input = gr.Textbox(
label="Transcribed text:",
interactive=False,
placeholder="Enter message",
elem_id="chat_input",
visible=True
)
chat_msg = chat_input.change(add_message, [chatbot, chat_input], [chatbot])
bot_msg = chat_msg.then(bot, [chatbot, chat_input], chatbot, api_name="bot_response")
# bot_msg.then(lambda: gr.Textbox(interactive=False), None, [chat_input])
submit_button.click(
fn=transcribe,
inputs = [audio_file],
outputs = [chat_input]
)
demo.queue()
if __name__ == "__main__":
demo.launch()