Ritvik19 commited on
Commit
827774a
·
verified ·
1 Parent(s): a6f365e
Files changed (2) hide show
  1. app.py +43 -18
  2. chains.py +107 -0
app.py CHANGED
@@ -6,22 +6,34 @@ from process_documents import process_documents
6
  from embed_documents import create_retriever
7
  import json
8
  from langchain.callbacks import get_openai_callback
9
- from langchain.chains import ConversationalRetrievalChain
10
  from langchain_openai import ChatOpenAI
11
  import base64
 
 
12
 
13
  st.set_page_config(layout="wide")
14
  os.environ["OPENAI_API_KEY"] = "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"
15
 
16
- get_references = lambda relevant_docs: " ".join(
17
- [f"[{ref}]" for ref in sorted([ref.metadata["chunk_id"] for ref in relevant_docs])]
18
  )
19
- session_state_2_llm_chat_history = lambda session_state: [
20
- ss[:2] for ss in session_state if not ss[0].startswith("/")
21
- ]
 
 
 
 
 
 
 
 
22
  ai_message_format = lambda message, references: (
23
- f"{message}\n\n---\n\n{references}" if references != "" else message
 
 
24
  )
 
25
  welcome_message = """
26
  Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you.
27
 
@@ -108,26 +120,39 @@ def download_conversation_wrapper(inputs=None):
108
 
109
  def query_llm_wrapper(inputs):
110
  retriever = st.session_state.retriever
111
- qa_chain = ConversationalRetrievalChain.from_llm(
112
- llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0),
113
- retriever=retriever,
114
- return_source_documents=True,
115
- chain_type="stuff",
116
  )
117
  relevant_docs = retriever.get_relevant_documents(inputs)
118
  with get_openai_callback() as cb:
119
- result = qa_chain(
120
  {
121
  "question": inputs,
122
  "chat_history": session_state_2_llm_chat_history(
123
  st.session_state.messages
124
  ),
125
  }
126
- )
127
  stats = cb
128
- result = result["answer"]
129
- references = get_references(relevant_docs)
130
- st.session_state.messages.append((inputs, result, references))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  st.session_state.costing.append(
132
  {
133
  "prompt tokens": stats.prompt_tokens,
@@ -135,7 +160,7 @@ def query_llm_wrapper(inputs):
135
  "cost": stats.total_cost,
136
  }
137
  )
138
- return result, references
139
 
140
 
141
  def boot(command_center):
 
6
  from embed_documents import create_retriever
7
  import json
8
  from langchain.callbacks import get_openai_callback
 
9
  from langchain_openai import ChatOpenAI
10
  import base64
11
+ from chains import rag_chain, parse_model_response
12
+ from langchain_core.messages import AIMessage, HumanMessage
13
 
14
  st.set_page_config(layout="wide")
15
  os.environ["OPENAI_API_KEY"] = "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"
16
 
17
+ format_citations = lambda citations: "\n\n".join(
18
+ [f"{citation['quote']} ... [{citation['source_id']}]" for citation in citations]
19
  )
20
+
21
+
22
+ def session_state_2_llm_chat_history(session_state):
23
+ chat_history = []
24
+ for ss in session_state:
25
+ if not ss[0].startswith("/"):
26
+ chat_history.append(HumanMessage(content=ss[0]))
27
+ chat_history.append(AIMessage(content=ss[1]))
28
+ return chat_history
29
+
30
+
31
  ai_message_format = lambda message, references: (
32
+ f"{message}\n\n---\n\n{format_citations(references)}"
33
+ if references != ""
34
+ else message
35
  )
36
+
37
  welcome_message = """
38
  Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you.
39
 
 
120
 
121
  def query_llm_wrapper(inputs):
122
  retriever = st.session_state.retriever
123
+ qa_chain = rag_chain(
124
+ retriever, ChatOpenAI(model="gpt-4-0125-preview", temperature=0)
 
 
 
125
  )
126
  relevant_docs = retriever.get_relevant_documents(inputs)
127
  with get_openai_callback() as cb:
128
+ response = qa_chain.invoke(
129
  {
130
  "question": inputs,
131
  "chat_history": session_state_2_llm_chat_history(
132
  st.session_state.messages
133
  ),
134
  }
135
+ ).content
136
  stats = cb
137
+ response = parse_model_response(response)
138
+ answer = response["answer"]
139
+ citations = response["citations"]
140
+ citations.append(
141
+ {
142
+ "source_id": " ".join(
143
+ [
144
+ f"[{ref}]"
145
+ for ref in sorted(
146
+ [ref.metadata["chunk_id"] for ref in relevant_docs],
147
+ key=lambda x: int(x.split("_")[1]),
148
+ )
149
+ ]
150
+ ),
151
+ "quote": "other sources",
152
+ }
153
+ )
154
+
155
+ st.session_state.messages.append((inputs, answer, citations))
156
  st.session_state.costing.append(
157
  {
158
  "prompt tokens": stats.prompt_tokens,
 
160
  "cost": stats.total_cost,
161
  }
162
  )
163
+ return answer, citations
164
 
165
 
166
  def boot(command_center):
chains.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
2
+ from langchain_core.output_parsers import StrOutputParser
3
+ from langchain_core.runnables import RunnablePassthrough
4
+ import xml.etree.ElementTree as ET
5
+ import re
6
+
7
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
8
+ which might reference context in the chat history, formulate a standalone question \
9
+ which can be understood without the chat history. Do NOT answer the question, \
10
+ just reformulate it if needed and otherwise return it as is."""
11
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
12
+ [
13
+ ("system", contextualize_q_system_prompt),
14
+ MessagesPlaceholder(variable_name="chat_history"),
15
+ ("human", "{question}"),
16
+ ]
17
+ )
18
+ contextualize_q_chain = lambda llm: contextualize_q_prompt | llm | StrOutputParser()
19
+
20
+ qa_system_prompt = """As Zeta, your mission is to assist users in navigating the vast sea of machine learning research with ease and insight. When responding to inquiries, adhere to the following guidelines to ensure the utmost accuracy and utility:
21
+
22
+ Contextual Understanding: When presented with a question, apply your understanding of machine learning concepts to interpret the context provided accurately. Utilize this context to guide your search for answers within the specified research papers.
23
+
24
+ Answer Provision: Always provide an answer that is directly supported by the research papers' content. If the information needed to answer the question is not available, clearly state, "I don't know."
25
+
26
+ Citation Requirement: For every answer given, include multiple citations from the research papers. A citation must include a direct quote from the paper that supports your answer, along with the identification (ID) of the paper. This ensures that all provided information can be traced back to its source, maintaining a high level of credibility and transparency.
27
+
28
+ Formatting Guidelines: Present your citations in the following structured format at the end of your answer to maintain clarity and consistency:
29
+
30
+
31
+ <citations>
32
+ <citation><source_id>[Source ID]</source_id><quote>[Direct quote from the source]</quote></citation>
33
+ ...
34
+ </citations>
35
+
36
+
37
+ Conflict Resolution: In cases where multiple sources offer conflicting information, evaluate the context, relevance, and credibility of each source to determine the most accurate answer. Explain your reasoning within the citation section to provide insight into your decision-making process.
38
+
39
+ User Engagement: Encourage user engagement by asking clarifying questions if the initial inquiry is ambiguous or lacks specific context. This helps in providing more targeted and relevant responses.
40
+
41
+ Continual Learning: Although you are not expected to generate new text or insights beyond the provided papers, be open to learning from new information as it becomes available to you through user interactions and queries.
42
+
43
+ By following these guidelines, you ensure that users receive valuable, accurate, and source-backed insights into their inquiries, making their exploration of machine learning research more productive and enlightening.
44
+
45
+ {context}"""
46
+ qa_prompt = ChatPromptTemplate.from_messages(
47
+ [
48
+ ("system", qa_system_prompt),
49
+ MessagesPlaceholder(variable_name="chat_history"),
50
+ ("human", "{question}"),
51
+ ]
52
+ )
53
+
54
+
55
+ def format_docs(docs):
56
+ return "\n\n".join(
57
+ f"{doc.metadata['chunk_id']}: {doc.page_content}" for doc in docs
58
+ )
59
+
60
+
61
+ def contextualized_question(input: dict):
62
+ if input.get("chat_history"):
63
+ return contextualize_q_chain
64
+ else:
65
+ return input["question"]
66
+
67
+
68
+ rag_chain = lambda retriever, llm: (
69
+ RunnablePassthrough.assign(
70
+ context=contextualized_question | retriever | format_docs
71
+ )
72
+ | qa_prompt
73
+ | llm
74
+ )
75
+
76
+
77
+ def parse_model_response(input_string):
78
+ parsed_data = {"answer": "", "citations": []}
79
+ xml_matches = re.findall(r"<citations>.*?</citations>", input_string, re.DOTALL)
80
+ if not xml_matches:
81
+ parsed_data["answer"] = input_string
82
+ return parsed_data
83
+
84
+ outside_text_parts = []
85
+ last_end_pos = 0
86
+
87
+ for xml_string in xml_matches:
88
+ match = re.search(re.escape(xml_string), input_string[last_end_pos:], re.DOTALL)
89
+
90
+ if match:
91
+ outside_text_parts.append(
92
+ input_string[last_end_pos : match.start() + last_end_pos]
93
+ )
94
+ last_end_pos += match.end()
95
+
96
+ root = ET.fromstring(xml_string)
97
+
98
+ for citation in root.findall("citation"):
99
+ source_id = citation.find("source_id").text
100
+ quote = citation.find("quote").text
101
+ parsed_data["citations"].append({"source_id": source_id, "quote": quote})
102
+
103
+ outside_text_parts.append(input_string[last_end_pos:])
104
+
105
+ parsed_data["answer"] = "".join(outside_text_parts)
106
+
107
+ return parsed_data