Merge branch 'bugfix/add_dummy_searchs' into feature/graph_recommandation
Browse files- app.py +2 -2
- climateqa/engine/chains/keywords_extraction.py +40 -0
- climateqa/engine/chains/query_transformation.py +45 -4
- climateqa/engine/chains/retrieve_documents.py +159 -0
- climateqa/engine/graph.py +22 -10
- climateqa/engine/llm/__init__.py +3 -0
- climateqa/engine/llm/ollama.py +6 -0
- climateqa/engine/utils.py +17 -0
- climateqa/knowledge/__init__.py +0 -0
- climateqa/{papers → knowledge}/openalex.py +61 -12
- climateqa/{engine → knowledge}/retriever.py +1 -83
- climateqa/papers/__init__.py +0 -43
- requirements.txt +3 -1
- sandbox/20240310 - CQA - Semantic Routing 1.ipynb +0 -0
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from climateqa.engine.embeddings import get_embeddings_function
|
2 |
embeddings_function = get_embeddings_function()
|
3 |
|
4 |
-
from climateqa.
|
5 |
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
|
@@ -31,7 +31,7 @@ from collections import defaultdict
|
|
31 |
# ClimateQ&A imports
|
32 |
from climateqa.engine.llm import get_llm
|
33 |
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
34 |
-
from climateqa.
|
35 |
from climateqa.engine.reranker import get_reranker
|
36 |
from climateqa.engine.embeddings import get_embeddings_function
|
37 |
from climateqa.engine.chains.prompts import audience_prompts
|
|
|
1 |
from climateqa.engine.embeddings import get_embeddings_function
|
2 |
embeddings_function = get_embeddings_function()
|
3 |
|
4 |
+
from climateqa.knowledge.openalex import OpenAlex
|
5 |
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
|
|
|
31 |
# ClimateQ&A imports
|
32 |
from climateqa.engine.llm import get_llm
|
33 |
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
34 |
+
from climateqa.knowledge.retriever import ClimateQARetriever
|
35 |
from climateqa.engine.reranker import get_reranker
|
36 |
from climateqa.engine.embeddings import get_embeddings_function
|
37 |
from climateqa.engine.chains.prompts import audience_prompts
|
climateqa/engine/chains/keywords_extraction.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
+
from typing import List
|
4 |
+
from typing import Literal
|
5 |
+
from langchain.prompts import ChatPromptTemplate
|
6 |
+
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
+
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
+
|
9 |
+
|
10 |
+
class KeywordExtraction(BaseModel):
|
11 |
+
"""
|
12 |
+
Analyzing the user query to extract keywords to feed a search engine
|
13 |
+
"""
|
14 |
+
|
15 |
+
keywords: List[str] = Field(
|
16 |
+
description="""
|
17 |
+
Extract the keywords from the user query to feed a search engine as a list
|
18 |
+
Avoid adding super specific keywords to prefer general keywords
|
19 |
+
Maximum 3 keywords
|
20 |
+
|
21 |
+
Examples:
|
22 |
+
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
23 |
+
- "How will El Nino be impacted by climate change" -> ["el nino","climate change"]
|
24 |
+
- "Is climate change a hoax" -> ["climate change","hoax"]
|
25 |
+
"""
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def make_keywords_extraction_chain(llm):
|
30 |
+
|
31 |
+
openai_functions = [convert_to_openai_function(KeywordExtraction)]
|
32 |
+
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"KeywordExtraction"})
|
33 |
+
|
34 |
+
prompt = ChatPromptTemplate.from_messages([
|
35 |
+
("system", "You are a helpful assistant"),
|
36 |
+
("user", "input: {input}")
|
37 |
+
])
|
38 |
+
|
39 |
+
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
40 |
+
return chain
|
climateqa/engine/chains/query_transformation.py
CHANGED
@@ -8,6 +8,13 @@ from langchain_core.utils.function_calling import convert_to_openai_function
|
|
8 |
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# Prompt from the original paper https://arxiv.org/pdf/2305.14283
|
12 |
# Query Rewriting for Retrieval-Augmented Large Language Models
|
13 |
class QueryDecomposition(BaseModel):
|
@@ -20,8 +27,8 @@ class QueryDecomposition(BaseModel):
|
|
20 |
description="""
|
21 |
Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need.
|
22 |
Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
|
23 |
-
- If it's already a standalone
|
24 |
-
- If you need to decompose the question, output a list of maximum 3 questions
|
25 |
"""
|
26 |
)
|
27 |
|
@@ -125,12 +132,20 @@ def make_query_rewriter_chain(llm):
|
|
125 |
return chain
|
126 |
|
127 |
|
128 |
-
def make_query_transform_node(llm):
|
129 |
|
130 |
decomposition_chain = make_query_decomposition_chain(llm)
|
131 |
rewriter_chain = make_query_rewriter_chain(llm)
|
132 |
|
133 |
def transform_query(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
new_state = {}
|
136 |
|
@@ -150,7 +165,33 @@ def make_query_transform_node(llm):
|
|
150 |
|
151 |
question_state.update(analysis_output)
|
152 |
questions.append(question_state)
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
return new_state
|
156 |
|
|
|
8 |
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
9 |
|
10 |
|
11 |
+
ROUTING_INDEX = {
|
12 |
+
"Vector":["IPCC","IPBES","IPOS"],
|
13 |
+
"OpenAlex":["OpenAlex"],
|
14 |
+
}
|
15 |
+
|
16 |
+
POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values]
|
17 |
+
|
18 |
# Prompt from the original paper https://arxiv.org/pdf/2305.14283
|
19 |
# Query Rewriting for Retrieval-Augmented Large Language Models
|
20 |
class QueryDecomposition(BaseModel):
|
|
|
27 |
description="""
|
28 |
Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need.
|
29 |
Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
|
30 |
+
- If it's already a standalone and explicit question, just return the reformulated question for the search engine
|
31 |
+
- If you need to decompose the question, output a list of maximum 2 to 3 questions
|
32 |
"""
|
33 |
)
|
34 |
|
|
|
132 |
return chain
|
133 |
|
134 |
|
135 |
+
def make_query_transform_node(llm,k_final=15):
|
136 |
|
137 |
decomposition_chain = make_query_decomposition_chain(llm)
|
138 |
rewriter_chain = make_query_rewriter_chain(llm)
|
139 |
|
140 |
def transform_query(state):
|
141 |
+
|
142 |
+
if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
|
143 |
+
auto_mode = False
|
144 |
+
else:
|
145 |
+
auto_mode = True
|
146 |
+
|
147 |
+
sources_input = state.get("sources_input")
|
148 |
+
if sources_input is None: sources_input = ROUTING_INDEX["Vector"]
|
149 |
|
150 |
new_state = {}
|
151 |
|
|
|
165 |
|
166 |
question_state.update(analysis_output)
|
167 |
questions.append(question_state)
|
168 |
+
|
169 |
+
# Explode the questions into multiple questions with different sources
|
170 |
+
new_questions = []
|
171 |
+
for q in questions:
|
172 |
+
question,sources = q["question"],q["sources"]
|
173 |
+
|
174 |
+
# If not auto mode we take the configuration
|
175 |
+
if not auto_mode:
|
176 |
+
sources = sources_input
|
177 |
+
|
178 |
+
for index,index_sources in ROUTING_INDEX.items():
|
179 |
+
selected_sources = list(set(sources).intersection(index_sources))
|
180 |
+
if len(selected_sources) > 0:
|
181 |
+
new_questions.append({"question":question,"sources":selected_sources,"index":index})
|
182 |
+
|
183 |
+
# # Add the number of questions to search
|
184 |
+
# k_by_question = k_final // len(new_questions)
|
185 |
+
# for q in new_questions:
|
186 |
+
# q["k"] = k_by_question
|
187 |
+
|
188 |
+
# new_state["questions"] = new_questions
|
189 |
+
# new_state["remaining_questions"] = new_questions
|
190 |
+
|
191 |
+
new_state = {
|
192 |
+
"remaining_questions":new_questions,
|
193 |
+
"n_questions":len(new_questions),
|
194 |
+
}
|
195 |
|
196 |
return new_state
|
197 |
|
climateqa/engine/chains/retrieve_documents.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
from contextlib import contextmanager
|
4 |
+
|
5 |
+
from langchain_core.tools import tool
|
6 |
+
from langchain_core.runnables import chain
|
7 |
+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
8 |
+
from langchain_core.runnables import RunnableLambda
|
9 |
+
|
10 |
+
from ..reranker import rerank_docs
|
11 |
+
from ...knowledge.retriever import ClimateQARetriever
|
12 |
+
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
+
from .keywords_extraction import make_keywords_extraction_chain
|
14 |
+
from ..utils import log_event
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
def divide_into_parts(target, parts):
|
19 |
+
# Base value for each part
|
20 |
+
base = target // parts
|
21 |
+
# Remainder to distribute
|
22 |
+
remainder = target % parts
|
23 |
+
# List to hold the result
|
24 |
+
result = []
|
25 |
+
|
26 |
+
for i in range(parts):
|
27 |
+
if i < remainder:
|
28 |
+
# These parts get base value + 1
|
29 |
+
result.append(base + 1)
|
30 |
+
else:
|
31 |
+
# The rest get the base value
|
32 |
+
result.append(base)
|
33 |
+
|
34 |
+
return result
|
35 |
+
|
36 |
+
|
37 |
+
@contextmanager
|
38 |
+
def suppress_output():
|
39 |
+
# Open a null device
|
40 |
+
with open(os.devnull, 'w') as devnull:
|
41 |
+
# Store the original stdout and stderr
|
42 |
+
old_stdout = sys.stdout
|
43 |
+
old_stderr = sys.stderr
|
44 |
+
# Redirect stdout and stderr to the null device
|
45 |
+
sys.stdout = devnull
|
46 |
+
sys.stderr = devnull
|
47 |
+
try:
|
48 |
+
yield
|
49 |
+
finally:
|
50 |
+
# Restore stdout and stderr
|
51 |
+
sys.stdout = old_stdout
|
52 |
+
sys.stderr = old_stderr
|
53 |
+
|
54 |
+
|
55 |
+
@tool
|
56 |
+
def query_retriever(question):
|
57 |
+
"""Just a dummy tool to simulate the retriever query"""
|
58 |
+
return question
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
67 |
+
|
68 |
+
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
69 |
+
@chain
|
70 |
+
async def retrieve_documents(state,config):
|
71 |
+
|
72 |
+
keywords_extraction = make_keywords_extraction_chain(llm)
|
73 |
+
|
74 |
+
current_question = state["remaining_questions"][0]
|
75 |
+
remaining_questions = state["remaining_questions"][1:]
|
76 |
+
|
77 |
+
# ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")
|
78 |
+
|
79 |
+
|
80 |
+
# # There are several options to get the final top k
|
81 |
+
# # Option 1 - Get 100 documents by question and rerank by question
|
82 |
+
# # Option 2 - Get 100/n documents by question and rerank the total
|
83 |
+
# if rerank_by_question:
|
84 |
+
# k_by_question = divide_into_parts(k_final,len(questions))
|
85 |
+
|
86 |
+
# docs = state["documents"]
|
87 |
+
# if docs is None: docs = []
|
88 |
+
|
89 |
+
docs = []
|
90 |
+
k_by_question = k_final // state["n_questions"]
|
91 |
+
|
92 |
+
sources = current_question["sources"]
|
93 |
+
question = current_question["question"]
|
94 |
+
index = current_question["index"]
|
95 |
+
|
96 |
+
|
97 |
+
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
98 |
+
|
99 |
+
|
100 |
+
if index == "Vector":
|
101 |
+
|
102 |
+
# Search the document store using the retriever
|
103 |
+
# Configure high top k for further reranking step
|
104 |
+
retriever = ClimateQARetriever(
|
105 |
+
vectorstore=vectorstore,
|
106 |
+
sources = sources,
|
107 |
+
min_size = 200,
|
108 |
+
k_summary = k_summary,
|
109 |
+
k_total = k_before_reranking,
|
110 |
+
threshold = 0.5,
|
111 |
+
)
|
112 |
+
docs_question = await retriever.ainvoke(question,config)
|
113 |
+
|
114 |
+
elif index == "OpenAlex":
|
115 |
+
|
116 |
+
keywords = keywords_extraction.invoke(question)["keywords"]
|
117 |
+
openalex_query = " AND ".join(keywords)
|
118 |
+
|
119 |
+
print(f"... OpenAlex query: {openalex_query}")
|
120 |
+
|
121 |
+
retriever_openalex = OpenAlexRetriever(
|
122 |
+
min_year = state.get("min_year",1960),
|
123 |
+
max_year = state.get("max_year",None),
|
124 |
+
k = k_before_reranking
|
125 |
+
)
|
126 |
+
docs_question = await retriever_openalex.ainvoke(openalex_query,config)
|
127 |
+
|
128 |
+
else:
|
129 |
+
raise Exception(f"Index {index} not found in the routing index")
|
130 |
+
|
131 |
+
# Rerank
|
132 |
+
if reranker is not None:
|
133 |
+
with suppress_output():
|
134 |
+
docs_question = rerank_docs(reranker,docs_question,question)
|
135 |
+
else:
|
136 |
+
# Add a default reranking score
|
137 |
+
for doc in docs_question:
|
138 |
+
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
139 |
+
|
140 |
+
# If rerank by question we select the top documents for each question
|
141 |
+
if rerank_by_question:
|
142 |
+
docs_question = docs_question[:k_by_question]
|
143 |
+
|
144 |
+
# Add sources used in the metadata
|
145 |
+
for doc in docs_question:
|
146 |
+
doc.metadata["sources_used"] = sources
|
147 |
+
doc.metadata["question_used"] = question
|
148 |
+
doc.metadata["index_used"] = index
|
149 |
+
|
150 |
+
# Add to the list of docs
|
151 |
+
docs.extend(docs_question)
|
152 |
+
|
153 |
+
# Sorting the list in descending order by rerank_score
|
154 |
+
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
155 |
+
new_state = {"documents":docs,"remaining_questions":remaining_questions}
|
156 |
+
return new_state
|
157 |
+
|
158 |
+
return retrieve_documents
|
159 |
+
|
climateqa/engine/graph.py
CHANGED
@@ -16,10 +16,9 @@ from .chains.answer_ai_impact import make_ai_impact_node
|
|
16 |
from .chains.query_transformation import make_query_transform_node
|
17 |
from .chains.translation import make_translation_node
|
18 |
from .chains.intent_categorization import make_intent_categorization_node
|
19 |
-
from .chains.
|
20 |
from .chains.answer_rag import make_rag_node
|
21 |
|
22 |
-
|
23 |
class GraphState(TypedDict):
|
24 |
"""
|
25 |
Represents the state of our graph.
|
@@ -29,23 +28,30 @@ class GraphState(TypedDict):
|
|
29 |
intent : str
|
30 |
search_graphs_chitchat : bool
|
31 |
query: str
|
32 |
-
|
|
|
33 |
answer: str
|
34 |
audience: str = "experts"
|
35 |
-
sources_input: List[str] = ["
|
|
|
|
|
|
|
36 |
documents: List[Document]
|
37 |
recommended_content : List[Document]
|
38 |
# graphs_returned: Dict[str,str]
|
39 |
|
40 |
-
def search(state):
|
41 |
-
return
|
|
|
|
|
|
|
42 |
|
43 |
def route_intent(state):
|
44 |
intent = state["intent"]
|
45 |
if intent in ["chitchat","esg"]:
|
46 |
return "answer_chitchat"
|
47 |
-
elif intent == "ai_impact":
|
48 |
-
|
49 |
else:
|
50 |
# Search route
|
51 |
return "search"
|
@@ -95,6 +101,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
95 |
workflow.add_node("set_defaults", set_defaults)
|
96 |
workflow.add_node("categorize_intent", categorize_intent)
|
97 |
workflow.add_node("search", search)
|
|
|
98 |
workflow.add_node("transform_query", transform_query)
|
99 |
workflow.add_node("translate_query", translate_query)
|
100 |
# workflow.add_node("transform_query_ai", transform_query)
|
@@ -118,7 +125,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
118 |
workflow.add_conditional_edges(
|
119 |
"categorize_intent",
|
120 |
route_intent,
|
121 |
-
make_id_dict(["answer_chitchat","
|
122 |
)
|
123 |
|
124 |
workflow.add_conditional_edges(
|
@@ -132,9 +139,14 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
132 |
route_translation,
|
133 |
make_id_dict(["translate_query","transform_query"])
|
134 |
)
|
135 |
-
|
136 |
workflow.add_conditional_edges(
|
137 |
"retrieve_documents",
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
139 |
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
140 |
)
|
|
|
16 |
from .chains.query_transformation import make_query_transform_node
|
17 |
from .chains.translation import make_translation_node
|
18 |
from .chains.intent_categorization import make_intent_categorization_node
|
19 |
+
from .chains.retrieve_documents import make_retriever_node
|
20 |
from .chains.answer_rag import make_rag_node
|
21 |
|
|
|
22 |
class GraphState(TypedDict):
|
23 |
"""
|
24 |
Represents the state of our graph.
|
|
|
28 |
intent : str
|
29 |
search_graphs_chitchat : bool
|
30 |
query: str
|
31 |
+
remaining_questions : List[dict]
|
32 |
+
n_questions : int
|
33 |
answer: str
|
34 |
audience: str = "experts"
|
35 |
+
sources_input: List[str] = ["IPCC","IPBES"]
|
36 |
+
sources_auto: bool = True
|
37 |
+
min_year: int = 1960
|
38 |
+
max_year: int = None
|
39 |
documents: List[Document]
|
40 |
recommended_content : List[Document]
|
41 |
# graphs_returned: Dict[str,str]
|
42 |
|
43 |
+
def search(state): #TODO
|
44 |
+
return state
|
45 |
+
|
46 |
+
def answer_search(state):#TODO
|
47 |
+
return state
|
48 |
|
49 |
def route_intent(state):
|
50 |
intent = state["intent"]
|
51 |
if intent in ["chitchat","esg"]:
|
52 |
return "answer_chitchat"
|
53 |
+
# elif intent == "ai_impact":
|
54 |
+
# return "answer_ai_impact"
|
55 |
else:
|
56 |
# Search route
|
57 |
return "search"
|
|
|
101 |
workflow.add_node("set_defaults", set_defaults)
|
102 |
workflow.add_node("categorize_intent", categorize_intent)
|
103 |
workflow.add_node("search", search)
|
104 |
+
workflow.add_node("answer_search", answer_search)
|
105 |
workflow.add_node("transform_query", transform_query)
|
106 |
workflow.add_node("translate_query", translate_query)
|
107 |
# workflow.add_node("transform_query_ai", transform_query)
|
|
|
125 |
workflow.add_conditional_edges(
|
126 |
"categorize_intent",
|
127 |
route_intent,
|
128 |
+
make_id_dict(["answer_chitchat","search"])
|
129 |
)
|
130 |
|
131 |
workflow.add_conditional_edges(
|
|
|
139 |
route_translation,
|
140 |
make_id_dict(["translate_query","transform_query"])
|
141 |
)
|
|
|
142 |
workflow.add_conditional_edges(
|
143 |
"retrieve_documents",
|
144 |
+
lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
|
145 |
+
make_id_dict(["retrieve_documents","answer_search"])
|
146 |
+
)
|
147 |
+
|
148 |
+
workflow.add_conditional_edges(
|
149 |
+
"answer_search",
|
150 |
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
151 |
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
152 |
)
|
climateqa/engine/llm/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from climateqa.engine.llm.openai import get_llm as get_openai_llm
|
2 |
from climateqa.engine.llm.azure import get_llm as get_azure_llm
|
|
|
3 |
|
4 |
|
5 |
def get_llm(provider="openai",**kwargs):
|
@@ -8,6 +9,8 @@ def get_llm(provider="openai",**kwargs):
|
|
8 |
return get_openai_llm(**kwargs)
|
9 |
elif provider == "azure":
|
10 |
return get_azure_llm(**kwargs)
|
|
|
|
|
11 |
else:
|
12 |
raise ValueError(f"Unknown provider: {provider}")
|
13 |
|
|
|
1 |
from climateqa.engine.llm.openai import get_llm as get_openai_llm
|
2 |
from climateqa.engine.llm.azure import get_llm as get_azure_llm
|
3 |
+
from climateqa.engine.llm.ollama import get_llm as get_ollama_llm
|
4 |
|
5 |
|
6 |
def get_llm(provider="openai",**kwargs):
|
|
|
9 |
return get_openai_llm(**kwargs)
|
10 |
elif provider == "azure":
|
11 |
return get_azure_llm(**kwargs)
|
12 |
+
elif provider == "ollama":
|
13 |
+
return get_ollama_llm(**kwargs)
|
14 |
else:
|
15 |
raise ValueError(f"Unknown provider: {provider}")
|
16 |
|
climateqa/engine/llm/ollama.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from langchain_community.llms import Ollama
|
4 |
+
|
5 |
+
def get_llm(model="llama3", **kwargs):
|
6 |
+
return Ollama(model=model, **kwargs)
|
climateqa/engine/utils.py
CHANGED
@@ -1,8 +1,15 @@
|
|
1 |
from operator import itemgetter
|
2 |
from typing import Any, Dict, Iterable, Tuple
|
|
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
def pass_values(x):
|
7 |
if not isinstance(x, list):
|
8 |
x = [x]
|
@@ -67,3 +74,13 @@ def flatten_dict(
|
|
67 |
"""
|
68 |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
69 |
return flat_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from operator import itemgetter
|
2 |
from typing import Any, Dict, Iterable, Tuple
|
3 |
+
import tiktoken
|
4 |
from langchain_core.runnables import RunnablePassthrough
|
5 |
|
6 |
|
7 |
+
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
|
8 |
+
encoding = tiktoken.get_encoding(encoding_name)
|
9 |
+
num_tokens = len(encoding.encode(string))
|
10 |
+
return num_tokens
|
11 |
+
|
12 |
+
|
13 |
def pass_values(x):
|
14 |
if not isinstance(x, list):
|
15 |
x = [x]
|
|
|
74 |
"""
|
75 |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
76 |
return flat_dict
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
async def log_event(info,name,config):
|
81 |
+
"""Helper function that will run a dummy chain with the given info
|
82 |
+
The astream_event function will catch this chain and stream the dict info to the logger
|
83 |
+
"""
|
84 |
+
|
85 |
+
chain = RunnablePassthrough().with_config(run_name=name)
|
86 |
+
_ = await chain.ainvoke(info,config)
|
climateqa/knowledge/__init__.py
ADDED
File without changes
|
climateqa/{papers → knowledge}/openalex.py
RENAMED
@@ -3,18 +3,32 @@ import networkx as nx
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from pyvis.network import Network
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
|
7 |
import pyalex
|
8 |
|
9 |
pyalex.config.email = "[email protected]"
|
10 |
|
|
|
|
|
|
|
|
|
11 |
class OpenAlex():
|
12 |
def __init__(self):
|
13 |
pass
|
14 |
|
15 |
|
16 |
-
|
17 |
-
def search(self,keywords,n_results = 100,after = None,before = None):
|
18 |
|
19 |
if isinstance(keywords,str):
|
20 |
works = Works().search(keywords)
|
@@ -27,18 +41,21 @@ class OpenAlex():
|
|
27 |
break
|
28 |
|
29 |
df_works = pd.DataFrame(page)
|
30 |
-
df_works
|
|
|
|
|
31 |
df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
|
32 |
df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
|
33 |
-
df_works["
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
else:
|
36 |
-
|
37 |
-
for keyword in keywords:
|
38 |
-
df_keyword = self.search(keyword,n_results = n_results,after = after,before = before)
|
39 |
-
df_works.append(df_keyword)
|
40 |
-
df_works = pd.concat(df_works,ignore_index=True,axis = 0)
|
41 |
-
return df_works
|
42 |
|
43 |
|
44 |
def rerank(self,query,df,reranker):
|
@@ -139,4 +156,36 @@ class OpenAlex():
|
|
139 |
reconstructed[position] = token
|
140 |
|
141 |
# Join the tokens to form the reconstructed sentence(s)
|
142 |
-
return ' '.join(reconstructed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from pyvis.network import Network
|
5 |
|
6 |
+
from langchain_core.retrievers import BaseRetriever
|
7 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
8 |
+
from langchain_core.documents.base import Document
|
9 |
+
from langchain_core.vectorstores import VectorStore
|
10 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
11 |
+
|
12 |
+
from ..engine.utils import num_tokens_from_string
|
13 |
+
|
14 |
+
from typing import List
|
15 |
+
from pydantic import Field
|
16 |
+
|
17 |
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
|
18 |
import pyalex
|
19 |
|
20 |
pyalex.config.email = "[email protected]"
|
21 |
|
22 |
+
|
23 |
+
def replace_nan_with_empty_dict(x):
|
24 |
+
return x if pd.notna(x) else {}
|
25 |
+
|
26 |
class OpenAlex():
|
27 |
def __init__(self):
|
28 |
pass
|
29 |
|
30 |
|
31 |
+
def search(self,keywords:str,n_results = 100,after = None,before = None):
|
|
|
32 |
|
33 |
if isinstance(keywords,str):
|
34 |
works = Works().search(keywords)
|
|
|
41 |
break
|
42 |
|
43 |
df_works = pd.DataFrame(page)
|
44 |
+
df_works = df_works.dropna(subset = ["title"])
|
45 |
+
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
|
46 |
+
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
|
47 |
df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
|
48 |
df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
|
49 |
+
df_works["url"] = df_works["id"]
|
50 |
+
df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
|
51 |
+
df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
|
52 |
+
|
53 |
+
df_works = df_works.drop(columns = ["abstract_inverted_index"])
|
54 |
+
# df_works["subtitle"] = df_works["title"] + " - " + df_works["primary_location"]["source"]["display_name"] + " - " + df_works["publication_year"]
|
55 |
+
|
56 |
+
return df_works
|
57 |
else:
|
58 |
+
raise Exception("Keywords must be a string")
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
|
61 |
def rerank(self,query,df,reranker):
|
|
|
156 |
reconstructed[position] = token
|
157 |
|
158 |
# Join the tokens to form the reconstructed sentence(s)
|
159 |
+
return ' '.join(reconstructed)
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
class OpenAlexRetriever(BaseRetriever):
|
164 |
+
min_year:int = 1960
|
165 |
+
max_year:int = None
|
166 |
+
k:int = 100
|
167 |
+
|
168 |
+
def _get_relevant_documents(
|
169 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
170 |
+
) -> List[Document]:
|
171 |
+
|
172 |
+
openalex = OpenAlex()
|
173 |
+
|
174 |
+
# Search for documents
|
175 |
+
df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)
|
176 |
+
|
177 |
+
docs = []
|
178 |
+
for i,row in df_docs.iterrows():
|
179 |
+
num_tokens = row["num_tokens"]
|
180 |
+
|
181 |
+
if num_tokens < 50 or num_tokens > 1000:
|
182 |
+
continue
|
183 |
+
|
184 |
+
doc = Document(
|
185 |
+
page_content = row["content"],
|
186 |
+
metadata = row.to_dict()
|
187 |
+
)
|
188 |
+
docs.append(doc)
|
189 |
+
return docs
|
190 |
+
|
191 |
+
|
climateqa/{engine → knowledge}/retriever.py
RENAMED
@@ -67,6 +67,7 @@ class ClimateQARetriever(BaseRetriever):
|
|
67 |
# Add score to metadata
|
68 |
results = []
|
69 |
for i,(doc,score) in enumerate(docs):
|
|
|
70 |
doc.metadata["similarity_score"] = score
|
71 |
doc.metadata["content"] = doc.page_content
|
72 |
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
@@ -79,86 +80,3 @@ class ClimateQARetriever(BaseRetriever):
|
|
79 |
return results
|
80 |
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
# def filter_summaries(df,k_summary = 3,k_total = 10):
|
85 |
-
# # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
|
86 |
-
|
87 |
-
# # # Filter by source
|
88 |
-
# # if source == "IPCC":
|
89 |
-
# # df = df.loc[df["source"]=="IPCC"]
|
90 |
-
# # elif source == "IPBES":
|
91 |
-
# # df = df.loc[df["source"]=="IPBES"]
|
92 |
-
# # else:
|
93 |
-
# # pass
|
94 |
-
|
95 |
-
# # Separate summaries and full reports
|
96 |
-
# df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
|
97 |
-
# df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
|
98 |
-
|
99 |
-
# # Find passages from summaries dataset
|
100 |
-
# passages_summaries = df_summaries.head(k_summary)
|
101 |
-
|
102 |
-
# # Find passages from full reports dataset
|
103 |
-
# passages_fullreports = df_full.head(k_total - len(passages_summaries))
|
104 |
-
|
105 |
-
# # Concatenate passages
|
106 |
-
# passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
|
107 |
-
# return passages
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
|
113 |
-
# assert max_k > k_total
|
114 |
-
|
115 |
-
# validated_sources = ["IPCC","IPBES"]
|
116 |
-
# sources = [x for x in sources if x in validated_sources]
|
117 |
-
# filters = {
|
118 |
-
# "source": { "$in": sources },
|
119 |
-
# }
|
120 |
-
# print(filters)
|
121 |
-
|
122 |
-
# # Retrieve documents
|
123 |
-
# docs = retriever.retrieve(query,top_k = max_k,filters = filters)
|
124 |
-
|
125 |
-
# # Filter by score
|
126 |
-
# docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
|
127 |
-
|
128 |
-
# if len(docs) == 0:
|
129 |
-
# return []
|
130 |
-
# res = pd.DataFrame(docs)
|
131 |
-
# passages_df = filter_summaries(res,k_summary,k_total)
|
132 |
-
# if as_dict:
|
133 |
-
# contents = passages_df["content"].tolist()
|
134 |
-
# meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
|
135 |
-
# passages = []
|
136 |
-
# for i in range(len(contents)):
|
137 |
-
# passages.append({"content":contents[i],"meta":meta[i]})
|
138 |
-
# return passages
|
139 |
-
# else:
|
140 |
-
# return passages_df
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
|
145 |
-
|
146 |
-
|
147 |
-
# print("hellooooo")
|
148 |
-
|
149 |
-
# # Reformulate queries
|
150 |
-
# reformulated_query,language = reformulate(query)
|
151 |
-
|
152 |
-
# print(reformulated_query)
|
153 |
-
|
154 |
-
# # Retrieve documents
|
155 |
-
# passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
|
156 |
-
# response = {
|
157 |
-
# "query":query,
|
158 |
-
# "reformulated_query":reformulated_query,
|
159 |
-
# "language":language,
|
160 |
-
# "sources":passages,
|
161 |
-
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
|
162 |
-
# }
|
163 |
-
# return response
|
164 |
-
|
|
|
67 |
# Add score to metadata
|
68 |
results = []
|
69 |
for i,(doc,score) in enumerate(docs):
|
70 |
+
doc.page_content = doc.page_content.replace("\r\n"," ")
|
71 |
doc.metadata["similarity_score"] = score
|
72 |
doc.metadata["content"] = doc.page_content
|
73 |
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
|
|
80 |
return results
|
81 |
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/papers/__init__.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
|
3 |
-
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
|
4 |
-
import pyalex
|
5 |
-
|
6 |
-
pyalex.config.email = "[email protected]"
|
7 |
-
|
8 |
-
class OpenAlex():
|
9 |
-
def __init__(self):
|
10 |
-
pass
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
def search(self,keywords,n_results = 100,after = None,before = None):
|
15 |
-
works = Works().search(keywords).get()
|
16 |
-
|
17 |
-
for page in works.paginate(per_page=n_results):
|
18 |
-
break
|
19 |
-
|
20 |
-
df_works = pd.DataFrame(page)
|
21 |
-
|
22 |
-
return works
|
23 |
-
|
24 |
-
|
25 |
-
def make_network(self):
|
26 |
-
pass
|
27 |
-
|
28 |
-
|
29 |
-
def get_abstract_from_inverted_index(self,index):
|
30 |
-
|
31 |
-
# Determine the maximum index to know the length of the reconstructed array
|
32 |
-
max_index = max([max(positions) for positions in index.values()])
|
33 |
-
|
34 |
-
# Initialize a list with placeholders for all positions
|
35 |
-
reconstructed = [''] * (max_index + 1)
|
36 |
-
|
37 |
-
# Iterate through the inverted index and place each token at its respective position(s)
|
38 |
-
for token, positions in index.items():
|
39 |
-
for position in positions:
|
40 |
-
reconstructed[position] = token
|
41 |
-
|
42 |
-
# Join the tokens to form the reconstructed sentence(s)
|
43 |
-
return ' '.join(reconstructed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
gradio==4.
|
2 |
azure-storage-file-share==12.11.1
|
3 |
azure-storage-blob
|
4 |
python-dotenv==1.0.0
|
@@ -15,3 +15,5 @@ flashrank==0.2.5
|
|
15 |
rerankers==0.3.0
|
16 |
torch==2.3.0
|
17 |
nvidia-cudnn-cu12==8.9.2.26
|
|
|
|
|
|
1 |
+
gradio==4.44
|
2 |
azure-storage-file-share==12.11.1
|
3 |
azure-storage-blob
|
4 |
python-dotenv==1.0.0
|
|
|
15 |
rerankers==0.3.0
|
16 |
torch==2.3.0
|
17 |
nvidia-cudnn-cu12==8.9.2.26
|
18 |
+
langchain-community==0.2
|
19 |
+
msal==1.31
|
sandbox/20240310 - CQA - Semantic Routing 1.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|