NinaAchache
commited on
Commit
·
a5686cb
1
Parent(s):
78ffcff
Add app.py & requirements
Browse filesApp.py is V0 not clean yet
api_key is asked
V0 to test Requirements.txt
- app.py +123 -0
- requirements.txt +23 -0
app.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
from haystack.document_stores import FAISSDocumentStore
|
4 |
+
from haystack.nodes import EmbeddingRetriever
|
5 |
+
import numpy as np
|
6 |
+
import openai
|
7 |
+
|
8 |
+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
9 |
+
|
10 |
+
system_template = {
|
11 |
+
"role": "system",
|
12 |
+
"content": "You have been a climate change expert for 30 years. You answer questions about climate change in an educationnal and concise manner.",
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
document_store = FAISSDocumentStore.load(
|
17 |
+
index_path=f"./climate_gpt.faiss",
|
18 |
+
config_path=f"./climate_gpt.json",
|
19 |
+
)
|
20 |
+
dense = EmbeddingRetriever(
|
21 |
+
document_store=document_store,
|
22 |
+
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
23 |
+
model_format="sentence_transformers",
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def is_climate_change_related(sentence: str) -> bool:
|
28 |
+
results = classifier(
|
29 |
+
sequences=sentence,
|
30 |
+
candidate_labels=["climate change related", "non climate change related"],
|
31 |
+
)
|
32 |
+
return results["labels"][np.argmax(results["scores"])] == "climate change related"
|
33 |
+
|
34 |
+
|
35 |
+
def make_pairs(lst):
|
36 |
+
"""from a list of even lenght, make tupple pairs"""
|
37 |
+
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
|
38 |
+
|
39 |
+
|
40 |
+
def gen_conv(query: str, history=[system_template], ipcc=True):
|
41 |
+
"""return (answer:str, history:list[dict], sources:str)"""
|
42 |
+
retrieve = ipcc and is_climate_change_related(query)
|
43 |
+
sources = ""
|
44 |
+
messages = history + [
|
45 |
+
{"role": "user", "content": query},
|
46 |
+
]
|
47 |
+
|
48 |
+
if retrieve:
|
49 |
+
docs = dense.retrieve(query=query, top_k=5)
|
50 |
+
sources = "\n\n".join(
|
51 |
+
["If relevant, use those extracts from IPCC reports in your answer"]
|
52 |
+
+ [
|
53 |
+
f"{d.meta['path']} Page {d.meta['page_id']} paragraph {d.meta['paragraph_id']}:\n{d.content}"
|
54 |
+
for d in docs
|
55 |
+
]
|
56 |
+
)
|
57 |
+
messages.append({"role": "system", "content": sources})
|
58 |
+
|
59 |
+
answer = openai.ChatCompletion.create(
|
60 |
+
model="gpt-3.5-turbo",
|
61 |
+
messages=messages,
|
62 |
+
temperature=0.2,
|
63 |
+
# max_tokens=200,
|
64 |
+
)["choices"][0]["message"]["content"]
|
65 |
+
|
66 |
+
if retrieve:
|
67 |
+
messages.pop()
|
68 |
+
answer = "(top 5 documents retrieved) " + answer
|
69 |
+
sources = "\n\n".join(
|
70 |
+
f"{d.meta['path']} Page {d.meta['page_id']} paragraph {d.meta['paragraph_id']}:\n{d.content[:100]} [...]"
|
71 |
+
for d in docs
|
72 |
+
)
|
73 |
+
|
74 |
+
messages.append({"role": "assistant", "content": answer})
|
75 |
+
|
76 |
+
gradio_format = make_pairs([a["content"] for a in messages[1:]])
|
77 |
+
|
78 |
+
return gradio_format, messages, sources
|
79 |
+
|
80 |
+
|
81 |
+
def connect(text):
|
82 |
+
openai.api_key = text
|
83 |
+
return "You're all set"
|
84 |
+
|
85 |
+
|
86 |
+
with gr.Blocks(title="Eki IPCC Explorer") as demo:
|
87 |
+
with gr.Row():
|
88 |
+
with gr.Column():
|
89 |
+
api_key = gr.Textbox(label="Open AI api key")
|
90 |
+
connect_btn = gr.Button(value="Connect")
|
91 |
+
with gr.Column():
|
92 |
+
result = gr.Textbox(label="Connection")
|
93 |
+
|
94 |
+
connect_btn.click(connect, inputs=api_key, outputs=result, api_name="Connection")
|
95 |
+
|
96 |
+
gr.Markdown(
|
97 |
+
"""
|
98 |
+
# Ask me anything, I'm an IPCC report
|
99 |
+
"""
|
100 |
+
)
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
with gr.Column(scale=2):
|
104 |
+
chatbot = gr.Chatbot()
|
105 |
+
state = gr.State([system_template])
|
106 |
+
|
107 |
+
with gr.Row():
|
108 |
+
ask = gr.Textbox(
|
109 |
+
show_label=False, placeholder="Enter text and press enter"
|
110 |
+
).style(container=False)
|
111 |
+
|
112 |
+
with gr.Column(scale=1, variant="panel"):
|
113 |
+
|
114 |
+
gr.Markdown("### Sources")
|
115 |
+
sources_textbox = gr.Textbox(
|
116 |
+
interactive=False, show_label=False, max_lines=50
|
117 |
+
)
|
118 |
+
|
119 |
+
ask.submit(
|
120 |
+
fn=gen_conv, inputs=[ask, state], outputs=[chatbot, state, sources_textbox]
|
121 |
+
)
|
122 |
+
|
123 |
+
demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
faiss-cpu==1.7.2
|
3 |
+
farm-haystack==1.14.0
|
4 |
+
gradio==3.20.1
|
5 |
+
huggingface-hub==0.12.1
|
6 |
+
mlflow==2.2.1
|
7 |
+
mmh3==3.0.0
|
8 |
+
openai==0.27.0
|
9 |
+
orjson==3.8.7
|
10 |
+
pandas==1.5.3
|
11 |
+
simplejson==3.18.3
|
12 |
+
six==1.16.0
|
13 |
+
slicer==0.0.7
|
14 |
+
smmap==5.0.0
|
15 |
+
SQLAlchemy==1.4.46
|
16 |
+
SQLAlchemy-Utils==0.40.0
|
17 |
+
sqlparse==0.4.3
|
18 |
+
tokenizers==0.13.2
|
19 |
+
torch==1.13.1
|
20 |
+
torchvision==0.14.1
|
21 |
+
transformers==4.25.1
|
22 |
+
trove-classifiers==2023.2.20
|
23 |
+
|