HTK / app1.py
faori's picture
Upload folder using huggingface_hub
42073db verified
import gradio as gr
import io
import os
import yaml
import pyarrow
import tokenizers
from retro_reader import RetroReader
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def from_library():
from retro_reader import constants as C
return C, RetroReader
C, RetroReader = from_library()
# Assuming RetroReader.load is a method from your imports
def load_model(config_path):
return RetroReader.load(config_file=config_path)
# Loading models
model_electra_base = load_model("configs/inference_en_electra_base.yaml")
model_electra_large = load_model("configs/inference_en_electra_large.yaml")
model_roberta = load_model("configs/inference_en_roberta.yaml")
model_distilbert = load_model("configs/inference_en_distilbert.yaml")
def retro_reader_demo(query, context, model_choice):
# Choose the model based on the model_choice
if model_choice == "Electra Base":
model = model_electra_base
elif model_choice == "Electra Large":
model = model_electra_large
elif model_choice == "Roberta":
model = model_roberta
elif model_choice == "DistilBERT":
model = model_distilbert
else:
return "Invalid model choice"
# Generate outputs using the chosen model
outputs = model(query=query, context=context, return_submodule_outputs=True)
# Extract the answer
answer = outputs[0]["id-01"] if outputs[0]["id-01"] else "No answer found"
return answer
# Gradio app interface
iface = gr.Interface(
fn=retro_reader_demo,
inputs=[
gr.Textbox(label="Query", placeholder="Type your query here..."),
gr.Textbox(label="Context", placeholder="Provide the context here...", lines=10),
gr.Radio(choices=["Electra Base", "Electra Large", "Roberta", "DistilBERT"], label="Model Choice")
],
outputs=gr.Textbox(label="Answer"),
title="Retrospective Reader Demo",
description="This interface uses the RetroReader model to perform reading comprehension tasks."
)
if __name__ == "__main__":
iface.launch(share=True)