Kiril commited on
Commit
0b3043b
·
1 Parent(s): 435b3cb

RAG with feedback

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. app.py +186 -0
  3. data/.gitkeep +0 -0
  4. feedback.py +67 -0
  5. rag_bot.py +92 -0
  6. requirements.txt +18 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
 
1
+ data/
2
+
3
  # Byte-compiled / optimized / DLL files
4
  __pycache__/
5
  *.py[cod]
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List
4
+
5
+ import chainlit as cl
6
+ import chainlit.data as cl_data
7
+ from langchain.callbacks.base import BaseCallbackHandler
8
+ from langchain.indexes import SQLRecordManager, index
9
+ from langchain.prompts import ChatPromptTemplate
10
+ from langchain.schema import Document
11
+ from langchain.schema import StrOutputParser
12
+ from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain_community.document_loaders import (
15
+ PyPDFDirectoryLoader,
16
+ )
17
+ from langchain_community.vectorstores import Chroma
18
+ # from langchain_openai import ChatOpenAI, OpenAIEmbeddings
19
+ from langchain_groq import ChatGroq
20
+ from langchain_huggingface import HuggingFaceEndpointEmbeddings
21
+
22
+ from feedback import CustomDataLayer
23
+ from rag_bot import RagBot
24
+
25
+ chunk_size = 1024
26
+ chunk_overlap = 50
27
+
28
+ embeddings_model = HuggingFaceEndpointEmbeddings(
29
+ huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
30
+ model="sentence-transformers/all-MiniLM-L12-v2",
31
+ )
32
+
33
+
34
+ # Feedback
35
+ cl_data._data_layer = CustomDataLayer()
36
+
37
+ PDF_STORAGE_PATH = "./data"
38
+
39
+
40
+ def process_pdfs(pdf_storage_path: str):
41
+ pdf_directory = Path(pdf_storage_path)
42
+ docs = [] # type: List[Document]
43
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
44
+
45
+ loader = PyPDFDirectoryLoader(pdf_directory)
46
+ documents = loader.load()
47
+ recursive_text_splitter = RecursiveCharacterTextSplitter(
48
+ chunk_size=chunk_size,
49
+ chunk_overlap=chunk_overlap,
50
+ length_function=len,
51
+ is_separator_regex=False,
52
+ )
53
+ docs = recursive_text_splitter.split_documents(documents)
54
+
55
+ doc_search = Chroma.from_documents(docs, embeddings_model)
56
+
57
+ namespace = "chromadb/my_documents"
58
+ record_manager = SQLRecordManager(
59
+ namespace, db_url="sqlite:///record_manager_cache.sql"
60
+ )
61
+ record_manager.create_schema()
62
+
63
+ index_result = index(
64
+ docs,
65
+ record_manager,
66
+ doc_search,
67
+ cleanup="full",
68
+ source_id_key="source",
69
+ )
70
+
71
+ print(f"Indexing stats: {index_result}")
72
+
73
+ return doc_search
74
+
75
+
76
+ doc_search = process_pdfs(PDF_STORAGE_PATH)
77
+ # model = ChatOpenAI(model_name="gpt-4", streaming=True)
78
+ model = ChatGroq(
79
+ model='llama-3.1-70b-versatile',
80
+ temperature=0,
81
+ max_tokens=1024,
82
+ timeout=None,
83
+ max_retries=5,
84
+ api_key=os.getenv("GROQ_API_KEY"),
85
+ # other params...
86
+ )
87
+
88
+
89
+ @cl.on_chat_start
90
+ async def on_chat_start():
91
+
92
+ prompt = ChatPromptTemplate.from_messages(
93
+ [
94
+ ("system",
95
+ """You are a helpful assistant that can answer questions about technical documents in any language.
96
+ Keep your answers only in the language of the question(s).
97
+
98
+ Only use the factual information from the document(s) to answer the question(s). Keep your answers concise and to the point.
99
+
100
+ If you do not have have sufficient information to answer a question, politely refuse to answer and say "I don't know".
101
+ \n\nRelevant documents will be retrieved below."""
102
+ "Context: {context}"
103
+ ),
104
+ ("human", "{question}"),
105
+ ]
106
+ )
107
+
108
+ def format_docs(docs):
109
+ return "\n\n".join([d.page_content for d in docs])
110
+
111
+ retriever = doc_search.as_retriever(search_kwargs={"k": 5})
112
+
113
+ runnable = (
114
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
115
+ | prompt
116
+ | model
117
+ | StrOutputParser()
118
+ )
119
+
120
+ cl.user_session.set("runnable", runnable)
121
+
122
+
123
+ @cl.on_message
124
+ async def on_message(message: cl.Message):
125
+ runnable = cl.user_session.get("runnable") # type: Runnable
126
+ msg = cl.Message(content="")
127
+
128
+ class PostMessageHandler(BaseCallbackHandler):
129
+ """
130
+ Callback handler for handling the retriever and LLM processes.
131
+ Used to post the sources of the retrieved documents as a Chainlit element.
132
+ """
133
+
134
+ def __init__(self, msg: cl.Message):
135
+ BaseCallbackHandler.__init__(self)
136
+ self.msg = msg
137
+ self.sources = [] # To store unique pairs
138
+
139
+ def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
140
+ for doc in documents:
141
+ source = doc.metadata.get('source', 'Unknown Source')
142
+ page = doc.metadata.get('page', 'N/A')
143
+ page_content = doc.page_content
144
+ # self.sources.add(source_page_pair) # Add unique pairs to the set
145
+ if not any(s["source"] == source and s["page"] == page for s in self.sources):
146
+ self.sources.append({
147
+ "source": source,
148
+ "page": page,
149
+ "content": page_content
150
+ })
151
+
152
+ def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
153
+ if len(self.sources):
154
+ # Create a list of clickable elements for sources
155
+ text_elements = []
156
+ source_references = []
157
+ for idx, src in enumerate(self.sources):
158
+ source_name = f"{src['source']} p.{src['page']}"
159
+ source_references.append(source_name)
160
+
161
+ # Add a previewable Chainlit element
162
+ text_elements.append(
163
+ cl.Text(
164
+ name=source_name,
165
+ content=src["content"],
166
+ display="side",
167
+ )
168
+ )
169
+ # Generate the answer with clickable source names
170
+ self.msg.content += f"\n\nSources: {", ".join(
171
+ source_references
172
+ )}"
173
+
174
+ # Append text elements to the message
175
+ self.msg.elements.extend(text_elements)
176
+
177
+ async for chunk in runnable.astream(
178
+ message.content,
179
+ config=RunnableConfig(callbacks=[
180
+ cl.LangchainCallbackHandler(),
181
+ PostMessageHandler(msg)
182
+ ]),
183
+ ):
184
+ await msg.stream_token(chunk)
185
+
186
+ await msg.send()
data/.gitkeep ADDED
File without changes
feedback.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit.data as cl_data
2
+ import chainlit as cl
3
+ from langsmith import traceable, Client
4
+ import uuid
5
+
6
+
7
+ class CustomDataLayer(cl_data.BaseDataLayer):
8
+ async def upsert_feedback(self, feedback: cl_data.base.Feedback) -> str:
9
+ client = Client()
10
+ run_id = uuid.uuid4()
11
+ cl.message(f"Creating feedback for run_id: {run_id} \n{feedback}")
12
+
13
+ client.create_feedback(
14
+ run_id,
15
+ key="correction",
16
+ score=feedback.value,
17
+ comment=feedback.comment,
18
+ )
19
+
20
+
21
+ return await super().upsert_feedback(feedback)
22
+
23
+ async def build_debug_url(self, *args, **kwargs):
24
+ pass
25
+
26
+ async def create_element(self, *args, **kwargs):
27
+ pass
28
+
29
+ async def create_step(self, *args, **kwargs):
30
+ pass
31
+
32
+ async def create_user(self, *args, **kwargs):
33
+ pass
34
+
35
+ async def delete_element(self, *args, **kwargs):
36
+ pass
37
+
38
+ async def delete_feedback(self, *args, **kwargs):
39
+ pass
40
+
41
+ async def delete_step(self, *args, **kwargs):
42
+ pass
43
+
44
+ async def delete_thread(self, *args, **kwargs):
45
+ pass
46
+
47
+ async def get_element(self, *args, **kwargs):
48
+ pass
49
+
50
+ async def get_thread(self, *args, **kwargs):
51
+ pass
52
+
53
+ async def get_thread_author(self, *args, **kwargs):
54
+ pass
55
+
56
+ async def get_user(self, *args, **kwargs):
57
+ pass
58
+
59
+ async def list_threads(self, *args, **kwargs):
60
+ pass
61
+
62
+ async def update_step(self, *args, **kwargs):
63
+ pass
64
+
65
+ async def update_thread(self, *args, **kwargs):
66
+ pass
67
+
rag_bot.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from operator import itemgetter
3
+ from typing import List
4
+
5
+ from langchain.retrievers import EnsembleRetriever
6
+ from langchain_community.retrievers import BM25Retriever
7
+ from langchain_core.documents import Document
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.runnables import chain
11
+ from langsmith import traceable
12
+ from nltk.tokenize import word_tokenize
13
+
14
+
15
+ class RagBot:
16
+
17
+ def __init__(self, retriever, model, is_local_model):
18
+ self._retriever = retriever
19
+ # Wrapping the client instruments the LLM
20
+ self._model = model
21
+ self._is_local_model = is_local_model
22
+ self._prompt = self.prompt_template()
23
+
24
+ # Set up the prompt template
25
+ def prompt_template(self):
26
+ return ChatPromptTemplate.from_messages(
27
+ [
28
+ ("system",
29
+ """You are a helpful assistant that can answer questions about technical documents in any language.
30
+ Keep your answers only in the language of the question(s).
31
+
32
+ Only use the factual information from the document(s) to answer the question(s). Keep your answers concise and to the point.
33
+
34
+ If you do not have have sufficient information to answer a question, politely refuse to answer and say "I don't know".
35
+ \n\nRelevant documents will be retrieved below."""
36
+ "Context: {context}"
37
+ ),
38
+ ("human", "{question}"),
39
+ ])
40
+
41
+ @traceable()
42
+ def retrieve_docs(self, question):
43
+ return self._retriever.invoke(question)
44
+
45
+ @traceable()
46
+ def invoke_llm(self, query, docs):
47
+ chain = (
48
+ # {"docs": retriever,"question": RunnablePassthrough()}
49
+ {"context": itemgetter("context"), "question": itemgetter("question")}
50
+ | self._prompt | self._model | StrOutputParser()
51
+ )
52
+
53
+ # Visualize input schema if needed
54
+ # chain.input_schema.schema()
55
+ # Retrieve context docs
56
+ # context = retriever.invoke(query)
57
+
58
+ print(f"Question : \n{query}\n\n")
59
+ # Stream the result if HuggingFaceEndpoint is used
60
+ result = ""
61
+ stopwatch = time.perf_counter() # measure time
62
+ if not self._is_local_model:
63
+ print(f"Invoking the result with Inference API...\n")
64
+ chunks = []
65
+ result = chain.invoke({"question": query, "context": docs})
66
+ print(result)
67
+ # for chunk in chain.stream({"context": context, "question": query}):
68
+ # result+=chunk
69
+ # print(chunk, end='|', flush=True)
70
+
71
+ else:
72
+ print(f"Invoking the result with Local LLM...\n")
73
+ result = chain.invoke({"context": docs, "question": query})
74
+ # result.append(chunk)
75
+ # print(chunk, end='|', flush=True)
76
+ print(f"\n\nTime for invoke {(time.perf_counter() - stopwatch) / 60}")
77
+ print(f"\nThe answer is based on the following {self._retriever.k} relevant documents:")
78
+ # context = result.get("context", []) # Retrieve the context
79
+ for doc in docs:
80
+ print(f"\n{doc.page_content}\nMetadata: {doc.metadata}\n")
81
+
82
+ # Evaluators will expect "answer" and "contexts"
83
+ return {
84
+ "answer": result,
85
+ "contexts": docs,
86
+ }
87
+
88
+ @traceable()
89
+ def get_answer(self, query: str):
90
+ docs = self.retrieve_docs(query)
91
+ return self.invoke_llm(query, docs)
92
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chainlit
2
+ chainlit
3
+ pydantic==2.10.1 # Required for chainlit
4
+ # Langchain
5
+ langchain
6
+ langchain-core
7
+ langchain-community
8
+ langchain-groq
9
+ langchain-huggingface
10
+ langsmith
11
+
12
+ chromadb
13
+ tiktoken
14
+ pypdf
15
+ cryptography # required for pypdf
16
+ # BM25 retriever
17
+ nltk
18
+