Spaces:
Sleeping
Sleeping
import gradio as gr | |
import soundfile | |
import time | |
import torch | |
import scipy.io.wavfile | |
from espnet2.utils.types import str_or_none | |
from espnet2.bin.asr_inference import Speech2Text | |
from subprocess import call | |
import os | |
from espnet_model_zoo.downloader import ModelDownloader | |
# print(a1) | |
# exit() | |
# exit() | |
# tagen = 'kan-bayashi/ljspeech_vits' | |
# vocoder_tagen = "none" | |
audio_class_str='0."dog", 1."rooster", 2."pig", 3."cow", 4."frog", 5."cat", 6."hen", 7."insects", 8."sheep", 9."crow", 10."rain", 11."sea waves", 12."crackling fire", 13."crickets", 14."chirping birds", 15."water drops", 16."wind", 17."pouring water", 18."toilet flush", 19."thunderstorm", 20."crying baby", 21."sneezing", 22."clapping", 23."breathing", 24."coughing", 25."footsteps", 26."laughing", 27."brushing teeth", 28."snoring", 29."drinking sipping", 30."door wood knock", 31."mouse click", 32."keyboard typing", 33."door wood creaks", 34."can opening", 35."washing machine", 36."vacuum cleaner", 37."clock alarm", 38."clock tick", 39."glass breaking", 40."helicopter", 41."chainsaw", 42."siren", 43."car horn", 44."engine", 45."train", 46."church bells", 47."airplane", 48."fireworks", 49."hand saw".' | |
audio_class_arr=audio_class_str.split(", ") | |
audio_class_arr=[k.split('"')[1] for k in audio_class_arr] | |
def inference(wav,data): | |
# import pdb;pdb.set_trace() | |
with torch.no_grad(): | |
speech, rate = soundfile.read(wav) | |
if len(speech.shape)==2: | |
speech=speech[:,0] | |
if data == "English intent classification and named entity recognition task based on the SLURP database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|ner|> <|SLURP|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
beam_size=1, | |
ctc_weight=0.0, | |
penalty=0.1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("in:","") | |
scenario=intent.split("_")[0] | |
action=intent.split("_")[1] | |
ner_text=text.split(" SEP ")[1:-1] | |
text="INTENT: {scenario: "+scenario+", action: "+action+"}\n" | |
text=text+"NAMED ENTITIES: {" | |
for k in ner_text: | |
slot_name=k.split(" FILL ")[0].replace("sl:","") | |
slot_val=k.split(" FILL ")[1] | |
text=text+" "+slot_name+" : "+slot_val+"," | |
text=text+"}" | |
elif data == "English intent classification task based on the FSC database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|ic|> <|fsc|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("in:","") | |
action=intent.split("_")[0] | |
objects=intent.split("_")[1] | |
location=intent.split("_")[2] | |
text="INTENT: {action: "+action+", object: "+objects+", location: "+location+"}" | |
elif data == "English intent classification task based on the SNIPS database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|ic|> <|SNIPS|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("in:","") | |
text="INTENT: "+intent | |
elif data == "Dutch speech command recognition task based on the Grabo database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|nl|> <|scr|> <|grabo_scr|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0] | |
text="SPEECH COMMAND: "+intent | |
elif data == "English speech command recognition task based on the Google Speech Commands database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|scr|> <|google_scr|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("command:","") | |
text="SPEECH COMMAND: "+intent | |
elif data == "Lithuanian speech command recognition task based on the Lithuanian SC database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token= "<|lt|> <|scr|> <|lt_scr|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text | |
text="SPEECH COMMAND: "+intent | |
elif data == "Arabic speech command recognition task based on the Arabic SC database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token= "<|ar|> <|scr|> <|ar_scr|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("command:","") | |
text="SPEECH COMMAND: "+intent | |
elif data == "Language Identification task based on the VoxForge database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lid_prompt=True, | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
# import pdb;pdb.set_trace() | |
lang=speech2text.converter.tokenizer.tokenizer.convert_ids_to_tokens(nbests[0][2][0]).replace("|>","").replace("<|","") | |
text="LANG: "+lang | |
elif data == "English Fake Speech Detection task based on the ASVSpoof database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|fsd|> <|asvspoof|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("class:","") | |
text="SPEECH CLASS: "+intent | |
elif data == "English emotion recognition task based on the IEMOCAP database": | |
replace_dict={} | |
replace_dict["em:neu"]="Neutral" | |
replace_dict["em:ang"]="Angry" | |
replace_dict["em:sad"]="Sad" | |
replace_dict["em:hap"]="Happy" | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|er|> <|iemocap|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=replace_dict[text.split(" ")[0]] | |
text="EMOTION: "+intent | |
elif data == "English accent classification task based on the Accent DB database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|accent_rec|> <|accentdb|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("accent:","") | |
text="ACCENT: "+intent | |
elif data == "English sarcasm detection task based on the MUStARD database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|scd|> <|mustard|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("class:","") | |
text="SARCASM CLASS: "+intent | |
elif data == "English sarcasm detection task based on the MUStARD++ database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|scd|> <|mustard_plus_plus|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("class:","") | |
text="SARCASM CLASS: "+intent | |
elif data == "English gender identification task based on the VoxCeleb1 database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|gid|> <|voxceleb|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("gender:f","female").replace("gender:m","male") | |
text="GENDER: "+intent | |
elif data == "Audio classification task based on the ESC-50 database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|audio|> <|auc|> <|esc50|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1] | |
intent=text.split(" ")[0].replace("audio_class:","") | |
text="AUDIO EVENT CLASS: "+audio_class_arr[int(intent)] | |
elif data == "English semantic parsing task based on the STOP database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lang_prompt_token="<|en|> <|sp|> <|STOP|>", | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
penalty=0.1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
text, *_ = nbests[0] | |
text=text.split("|>")[-1].replace("_STOP","") | |
text="SEMANTIC PARSE SEQUENCE: "+text | |
elif data == "Voice activity detection task based on the Google Speech Commands v2 and Freesound database": | |
speech2text = Speech2Text.from_pretrained( | |
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml", | |
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth", | |
# Decoding parameters are not included in the model file | |
lid_prompt=True, | |
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt", | |
ctc_weight=0.0, | |
beam_size=1, | |
nbest=1 | |
) | |
nbests = speech2text(speech) | |
lang=speech2text.converter.tokenizer.tokenizer.convert_ids_to_tokens(nbests[0][2][0]) | |
if lang=="<|nospeech|>": | |
text="VAD: no speech" | |
else: | |
text="VAD: speech" | |
# if lang == "chinese": | |
# wav = text2speechch(text)["wav"] | |
# scipy.io.wavfile.write("out.wav",text2speechch.fs , wav.view(-1).cpu().numpy()) | |
# if lang == "japanese": | |
# wav = text2speechjp(text)["wav"] | |
# scipy.io.wavfile.write("out.wav",text2speechjp.fs , wav.view(-1).cpu().numpy()) | |
return text | |
title = "UniverSLU" | |
description = "Gradio demo for UniverSLU Task Specifier (https://huggingface.co/espnet/UniverSLU-17-Task-Specifier). UniverSLU-17 Task Specifier is a Multi-task Spoken Language Understanding model from CMU WAVLab. It adapts Whisper to additional tasks using single-token task specifiers. To use it, simply record your audio or click one of the examples to load them. More details about the SLU tasks that the model is trained on and it's performance on these tasks can be found in our paper: https://aclanthology.org/2024.naacl-long.151/" | |
article = "<p style='text-align: center'><a href='https://github.com/espnet/espnet' target='_blank'>Github Repo</a></p>" | |
examples=[['audio_slurp_ner.flac',"English intent classification and named entity recognition task based on the SLURP database"],['audio_fsc.wav',"English intent classification task based on the FSC database"],['audio_grabo.wav',"Dutch speech command recognition task based on the Grabo database"],['audio_english_scr.wav',"English speech command recognition task based on the Google Speech Commands database"],['audio_lt_scr.wav',"Lithuanian speech command recognition task based on the Lithuanian SC database"],['audio_ar_scr.wav',"Arabic speech command recognition task based on the Arabic SC database"],['audio_snips.wav',"English intent classification task based on the SNIPS database"],['audio_lid.wav',"Language Identification task based on the VoxForge database"],['audio_fsd.wav',"English Fake Speech Detection task based on the ASVSpoof database"],['audio_er.wav',"English emotion recognition task based on the IEMOCAP database"],['audio_acc.wav',"English accent classification task based on the Accent DB database"],['audio_mustard.wav',"English sarcasm detection task based on the MUStARD database"],['audio_mustard_plus.wav',"English sarcasm detection task based on the MUStARD++ database"],['audio_voxceleb1.wav',"English gender identification task based on the VoxCeleb1 database"],['audio_esc50.wav',"Audio classification task based on the ESC-50 database"],['audio_stop.wav',"English semantic parsing task based on the STOP database"],['audio_freesound.wav',"Voice activity detection task based on the Google Speech Commands v2 and Freesound database"]] | |
# gr.inputs.Textbox(label="input text",lines=10),gr.inputs.Radio(choices=["english"], type="value", default="english", label="language") | |
gr.Interface( | |
inference, | |
[gr.Audio(label="input audio",sources=["microphone"],type="filepath"),gr.Radio(choices=["English intent classification and named entity recognition task based on the SLURP database","English intent classification task based on the FSC database","Dutch speech command recognition task based on the Grabo database","English speech command recognition task based on the Google Speech Commands database","Lithuanian speech command recognition task based on the Lithuanian SC database","Arabic speech command recognition task based on the Arabic SC database","English intent classification task based on the SNIPS database","Language Identification task based on the VoxForge database","English Fake Speech Detection task based on the ASVSpoof database","English emotion recognition task based on the IEMOCAP database","English accent classification task based on the Accent DB database","English sarcasm detection task based on the MUStARD database","English sarcasm detection task based on the MUStARD++ database","English gender identification task based on the VoxCeleb1 database","Audio classification task based on the ESC-50 database","English semantic parsing task based on the STOP database","Voice activity detection task based on the Google Speech Commands v2 and Freesound database"], type="value", label="Task")], | |
gr.Textbox(type="text", label="Output"), | |
title=title, | |
cache_examples=False, | |
description=description, | |
article=article, | |
examples=examples | |
).launch(debug=True) |