Ritvik19 commited on
Commit
20c0b83
·
verified ·
1 Parent(s): b05ff4a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -17
app.py CHANGED
@@ -22,6 +22,48 @@ LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store"
22
 
23
  deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip()
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def embeddings_on_local_vectordb(texts):
27
  colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9")
@@ -48,12 +90,18 @@ def query_llm(retriever, query):
48
  relevant_docs = retriever.get_relevant_documents(query)
49
  with get_openai_callback() as cb:
50
  result = qa_chain(
51
- {"question": query, "chat_history": st.session_state.messages}
 
 
 
 
 
52
  )
53
  stats = cb
54
  result = result["answer"]
55
- st.session_state.messages.append((query, result))
56
- return relevant_docs, result, stats
 
57
 
58
 
59
  def input_fields():
@@ -202,7 +250,7 @@ def process_web(url):
202
 
203
 
204
  def boot():
205
- st.title("Xi Chatbot")
206
  st.sidebar.title("Input Documents")
207
  input_fields()
208
  st.sidebar.button("Submit Documents", on_click=process_documents)
@@ -216,19 +264,17 @@ def boot():
216
 
217
  for message in st.session_state.messages:
218
  st.chat_message("human").write(message[0])
219
- st.chat_message("ai").write(message[1])
220
  if query := st.chat_input():
221
  st.chat_message("human").write(query)
222
- references, response, stats = query_llm(st.session_state.retriever, query)
223
- sorted_references = sorted([ref.metadata["chunk_id"] for ref in references])
224
- references_str = " ".join([f"[{ref}]" for ref in sorted_references])
225
- st.chat_message("ai").write(response + "\n\n---\nReferences:" + references_str)
226
 
227
  st.session_state.costing.append(
228
  {
229
  "prompt tokens": stats.prompt_tokens,
230
  "completion tokens": stats.completion_tokens,
231
- "total cost": stats.total_cost,
232
  }
233
  )
234
  stats_df = pd.DataFrame(st.session_state.costing)
@@ -236,15 +282,10 @@ def boot():
236
  st.sidebar.write(stats_df)
237
  st.sidebar.download_button(
238
  "Download Conversation",
239
- json.dumps(
240
- [
241
- {"human": message[0], "ai": message[1]}
242
- for message in st.session_state.messages
243
- ]
244
- ),
245
  "conversation.json",
246
  )
247
 
248
 
249
  if __name__ == "__main__":
250
- boot()
 
22
 
23
  deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip()
24
 
25
+ get_references = lambda relevant_docs: " ".join(
26
+ [f"[{ref}]" for ref in sorted([ref.metadata["chunk_id"] for ref in relevant_docs])]
27
+ )
28
+ session_state_2_llm_chat_history = lambda session_state: [
29
+ ss[:2] for ss in session_state
30
+ ]
31
+
32
+
33
+ def get_conversation_history():
34
+ return json.dumps(
35
+ {
36
+ "document_urls": (
37
+ st.session_state.source_doc_urls
38
+ if "source_doc_urls" in st.session_state
39
+ else []
40
+ ),
41
+ "document_snippets": (
42
+ st.session_state.headers.to_list()
43
+ if "headers" in st.session_state
44
+ else []
45
+ ),
46
+ "conversation": [
47
+ {"human": message[0], "ai": message[1], "references": message[2]}
48
+ for message in st.session_state.messages
49
+ ],
50
+ "costing": (
51
+ st.session_state.costing if "costing" in st.session_state else []
52
+ ),
53
+ "total_cost": (
54
+ {
55
+ k: sum(d[k] for d in st.session_state.costing)
56
+ for k in st.session_state.costing[0]
57
+ }
58
+ if "costing" in st.session_state and len(st.session_state.costing) > 0
59
+ else {}
60
+ ),
61
+ }
62
+ )
63
+
64
+
65
+ ai_message_format = lambda message, references: f"{message}\n\n---\n\n{references}"
66
+
67
 
68
  def embeddings_on_local_vectordb(texts):
69
  colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9")
 
90
  relevant_docs = retriever.get_relevant_documents(query)
91
  with get_openai_callback() as cb:
92
  result = qa_chain(
93
+ {
94
+ "question": query,
95
+ "chat_history": session_state_2_llm_chat_history(
96
+ st.session_state.messages
97
+ ),
98
+ }
99
  )
100
  stats = cb
101
  result = result["answer"]
102
+ references = get_references(relevant_docs)
103
+ st.session_state.messages.append((query, result, references))
104
+ return result, references, stats
105
 
106
 
107
  def input_fields():
 
250
 
251
 
252
  def boot():
253
+ st.title("Agent Xi - An ArXiv Chatbot")
254
  st.sidebar.title("Input Documents")
255
  input_fields()
256
  st.sidebar.button("Submit Documents", on_click=process_documents)
 
264
 
265
  for message in st.session_state.messages:
266
  st.chat_message("human").write(message[0])
267
+ st.chat_message("ai").write(ai_message_format(message[1], message[2]))
268
  if query := st.chat_input():
269
  st.chat_message("human").write(query)
270
+ response, references, stats = query_llm(st.session_state.retriever, query)
271
+ st.chat_message("ai").write(ai_message_format(response, references))
 
 
272
 
273
  st.session_state.costing.append(
274
  {
275
  "prompt tokens": stats.prompt_tokens,
276
  "completion tokens": stats.completion_tokens,
277
+ "cost": stats.total_cost,
278
  }
279
  )
280
  stats_df = pd.DataFrame(st.session_state.costing)
 
282
  st.sidebar.write(stats_df)
283
  st.sidebar.download_button(
284
  "Download Conversation",
285
+ get_conversation_history(),
 
 
 
 
 
286
  "conversation.json",
287
  )
288
 
289
 
290
  if __name__ == "__main__":
291
+ boot()