Spaces:
Runtime error
Runtime error
Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from haystack.schema import Answer
|
5 |
+
from haystack.document_stores import InMemoryDocumentStore
|
6 |
+
from haystack.pipeline import FAQPipeline
|
7 |
+
from haystack.retriever.dense import EmbeddingRetriever
|
8 |
+
from haystack.utils import print_answers
|
9 |
+
import logging
|
10 |
+
|
11 |
+
#Haystack function calls - streamlit structure from Tuana GoT QA Haystack demo
|
12 |
+
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True) # use streamlit cache
|
13 |
+
|
14 |
+
def start_haystack():
|
15 |
+
document_store = InMemoryDocumentStore(index="document", embedding_field='embedding', embedding_dim=384, similarity='cosine')
|
16 |
+
retriever = EmbeddingRetriever(document_store=document_store, embedding_model='sentence-transformers/all-MiniLM-L6-v2', use_gpu=True, top_k=1)
|
17 |
+
load_data_to_store(document_store,retriever)
|
18 |
+
pipeline = FAQPipeline(retriever=retriever)
|
19 |
+
return pipeline
|
20 |
+
|
21 |
+
def load_data_to_store(document_store, retriever):
|
22 |
+
df = pd.read_csv('monopoly_qa-v1.csv')
|
23 |
+
questions = list(df.Question)
|
24 |
+
df['embedding'] = retriever.embed_queries(texts=questions)
|
25 |
+
df = df.rename(columns={"Question":"content","Answer":"answer"})
|
26 |
+
df.drop('link to source (to prevent duplicate sources)',axis=1, inplace=True)
|
27 |
+
|
28 |
+
dicts = df.to_dict(orient="records")
|
29 |
+
document_store.write_documents(dicts)
|
30 |
+
|
31 |
+
pipeline = start_haystack()
|
32 |
+
|
33 |
+
def predict(question):
|
34 |
+
predictions = pipeline.run(question)
|
35 |
+
answer = predictions["answers"]
|
36 |
+
return answer
|
37 |
+
|
38 |
+
gr.Interface(
|
39 |
+
predict,
|
40 |
+
inputs=gr.inputs.Textbox(label="enter your monopoly question here"),
|
41 |
+
outputs=gr.outputs.Label(num_top_classes=1),
|
42 |
+
title="Monopoly FAQ Semantic Search",
|
43 |
+
).launch()
|