NinaAchache commited on
Commit
a5686cb
·
1 Parent(s): 78ffcff

Add app.py & requirements

Browse files

App.py is V0 not clean yet
api_key is asked
V0 to test Requirements.txt

Files changed (2) hide show
  1. app.py +123 -0
  2. requirements.txt +23 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from haystack.document_stores import FAISSDocumentStore
4
+ from haystack.nodes import EmbeddingRetriever
5
+ import numpy as np
6
+ import openai
7
+
8
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
9
+
10
+ system_template = {
11
+ "role": "system",
12
+ "content": "You have been a climate change expert for 30 years. You answer questions about climate change in an educationnal and concise manner.",
13
+ }
14
+
15
+
16
+ document_store = FAISSDocumentStore.load(
17
+ index_path=f"./climate_gpt.faiss",
18
+ config_path=f"./climate_gpt.json",
19
+ )
20
+ dense = EmbeddingRetriever(
21
+ document_store=document_store,
22
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
23
+ model_format="sentence_transformers",
24
+ )
25
+
26
+
27
+ def is_climate_change_related(sentence: str) -> bool:
28
+ results = classifier(
29
+ sequences=sentence,
30
+ candidate_labels=["climate change related", "non climate change related"],
31
+ )
32
+ return results["labels"][np.argmax(results["scores"])] == "climate change related"
33
+
34
+
35
+ def make_pairs(lst):
36
+ """from a list of even lenght, make tupple pairs"""
37
+ return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
38
+
39
+
40
+ def gen_conv(query: str, history=[system_template], ipcc=True):
41
+ """return (answer:str, history:list[dict], sources:str)"""
42
+ retrieve = ipcc and is_climate_change_related(query)
43
+ sources = ""
44
+ messages = history + [
45
+ {"role": "user", "content": query},
46
+ ]
47
+
48
+ if retrieve:
49
+ docs = dense.retrieve(query=query, top_k=5)
50
+ sources = "\n\n".join(
51
+ ["If relevant, use those extracts from IPCC reports in your answer"]
52
+ + [
53
+ f"{d.meta['path']} Page {d.meta['page_id']} paragraph {d.meta['paragraph_id']}:\n{d.content}"
54
+ for d in docs
55
+ ]
56
+ )
57
+ messages.append({"role": "system", "content": sources})
58
+
59
+ answer = openai.ChatCompletion.create(
60
+ model="gpt-3.5-turbo",
61
+ messages=messages,
62
+ temperature=0.2,
63
+ # max_tokens=200,
64
+ )["choices"][0]["message"]["content"]
65
+
66
+ if retrieve:
67
+ messages.pop()
68
+ answer = "(top 5 documents retrieved) " + answer
69
+ sources = "\n\n".join(
70
+ f"{d.meta['path']} Page {d.meta['page_id']} paragraph {d.meta['paragraph_id']}:\n{d.content[:100]} [...]"
71
+ for d in docs
72
+ )
73
+
74
+ messages.append({"role": "assistant", "content": answer})
75
+
76
+ gradio_format = make_pairs([a["content"] for a in messages[1:]])
77
+
78
+ return gradio_format, messages, sources
79
+
80
+
81
+ def connect(text):
82
+ openai.api_key = text
83
+ return "You're all set"
84
+
85
+
86
+ with gr.Blocks(title="Eki IPCC Explorer") as demo:
87
+ with gr.Row():
88
+ with gr.Column():
89
+ api_key = gr.Textbox(label="Open AI api key")
90
+ connect_btn = gr.Button(value="Connect")
91
+ with gr.Column():
92
+ result = gr.Textbox(label="Connection")
93
+
94
+ connect_btn.click(connect, inputs=api_key, outputs=result, api_name="Connection")
95
+
96
+ gr.Markdown(
97
+ """
98
+ # Ask me anything, I'm an IPCC report
99
+ """
100
+ )
101
+
102
+ with gr.Row():
103
+ with gr.Column(scale=2):
104
+ chatbot = gr.Chatbot()
105
+ state = gr.State([system_template])
106
+
107
+ with gr.Row():
108
+ ask = gr.Textbox(
109
+ show_label=False, placeholder="Enter text and press enter"
110
+ ).style(container=False)
111
+
112
+ with gr.Column(scale=1, variant="panel"):
113
+
114
+ gr.Markdown("### Sources")
115
+ sources_textbox = gr.Textbox(
116
+ interactive=False, show_label=False, max_lines=50
117
+ )
118
+
119
+ ask.submit(
120
+ fn=gen_conv, inputs=[ask, state], outputs=[chatbot, state, sources_textbox]
121
+ )
122
+
123
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ faiss-cpu==1.7.2
3
+ farm-haystack==1.14.0
4
+ gradio==3.20.1
5
+ huggingface-hub==0.12.1
6
+ mlflow==2.2.1
7
+ mmh3==3.0.0
8
+ openai==0.27.0
9
+ orjson==3.8.7
10
+ pandas==1.5.3
11
+ simplejson==3.18.3
12
+ six==1.16.0
13
+ slicer==0.0.7
14
+ smmap==5.0.0
15
+ SQLAlchemy==1.4.46
16
+ SQLAlchemy-Utils==0.40.0
17
+ sqlparse==0.4.3
18
+ tokenizers==0.13.2
19
+ torch==1.13.1
20
+ torchvision==0.14.1
21
+ transformers==4.25.1
22
+ trove-classifiers==2023.2.20
23
+