Timjo88 commited on
Commit
421b7df
Β·
1 Parent(s): 4c072a3

Create new file

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+
4
+ from haystack.schema import Answer
5
+ from haystack.document_stores import InMemoryDocumentStore
6
+ from haystack.pipeline import FAQPipeline
7
+ from haystack.retriever.dense import EmbeddingRetriever
8
+ from haystack.utils import print_answers
9
+ import logging
10
+
11
+ #Haystack function calls - streamlit structure from Tuana GoT QA Haystack demo
12
+ @st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True) # use streamlit cache
13
+
14
+ def start_haystack():
15
+ document_store = InMemoryDocumentStore(index="document", embedding_field='embedding', embedding_dim=384, similarity='cosine')
16
+ retriever = EmbeddingRetriever(document_store=document_store, embedding_model='sentence-transformers/all-MiniLM-L6-v2', use_gpu=True, top_k=1)
17
+ load_data_to_store(document_store,retriever)
18
+ pipeline = FAQPipeline(retriever=retriever)
19
+ return pipeline
20
+
21
+ def load_data_to_store(document_store, retriever):
22
+ df = pd.read_csv('monopoly_qa-v1.csv')
23
+ questions = list(df.Question)
24
+ df['embedding'] = retriever.embed_queries(texts=questions)
25
+ df = df.rename(columns={"Question":"content","Answer":"answer"})
26
+ df.drop('link to source (to prevent duplicate sources)',axis=1, inplace=True)
27
+
28
+ dicts = df.to_dict(orient="records")
29
+ document_store.write_documents(dicts)
30
+
31
+ pipeline = start_haystack()
32
+
33
+ def predict(question):
34
+ predictions = pipeline.run(question)
35
+ answer = predictions["answers"]
36
+ return answer
37
+
38
+ gr.Interface(
39
+ predict,
40
+ inputs=gr.inputs.Textbox(label="enter your monopoly question here"),
41
+ outputs=gr.outputs.Label(num_top_classes=1),
42
+ title="Monopoly FAQ Semantic Search",
43
+ ).launch()