Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import models as MOD | |
import process_data as PD | |
from transformers import pipeline | |
model_master = { | |
"SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold": 3.3330237865448, | |
"data_process_func": "process_ssl_assist_input", | |
"note": "This model is trained only on ASVSpoof 2024 training data.", | |
"model_class": "Model", | |
"model_checkpoint": "ssl_aasist_epoch_7.pth"}, | |
"AASIST": {"eer_threshold": 1.8018419742584229, | |
"data_process_func": "process_assist_input", | |
"note": "This model is trained on ASVSpoof 2024 training data.", | |
"model_class":"AASIST_Model", | |
"model_checkpoint": "orig_aasist_epoch_1.pth"} | |
} | |
model = MOD.Model(None, "cpu") | |
model.load_state_dict(torch.load("ssl_aasist_epoch_7.pth", map_location="cpu")) | |
model.eval() | |
loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)" | |
def process(file, type): | |
global model | |
global loaded_model | |
inp = getattr(PD, model_master[type]["data_process_func"])(file) | |
if not loaded_model == type: | |
model = getattr(MOD, model_master[type]["model_class"])(None, "cpu") | |
model.load_state_dict(torch.load(model_master[type]["model_checkpoint"], map_location="cpu")) | |
model.eval() | |
loaded_model = type | |
op = model(inp).detach().squeeze()[1].item() | |
response_text = "Decision score: {} \nDecision threshold: {} \nNotes: 1. Any score below threshold is indicative of fake. \n2. {} ".format( | |
str(op), str(model_master[type]["eer_threshold"]), model_master[type]["note"]) | |
return response_text | |
demo = gr.Blocks() | |
file_proc = gr.Interface( | |
fn=process, | |
inputs=[ | |
gr.Audio(sources=["upload"], label="Audio file", type="filepath"), | |
gr.Radio(["SSL-AASIST (Trained on ASV-Spoof5)", "AASIST"], label="Select Model", type="value"), | |
], | |
outputs="text", | |
title="Find the Fake: Analyze 'Real' or 'Fake'.", | |
description=( | |
"Analyze fake or real with a click of a button. Upload a .wav or .flac file." | |
), | |
examples=[ | |
["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"], | |
["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"], | |
["./bonafide.flac", "AASIST"], | |
["./fake.flac", "AASIST"], | |
], | |
cache_examples=True, | |
allow_flagging="never", | |
) | |
##################################################################################### | |
# For ASR interface | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model="openai/whisper-large-v3", | |
chunk_length_s=30, | |
device="cpu", | |
) | |
def transcribe(inputs): | |
if inputs is None: | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
op = pipe(inputs, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=False, return_language=True) | |
lang = op["chunks"][0]["language"] | |
text = op["text"] | |
return lang, text | |
transcribe_proc = gr.Interface( | |
fn = transcribe, | |
inputs = [ | |
gr.Audio(type="filepath", label="Speech file (<30s)", max_length=30, sources=["microphone", "upload"], show_download_button=True) | |
], | |
outputs=[ | |
gr.Text(label="Predicted Language", info="Language identification is performed automatically."), | |
gr.Text(label="Predicted transcription", info="Best hypothesis."), | |
], | |
title="Transcribe Anything.", | |
description=( | |
"Automatactic language identification and transcription service by Whisper Large V3. Upload a .wav or .flac file." | |
), | |
allow_flagging="never" | |
) | |
with demo: | |
gr.TabbedInterface([file_proc, transcribe_proc], ["Analyze Audio File", "Transcribe Audio File"]) | |
demo.queue(max_size=10) | |
demo.launch(share=True) | |