Spaces:
Build error
Build error
XThomasBU
commited on
Commit
·
9b7a7cf
1
Parent(s):
05f78f2
updates
Browse files- code/main.py +123 -65
- code/modules/chat/chat_model_loader.py +5 -1
- code/modules/chat/langgraph/langgraph_rag.py +0 -303
- code/modules/chat/llm_tutor.py +0 -9
- code/modules/chat_processor/base.py +0 -18
- code/modules/chat_processor/chat_processor.py +0 -55
- code/modules/chat_processor/literal_ai.py +5 -108
- code/modules/config/config.yml +3 -4
- code/modules/config/constants.py +3 -1
- code/modules/vectorstore/store_manager.py +18 -5
- code/modules/vectorstore/vectorstore.py +2 -2
- code/public/test.css +10 -0
code/main.py
CHANGED
@@ -1,14 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
import yaml
|
3 |
import os
|
4 |
from typing import Any, Dict, no_type_check
|
5 |
import chainlit as cl
|
6 |
from modules.chat.llm_tutor import LLMTutor
|
7 |
-
from modules.chat_processor.chat_processor import ChatProcessor
|
8 |
-
from modules.config.constants import LLAMA_PATH
|
9 |
from modules.chat.helpers import get_sources
|
10 |
import copy
|
11 |
from typing import Optional
|
|
|
12 |
|
13 |
USER_TIMEOUT = 60_000
|
14 |
SYSTEM = "System 🖥️"
|
@@ -18,12 +26,18 @@ YOU = "You 😃"
|
|
18 |
ERROR = "Error 🚫"
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
class Chatbot:
|
22 |
def __init__(self):
|
23 |
"""
|
24 |
Initialize the Chatbot class.
|
25 |
"""
|
26 |
self.config = self._load_config()
|
|
|
27 |
|
28 |
def _load_config(self):
|
29 |
"""
|
@@ -60,11 +74,9 @@ class Chatbot:
|
|
60 |
self.chain = self.llm_tutor.qa_bot(memory=memory)
|
61 |
|
62 |
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
63 |
-
self.chat_processor.config = self.config
|
64 |
|
65 |
cl.user_session.set("chain", self.chain)
|
66 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
67 |
-
cl.user_session.set("chat_processor", self.chat_processor)
|
68 |
|
69 |
@no_type_check
|
70 |
async def update_llm(self, new_settings: Dict[str, Any]):
|
@@ -91,14 +103,21 @@ class Chatbot:
|
|
91 |
cl.input_widget.Select(
|
92 |
id="chat_model",
|
93 |
label="Model Name (Default GPT-3)",
|
94 |
-
values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4"],
|
95 |
-
initial_index=[
|
|
|
|
|
|
|
|
|
|
|
96 |
),
|
97 |
cl.input_widget.Select(
|
98 |
id="retriever_method",
|
99 |
label="Retriever (Default FAISS)",
|
100 |
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
|
101 |
-
initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index(
|
|
|
|
|
102 |
),
|
103 |
cl.input_widget.Slider(
|
104 |
id="memory_window",
|
@@ -112,7 +131,7 @@ class Chatbot:
|
|
112 |
id="view_sources", label="View Sources", initial=False
|
113 |
),
|
114 |
cl.input_widget.Switch(
|
115 |
-
id="stream_response", label="Stream response", initial=
|
116 |
),
|
117 |
cl.input_widget.Select(
|
118 |
id="llm_style",
|
@@ -158,28 +177,37 @@ class Chatbot:
|
|
158 |
"""
|
159 |
Set starter messages for the chatbot.
|
160 |
"""
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
)
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
def rename(self, orig_author: str):
|
185 |
"""
|
@@ -194,44 +222,25 @@ class Chatbot:
|
|
194 |
rename_dict = {"Chatbot": "AI Tutor"}
|
195 |
return rename_dict.get(orig_author, orig_author)
|
196 |
|
197 |
-
async def start(self):
|
198 |
"""
|
199 |
Start the chatbot, initialize settings widgets,
|
200 |
and display and load previous conversation if chat logging is enabled.
|
201 |
"""
|
202 |
-
await cl.Message(content="Welcome back! Setting up your session...").send()
|
203 |
|
204 |
await self.make_llm_settings_widgets(self.config)
|
205 |
user = cl.user_session.get("user")
|
206 |
self.user = {
|
207 |
"user_id": user.identifier,
|
208 |
-
"session_id":
|
209 |
}
|
|
|
|
|
210 |
cl.user_session.set("user", self.user)
|
211 |
-
self.chat_processor = ChatProcessor(self.config, self.user)
|
212 |
self.llm_tutor = LLMTutor(self.config, user=self.user)
|
213 |
-
if self.config["chat_logging"]["log_chat"]:
|
214 |
-
# get previous conversation of the user
|
215 |
-
memory = self.chat_processor.processor.prev_conv
|
216 |
-
if len(self.chat_processor.processor.prev_conv) > 0:
|
217 |
-
for idx, conv in enumerate(self.chat_processor.processor.prev_conv):
|
218 |
-
await cl.Message(
|
219 |
-
author="User", content=conv[0], type="user_message"
|
220 |
-
).send()
|
221 |
-
await cl.Message(author="AI Tutor", content=conv[1]).send()
|
222 |
-
else:
|
223 |
-
memory = []
|
224 |
self.chain = self.llm_tutor.qa_bot(memory=memory)
|
225 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
226 |
cl.user_session.set("chain", self.chain)
|
227 |
-
cl.user_session.set("chat_processor", self.chat_processor)
|
228 |
-
|
229 |
-
async def on_chat_end(self):
|
230 |
-
"""
|
231 |
-
Handle the end of the chat session by sending a goodbye message.
|
232 |
-
# TODO: Not used as of now - useful when the implementation for the conversation limiting is implemented
|
233 |
-
"""
|
234 |
-
await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
|
235 |
|
236 |
async def stream_response(self, response):
|
237 |
"""
|
@@ -245,8 +254,8 @@ class Chatbot:
|
|
245 |
|
246 |
output = {}
|
247 |
for chunk in response:
|
248 |
-
if
|
249 |
-
await msg.stream_token(chunk[
|
250 |
|
251 |
for key in chunk:
|
252 |
if key not in output:
|
@@ -262,39 +271,88 @@ class Chatbot:
|
|
262 |
Args:
|
263 |
message: The incoming chat message.
|
264 |
"""
|
|
|
265 |
chain = cl.user_session.get("chain")
|
266 |
llm_settings = cl.user_session.get("llm_settings", {})
|
267 |
view_sources = llm_settings.get("view_sources", False)
|
268 |
-
stream = (llm_settings.get("stream_response", True)) or (
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
if stream:
|
|
|
274 |
res = await self.stream_response(res)
|
|
|
|
|
275 |
|
276 |
answer = res.get("answer", res.get("result"))
|
277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
answer_with_sources, source_elements, sources_dict = get_sources(
|
279 |
res, answer, stream=stream, view_sources=view_sources
|
280 |
)
|
281 |
-
processor._process(message.content, answer, sources_dict)
|
282 |
|
283 |
await cl.Message(content=answer_with_sources, elements=source_elements).send()
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
@cl.oauth_callback
|
286 |
def auth_callback(
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
) -> Optional[cl.User]:
|
292 |
return default_user
|
293 |
|
|
|
294 |
chatbot = Chatbot()
|
295 |
cl.set_starters(chatbot.set_starters)
|
296 |
cl.author_rename(chatbot.rename)
|
297 |
cl.on_chat_start(chatbot.start)
|
298 |
-
cl.
|
299 |
cl.on_message(chatbot.main)
|
300 |
cl.on_settings_update(chatbot.update_llm)
|
|
|
1 |
+
import chainlit.data as cl_data
|
2 |
+
|
3 |
+
from modules.config.constants import (
|
4 |
+
LLAMA_PATH,
|
5 |
+
LITERAL_API_KEY_LOGGING,
|
6 |
+
LITERAL_API_URL,
|
7 |
+
)
|
8 |
+
from modules.chat_processor.literal_ai import CustomLiteralDataLayer
|
9 |
+
|
10 |
import json
|
11 |
import yaml
|
12 |
import os
|
13 |
from typing import Any, Dict, no_type_check
|
14 |
import chainlit as cl
|
15 |
from modules.chat.llm_tutor import LLMTutor
|
|
|
|
|
16 |
from modules.chat.helpers import get_sources
|
17 |
import copy
|
18 |
from typing import Optional
|
19 |
+
from chainlit.types import ThreadDict
|
20 |
|
21 |
USER_TIMEOUT = 60_000
|
22 |
SYSTEM = "System 🖥️"
|
|
|
26 |
ERROR = "Error 🚫"
|
27 |
|
28 |
|
29 |
+
cl_data._data_layer = CustomLiteralDataLayer(
|
30 |
+
api_key=LITERAL_API_KEY_LOGGING, server=LITERAL_API_URL
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
class Chatbot:
|
35 |
def __init__(self):
|
36 |
"""
|
37 |
Initialize the Chatbot class.
|
38 |
"""
|
39 |
self.config = self._load_config()
|
40 |
+
self.literal_client = cl_data._data_layer.client
|
41 |
|
42 |
def _load_config(self):
|
43 |
"""
|
|
|
74 |
self.chain = self.llm_tutor.qa_bot(memory=memory)
|
75 |
|
76 |
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
|
|
77 |
|
78 |
cl.user_session.set("chain", self.chain)
|
79 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
|
|
80 |
|
81 |
@no_type_check
|
82 |
async def update_llm(self, new_settings: Dict[str, Any]):
|
|
|
103 |
cl.input_widget.Select(
|
104 |
id="chat_model",
|
105 |
label="Model Name (Default GPT-3)",
|
106 |
+
values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"],
|
107 |
+
initial_index=[
|
108 |
+
"local_llm",
|
109 |
+
"gpt-3.5-turbo-1106",
|
110 |
+
"gpt-4",
|
111 |
+
"gpt-4o-mini",
|
112 |
+
].index(config["llm_params"]["llm_loader"]),
|
113 |
),
|
114 |
cl.input_widget.Select(
|
115 |
id="retriever_method",
|
116 |
label="Retriever (Default FAISS)",
|
117 |
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
|
118 |
+
initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index(
|
119 |
+
config["vectorstore"]["db_option"]
|
120 |
+
),
|
121 |
),
|
122 |
cl.input_widget.Slider(
|
123 |
id="memory_window",
|
|
|
131 |
id="view_sources", label="View Sources", initial=False
|
132 |
),
|
133 |
cl.input_widget.Switch(
|
134 |
+
id="stream_response", label="Stream response", initial=False
|
135 |
),
|
136 |
cl.input_widget.Select(
|
137 |
id="llm_style",
|
|
|
177 |
"""
|
178 |
Set starter messages for the chatbot.
|
179 |
"""
|
180 |
+
# Return Starters only if the chat is new
|
181 |
+
|
182 |
+
try:
|
183 |
+
thread = cl_data._data_layer.get_thread(
|
184 |
+
cl.context.session.thread_id
|
185 |
+
) # see if the thread has any steps
|
186 |
+
if thread.steps or len(thread.steps) > 0:
|
187 |
+
return None
|
188 |
+
except:
|
189 |
+
return [
|
190 |
+
cl.Starter(
|
191 |
+
label="recording on CNNs?",
|
192 |
+
message="Where can I find the recording for the lecture on Transformers?",
|
193 |
+
icon="/public/adv-screen-recorder-svgrepo-com.svg",
|
194 |
+
),
|
195 |
+
cl.Starter(
|
196 |
+
label="where's the slides?",
|
197 |
+
message="When are the lectures? I can't find the schedule.",
|
198 |
+
icon="/public/alarmy-svgrepo-com.svg",
|
199 |
+
),
|
200 |
+
cl.Starter(
|
201 |
+
label="Due Date?",
|
202 |
+
message="When is the final project due?",
|
203 |
+
icon="/public/calendar-samsung-17-svgrepo-com.svg",
|
204 |
+
),
|
205 |
+
cl.Starter(
|
206 |
+
label="Explain backprop.",
|
207 |
+
message="I didn't understand the math behind backprop, could you explain it?",
|
208 |
+
icon="/public/acastusphoton-svgrepo-com.svg",
|
209 |
+
),
|
210 |
+
]
|
211 |
|
212 |
def rename(self, orig_author: str):
|
213 |
"""
|
|
|
222 |
rename_dict = {"Chatbot": "AI Tutor"}
|
223 |
return rename_dict.get(orig_author, orig_author)
|
224 |
|
225 |
+
async def start(self, thread=None, memory=[]):
|
226 |
"""
|
227 |
Start the chatbot, initialize settings widgets,
|
228 |
and display and load previous conversation if chat logging is enabled.
|
229 |
"""
|
|
|
230 |
|
231 |
await self.make_llm_settings_widgets(self.config)
|
232 |
user = cl.user_session.get("user")
|
233 |
self.user = {
|
234 |
"user_id": user.identifier,
|
235 |
+
"session_id": cl.context.session.thread_id,
|
236 |
}
|
237 |
+
print(self.user)
|
238 |
+
|
239 |
cl.user_session.set("user", self.user)
|
|
|
240 |
self.llm_tutor = LLMTutor(self.config, user=self.user)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
self.chain = self.llm_tutor.qa_bot(memory=memory)
|
242 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
243 |
cl.user_session.set("chain", self.chain)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
async def stream_response(self, response):
|
246 |
"""
|
|
|
254 |
|
255 |
output = {}
|
256 |
for chunk in response:
|
257 |
+
if "answer" in chunk:
|
258 |
+
await msg.stream_token(chunk["answer"])
|
259 |
|
260 |
for key in chunk:
|
261 |
if key not in output:
|
|
|
271 |
Args:
|
272 |
message: The incoming chat message.
|
273 |
"""
|
274 |
+
|
275 |
chain = cl.user_session.get("chain")
|
276 |
llm_settings = cl.user_session.get("llm_settings", {})
|
277 |
view_sources = llm_settings.get("view_sources", False)
|
278 |
+
stream = (llm_settings.get("stream_response", True)) or (
|
279 |
+
not self.config["llm_params"]["stream"]
|
280 |
+
)
|
281 |
+
user_query_dict = {"input": message.content}
|
282 |
+
# Define the base configuration
|
283 |
+
chain_config = {
|
284 |
+
"configurable": {
|
285 |
+
"user_id": self.user["user_id"],
|
286 |
+
"conversation_id": self.user["session_id"],
|
287 |
+
"memory_window": self.config["llm_params"]["memory_window"],
|
288 |
+
}
|
289 |
+
}
|
290 |
|
291 |
if stream:
|
292 |
+
res = chain.stream(user_query=user_query_dict, config=chain_config)
|
293 |
res = await self.stream_response(res)
|
294 |
+
else:
|
295 |
+
res = chain.invoke(user_query=user_query_dict, config=chain_config)
|
296 |
|
297 |
answer = res.get("answer", res.get("result"))
|
298 |
|
299 |
+
with cl_data._data_layer.client.step(
|
300 |
+
type="retrieval",
|
301 |
+
name="RAG",
|
302 |
+
thread_id=cl.context.session.thread_id,
|
303 |
+
# tags=self.tags,
|
304 |
+
) as step:
|
305 |
+
step.input = {"question": user_query_dict["input"]}
|
306 |
+
step.output = {
|
307 |
+
"chat_history": res.get("chat_history"),
|
308 |
+
"context": res.get("context"),
|
309 |
+
"answer": answer,
|
310 |
+
"rephrase_prompt": res.get("rephrase_prompt"),
|
311 |
+
"qa_prompt": res.get("qa_prompt"),
|
312 |
+
}
|
313 |
+
step.metadata = self.config
|
314 |
+
|
315 |
answer_with_sources, source_elements, sources_dict = get_sources(
|
316 |
res, answer, stream=stream, view_sources=view_sources
|
317 |
)
|
|
|
318 |
|
319 |
await cl.Message(content=answer_with_sources, elements=source_elements).send()
|
320 |
|
321 |
+
async def on_chat_resume(self, thread: ThreadDict):
|
322 |
+
steps = thread["steps"]
|
323 |
+
conversation_pairs = []
|
324 |
+
|
325 |
+
user_message = None
|
326 |
+
k = self.config["llm_params"]["memory_window"]
|
327 |
+
count = 0
|
328 |
+
|
329 |
+
for step in steps:
|
330 |
+
if step["type"] == "user_message":
|
331 |
+
user_message = step["output"]
|
332 |
+
elif step["type"] == "assistant_message" and user_message is not None:
|
333 |
+
assistant_message = step["output"]
|
334 |
+
conversation_pairs.append((user_message, assistant_message))
|
335 |
+
user_message = None
|
336 |
+
count += 1
|
337 |
+
if count >= k:
|
338 |
+
break
|
339 |
+
|
340 |
+
await self.start(thread, memory=conversation_pairs)
|
341 |
+
|
342 |
@cl.oauth_callback
|
343 |
def auth_callback(
|
344 |
+
provider_id: str,
|
345 |
+
token: str,
|
346 |
+
raw_user_data: Dict[str, str],
|
347 |
+
default_user: cl.User,
|
348 |
) -> Optional[cl.User]:
|
349 |
return default_user
|
350 |
|
351 |
+
|
352 |
chatbot = Chatbot()
|
353 |
cl.set_starters(chatbot.set_starters)
|
354 |
cl.author_rename(chatbot.rename)
|
355 |
cl.on_chat_start(chatbot.start)
|
356 |
+
cl.on_chat_resume(chatbot.on_chat_resume)
|
357 |
cl.on_message(chatbot.main)
|
358 |
cl.on_settings_update(chatbot.update_llm)
|
code/modules/chat/chat_model_loader.py
CHANGED
@@ -16,7 +16,11 @@ class ChatModelLoader:
|
|
16 |
self.huggingface_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
17 |
|
18 |
def load_chat_model(self):
|
19 |
-
if self.config["llm_params"]["llm_loader"] in [
|
|
|
|
|
|
|
|
|
20 |
llm = ChatOpenAI(model_name=self.config["llm_params"]["llm_loader"])
|
21 |
elif self.config["llm_params"]["llm_loader"] == "local_llm":
|
22 |
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
|
|
|
16 |
self.huggingface_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
17 |
|
18 |
def load_chat_model(self):
|
19 |
+
if self.config["llm_params"]["llm_loader"] in [
|
20 |
+
"gpt-3.5-turbo-1106",
|
21 |
+
"gpt-4",
|
22 |
+
"gpt-4o-mini",
|
23 |
+
]:
|
24 |
llm = ChatOpenAI(model_name=self.config["llm_params"]["llm_loader"])
|
25 |
elif self.config["llm_params"]["llm_loader"] == "local_llm":
|
26 |
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
|
code/modules/chat/langgraph/langgraph_rag.py
DELETED
@@ -1,303 +0,0 @@
|
|
1 |
-
# Adapted from https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_crag.ipynb?ref=blog.langchain.dev
|
2 |
-
|
3 |
-
from typing import List
|
4 |
-
|
5 |
-
from typing_extensions import TypedDict
|
6 |
-
from langgraph.graph import END, StateGraph, START
|
7 |
-
from modules.chat.base import BaseRAG
|
8 |
-
from langchain.memory import ChatMessageHistory
|
9 |
-
from langchain_core.prompts import ChatPromptTemplate
|
10 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
11 |
-
from langchain_openai import ChatOpenAI
|
12 |
-
from langchain_core.output_parsers import StrOutputParser
|
13 |
-
from langchain_core.prompts import ChatPromptTemplate
|
14 |
-
|
15 |
-
|
16 |
-
class GradeDocuments(BaseModel):
|
17 |
-
"""Binary score for relevance check on retrieved documents."""
|
18 |
-
|
19 |
-
binary_score: str = Field(
|
20 |
-
description="Documents are relevant to the question, 'yes' or 'no'"
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
class GraphState(TypedDict):
|
25 |
-
"""
|
26 |
-
Represents the state of our graph.
|
27 |
-
|
28 |
-
Attributes:
|
29 |
-
question: question
|
30 |
-
generation: LLM generation
|
31 |
-
documents: list of documents
|
32 |
-
"""
|
33 |
-
|
34 |
-
question: str
|
35 |
-
generation: str
|
36 |
-
documents: List[str]
|
37 |
-
|
38 |
-
|
39 |
-
class Langgraph_RAG(BaseRAG):
|
40 |
-
def __init__(self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str):
|
41 |
-
"""
|
42 |
-
Initialize the Langgraph_RAG class.
|
43 |
-
|
44 |
-
Args:
|
45 |
-
llm (LanguageModelLike): The language model instance.
|
46 |
-
memory (BaseChatMessageHistory): The chat message history instance.
|
47 |
-
retriever (BaseRetriever): The retriever instance.
|
48 |
-
qa_prompt (str): The QA prompt string.
|
49 |
-
rephrase_prompt (str): The rephrase prompt string.
|
50 |
-
"""
|
51 |
-
self.llm = llm
|
52 |
-
self.structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
53 |
-
self.memory = self.add_history_from_list(memory)
|
54 |
-
self.retriever = retriever
|
55 |
-
self.qa_prompt = (
|
56 |
-
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
|
57 |
-
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
58 |
-
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
59 |
-
"Context:\n{context}\n\n"
|
60 |
-
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
61 |
-
"Student: {question}\n"
|
62 |
-
"AI Tutor:"
|
63 |
-
)
|
64 |
-
self.rephrase_prompt = rephrase_prompt
|
65 |
-
self.store = {}
|
66 |
-
|
67 |
-
## Fix below ##
|
68 |
-
|
69 |
-
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
70 |
-
If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
|
71 |
-
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
72 |
-
grade_prompt = ChatPromptTemplate.from_messages(
|
73 |
-
[
|
74 |
-
("system", system),
|
75 |
-
(
|
76 |
-
"human",
|
77 |
-
"Retrieved document: \n\n {document} \n\n User question: {question}",
|
78 |
-
),
|
79 |
-
]
|
80 |
-
)
|
81 |
-
|
82 |
-
self.retrieval_grader = grade_prompt | self.structured_llm_grader
|
83 |
-
|
84 |
-
system = """You a question re-writer that converts an input question to a better version that is optimized \n
|
85 |
-
for web search. Look at the input and try to reason about the underlying semantic intent / meaning."""
|
86 |
-
re_write_prompt = ChatPromptTemplate.from_messages(
|
87 |
-
[
|
88 |
-
("system", system),
|
89 |
-
(
|
90 |
-
"human",
|
91 |
-
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
|
92 |
-
),
|
93 |
-
]
|
94 |
-
)
|
95 |
-
|
96 |
-
self.question_rewriter = re_write_prompt | self.llm | StrOutputParser()
|
97 |
-
|
98 |
-
# Generate
|
99 |
-
self.qa_prompt_template = ChatPromptTemplate.from_template(self.qa_prompt)
|
100 |
-
self.rag_chain = self.qa_prompt_template | self.llm | StrOutputParser()
|
101 |
-
|
102 |
-
###
|
103 |
-
|
104 |
-
# build the agentic graph
|
105 |
-
self.app = self.create_agentic_graph()
|
106 |
-
|
107 |
-
def retrieve(self, state):
|
108 |
-
"""
|
109 |
-
Retrieve documents
|
110 |
-
|
111 |
-
Args:
|
112 |
-
state (dict): The current graph state
|
113 |
-
|
114 |
-
Returns:
|
115 |
-
state (dict): New key added to state, documents, that contains retrieved documents
|
116 |
-
"""
|
117 |
-
print("---RETRIEVE---")
|
118 |
-
question = state["question"]
|
119 |
-
|
120 |
-
# Retrieval
|
121 |
-
documents = self.retriever.get_relevant_documents(question)
|
122 |
-
return {"documents": documents, "question": question}
|
123 |
-
|
124 |
-
def generate(self, state):
|
125 |
-
"""
|
126 |
-
Generate answer
|
127 |
-
|
128 |
-
Args:
|
129 |
-
state (dict): The current graph state
|
130 |
-
|
131 |
-
Returns:
|
132 |
-
state (dict): New key added to state, generation, that contains LLM generation
|
133 |
-
"""
|
134 |
-
print("---GENERATE---")
|
135 |
-
question = state["question"]
|
136 |
-
documents = state["documents"]
|
137 |
-
|
138 |
-
# RAG generation
|
139 |
-
generation = self.rag_chain.invoke({"context": documents, "question": question})
|
140 |
-
return {"documents": documents, "question": question, "generation": generation}
|
141 |
-
|
142 |
-
def transform_query(self, state):
|
143 |
-
"""
|
144 |
-
Transform the query to produce a better question.
|
145 |
-
|
146 |
-
Args:
|
147 |
-
state (dict): The current graph state
|
148 |
-
|
149 |
-
Returns:
|
150 |
-
state (dict): Updates question key with a re-phrased question
|
151 |
-
"""
|
152 |
-
|
153 |
-
print("---TRANSFORM QUERY---")
|
154 |
-
question = state["question"]
|
155 |
-
documents = state["documents"]
|
156 |
-
|
157 |
-
# Re-write question
|
158 |
-
better_question = self.question_rewriter.invoke({"question": question})
|
159 |
-
return {"documents": documents, "question": better_question}
|
160 |
-
|
161 |
-
def grade_documents(self, state):
|
162 |
-
"""
|
163 |
-
Determines whether the retrieved documents are relevant to the question.
|
164 |
-
|
165 |
-
Args:
|
166 |
-
state (dict): The current graph state
|
167 |
-
|
168 |
-
Returns:
|
169 |
-
state (dict): Updates documents key with only filtered relevant documents
|
170 |
-
"""
|
171 |
-
|
172 |
-
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
173 |
-
question = state["question"]
|
174 |
-
documents = state["documents"]
|
175 |
-
|
176 |
-
# Score each doc
|
177 |
-
filtered_docs = []
|
178 |
-
web_search = "No"
|
179 |
-
for d in documents:
|
180 |
-
score = self.retrieval_grader.invoke(
|
181 |
-
{"question": question, "document": d.page_content}
|
182 |
-
)
|
183 |
-
grade = score.binary_score
|
184 |
-
if grade == "yes":
|
185 |
-
print("---GRADE: DOCUMENT RELEVANT---")
|
186 |
-
filtered_docs.append(d)
|
187 |
-
else:
|
188 |
-
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
189 |
-
web_search = "Yes"
|
190 |
-
continue
|
191 |
-
return {
|
192 |
-
"documents": filtered_docs,
|
193 |
-
"question": question,
|
194 |
-
"web_search": web_search,
|
195 |
-
}
|
196 |
-
|
197 |
-
def decide_to_generate(self, state):
|
198 |
-
"""
|
199 |
-
Determines whether to generate an answer, or re-generate a question.
|
200 |
-
|
201 |
-
Args:
|
202 |
-
state (dict): The current graph state
|
203 |
-
|
204 |
-
Returns:
|
205 |
-
str: Binary decision for next node to call
|
206 |
-
"""
|
207 |
-
|
208 |
-
print("---ASSESS GRADED DOCUMENTS---")
|
209 |
-
state["question"]
|
210 |
-
web_search = state["web_search"]
|
211 |
-
state["documents"]
|
212 |
-
|
213 |
-
if web_search == "Yes":
|
214 |
-
# All documents have been filtered check_relevance
|
215 |
-
# We will re-generate a new query
|
216 |
-
print(
|
217 |
-
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
218 |
-
)
|
219 |
-
return "transform_query"
|
220 |
-
else:
|
221 |
-
# We have relevant documents, so generate answer
|
222 |
-
print("---DECISION: GENERATE---")
|
223 |
-
return "generate"
|
224 |
-
|
225 |
-
def create_agentic_graph(self):
|
226 |
-
"""
|
227 |
-
Create an agentic graph to answer questions.
|
228 |
-
|
229 |
-
Returns:
|
230 |
-
dict: Agentic graph
|
231 |
-
"""
|
232 |
-
self.workflow = StateGraph(GraphState)
|
233 |
-
self.workflow.add_node("retrieve", self.retrieve)
|
234 |
-
self.workflow.add_node(
|
235 |
-
"grade_documents", self.grade_documents
|
236 |
-
) # grade documents
|
237 |
-
self.workflow.add_node("generate", self.generate) # generatae
|
238 |
-
self.workflow.add_node(
|
239 |
-
"transform_query", self.transform_query
|
240 |
-
) # transform_query
|
241 |
-
|
242 |
-
# build the graph
|
243 |
-
self.workflow.add_edge(START, "retrieve")
|
244 |
-
self.workflow.add_edge("retrieve", "grade_documents")
|
245 |
-
self.workflow.add_conditional_edges(
|
246 |
-
"grade_documents",
|
247 |
-
self.decide_to_generate,
|
248 |
-
{
|
249 |
-
"transform_query": "transform_query",
|
250 |
-
"generate": "generate",
|
251 |
-
},
|
252 |
-
)
|
253 |
-
|
254 |
-
self.workflow.add_edge("transform_query", "generate")
|
255 |
-
self.workflow.add_edge("generate", END)
|
256 |
-
|
257 |
-
# Compile
|
258 |
-
app = self.workflow.compile()
|
259 |
-
return app
|
260 |
-
|
261 |
-
def invoke(self, user_query, config):
|
262 |
-
"""
|
263 |
-
Invoke the chain.
|
264 |
-
|
265 |
-
Args:
|
266 |
-
kwargs: The input variables.
|
267 |
-
|
268 |
-
Returns:
|
269 |
-
dict: The output variables.
|
270 |
-
"""
|
271 |
-
|
272 |
-
inputs = {
|
273 |
-
"question": user_query["input"],
|
274 |
-
}
|
275 |
-
|
276 |
-
for output in self.app.stream(inputs):
|
277 |
-
for key, value in output.items():
|
278 |
-
# Node
|
279 |
-
print(f"Node {key} returned: {value}")
|
280 |
-
print("\n\n")
|
281 |
-
|
282 |
-
print(value["generation"])
|
283 |
-
|
284 |
-
# rename generation to answer
|
285 |
-
value["answer"] = value.pop("generation")
|
286 |
-
value["context"] = value.pop("documents")
|
287 |
-
|
288 |
-
return value
|
289 |
-
|
290 |
-
def add_history_from_list(self, history_list):
|
291 |
-
"""
|
292 |
-
Add messages from a list to the chat history.
|
293 |
-
|
294 |
-
Args:
|
295 |
-
messages (list): The list of messages to add.
|
296 |
-
"""
|
297 |
-
history = ChatMessageHistory()
|
298 |
-
|
299 |
-
for idx, message_pairs in enumerate(history_list):
|
300 |
-
history.add_user_message(message_pairs[0])
|
301 |
-
history.add_ai_message(message_pairs[1])
|
302 |
-
|
303 |
-
return history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/chat/llm_tutor.py
CHANGED
@@ -3,7 +3,6 @@ from modules.chat.chat_model_loader import ChatModelLoader
|
|
3 |
from modules.vectorstore.store_manager import VectorStoreManager
|
4 |
from modules.retriever.retriever import Retriever
|
5 |
from modules.chat.langchain.langchain_rag import Langchain_RAG
|
6 |
-
from modules.chat.langgraph.langgraph_rag import Langgraph_RAG
|
7 |
|
8 |
|
9 |
class LLMTutor:
|
@@ -111,14 +110,6 @@ class LLMTutor:
|
|
111 |
qa_prompt=qa_prompt,
|
112 |
rephrase_prompt=rephrase_prompt,
|
113 |
)
|
114 |
-
elif self.config["llm_params"]["llm_arch"] == "langgraph_agentic":
|
115 |
-
self.qa_chain = Langgraph_RAG(
|
116 |
-
llm=llm,
|
117 |
-
memory=memory,
|
118 |
-
retriever=retriever,
|
119 |
-
qa_prompt=qa_prompt,
|
120 |
-
rephrase_prompt=rephrase_prompt,
|
121 |
-
)
|
122 |
else:
|
123 |
raise ValueError(
|
124 |
f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}"
|
|
|
3 |
from modules.vectorstore.store_manager import VectorStoreManager
|
4 |
from modules.retriever.retriever import Retriever
|
5 |
from modules.chat.langchain.langchain_rag import Langchain_RAG
|
|
|
6 |
|
7 |
|
8 |
class LLMTutor:
|
|
|
110 |
qa_prompt=qa_prompt,
|
111 |
rephrase_prompt=rephrase_prompt,
|
112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
else:
|
114 |
raise ValueError(
|
115 |
f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}"
|
code/modules/chat_processor/base.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
# Template for chat processor classes
|
2 |
-
|
3 |
-
|
4 |
-
class ChatProcessorBase:
|
5 |
-
def __init__(self):
|
6 |
-
pass
|
7 |
-
|
8 |
-
def process(self, message):
|
9 |
-
"""
|
10 |
-
Processes and Logs the message
|
11 |
-
"""
|
12 |
-
raise NotImplementedError("process method not implemented")
|
13 |
-
|
14 |
-
async def rag(self, user_query: dict, config: dict, chain):
|
15 |
-
"""
|
16 |
-
Retrieves the response from the chain
|
17 |
-
"""
|
18 |
-
raise NotImplementedError("rag method not implemented")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/chat_processor/chat_processor.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
from modules.chat_processor.literal_ai import LiteralaiChatProcessor
|
2 |
-
|
3 |
-
|
4 |
-
class ChatProcessor:
|
5 |
-
def __init__(self, config, user, tags=None):
|
6 |
-
self.config = config
|
7 |
-
self.chat_processor_type = self.config["chat_logging"]["platform"]
|
8 |
-
self.logging = self.config["chat_logging"]["log_chat"]
|
9 |
-
self.user = user
|
10 |
-
if tags is None:
|
11 |
-
self.tags = self._create_tags()
|
12 |
-
else:
|
13 |
-
self.tags = tags
|
14 |
-
if self.logging:
|
15 |
-
self._init_processor()
|
16 |
-
|
17 |
-
def _create_tags(self):
|
18 |
-
tags = []
|
19 |
-
tags.append(self.config["vectorstore"]["db_option"])
|
20 |
-
return tags
|
21 |
-
|
22 |
-
def _init_processor(self):
|
23 |
-
if self.chat_processor_type == "literalai":
|
24 |
-
self.processor = LiteralaiChatProcessor(self.user, self.tags)
|
25 |
-
else:
|
26 |
-
raise ValueError(
|
27 |
-
f"Chat processor type {self.chat_processor_type} not supported"
|
28 |
-
)
|
29 |
-
|
30 |
-
def _process(self, user_message, assistant_message, source_dict):
|
31 |
-
if self.logging:
|
32 |
-
return self.processor.process(user_message, assistant_message, source_dict)
|
33 |
-
else:
|
34 |
-
pass
|
35 |
-
|
36 |
-
async def rag(self, user_query: str, chain, stream):
|
37 |
-
user_query_dict = {"input": user_query}
|
38 |
-
# Define the base configuration
|
39 |
-
config = {
|
40 |
-
"configurable": {
|
41 |
-
"user_id": self.user["user_id"],
|
42 |
-
"conversation_id": self.user["session_id"],
|
43 |
-
"memory_window": self.config["llm_params"]["memory_window"],
|
44 |
-
}
|
45 |
-
}
|
46 |
-
|
47 |
-
# Process the user query using the appropriate method
|
48 |
-
if self.logging:
|
49 |
-
return await self.processor.rag(
|
50 |
-
user_query=user_query_dict, config=config, chain=chain
|
51 |
-
)
|
52 |
-
else:
|
53 |
-
if stream:
|
54 |
-
return chain.stream(user_query=user_query_dict, config=config)
|
55 |
-
return chain.invoke(user_query=user_query_dict, config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/chat_processor/literal_ai.py
CHANGED
@@ -1,110 +1,7 @@
|
|
1 |
-
from
|
2 |
-
from literalai.api import LiteralAPI
|
3 |
-
from literalai.filter import Filter as ThreadFilter
|
4 |
|
5 |
-
import os
|
6 |
-
from .base import ChatProcessorBase
|
7 |
|
8 |
-
|
9 |
-
class
|
10 |
-
def __init__(self,
|
11 |
-
super().__init__()
|
12 |
-
self.user = user
|
13 |
-
self.tags = tags
|
14 |
-
self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
|
15 |
-
self.literal_api = LiteralAPI(
|
16 |
-
api_key=os.getenv("LITERAL_API_KEY"), url=os.getenv("LITERAL_API_URL")
|
17 |
-
)
|
18 |
-
self.literal_client.reset_context()
|
19 |
-
self.user_info = self._fetch_userinfo()
|
20 |
-
self.user_thread = self._fetch_user_threads()
|
21 |
-
if len(self.user_thread["data"]) == 0:
|
22 |
-
self.thread = self._create_user_thread()
|
23 |
-
else:
|
24 |
-
self.thread = self._get_user_thread()
|
25 |
-
self.thread_id = self.thread["id"]
|
26 |
-
|
27 |
-
self.prev_conv = self._get_prev_k_conversations()
|
28 |
-
|
29 |
-
def _get_user_thread(self):
|
30 |
-
thread = self.literal_api.get_thread(id=self.user_thread["data"][0]["id"])
|
31 |
-
return thread.to_dict()
|
32 |
-
|
33 |
-
def _create_user_thread(self):
|
34 |
-
thread = self.literal_api.create_thread(
|
35 |
-
name=f"{self.user_info['identifier']}",
|
36 |
-
participant_id=self.user_info["metadata"]["id"],
|
37 |
-
environment="dev",
|
38 |
-
)
|
39 |
-
|
40 |
-
return thread.to_dict()
|
41 |
-
|
42 |
-
def _get_prev_k_conversations(self, k=3):
|
43 |
-
|
44 |
-
steps = self.thread["steps"]
|
45 |
-
conversation_pairs = []
|
46 |
-
count = 0
|
47 |
-
for i in range(len(steps) - 1, 0, -1):
|
48 |
-
if (
|
49 |
-
steps[i - 1]["type"] == "user_message"
|
50 |
-
and steps[i]["type"] == "assistant_message"
|
51 |
-
):
|
52 |
-
user_message = steps[i - 1]["output"]["content"]
|
53 |
-
assistant_message = steps[i]["output"]["content"]
|
54 |
-
conversation_pairs.append((user_message, assistant_message))
|
55 |
-
|
56 |
-
count += 1
|
57 |
-
if count >= k:
|
58 |
-
break
|
59 |
-
|
60 |
-
# Return the last k conversation pairs, reversed to maintain chronological order
|
61 |
-
return conversation_pairs[::-1]
|
62 |
-
|
63 |
-
def _fetch_user_threads(self):
|
64 |
-
filters = filters = [
|
65 |
-
{
|
66 |
-
"operator": "eq",
|
67 |
-
"field": "participantId",
|
68 |
-
"value": self.user_info["metadata"]["id"],
|
69 |
-
}
|
70 |
-
]
|
71 |
-
user_threads = self.literal_api.get_threads(filters=filters)
|
72 |
-
return user_threads.to_dict()
|
73 |
-
|
74 |
-
def _fetch_userinfo(self):
|
75 |
-
user_info = self.literal_api.get_or_create_user(
|
76 |
-
identifier=self.user["user_id"]
|
77 |
-
).to_dict()
|
78 |
-
# TODO: Have to do this more elegantly
|
79 |
-
# update metadata with unique id for now
|
80 |
-
# (literalai seems to not return the unique id as of now,
|
81 |
-
# so have to explicitly update it in the metadata)
|
82 |
-
user_info = self.literal_api.update_user(
|
83 |
-
id=user_info["id"],
|
84 |
-
metadata={
|
85 |
-
"id": user_info["id"],
|
86 |
-
},
|
87 |
-
).to_dict()
|
88 |
-
return user_info
|
89 |
-
|
90 |
-
def process(self, user_message, assistant_message, source_dict):
|
91 |
-
with self.literal_client.thread(thread_id=self.thread_id) as thread:
|
92 |
-
self.literal_client.message(
|
93 |
-
content=user_message,
|
94 |
-
type="user_message",
|
95 |
-
name="User",
|
96 |
-
)
|
97 |
-
self.literal_client.message(
|
98 |
-
content=assistant_message,
|
99 |
-
type="assistant_message",
|
100 |
-
name="AI_Tutor",
|
101 |
-
)
|
102 |
-
|
103 |
-
async def rag(self, user_query: dict, config: dict, chain):
|
104 |
-
with self.literal_client.step(
|
105 |
-
type="retrieval", name="RAG", thread_id=self.thread_id, tags=self.tags
|
106 |
-
) as step:
|
107 |
-
step.input = {"question": user_query["input"]}
|
108 |
-
res = chain.invoke(user_query, config)
|
109 |
-
step.output = res
|
110 |
-
return res
|
|
|
1 |
+
from chainlit.data import ChainlitDataLayer
|
|
|
|
|
2 |
|
|
|
|
|
3 |
|
4 |
+
# update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py)
|
5 |
+
class CustomLiteralDataLayer(ChainlitDataLayer):
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/config/config.yml
CHANGED
@@ -4,13 +4,12 @@ device: 'cpu' # str [cuda, cpu]
|
|
4 |
|
5 |
vectorstore:
|
6 |
load_from_HF: True # bool
|
7 |
-
HF_path: "XThomasBU/Colbert_Index" # str
|
8 |
embedd_files: False # bool
|
9 |
data_path: '../storage/data' # str
|
10 |
url_file_path: '../storage/data/urls.txt' # str
|
11 |
expand_urls: True # bool
|
12 |
-
db_option : '
|
13 |
-
db_path : 'vectorstores' # str
|
14 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
15 |
search_top_k : 3 # int
|
16 |
score_threshold : 0.2 # float
|
@@ -30,7 +29,7 @@ llm_params:
|
|
30 |
use_history: True # bool
|
31 |
memory_window: 3 # int
|
32 |
llm_style: 'Normal' # str [Normal, ELI5, Socratic]
|
33 |
-
llm_loader: 'gpt-
|
34 |
openai_params:
|
35 |
temperature: 0.7 # float
|
36 |
local_llm_params:
|
|
|
4 |
|
5 |
vectorstore:
|
6 |
load_from_HF: True # bool
|
|
|
7 |
embedd_files: False # bool
|
8 |
data_path: '../storage/data' # str
|
9 |
url_file_path: '../storage/data/urls.txt' # str
|
10 |
expand_urls: True # bool
|
11 |
+
db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille, RAPTOR]
|
12 |
+
db_path : '../vectorstores' # str
|
13 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
14 |
search_top_k : 3 # int
|
15 |
score_threshold : 0.2 # float
|
|
|
29 |
use_history: True # bool
|
30 |
memory_window: 3 # int
|
31 |
llm_style: 'Normal' # str [Normal, ELI5, Socratic]
|
32 |
+
llm_loader: 'gpt-4o-mini' # str [local_llm, gpt-3.5-turbo-1106, gpt-4, gpt-4o-mini]
|
33 |
openai_params:
|
34 |
temperature: 0.7 # float
|
35 |
local_llm_params:
|
code/modules/config/constants.py
CHANGED
@@ -7,7 +7,7 @@ load_dotenv()
|
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
9 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
10 |
-
|
11 |
LITERAL_API_URL = os.getenv("LITERAL_API_URL")
|
12 |
|
13 |
OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
|
@@ -18,3 +18,5 @@ opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me question
|
|
18 |
# Model Paths
|
19 |
|
20 |
LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
|
|
|
|
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
9 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
10 |
+
LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING")
|
11 |
LITERAL_API_URL = os.getenv("LITERAL_API_URL")
|
12 |
|
13 |
OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
|
|
|
18 |
# Model Paths
|
19 |
|
20 |
LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
21 |
+
|
22 |
+
RETRIEVER_HF_PATHS = {"RAGatouille": "XThomasBU/Colbert_Index"}
|
code/modules/vectorstore/store_manager.py
CHANGED
@@ -3,6 +3,7 @@ from modules.vectorstore.helpers import *
|
|
3 |
from modules.dataloader.webpage_crawler import WebpageCrawler
|
4 |
from modules.dataloader.data_loader import DataLoader
|
5 |
from modules.dataloader.helpers import *
|
|
|
6 |
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
7 |
import logging
|
8 |
import os
|
@@ -135,7 +136,13 @@ class VectorStoreManager:
|
|
135 |
self.embedding_model = self.create_embedding_model()
|
136 |
else:
|
137 |
self.embedding_model = None
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
end_time = time.time() # End time for loading database
|
140 |
self.logger.info(
|
141 |
f"Time taken to load database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
@@ -143,9 +150,9 @@ class VectorStoreManager:
|
|
143 |
self.logger.info("Loaded database")
|
144 |
return self.loaded_vector_db
|
145 |
|
146 |
-
def load_from_HF(self):
|
147 |
start_time = time.time() # Start time for loading database
|
148 |
-
self.vector_db._load_from_HF()
|
149 |
end_time = time.time()
|
150 |
self.logger.info(
|
151 |
f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
@@ -164,8 +171,14 @@ if __name__ == "__main__":
|
|
164 |
print(config)
|
165 |
print(f"Trying to create database with config: {config}")
|
166 |
vector_db = VectorStoreManager(config)
|
167 |
-
if config["vectorstore"]["load_from_HF"]
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
else:
|
170 |
vector_db.create_database()
|
171 |
print("Created database")
|
|
|
3 |
from modules.dataloader.webpage_crawler import WebpageCrawler
|
4 |
from modules.dataloader.data_loader import DataLoader
|
5 |
from modules.dataloader.helpers import *
|
6 |
+
from modules.config.constants import RETRIEVER_HF_PATHS
|
7 |
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
8 |
import logging
|
9 |
import os
|
|
|
136 |
self.embedding_model = self.create_embedding_model()
|
137 |
else:
|
138 |
self.embedding_model = None
|
139 |
+
try:
|
140 |
+
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
141 |
+
except Exception as e:
|
142 |
+
raise ValueError(f"Error loading database, check if it exists. if not run python -m modules.vectorstore.store_manager / Resteart the HF Space: {e}")
|
143 |
+
# print(f"Creating database")
|
144 |
+
# self.create_database()
|
145 |
+
# self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
146 |
end_time = time.time() # End time for loading database
|
147 |
self.logger.info(
|
148 |
f"Time taken to load database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
|
|
150 |
self.logger.info("Loaded database")
|
151 |
return self.loaded_vector_db
|
152 |
|
153 |
+
def load_from_HF(self, HF_PATH):
|
154 |
start_time = time.time() # Start time for loading database
|
155 |
+
self.vector_db._load_from_HF(HF_PATH)
|
156 |
end_time = time.time()
|
157 |
self.logger.info(
|
158 |
f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
|
|
171 |
print(config)
|
172 |
print(f"Trying to create database with config: {config}")
|
173 |
vector_db = VectorStoreManager(config)
|
174 |
+
if config["vectorstore"]["load_from_HF"]:
|
175 |
+
if config["vectorstore"]["db_option"] in RETRIEVER_HF_PATHS:
|
176 |
+
vector_db.load_from_HF(HF_PATH = RETRIEVER_HF_PATHS[config["vectorstore"]["db_option"]])
|
177 |
+
else:
|
178 |
+
# print(f"HF_PATH not available for {config['vectorstore']['db_option']}")
|
179 |
+
# print("Creating database")
|
180 |
+
# vector_db.create_database()
|
181 |
+
raise ValueError(f"HF_PATH not available for {config['vectorstore']['db_option']}")
|
182 |
else:
|
183 |
vector_db.create_database()
|
184 |
print("Created database")
|
code/modules/vectorstore/vectorstore.py
CHANGED
@@ -53,11 +53,11 @@ class VectorStore:
|
|
53 |
else:
|
54 |
return self.vectorstore.load_database(embedding_model)
|
55 |
|
56 |
-
def _load_from_HF(self):
|
57 |
# Download the snapshot from Hugging Face Hub
|
58 |
# Note: Download goes to the cache directory
|
59 |
snapshot_path = snapshot_download(
|
60 |
-
repo_id=
|
61 |
repo_type="dataset",
|
62 |
force_download=True,
|
63 |
)
|
|
|
53 |
else:
|
54 |
return self.vectorstore.load_database(embedding_model)
|
55 |
|
56 |
+
def _load_from_HF(self, HF_PATH):
|
57 |
# Download the snapshot from Hugging Face Hub
|
58 |
# Note: Download goes to the cache directory
|
59 |
snapshot_path = snapshot_download(
|
60 |
+
repo_id=HF_PATH,
|
61 |
repo_type="dataset",
|
62 |
force_download=True,
|
63 |
)
|
code/public/test.css
CHANGED
@@ -31,3 +31,13 @@ a[href*='https://github.com/Chainlit/chainlit'] {
|
|
31 |
.MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
|
32 |
display: none;
|
33 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
.MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
|
32 |
display: none;
|
33 |
}
|
34 |
+
|
35 |
+
/* Hide the new chat button
|
36 |
+
#new-chat-button {
|
37 |
+
display: none;
|
38 |
+
} */
|
39 |
+
|
40 |
+
/* Hide the open sidebar button
|
41 |
+
#open-sidebar-button {
|
42 |
+
display: none;
|
43 |
+
} */
|