XThomasBU commited on
Commit
c658776
·
1 Parent(s): 65ce8c0

working commit

Browse files
code/main.py CHANGED
@@ -60,7 +60,47 @@ class Chatbot:
60
  )
61
 
62
  chain = cl.user_session.get("chain")
63
- memory = chain.memory if chain else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  old_config = copy.deepcopy(self.config)
66
  self.config["vectorstore"]["db_option"] = retriever_method
@@ -71,7 +111,7 @@ class Chatbot:
71
  self.llm_tutor.update_llm(
72
  old_config, self.config
73
  ) # update only attributes that are changed
74
- self.chain = self.llm_tutor.qa_bot(memory=memory)
75
 
76
  tags = [chat_profile, self.config["vectorstore"]["db_option"]]
77
 
@@ -222,7 +262,7 @@ class Chatbot:
222
  rename_dict = {"Chatbot": "AI Tutor"}
223
  return rename_dict.get(orig_author, orig_author)
224
 
225
- async def start(self, thread=None, memory=[]):
226
  """
227
  Start the chatbot, initialize settings widgets,
228
  and display and load previous conversation if chat logging is enabled.
@@ -236,6 +276,8 @@ class Chatbot:
236
  }
237
  print(self.user)
238
 
 
 
239
  cl.user_session.set("user", self.user)
240
  self.llm_tutor = LLMTutor(self.config, user=self.user)
241
  self.chain = self.llm_tutor.qa_bot(memory=memory)
@@ -273,6 +315,18 @@ class Chatbot:
273
  """
274
 
275
  chain = cl.user_session.get("chain")
 
 
 
 
 
 
 
 
 
 
 
 
276
  llm_settings = cl.user_session.get("llm_settings", {})
277
  view_sources = llm_settings.get("view_sources", False)
278
  stream = (llm_settings.get("stream_response", True)) or (
@@ -318,28 +372,47 @@ class Chatbot:
318
  res, answer, stream=stream, view_sources=view_sources
319
  )
320
 
321
- await cl.Message(content=answer_with_sources, elements=source_elements).send()
 
 
322
 
323
  async def on_chat_resume(self, thread: ThreadDict):
324
  steps = thread["steps"]
325
- conversation_pairs = []
 
326
 
327
  user_message = None
328
  k = self.config["llm_params"]["memory_window"]
329
  count = 0
330
 
331
- for step in steps:
332
- if step["type"] == "user_message":
333
- user_message = step["output"]
334
- elif step["type"] == "assistant_message" and user_message is not None:
335
- assistant_message = step["output"]
336
- conversation_pairs.append((user_message, assistant_message))
337
- user_message = None
338
- count += 1
339
- if count >= k:
340
- break
341
-
342
- await self.start(thread, memory=conversation_pairs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  @cl.oauth_callback
345
  def auth_callback(
 
60
  )
61
 
62
  chain = cl.user_session.get("chain")
63
+ print(list(chain.store.values()))
64
+ memory_list = cl.user_session.get(
65
+ "memory",
66
+ (
67
+ list(chain.store.values())[0].messages
68
+ if len(chain.store.values()) > 0
69
+ else []
70
+ ),
71
+ )
72
+ conversation_list = []
73
+ for message in memory_list:
74
+ # Convert to dictionary if possible
75
+ message_dict = message.to_dict() if hasattr(message, "to_dict") else message
76
+
77
+ # Check if the type attribute is present as a key or attribute
78
+ message_type = (
79
+ message_dict.get("type", None)
80
+ if isinstance(message_dict, dict)
81
+ else getattr(message, "type", None)
82
+ )
83
+
84
+ # Check if content is present as a key or attribute
85
+ message_content = (
86
+ message_dict.get("content", None)
87
+ if isinstance(message_dict, dict)
88
+ else getattr(message, "content", None)
89
+ )
90
+
91
+ if message_type in ["ai", "ai_message"]:
92
+ conversation_list.append(
93
+ {"type": "ai_message", "content": message_content}
94
+ )
95
+ elif message_type in ["human", "user_message"]:
96
+ conversation_list.append(
97
+ {"type": "user_message", "content": message_content}
98
+ )
99
+ else:
100
+ raise ValueError("Invalid message type")
101
+ print("\n\n\n")
102
+ print("history at setup_llm", conversation_list)
103
+ print("\n\n\n")
104
 
105
  old_config = copy.deepcopy(self.config)
106
  self.config["vectorstore"]["db_option"] = retriever_method
 
111
  self.llm_tutor.update_llm(
112
  old_config, self.config
113
  ) # update only attributes that are changed
114
+ self.chain = self.llm_tutor.qa_bot(memory=conversation_list)
115
 
116
  tags = [chat_profile, self.config["vectorstore"]["db_option"]]
117
 
 
262
  rename_dict = {"Chatbot": "AI Tutor"}
263
  return rename_dict.get(orig_author, orig_author)
264
 
265
+ async def start(self):
266
  """
267
  Start the chatbot, initialize settings widgets,
268
  and display and load previous conversation if chat logging is enabled.
 
276
  }
277
  print(self.user)
278
 
279
+ memory = cl.user_session.get("memory", [])
280
+
281
  cl.user_session.set("user", self.user)
282
  self.llm_tutor = LLMTutor(self.config, user=self.user)
283
  self.chain = self.llm_tutor.qa_bot(memory=memory)
 
315
  """
316
 
317
  chain = cl.user_session.get("chain")
318
+
319
+ print("\n\n\n")
320
+ print(
321
+ "session history",
322
+ chain.get_session_history(
323
+ self.user["user_id"],
324
+ self.user["session_id"],
325
+ self.config["llm_params"]["memory_window"],
326
+ ),
327
+ )
328
+ print("\n\n\n")
329
+
330
  llm_settings = cl.user_session.get("llm_settings", {})
331
  view_sources = llm_settings.get("view_sources", False)
332
  stream = (llm_settings.get("stream_response", True)) or (
 
372
  res, answer, stream=stream, view_sources=view_sources
373
  )
374
 
375
+ await cl.Message(
376
+ content=answer_with_sources, elements=source_elements, author=LLM
377
+ ).send()
378
 
379
  async def on_chat_resume(self, thread: ThreadDict):
380
  steps = thread["steps"]
381
+ # conversation_pairs = []
382
+ conversation_list = []
383
 
384
  user_message = None
385
  k = self.config["llm_params"]["memory_window"]
386
  count = 0
387
 
388
+ print(steps)
389
+
390
+ for step in reversed(steps):
391
+ print(step["type"])
392
+ if step["name"] not in [SYSTEM]:
393
+ if step["type"] == "user_message":
394
+ conversation_list.append(
395
+ {"type": "user_message", "content": step["output"]}
396
+ )
397
+ elif step["type"] == "assistant_message":
398
+ if step["name"] == LLM:
399
+ conversation_list.append(
400
+ {"type": "ai_message", "content": step["output"]}
401
+ )
402
+ else:
403
+ raise ValueError("Invalid message type")
404
+ count += 1
405
+ if count >= 2 * k: # 2 * k to account for both user and assistant messages
406
+ break
407
+
408
+ conversation_list = conversation_list[::-1]
409
+
410
+ print("\n\n\n")
411
+ print("history at on_chat_resume", conversation_list)
412
+ print(len(conversation_list))
413
+ print("\n\n\n")
414
+ cl.user_session.set("memory", conversation_list)
415
+ await self.start()
416
 
417
  @cl.oauth_callback
418
  def auth_callback(
code/modules/chat/helpers.py CHANGED
@@ -6,6 +6,11 @@ def get_sources(res, answer, stream=True, view_sources=False):
6
  source_elements = []
7
  source_dict = {} # Dictionary to store URL elements
8
 
 
 
 
 
 
9
  for idx, source in enumerate(res["context"]):
10
  source_metadata = source.metadata
11
  url = source_metadata.get("source", "N/A")
@@ -20,6 +25,9 @@ def get_sources(res, answer, stream=True, view_sources=False):
20
  source_type = source_metadata.get("source_type", "N/A")
21
 
22
  url_name = f"{url}_{page}"
 
 
 
23
  if url_name not in source_dict:
24
  source_dict[url_name] = {
25
  "text": source.page_content,
 
6
  source_elements = []
7
  source_dict = {} # Dictionary to store URL elements
8
 
9
+ print("\n\n\n")
10
+ print(res["context"])
11
+ print(len(res["context"]))
12
+ print("\n\n\n")
13
+
14
  for idx, source in enumerate(res["context"]):
15
  source_metadata = source.metadata
16
  url = source_metadata.get("source", "N/A")
 
25
  source_type = source_metadata.get("source_type", "N/A")
26
 
27
  url_name = f"{url}_{page}"
28
+ print("url")
29
+ print(url_name)
30
+ print("\n\n\n")
31
  if url_name not in source_dict:
32
  source_dict[url_name] = {
33
  "text": source.page_content,
code/modules/chat/langchain/langchain_rag.py CHANGED
@@ -211,7 +211,7 @@ class Langchain_RAG_V2(BaseRAG):
211
  res = self.rag_chain.stream(user_query, config)
212
  return res
213
 
214
- def add_history_from_list(self, history_list):
215
  """
216
  Add messages from a list to the chat history.
217
 
@@ -220,8 +220,22 @@ class Langchain_RAG_V2(BaseRAG):
220
  """
221
  history = ChatMessageHistory()
222
 
223
- for idx, message_pairs in enumerate(history_list):
224
- history.add_user_message(message_pairs[0])
225
- history.add_ai_message(message_pairs[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  return history
 
211
  res = self.rag_chain.stream(user_query, config)
212
  return res
213
 
214
+ def add_history_from_list(self, conversation_list):
215
  """
216
  Add messages from a list to the chat history.
217
 
 
220
  """
221
  history = ChatMessageHistory()
222
 
223
+ for idx, message in enumerate(conversation_list):
224
+ message_type = (
225
+ message.get("type", None)
226
+ if isinstance(message, dict)
227
+ else getattr(message, "type", None)
228
+ )
229
+
230
+ message_content = (
231
+ message.get("content", None)
232
+ if isinstance(message, dict)
233
+ else getattr(message, "content", None)
234
+ )
235
+
236
+ if message_type in ["human", "user_message"]:
237
+ history.add_user_message(message_content)
238
+ elif message_type in ["ai", "ai_message"]:
239
+ history.add_ai_message(message_content)
240
 
241
  return history
code/modules/chat/langchain/utils.py CHANGED
@@ -203,6 +203,7 @@ class CustomRunnableWithHistory(RunnableWithMessageHistory):
203
  print("Hist: ", hist)
204
  print("\n\n\n")
205
  messages = (await hist.aget_messages()).copy()
 
206
 
207
  if not self.history_messages_key:
208
  # return all messages
 
203
  print("Hist: ", hist)
204
  print("\n\n\n")
205
  messages = (await hist.aget_messages()).copy()
206
+ print("messages: ", messages)
207
 
208
  if not self.history_messages_key:
209
  # return all messages
code/modules/config/config.yml CHANGED
@@ -3,7 +3,7 @@ log_chunk_dir: '../storage/logs/chunks' # str
3
  device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
6
- load_from_HF: True # bool
7
  embedd_files: False # bool
8
  data_path: '../storage/data' # str
9
  url_file_path: '../storage/data/urls.txt' # str
 
3
  device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
6
+ load_from_HF: False # bool
7
  embedd_files: False # bool
8
  data_path: '../storage/data' # str
9
  url_file_path: '../storage/data/urls.txt' # str
code/modules/dataloader/data_loader.py CHANGED
@@ -228,11 +228,11 @@ class ChunkProcessor:
228
 
229
  page_num = doc.metadata.get("page", 0)
230
  file_data[page_num] = doc.page_content
231
- metadata = (
232
- addl_metadata.get(file_path, {})
233
- if metadata_source == "file"
234
- else {"source": file_path, "page": page_num}
235
- )
236
  file_metadata[page_num] = metadata
237
 
238
  if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
 
228
 
229
  page_num = doc.metadata.get("page", 0)
230
  file_data[page_num] = doc.page_content
231
+
232
+ # Create a new dictionary for metadata in each iteration
233
+ metadata = addl_metadata.get(file_path, {}).copy()
234
+ metadata["page"] = page_num
235
+ metadata["source"] = file_path
236
  file_metadata[page_num] = metadata
237
 
238
  if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]: