Find_The_Fake / app.py
Arnab Das
reverting changes
1e4d7e8
raw
history blame
3.89 kB
import torch
import gradio as gr
import models as MOD
import process_data as PD
from transformers import pipeline
model_master = {
"SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold": 3.3330237865448,
"data_process_func": "process_ssl_assist_input",
"note": "This model is trained only on ASVSpoof 2024 training data.",
"model_class": "Model",
"model_checkpoint": "ssl_aasist_epoch_7.pth"},
"AASIST": {"eer_threshold": 1.8018419742584229,
"data_process_func": "process_assist_input",
"note": "This model is trained on ASVSpoof 2024 training data.",
"model_class":"AASIST_Model",
"model_checkpoint": "orig_aasist_epoch_1.pth"}
}
model = MOD.Model(None, "cpu")
model.load_state_dict(torch.load("ssl_aasist_epoch_7.pth", map_location="cpu"))
model.eval()
loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
def process(file, type):
global model
global loaded_model
inp = getattr(PD, model_master[type]["data_process_func"])(file)
if not loaded_model == type:
model = getattr(MOD, model_master[type]["model_class"])(None, "cpu")
model.load_state_dict(torch.load(model_master[type]["model_checkpoint"], map_location="cpu"))
model.eval()
loaded_model = type
op = model(inp).detach().squeeze()[1].item()
response_text = "Decision score: {} \nDecision threshold: {} \nNotes: 1. Any score below threshold is indicative of fake. \n2. {} ".format(
str(op), str(model_master[type]["eer_threshold"]), model_master[type]["note"])
return response_text
demo = gr.Blocks()
file_proc = gr.Interface(
fn=process,
inputs=[
gr.Audio(sources=["upload"], label="Audio file", type="filepath"),
gr.Radio(["SSL-AASIST (Trained on ASV-Spoof5)", "AASIST"], label="Select Model", type="value"),
],
outputs="text",
title="Find the Fake: Analyze 'Real' or 'Fake'.",
description=(
"Analyze fake or real with a click of a button. Upload a .wav or .flac file."
),
examples=[
["./bonafide.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
["./fake.flac", "SSL-AASIST (Trained on ASV-Spoof5)"],
["./bonafide.flac", "AASIST"],
["./fake.flac", "AASIST"],
],
cache_examples=True,
allow_flagging="never",
)
#####################################################################################
# For ASR interface
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-large-v3",
chunk_length_s=30,
device="cpu",
)
def transcribe(inputs):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
op = pipe(inputs, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=False, return_language=True)
lang = op["chunks"][0]["language"]
text = op["text"]
return lang, text
transcribe_proc = gr.Interface(
fn = transcribe,
inputs = [
gr.Audio(type="filepath", label="Speech file (<30s)", max_length=30, sources=["microphone", "upload"], show_download_button=True)
],
outputs=[
gr.Text(label="Predicted Language", info="Language identification is performed automatically."),
gr.Text(label="Predicted transcription", info="Best hypothesis."),
],
title="Transcribe Anything.",
description=(
"Automatactic language identification and transcription service by Whisper Large V3. Upload a .wav or .flac file."
),
allow_flagging="never"
)
with demo:
gr.TabbedInterface([file_proc, transcribe_proc], ["Analyze Audio File", "Transcribe Audio File"])
demo.queue(max_size=10)
demo.launch(share=True)