Upload app.py
Browse files
app.py
CHANGED
@@ -22,6 +22,48 @@ LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store"
|
|
22 |
|
23 |
deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip()
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def embeddings_on_local_vectordb(texts):
|
27 |
colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9")
|
@@ -48,12 +90,18 @@ def query_llm(retriever, query):
|
|
48 |
relevant_docs = retriever.get_relevant_documents(query)
|
49 |
with get_openai_callback() as cb:
|
50 |
result = qa_chain(
|
51 |
-
{
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
stats = cb
|
54 |
result = result["answer"]
|
55 |
-
|
56 |
-
|
|
|
57 |
|
58 |
|
59 |
def input_fields():
|
@@ -202,7 +250,7 @@ def process_web(url):
|
|
202 |
|
203 |
|
204 |
def boot():
|
205 |
-
st.title("Xi Chatbot")
|
206 |
st.sidebar.title("Input Documents")
|
207 |
input_fields()
|
208 |
st.sidebar.button("Submit Documents", on_click=process_documents)
|
@@ -216,19 +264,17 @@ def boot():
|
|
216 |
|
217 |
for message in st.session_state.messages:
|
218 |
st.chat_message("human").write(message[0])
|
219 |
-
st.chat_message("ai").write(message[1])
|
220 |
if query := st.chat_input():
|
221 |
st.chat_message("human").write(query)
|
222 |
-
|
223 |
-
|
224 |
-
references_str = " ".join([f"[{ref}]" for ref in sorted_references])
|
225 |
-
st.chat_message("ai").write(response + "\n\n---\nReferences:" + references_str)
|
226 |
|
227 |
st.session_state.costing.append(
|
228 |
{
|
229 |
"prompt tokens": stats.prompt_tokens,
|
230 |
"completion tokens": stats.completion_tokens,
|
231 |
-
"
|
232 |
}
|
233 |
)
|
234 |
stats_df = pd.DataFrame(st.session_state.costing)
|
@@ -236,15 +282,10 @@ def boot():
|
|
236 |
st.sidebar.write(stats_df)
|
237 |
st.sidebar.download_button(
|
238 |
"Download Conversation",
|
239 |
-
|
240 |
-
[
|
241 |
-
{"human": message[0], "ai": message[1]}
|
242 |
-
for message in st.session_state.messages
|
243 |
-
]
|
244 |
-
),
|
245 |
"conversation.json",
|
246 |
)
|
247 |
|
248 |
|
249 |
if __name__ == "__main__":
|
250 |
-
boot()
|
|
|
22 |
|
23 |
deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip()
|
24 |
|
25 |
+
get_references = lambda relevant_docs: " ".join(
|
26 |
+
[f"[{ref}]" for ref in sorted([ref.metadata["chunk_id"] for ref in relevant_docs])]
|
27 |
+
)
|
28 |
+
session_state_2_llm_chat_history = lambda session_state: [
|
29 |
+
ss[:2] for ss in session_state
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
def get_conversation_history():
|
34 |
+
return json.dumps(
|
35 |
+
{
|
36 |
+
"document_urls": (
|
37 |
+
st.session_state.source_doc_urls
|
38 |
+
if "source_doc_urls" in st.session_state
|
39 |
+
else []
|
40 |
+
),
|
41 |
+
"document_snippets": (
|
42 |
+
st.session_state.headers.to_list()
|
43 |
+
if "headers" in st.session_state
|
44 |
+
else []
|
45 |
+
),
|
46 |
+
"conversation": [
|
47 |
+
{"human": message[0], "ai": message[1], "references": message[2]}
|
48 |
+
for message in st.session_state.messages
|
49 |
+
],
|
50 |
+
"costing": (
|
51 |
+
st.session_state.costing if "costing" in st.session_state else []
|
52 |
+
),
|
53 |
+
"total_cost": (
|
54 |
+
{
|
55 |
+
k: sum(d[k] for d in st.session_state.costing)
|
56 |
+
for k in st.session_state.costing[0]
|
57 |
+
}
|
58 |
+
if "costing" in st.session_state and len(st.session_state.costing) > 0
|
59 |
+
else {}
|
60 |
+
),
|
61 |
+
}
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
ai_message_format = lambda message, references: f"{message}\n\n---\n\n{references}"
|
66 |
+
|
67 |
|
68 |
def embeddings_on_local_vectordb(texts):
|
69 |
colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9")
|
|
|
90 |
relevant_docs = retriever.get_relevant_documents(query)
|
91 |
with get_openai_callback() as cb:
|
92 |
result = qa_chain(
|
93 |
+
{
|
94 |
+
"question": query,
|
95 |
+
"chat_history": session_state_2_llm_chat_history(
|
96 |
+
st.session_state.messages
|
97 |
+
),
|
98 |
+
}
|
99 |
)
|
100 |
stats = cb
|
101 |
result = result["answer"]
|
102 |
+
references = get_references(relevant_docs)
|
103 |
+
st.session_state.messages.append((query, result, references))
|
104 |
+
return result, references, stats
|
105 |
|
106 |
|
107 |
def input_fields():
|
|
|
250 |
|
251 |
|
252 |
def boot():
|
253 |
+
st.title("Agent Xi - An ArXiv Chatbot")
|
254 |
st.sidebar.title("Input Documents")
|
255 |
input_fields()
|
256 |
st.sidebar.button("Submit Documents", on_click=process_documents)
|
|
|
264 |
|
265 |
for message in st.session_state.messages:
|
266 |
st.chat_message("human").write(message[0])
|
267 |
+
st.chat_message("ai").write(ai_message_format(message[1], message[2]))
|
268 |
if query := st.chat_input():
|
269 |
st.chat_message("human").write(query)
|
270 |
+
response, references, stats = query_llm(st.session_state.retriever, query)
|
271 |
+
st.chat_message("ai").write(ai_message_format(response, references))
|
|
|
|
|
272 |
|
273 |
st.session_state.costing.append(
|
274 |
{
|
275 |
"prompt tokens": stats.prompt_tokens,
|
276 |
"completion tokens": stats.completion_tokens,
|
277 |
+
"cost": stats.total_cost,
|
278 |
}
|
279 |
)
|
280 |
stats_df = pd.DataFrame(st.session_state.costing)
|
|
|
282 |
st.sidebar.write(stats_df)
|
283 |
st.sidebar.download_button(
|
284 |
"Download Conversation",
|
285 |
+
get_conversation_history(),
|
|
|
|
|
|
|
|
|
|
|
286 |
"conversation.json",
|
287 |
)
|
288 |
|
289 |
|
290 |
if __name__ == "__main__":
|
291 |
+
boot()
|