Ritvik19 commited on
Commit
c323312
·
verified ·
1 Parent(s): 7793370

Upload 8 files

Browse files
Files changed (6) hide show
  1. app.py +119 -83
  2. chain_of_density.py +42 -0
  3. chat_chains.py +33 -27
  4. command_center.py +6 -0
  5. custom_exceptions.py +6 -0
  6. process_documents.py +20 -9
app.py CHANGED
@@ -8,32 +8,20 @@ import json
8
  from langchain.callbacks import get_openai_callback
9
  from langchain_openai import ChatOpenAI
10
  import base64
11
- from chat_chains import rag_chain, parse_model_response
12
- from langchain_core.messages import AIMessage, HumanMessage
13
- from autoqa_chains import auto_qa_chain, followup_qa_chain, auto_qa_output_parser
 
 
 
 
 
 
 
14
 
15
  st.set_page_config(layout="wide")
16
  os.environ["OPENAI_API_KEY"] = "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"
17
 
18
- format_citations = lambda citations: "\n\n".join(
19
- [f"{citation['quote']} ... [{citation['source_id']}]" for citation in citations]
20
- )
21
-
22
-
23
- def session_state_2_llm_chat_history(session_state):
24
- chat_history = []
25
- for ss in session_state:
26
- if not ss[0].startswith("/"):
27
- chat_history.append(HumanMessage(content=ss[0]))
28
- chat_history.append(AIMessage(content=ss[1]))
29
- return chat_history
30
-
31
-
32
- ai_message_format = lambda message, references: (
33
- f"{message}\n\n---\n\n{format_citations(references)}"
34
- if references != ""
35
- else message
36
- )
37
 
38
  welcome_message = """
39
  Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you.
@@ -42,17 +30,20 @@ Here's a quick guide to getting started with me:
42
 
43
  | Command | Description |
44
  |---------|-------------|
45
- | `/upload` <list of urls> | Upload and process documents for our conversation. |
46
- | `/index` | View an index of processed documents to easily navigate your research. |
47
- | `/cost` | Calculate the cost of our conversation, ensuring transparency in resource usage. |
48
- | `/download` | Download conversation data for your records or further analysis. |
49
- | `/auto` <document id> | Automatically generate questions and answers for a document. |
 
 
 
50
 
51
  <br>
52
 
53
  Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together!
54
 
55
- Use `/man` at any point of time to view this guide again.
56
  """
57
 
58
 
@@ -64,28 +55,26 @@ def process_documents_wrapper(inputs):
64
  [snip.metadata["chunk_id"], snip.metadata["header"]] for snip in snippets
65
  ]
66
  response = f"Uploaded and processed documents {inputs}"
67
- st.session_state.messages.append((f"/upload {inputs}", response, ""))
68
  st.session_state.documents = documents
69
- return response
70
 
71
 
72
  def index_documents_wrapper(inputs=None):
73
- response = pd.DataFrame(
74
- st.session_state.index, columns=["id", "reference"]
75
- ).to_markdown()
76
- st.session_state.messages.append(("/index", response, ""))
77
- return response
78
 
79
 
80
  def calculate_cost_wrapper(inputs=None):
81
  try:
82
  stats_df = pd.DataFrame(st.session_state.costing)
83
  stats_df.loc["total"] = stats_df.sum()
84
- response = stats_df.to_markdown()
85
  except ValueError:
86
  response = "No cost incurred yet"
87
- st.session_state.messages.append(("/cost", response, ""))
88
- return response
89
 
90
 
91
  def download_conversation_wrapper(inputs=None):
@@ -100,7 +89,7 @@ def download_conversation_wrapper(inputs=None):
100
  st.session_state.index if "index" in st.session_state else []
101
  ),
102
  "conversation": [
103
- {"human": message[0], "ai": message[1], "references": message[2]}
104
  for message in st.session_state.messages
105
  ],
106
  "costing": (
@@ -117,25 +106,22 @@ def download_conversation_wrapper(inputs=None):
117
  }
118
  )
119
  conversation_data = base64.b64encode(conversation_data.encode()).decode()
120
- st.session_state.messages.append(("/download", "Conversation data downloaded", ""))
121
- return f'<a href="data:text/csv;base64,{conversation_data}" download="conversation_data.json">Download Conversation</a>'
 
 
 
 
 
122
 
123
 
124
- def query_llm_wrapper(inputs):
125
- retriever = st.session_state.retriever
126
- qa_chain = rag_chain(
127
- retriever, ChatOpenAI(model="gpt-4-0125-preview", temperature=0)
128
- )
129
- relevant_docs = retriever.get_relevant_documents(inputs)
130
  with get_openai_callback() as cb:
131
- response = qa_chain.invoke(
132
- {
133
- "question": inputs,
134
- "chat_history": session_state_2_llm_chat_history(
135
- st.session_state.messages
136
- ),
137
- }
138
- ).content
139
  stats = cb
140
  response = parse_model_response(response)
141
  answer = response["answer"]
@@ -147,7 +133,6 @@ def query_llm_wrapper(inputs):
147
  f"[{ref}]"
148
  for ref in sorted(
149
  [ref.metadata["chunk_id"] for ref in relevant_docs],
150
- key=lambda x: int(x.split("_")[1]),
151
  )
152
  ]
153
  ),
@@ -155,7 +140,41 @@ def query_llm_wrapper(inputs):
155
  }
156
  )
157
 
158
- st.session_state.messages.append((inputs, answer, citations))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  st.session_state.costing.append(
160
  {
161
  "prompt tokens": stats.prompt_tokens,
@@ -163,11 +182,13 @@ def query_llm_wrapper(inputs):
163
  "cost": stats.total_cost,
164
  }
165
  )
166
- return answer, citations
167
 
168
 
169
  def auto_qa_chain_wrapper(inputs):
170
- document = st.session_state.documents[inputs]
 
 
171
  llm = ChatOpenAI(model="gpt-4-turbo-preview", temperature=0)
172
  auto_qa_conversation = []
173
  with get_openai_callback() as cb:
@@ -176,15 +197,15 @@ def auto_qa_chain_wrapper(inputs):
176
  "questions"
177
  ]
178
  auto_qa_conversation = [
179
- (f'/auto {qa["question"]}', qa["answer"], "")
180
  for qa in auto_qa_response_parsed
181
  ]
182
  stats = cb
183
  st.session_state.messages.append(
184
- (f"/auto {inputs}", "Auto Convervation Generated", "")
185
  )
186
  for qa in auto_qa_conversation:
187
- st.session_state.messages.append((qa[0], qa[1], ""))
188
 
189
  st.session_state.costing.append(
190
  {
@@ -193,12 +214,16 @@ def auto_qa_chain_wrapper(inputs):
193
  "cost": stats.total_cost,
194
  }
195
  )
196
- return "\n\n".join(
197
- f"Q: {qa['question']}\n\nA: {qa['answer']}" for qa in auto_qa_response_parsed
 
 
 
 
198
  )
199
 
200
 
201
- def boot(command_center):
202
  st.write("# Agent Zeta")
203
  if "costing" not in st.session_state:
204
  st.session_state.costing = []
@@ -208,34 +233,45 @@ def boot(command_center):
208
  for message in st.session_state.messages:
209
  st.chat_message("human").write(message[0])
210
  st.chat_message("ai").write(
211
- ai_message_format(message[1], message[2]), unsafe_allow_html=True
212
  )
213
  if query := st.chat_input():
214
- st.chat_message("human").write(query)
215
- response = command_center.execute_command(query)
216
- if response is None:
217
- pass
218
- elif type(response) == tuple:
219
- result, references = response
220
  st.chat_message("ai").write(
221
- ai_message_format(result, references), unsafe_allow_html=True
222
  )
223
- else:
224
- st.chat_message("ai").write(response, unsafe_allow_html=True)
225
 
226
 
227
  if __name__ == "__main__":
228
  all_commands = [
229
- ("/upload", list, process_documents_wrapper),
230
- ("/index", None, index_documents_wrapper),
231
- ("/cost", None, calculate_cost_wrapper),
232
- ("/download", None, download_conversation_wrapper),
233
- ("/man", None, lambda x: welcome_message),
234
- ("/auto", int, auto_qa_chain_wrapper),
 
 
235
  ]
236
  command_center = CommandCenter(
237
  default_input_type=str,
238
- default_function=query_llm_wrapper,
239
  all_commands=all_commands,
240
  )
241
- boot(command_center)
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from langchain.callbacks import get_openai_callback
9
  from langchain_openai import ChatOpenAI
10
  import base64
11
+ from chat_chains import (
12
+ parse_model_response,
13
+ qa_chain,
14
+ format_docs,
15
+ parse_context_and_question,
16
+ ai_response_format,
17
+ )
18
+ from autoqa_chains import auto_qa_chain, auto_qa_output_parser
19
+ from chain_of_density import chain_of_density_chain
20
+ from custom_exceptions import InvalidArgumentError, InvalidCommandError
21
 
22
  st.set_page_config(layout="wide")
23
  os.environ["OPENAI_API_KEY"] = "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"
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  welcome_message = """
27
  Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you.
 
30
 
31
  | Command | Description |
32
  |---------|-------------|
33
+ | `/add-papers <list of urls>` | Upload and process documents for our conversation. |
34
+ | `/library` | View an index of processed documents to easily navigate your research. |
35
+ | `/session-expense` | Calculate the cost of our conversation, ensuring transparency in resource usage. |
36
+ | `/export` | Download conversation data for your records or further analysis. |
37
+ | `/auto-insight <document id>` | Automatically generate questions and answers for a document. |
38
+ | `/deep-dive [<list of document ids>] <query>` | Query the AI with a specific document context. |
39
+ | `/condense-summary <document id>` | Generate increasingly concise, entity-dense summaries of a document. |
40
+
41
 
42
  <br>
43
 
44
  Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together!
45
 
46
+ Use `/help-me` at any point of time to view this guide again.
47
  """
48
 
49
 
 
55
  [snip.metadata["chunk_id"], snip.metadata["header"]] for snip in snippets
56
  ]
57
  response = f"Uploaded and processed documents {inputs}"
58
+ st.session_state.messages.append((f"/add-papers {inputs}", response, "identity"))
59
  st.session_state.documents = documents
60
+ return (response, "identity")
61
 
62
 
63
  def index_documents_wrapper(inputs=None):
64
+ response = pd.DataFrame(st.session_state.index, columns=["id", "reference"])
65
+ st.session_state.messages.append(("/library", response, "dataframe"))
66
+ return (response, "dataframe")
 
 
67
 
68
 
69
  def calculate_cost_wrapper(inputs=None):
70
  try:
71
  stats_df = pd.DataFrame(st.session_state.costing)
72
  stats_df.loc["total"] = stats_df.sum()
73
+ response = stats_df
74
  except ValueError:
75
  response = "No cost incurred yet"
76
+ st.session_state.messages.append(("/session-expense", response, "dataframe"))
77
+ return (response, "dataframe")
78
 
79
 
80
  def download_conversation_wrapper(inputs=None):
 
89
  st.session_state.index if "index" in st.session_state else []
90
  ),
91
  "conversation": [
92
+ {"human": message[0], "ai": jsonify_functions[message[2]](message[1])}
93
  for message in st.session_state.messages
94
  ],
95
  "costing": (
 
106
  }
107
  )
108
  conversation_data = base64.b64encode(conversation_data.encode()).decode()
109
+ st.session_state.messages.append(
110
+ ("/export", "Conversation data downloaded", "identity")
111
+ )
112
+ return (
113
+ f'<a href="data:text/csv;base64,{conversation_data}" download="conversation_data.json">Download Conversation</a>',
114
+ "identity",
115
+ )
116
 
117
 
118
+ def query_llm(inputs, relevant_docs):
 
 
 
 
 
119
  with get_openai_callback() as cb:
120
+ response = (
121
+ qa_chain(ChatOpenAI(model="gpt-4-0125-preview", temperature=0))
122
+ .invoke({"context": format_docs(relevant_docs), "question": inputs})
123
+ .content
124
+ )
 
 
 
125
  stats = cb
126
  response = parse_model_response(response)
127
  answer = response["answer"]
 
133
  f"[{ref}]"
134
  for ref in sorted(
135
  [ref.metadata["chunk_id"] for ref in relevant_docs],
 
136
  )
137
  ]
138
  ),
 
140
  }
141
  )
142
 
143
+ st.session_state.messages.append(
144
+ (inputs, {"answer": answer, "citations": citations}, "reponse_with_citations")
145
+ )
146
+ st.session_state.costing.append(
147
+ {
148
+ "prompt tokens": stats.prompt_tokens,
149
+ "completion tokens": stats.completion_tokens,
150
+ "cost": stats.total_cost,
151
+ }
152
+ )
153
+ return ({"answer": answer, "citations": citations}, "reponse_with_citations")
154
+
155
+
156
+ def rag_llm_wrapper(inputs):
157
+ retriever = st.session_state.retriever
158
+ relevant_docs = retriever.get_relevant_documents(inputs)
159
+ return query_llm(inputs, relevant_docs)
160
+
161
+
162
+ def query_llm_wrapper(inputs):
163
+ context, question = parse_context_and_question(inputs)
164
+ relevant_docs = [st.session_state.documents[c] for c in context]
165
+ print(context, question)
166
+ return query_llm(question, relevant_docs)
167
+
168
+
169
+ def chain_of_density_wrapper(inputs):
170
+ if inputs == "":
171
+ raise InvalidArgumentError("Please provide a document id")
172
+ document = st.session_state.documents[inputs].page_content
173
+ llm = ChatOpenAI(model="gpt-4-turbo-preview", temperature=0)
174
+ with get_openai_callback() as cb:
175
+ summary = chain_of_density_chain(llm).invoke({"paper": document})
176
+ stats = cb
177
+ st.session_state.messages.append(("/condense-summary", summary, "identity"))
178
  st.session_state.costing.append(
179
  {
180
  "prompt tokens": stats.prompt_tokens,
 
182
  "cost": stats.total_cost,
183
  }
184
  )
185
+ return (summary, "identity")
186
 
187
 
188
  def auto_qa_chain_wrapper(inputs):
189
+ if inputs == "":
190
+ raise InvalidArgumentError("Please provide a document id")
191
+ document = st.session_state.documents[inputs].page_content
192
  llm = ChatOpenAI(model="gpt-4-turbo-preview", temperature=0)
193
  auto_qa_conversation = []
194
  with get_openai_callback() as cb:
 
197
  "questions"
198
  ]
199
  auto_qa_conversation = [
200
+ (f'/auto {qa["question"]}', qa["answer"], "identity")
201
  for qa in auto_qa_response_parsed
202
  ]
203
  stats = cb
204
  st.session_state.messages.append(
205
+ (f"/auto-insight {inputs}", "Auto Convervation Generated", "identity")
206
  )
207
  for qa in auto_qa_conversation:
208
+ st.session_state.messages.append((qa[0], qa[1], "identity"))
209
 
210
  st.session_state.costing.append(
211
  {
 
214
  "cost": stats.total_cost,
215
  }
216
  )
217
+ return (
218
+ "\n\n".join(
219
+ f"Q: {qa['question']}\n\nA: {qa['answer']}"
220
+ for qa in auto_qa_response_parsed
221
+ ),
222
+ "identity",
223
  )
224
 
225
 
226
+ def boot(command_center, formating_functions):
227
  st.write("# Agent Zeta")
228
  if "costing" not in st.session_state:
229
  st.session_state.costing = []
 
233
  for message in st.session_state.messages:
234
  st.chat_message("human").write(message[0])
235
  st.chat_message("ai").write(
236
+ formating_functions[message[2]](message[1]), unsafe_allow_html=True
237
  )
238
  if query := st.chat_input():
239
+ try:
240
+ st.chat_message("human").write(query)
241
+ response, format_fn_name = command_center.execute_command(query)
 
 
 
242
  st.chat_message("ai").write(
243
+ formating_functions[format_fn_name](response), unsafe_allow_html=True
244
  )
245
+ except (InvalidArgumentError, InvalidCommandError) as e:
246
+ st.error(e)
247
 
248
 
249
  if __name__ == "__main__":
250
  all_commands = [
251
+ ("/add-papers", list, process_documents_wrapper),
252
+ ("/library", None, index_documents_wrapper),
253
+ ("/session-expense", None, calculate_cost_wrapper),
254
+ ("/export", None, download_conversation_wrapper),
255
+ ("/help-me", None, lambda x: (welcome_message, "identity")),
256
+ ("/auto-insight", str, auto_qa_chain_wrapper),
257
+ ("/deep-dive", str, query_llm_wrapper),
258
+ ("/condense-summary", str, chain_of_density_wrapper),
259
  ]
260
  command_center = CommandCenter(
261
  default_input_type=str,
262
+ default_function=rag_llm_wrapper,
263
  all_commands=all_commands,
264
  )
265
+ formating_functions = {
266
+ "identity": lambda x: x,
267
+ "dataframe": lambda x: x,
268
+ "reponse_with_citations": lambda x: ai_response_format(
269
+ x["answer"], x["citations"]
270
+ ),
271
+ }
272
+ jsonify_functions = {
273
+ "identity": lambda x: x,
274
+ "dataframe": lambda x: x.to_dict(orient="records"),
275
+ "reponse_with_citations": lambda x: x,
276
+ }
277
+ boot(command_center, formating_functions)
chain_of_density.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.output_parsers import JsonOutputParser
2
+ from langchain_core.prompts import PromptTemplate
3
+
4
+ chain_of_density_prompt_template = """
5
+ Research Paper: {paper}
6
+
7
+ You will generate increasingly concise, entity-dense summaries of the above research paper.
8
+
9
+ Repeat the following 2 steps 10 times.
10
+
11
+ Step 1. Identify 1-3 informative Entities ('; ' delimited) from the research paper that are missing from the previously generated summary. These entities should be key components such as research questions, methodologies, findings, theoretical contributions, or implications.
12
+ Step 2. Write a new, denser summary of identical length which covers every entity and detail from the previous summary plus the Missing Entities.
13
+
14
+ A Missing Entity is:
15
+ - Relevant: critical to understanding the paper’s contribution.
16
+ - Specific: descriptive yet concise (5 words or fewer).
17
+ - Novel: not included in the previous summary.
18
+ - Faithful: accurately represented in the research paper.
19
+ - Anywhere: can be found anywhere in the research paper.
20
+
21
+ Guidelines:
22
+ - The first summary should be long (4-5 sentences, ~100 words) yet focus on general information about the research paper, including its broad topic and objectives, without going into detail.
23
+ - Avoid using verbose language and fillers (e.g., 'This research paper discusses') to reach the word count.
24
+ - Strive for efficiency in word use: rewrite the previous summary to improve readability and make space for additional entities.
25
+ - Employ strategies such as fusion (combining entities), compression (shortening descriptions), and removal of uninformative phrases to make space for new entities.
26
+ - The summaries should evolve to be highly dense and concise yet remain self-contained, meaning they can be understood without reading the full paper.
27
+ - Missing entities should be integrated seamlessly into the new summary.
28
+ - Never omit entities from previous summaries. If space is a challenge, incorporate fewer new entities but maintain the same word count.
29
+
30
+ Remember, use the exact same number of words for each summary.
31
+
32
+ The JSON output should be a list (length 10) of dictionaries. Each dictionary must have two keys: 'missing_entities', listing the 1-3 entities added in each round; and 'denser_summary', presenting the new summary that integrates these entities without increasing the length.
33
+ """
34
+
35
+ chain_of_density_output_parser = JsonOutputParser()
36
+ chain_of_density_prompt = PromptTemplate(
37
+ template=chain_of_density_prompt_template,
38
+ input_variables=["paper"],
39
+ )
40
+ chain_of_density_chain = (
41
+ lambda model: chain_of_density_prompt | model | chain_of_density_output_parser
42
+ )
chat_chains.py CHANGED
@@ -1,22 +1,8 @@
1
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
2
- from langchain_core.output_parsers import StrOutputParser
3
  from langchain_core.runnables import RunnablePassthrough
4
  import xml.etree.ElementTree as ET
5
  import re
6
 
7
- contextualize_q_system_prompt = """Given a chat history and the latest user question \
8
- which might reference context in the chat history, formulate a standalone question \
9
- which can be understood without the chat history. Do NOT answer the question, \
10
- just reformulate it if needed and otherwise return it as is."""
11
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
12
- [
13
- ("system", contextualize_q_system_prompt),
14
- MessagesPlaceholder(variable_name="chat_history"),
15
- ("human", "{question}"),
16
- ]
17
- )
18
- contextualize_q_chain = lambda llm: contextualize_q_prompt | llm | StrOutputParser()
19
-
20
  qa_system_prompt = """As Zeta, your mission is to assist users in navigating the vast sea of machine learning research with ease and insight. When responding to inquiries, adhere to the following guidelines to ensure the utmost accuracy and utility:
21
 
22
  Contextual Understanding: When presented with a question, apply your understanding of machine learning concepts to interpret the context provided accurately. Utilize this context to guide your search for answers within the specified research papers.
@@ -46,7 +32,7 @@ By following these guidelines, you ensure that users receive valuable, accurate,
46
  qa_prompt = ChatPromptTemplate.from_messages(
47
  [
48
  ("system", qa_system_prompt),
49
- MessagesPlaceholder(variable_name="chat_history"),
50
  ("human", "{question}"),
51
  ]
52
  )
@@ -54,21 +40,19 @@ qa_prompt = ChatPromptTemplate.from_messages(
54
 
55
  def format_docs(docs):
56
  return "\n\n".join(
57
- f"{doc.metadata['chunk_id']}: {doc.page_content}" for doc in docs
 
58
  )
59
 
60
 
61
- def contextualized_question(input: dict):
62
- if input.get("chat_history"):
63
- return contextualize_q_chain
64
- else:
65
- return input["question"]
66
-
67
-
68
  rag_chain = lambda retriever, llm: (
69
- RunnablePassthrough.assign(
70
- context=contextualized_question | retriever | format_docs
71
- )
 
 
 
 
72
  | qa_prompt
73
  | llm
74
  )
@@ -105,3 +89,25 @@ def parse_model_response(input_string):
105
  parsed_data["answer"] = "".join(outside_text_parts)
106
 
107
  return parsed_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
 
2
  from langchain_core.runnables import RunnablePassthrough
3
  import xml.etree.ElementTree as ET
4
  import re
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  qa_system_prompt = """As Zeta, your mission is to assist users in navigating the vast sea of machine learning research with ease and insight. When responding to inquiries, adhere to the following guidelines to ensure the utmost accuracy and utility:
7
 
8
  Contextual Understanding: When presented with a question, apply your understanding of machine learning concepts to interpret the context provided accurately. Utilize this context to guide your search for answers within the specified research papers.
 
32
  qa_prompt = ChatPromptTemplate.from_messages(
33
  [
34
  ("system", qa_system_prompt),
35
+ # MessagesPlaceholder(variable_name="chat_history"),
36
  ("human", "{question}"),
37
  ]
38
  )
 
40
 
41
  def format_docs(docs):
42
  return "\n\n".join(
43
+ f"{doc.metadata['chunk_id']}: {doc.page_content}" if type(doc) != str else doc
44
+ for doc in docs
45
  )
46
 
47
 
 
 
 
 
 
 
 
48
  rag_chain = lambda retriever, llm: (
49
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
50
+ | qa_prompt
51
+ | llm
52
+ )
53
+
54
+ qa_chain = lambda llm: (
55
+ {"context": RunnablePassthrough(), "question": RunnablePassthrough()}
56
  | qa_prompt
57
  | llm
58
  )
 
89
  parsed_data["answer"] = "".join(outside_text_parts)
90
 
91
  return parsed_data
92
+
93
+
94
+ def parse_context_and_question(inputs):
95
+ pattern = r"\[(.*?)\]"
96
+ match = re.search(pattern, inputs)
97
+ if match:
98
+ context = match.group(1)
99
+ context = [c.strip() for c in context.split()]
100
+ question = inputs[: match.start()] + inputs[match.end() :]
101
+ return context, question
102
+ else:
103
+ return "", inputs
104
+
105
+
106
+ format_citations = lambda citations: "\n\n".join(
107
+ [f"{citation['quote']} ... [{citation['source_id']}]" for citation in citations]
108
+ )
109
+ ai_response_format = lambda message, references: (
110
+ f"{message}\n\n---\n\n{format_citations(references)}"
111
+ if references != ""
112
+ else message
113
+ )
command_center.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  class CommandCenter:
2
  def __init__(self, default_input_type, default_function=None, all_commands=None):
3
  self.commands = {}
@@ -20,6 +23,9 @@ class CommandCenter:
20
  command = inputs[0]
21
  argument = inputs[1:]
22
 
 
 
 
23
  # type casting the arguments
24
  if self.commands[command]["input_type"] == str:
25
  argument = " ".join(argument)
 
1
+ from custom_exceptions import InvalidCommandError
2
+
3
+
4
  class CommandCenter:
5
  def __init__(self, default_input_type, default_function=None, all_commands=None):
6
  self.commands = {}
 
23
  command = inputs[0]
24
  argument = inputs[1:]
25
 
26
+ if command not in self.commands:
27
+ raise InvalidCommandError("Invalid command")
28
+
29
  # type casting the arguments
30
  if self.commands[command]["input_type"] == str:
31
  argument = " ".join(argument)
custom_exceptions.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ class InvalidCommandError(Exception):
2
+ pass
3
+
4
+
5
+ class InvalidArgumentError(Exception):
6
+ pass
process_documents.py CHANGED
@@ -10,14 +10,25 @@ deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip()
10
 
11
  def process_documents(urls):
12
  snippets = []
13
- documents = []
14
  for source_id, url in enumerate(urls):
15
- if url.endswith(".pdf"):
16
- snippets.extend(process_pdf(url, source_id))
17
- documents.append("\n".join([snip.page_content for snip in snippets]))
18
- else:
19
- snippets.extend(process_web(url, source_id))
20
- documents.append("\n".join([snip.page_content for snip in snippets]))
 
 
 
 
 
 
 
 
 
 
 
21
  return snippets, documents
22
 
23
 
@@ -30,7 +41,7 @@ def process_web(url, source_id):
30
  "header": data.metadata["title"],
31
  "source_url": url,
32
  "source_type": "web",
33
- "chunk_id": f"{source_id}_0",
34
  "source_id": source_id,
35
  },
36
  )
@@ -54,7 +65,7 @@ def process_pdf(url, source_id):
54
  "header": " ".join(snip[1]["header_text"].split()[:10]),
55
  "source_url": url,
56
  "source_type": "pdf",
57
- "chunk_id": f"{source_id}_{i}",
58
  "source_id": source_id,
59
  },
60
  )
 
10
 
11
  def process_documents(urls):
12
  snippets = []
13
+ documents = {}
14
  for source_id, url in enumerate(urls):
15
+ snippet = (
16
+ process_pdf(url, source_id)
17
+ if url.endswith(".pdf")
18
+ else process_web(url, source_id)
19
+ )
20
+ snippets.extend(snippet)
21
+ documents[str(source_id)] = Document(
22
+ page_content="\n".join([snip.page_content for snip in snippet]),
23
+ metadata={
24
+ "source_url": url,
25
+ "source_type": "pdf" if url.endswith(".pdf") else "web",
26
+ "source_id": source_id,
27
+ "chunk_id": source_id,
28
+ },
29
+ )
30
+ for snip in snippet:
31
+ documents[snip.metadata["chunk_id"]] = snip
32
  return snippets, documents
33
 
34
 
 
41
  "header": data.metadata["title"],
42
  "source_url": url,
43
  "source_type": "web",
44
+ "chunk_id": source_id,
45
  "source_id": source_id,
46
  },
47
  )
 
65
  "header": " ".join(snip[1]["header_text"].split()[:10]),
66
  "source_url": url,
67
  "source_type": "pdf",
68
+ "chunk_id": f"{source_id}_{i:02d}",
69
  "source_id": source_id,
70
  },
71
  )