Fixed bugs with multi LLMs
Browse files- app.py +7 -57
- climateqa/engine/rag.py +8 -17
- climateqa/engine/utils.py +23 -6
app.py
CHANGED
@@ -146,88 +146,38 @@ async def chat(query,history,audience,sources,reports):
|
|
146 |
if len(reports) == 0:
|
147 |
reports = []
|
148 |
|
149 |
-
|
150 |
retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
|
151 |
rag_chain = make_rag_chain(retriever,llm)
|
152 |
-
|
153 |
-
# gradio_format = make_pairs([a.content for a in history]) + [(query, "")]
|
154 |
-
# history = history + [(query,"")]
|
155 |
-
# print(history)
|
156 |
-
# print(gradio_format)
|
157 |
-
|
158 |
-
# # reset memory
|
159 |
-
# memory.clear()
|
160 |
-
# for message in history:
|
161 |
-
# memory.chat_memory.add_message(message)
|
162 |
|
163 |
inputs = {"query": query,"audience": audience_prompt}
|
164 |
result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
|
165 |
# result = rag_chain.stream(inputs)
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
final_output_path_id = "/streamed_output/-"
|
171 |
|
172 |
docs_html = ""
|
173 |
output_query = ""
|
174 |
output_language = ""
|
175 |
gallery = []
|
176 |
|
177 |
-
# for output in result:
|
178 |
-
|
179 |
-
# if "language" in output:
|
180 |
-
# output_language = output["language"]
|
181 |
-
# if "question" in output:
|
182 |
-
# output_query = output["question"]
|
183 |
-
# if "docs" in output:
|
184 |
-
|
185 |
-
# try:
|
186 |
-
# docs = output['docs'] # List[Document]
|
187 |
-
# docs_html = []
|
188 |
-
# for i, d in enumerate(docs, 1):
|
189 |
-
# docs_html.append(make_html_source(d, i))
|
190 |
-
# docs_html = "".join(docs_html)
|
191 |
-
# except TypeError:
|
192 |
-
# print("No documents found")
|
193 |
-
# continue
|
194 |
-
|
195 |
-
# if "answer" in output:
|
196 |
-
# new_token = output["answer"] # str
|
197 |
-
# time.sleep(0.03)
|
198 |
-
# answer_yet = history[-1][1] + new_token
|
199 |
-
# answer_yet = parse_output_llm_with_sources(answer_yet)
|
200 |
-
# history[-1] = (query,answer_yet)
|
201 |
-
|
202 |
-
# yield history,docs_html,output_query,output_language,gallery
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
# async def fallback_iterator(iterable):
|
207 |
-
# async for item in iterable:
|
208 |
-
# try:
|
209 |
-
# yield item
|
210 |
-
# except Exception as e:
|
211 |
-
# print(f"Error in fallback iterator: {e}")
|
212 |
-
# raise gr.Error(f"ClimateQ&A Error: {e}\nThe error has been noted, try another question and if the error remains, you can contact us :)")
|
213 |
-
|
214 |
try:
|
215 |
async for op in result:
|
216 |
|
217 |
-
|
218 |
op = op.ops[0]
|
219 |
# print("ITERATION",op)
|
220 |
|
221 |
-
if op['path'] ==
|
222 |
try:
|
223 |
output_language = op['value']["language"] # str
|
224 |
output_query = op["value"]["question"]
|
225 |
except Exception as e:
|
226 |
raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
227 |
|
228 |
-
elif op['path'] ==
|
229 |
try:
|
230 |
-
docs = op['value']['
|
231 |
docs_html = []
|
232 |
for i, d in enumerate(docs, 1):
|
233 |
docs_html.append(make_html_source(d, i))
|
@@ -237,7 +187,7 @@ async def chat(query,history,audience,sources,reports):
|
|
237 |
print("op: ",op)
|
238 |
continue
|
239 |
|
240 |
-
elif op['path'] ==
|
241 |
new_token = op['value'] # str
|
242 |
time.sleep(0.01)
|
243 |
answer_yet = history[-1][1] + new_token
|
|
|
146 |
if len(reports) == 0:
|
147 |
reports = []
|
148 |
|
|
|
149 |
retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
|
150 |
rag_chain = make_rag_chain(retriever,llm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
inputs = {"query": query,"audience": audience_prompt}
|
153 |
result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
|
154 |
# result = rag_chain.stream(inputs)
|
155 |
|
156 |
+
path_reformulation = "/logs/reformulation/final_output"
|
157 |
+
path_retriever = "/logs/find_documents/final_output"
|
158 |
+
path_answer = "/logs/answer/streamed_output_str/-"
|
|
|
159 |
|
160 |
docs_html = ""
|
161 |
output_query = ""
|
162 |
output_language = ""
|
163 |
gallery = []
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
try:
|
166 |
async for op in result:
|
167 |
|
|
|
168 |
op = op.ops[0]
|
169 |
# print("ITERATION",op)
|
170 |
|
171 |
+
if op['path'] == path_reformulation: # reforulated question
|
172 |
try:
|
173 |
output_language = op['value']["language"] # str
|
174 |
output_query = op["value"]["question"]
|
175 |
except Exception as e:
|
176 |
raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
177 |
|
178 |
+
elif op['path'] == path_retriever: # documents
|
179 |
try:
|
180 |
+
docs = op['value']['docs'] # List[Document]
|
181 |
docs_html = []
|
182 |
for i, d in enumerate(docs, 1):
|
183 |
docs_html.append(make_html_source(d, i))
|
|
|
187 |
print("op: ",op)
|
188 |
continue
|
189 |
|
190 |
+
elif op['path'] == path_answer: # final answer
|
191 |
new_token = op['value'] # str
|
192 |
time.sleep(0.01)
|
193 |
answer_yet = history[-1][1] + new_token
|
climateqa/engine/rag.py
CHANGED
@@ -8,8 +8,7 @@ from langchain_core.prompts.base import format_document
|
|
8 |
|
9 |
from climateqa.engine.reformulation import make_reformulation_chain
|
10 |
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
11 |
-
from climateqa.engine.utils import pass_values, flatten_dict
|
12 |
-
|
13 |
|
14 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
15 |
|
@@ -44,21 +43,13 @@ def make_rag_chain(retriever,llm):
|
|
44 |
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
|
45 |
|
46 |
# ------- CHAIN 0 - Reformulation
|
47 |
-
|
48 |
-
reformulation = (
|
49 |
-
{"reformulation":reformulation_chain,**pass_values(["audience","query"])}
|
50 |
-
| RunnablePassthrough()
|
51 |
-
| flatten_dict
|
52 |
-
)
|
53 |
-
|
54 |
|
55 |
# ------- CHAIN 1
|
56 |
# Retrieved documents
|
57 |
-
find_documents =
|
58 |
-
|
59 |
-
**pass_values(["question","audience","language","query"])
|
60 |
-
} | RunnablePassthrough()
|
61 |
-
|
62 |
|
63 |
# ------- CHAIN 2
|
64 |
# Construct inputs for the llm
|
@@ -69,15 +60,15 @@ def make_rag_chain(retriever,llm):
|
|
69 |
|
70 |
# ------- CHAIN 3
|
71 |
# Bot answer
|
72 |
-
|
73 |
|
74 |
answer_with_docs = {
|
75 |
-
"answer": input_documents | prompt |
|
76 |
**pass_values(["question","audience","language","query","docs"]),
|
77 |
}
|
78 |
|
79 |
answer_without_docs = {
|
80 |
-
"answer": prompt_without_docs |
|
81 |
**pass_values(["question","audience","language","query","docs"]),
|
82 |
}
|
83 |
|
|
|
8 |
|
9 |
from climateqa.engine.reformulation import make_reformulation_chain
|
10 |
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
11 |
+
from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
|
|
|
12 |
|
13 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
14 |
|
|
|
43 |
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
|
44 |
|
45 |
# ------- CHAIN 0 - Reformulation
|
46 |
+
reformulation = make_reformulation_chain(llm)
|
47 |
+
reformulation = prepare_chain(reformulation,"reformulation")
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
# ------- CHAIN 1
|
50 |
# Retrieved documents
|
51 |
+
find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
|
52 |
+
find_documents = prepare_chain(find_documents,"find_documents")
|
|
|
|
|
|
|
53 |
|
54 |
# ------- CHAIN 2
|
55 |
# Construct inputs for the llm
|
|
|
60 |
|
61 |
# ------- CHAIN 3
|
62 |
# Bot answer
|
63 |
+
llm_final = rename_chain(llm,"answer")
|
64 |
|
65 |
answer_with_docs = {
|
66 |
+
"answer": input_documents | prompt | llm_final | StrOutputParser(),
|
67 |
**pass_values(["question","audience","language","query","docs"]),
|
68 |
}
|
69 |
|
70 |
answer_without_docs = {
|
71 |
+
"answer": prompt_without_docs | llm_final | StrOutputParser(),
|
72 |
**pass_values(["question","audience","language","query","docs"]),
|
73 |
}
|
74 |
|
climateqa/engine/utils.py
CHANGED
@@ -1,10 +1,29 @@
|
|
1 |
-
|
2 |
-
from typing import Any, Dict, Iterable, Tuple, Union
|
3 |
from operator import itemgetter
|
|
|
|
|
|
|
4 |
|
5 |
def pass_values(x):
|
6 |
-
if not isinstance(x,list):
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
# Drawn from langchain utils and modified to remove the parent key
|
@@ -48,5 +67,3 @@ def flatten_dict(
|
|
48 |
"""
|
49 |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
50 |
return flat_dict
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
1 |
from operator import itemgetter
|
2 |
+
from typing import Any, Dict, Iterable, Tuple
|
3 |
+
from langchain_core.runnables import RunnablePassthrough
|
4 |
+
|
5 |
|
6 |
def pass_values(x):
|
7 |
+
if not isinstance(x, list):
|
8 |
+
x = [x]
|
9 |
+
return {k: itemgetter(k) for k in x}
|
10 |
+
|
11 |
+
|
12 |
+
def prepare_chain(chain,name):
|
13 |
+
chain = propagate_inputs(chain)
|
14 |
+
chain = rename_chain(chain,name)
|
15 |
+
return chain
|
16 |
+
|
17 |
+
|
18 |
+
def propagate_inputs(chain):
|
19 |
+
chain_with_values = {
|
20 |
+
"outputs": chain,
|
21 |
+
"inputs": RunnablePassthrough()
|
22 |
+
} | RunnablePassthrough() | flatten_dict
|
23 |
+
return chain_with_values
|
24 |
+
|
25 |
+
def rename_chain(chain,name):
|
26 |
+
return chain.with_config({"run_name":name})
|
27 |
|
28 |
|
29 |
# Drawn from langchain utils and modified to remove the parent key
|
|
|
67 |
"""
|
68 |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
69 |
return flat_dict
|
|
|
|