Timjo88's picture
Create new file
421b7df
raw
history blame
1.68 kB
import gradio as gr
import pandas as pd
from haystack.schema import Answer
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipeline import FAQPipeline
from haystack.retriever.dense import EmbeddingRetriever
from haystack.utils import print_answers
import logging
#Haystack function calls - streamlit structure from Tuana GoT QA Haystack demo
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True) # use streamlit cache
def start_haystack():
document_store = InMemoryDocumentStore(index="document", embedding_field='embedding', embedding_dim=384, similarity='cosine')
retriever = EmbeddingRetriever(document_store=document_store, embedding_model='sentence-transformers/all-MiniLM-L6-v2', use_gpu=True, top_k=1)
load_data_to_store(document_store,retriever)
pipeline = FAQPipeline(retriever=retriever)
return pipeline
def load_data_to_store(document_store, retriever):
df = pd.read_csv('monopoly_qa-v1.csv')
questions = list(df.Question)
df['embedding'] = retriever.embed_queries(texts=questions)
df = df.rename(columns={"Question":"content","Answer":"answer"})
df.drop('link to source (to prevent duplicate sources)',axis=1, inplace=True)
dicts = df.to_dict(orient="records")
document_store.write_documents(dicts)
pipeline = start_haystack()
def predict(question):
predictions = pipeline.run(question)
answer = predictions["answers"]
return answer
gr.Interface(
predict,
inputs=gr.inputs.Textbox(label="enter your monopoly question here"),
outputs=gr.outputs.Label(num_top_classes=1),
title="Monopoly FAQ Semantic Search",
).launch()