TheoLvs commited on
Commit
99e91d8
1 Parent(s): fd67e15

agents mode

Browse files
climateqa/engine/chains/keywords_extraction.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+ from typing import List
4
+ from typing import Literal
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.utils.function_calling import convert_to_openai_function
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+
9
+
10
+ class KeywordExtraction(BaseModel):
11
+ """
12
+ Analyzing the user query to extract keywords to feed a search engine
13
+ """
14
+
15
+ keywords: List[str] = Field(
16
+ description="""
17
+ Extract the keywords from the user query to feed a search engine as a list
18
+ Avoid adding super specific keywords to prefer general keywords
19
+ Maximum 3 keywords
20
+
21
+ Examples:
22
+ - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
23
+ - "How will El Nino be impacted by climate change" -> ["el nino","climate change"]
24
+ - "Is climate change a hoax" -> ["climate change","hoax"]
25
+ """
26
+ )
27
+
28
+
29
+ def make_keywords_extraction_chain(llm):
30
+
31
+ openai_functions = [convert_to_openai_function(KeywordExtraction)]
32
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"KeywordExtraction"})
33
+
34
+ prompt = ChatPromptTemplate.from_messages([
35
+ ("system", "You are a helpful assistant"),
36
+ ("user", "input: {input}")
37
+ ])
38
+
39
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
40
+ return chain
climateqa/engine/chains/query_transformation.py CHANGED
@@ -8,6 +8,13 @@ from langchain_core.utils.function_calling import convert_to_openai_function
8
  from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
 
10
 
 
 
 
 
 
 
 
11
  # Prompt from the original paper https://arxiv.org/pdf/2305.14283
12
  # Query Rewriting for Retrieval-Augmented Large Language Models
13
  class QueryDecomposition(BaseModel):
@@ -20,8 +27,8 @@ class QueryDecomposition(BaseModel):
20
  description="""
21
  Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need.
22
  Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
23
- - If it's already a standalone question, you don't need to provide more questions, just reformulate it if relevant as a better question for a search engine
24
- - If you need to decompose the question, output a list of maximum 3 questions
25
  """
26
  )
27
 
@@ -125,12 +132,20 @@ def make_query_rewriter_chain(llm):
125
  return chain
126
 
127
 
128
- def make_query_transform_node(llm):
129
 
130
  decomposition_chain = make_query_decomposition_chain(llm)
131
  rewriter_chain = make_query_rewriter_chain(llm)
132
 
133
  def transform_query(state):
 
 
 
 
 
 
 
 
134
 
135
  new_state = {}
136
 
@@ -145,7 +160,33 @@ def make_query_transform_node(llm):
145
  analysis_output = rewriter_chain.invoke({"input":question})
146
  question_state.update(analysis_output)
147
  questions.append(question_state)
148
- new_state["questions"] = questions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  return new_state
151
 
 
8
  from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
 
10
 
11
+ ROUTING_INDEX = {
12
+ "Vector":["IPCC","IPBES","IPOS"],
13
+ "OpenAlex":["OpenAlex"],
14
+ }
15
+
16
+ POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values]
17
+
18
  # Prompt from the original paper https://arxiv.org/pdf/2305.14283
19
  # Query Rewriting for Retrieval-Augmented Large Language Models
20
  class QueryDecomposition(BaseModel):
 
27
  description="""
28
  Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need.
29
  Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
30
+ - If it's already a standalone and explicit question, just return the reformulated question for the search engine
31
+ - If you need to decompose the question, output a list of maximum 2 to 3 questions
32
  """
33
  )
34
 
 
132
  return chain
133
 
134
 
135
+ def make_query_transform_node(llm,k_final=15):
136
 
137
  decomposition_chain = make_query_decomposition_chain(llm)
138
  rewriter_chain = make_query_rewriter_chain(llm)
139
 
140
  def transform_query(state):
141
+
142
+ if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
143
+ auto_mode = False
144
+ else:
145
+ auto_mode = True
146
+
147
+ sources_input = state.get("sources_input")
148
+ if sources_input is None: sources_input = ROUTING_INDEX["Vector"]
149
 
150
  new_state = {}
151
 
 
160
  analysis_output = rewriter_chain.invoke({"input":question})
161
  question_state.update(analysis_output)
162
  questions.append(question_state)
163
+
164
+ # Explode the questions into multiple questions with different sources
165
+ new_questions = []
166
+ for q in questions:
167
+ question,sources = q["question"],q["sources"]
168
+
169
+ # If not auto mode we take the configuration
170
+ if not auto_mode:
171
+ sources = sources_input
172
+
173
+ for index,index_sources in ROUTING_INDEX.items():
174
+ selected_sources = list(set(sources).intersection(index_sources))
175
+ if len(selected_sources) > 0:
176
+ new_questions.append({"question":question,"sources":selected_sources,"index":index})
177
+
178
+ # # Add the number of questions to search
179
+ # k_by_question = k_final // len(new_questions)
180
+ # for q in new_questions:
181
+ # q["k"] = k_by_question
182
+
183
+ # new_state["questions"] = new_questions
184
+ # new_state["remaining_questions"] = new_questions
185
+
186
+ new_state = {
187
+ "remaining_questions":new_questions,
188
+ "n_questions":len(new_questions),
189
+ }
190
 
191
  return new_state
192
 
climateqa/engine/chains/{retriever.py → retrieve_documents.py} RENAMED
@@ -2,8 +2,16 @@ import sys
2
  import os
3
  from contextlib import contextmanager
4
 
 
 
 
 
 
5
  from ..reranker import rerank_docs
6
- from ..retriever import ClimateQARetriever
 
 
 
7
 
8
 
9
 
@@ -44,80 +52,107 @@ def suppress_output():
44
  sys.stderr = old_stderr
45
 
46
 
 
 
 
 
 
 
 
 
47
 
48
- def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
49
 
50
- def retrieve_documents(state):
 
 
 
 
 
 
 
51
 
52
- POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS","OpenAlex"]
53
- questions = state["questions"]
54
 
55
- # Use sources from the user input or from the LLM detection
56
- if "sources_input" not in state or state["sources_input"] is None:
57
- sources_input = ["auto"]
58
- else:
59
- sources_input = state["sources_input"]
60
- auto_mode = "auto" in sources_input
61
 
62
- # There are several options to get the final top k
63
- # Option 1 - Get 100 documents by question and rerank by question
64
- # Option 2 - Get 100/n documents by question and rerank the total
65
- if rerank_by_question:
66
- k_by_question = divide_into_parts(k_final,len(questions))
 
67
 
 
 
 
68
  docs = []
 
69
 
70
- for i,q in enumerate(questions):
71
-
72
- sources = q["sources"]
73
- question = q["question"]
74
-
75
- # If auto mode, we use the sources detected by the LLM
76
- if auto_mode:
77
- sources = [x for x in sources if x in POSSIBLE_SOURCES]
78
-
79
- # Otherwise, we use the config
80
- else:
81
- sources = sources_input
82
 
83
  # Search the document store using the retriever
84
  # Configure high top k for further reranking step
85
  retriever = ClimateQARetriever(
86
  vectorstore=vectorstore,
87
  sources = sources,
88
- # reports = ias_reports,
89
- min_size = 200,
90
- k_summary = k_summary,k_total = k_before_reranking,
91
- threshold = 0.5,
92
  )
93
- docs_question = retriever.get_relevant_documents(question)
94
-
95
- # Rerank
96
- if reranker is not None:
97
- with suppress_output():
98
- docs_question = rerank_docs(reranker,docs_question,question)
99
- else:
100
- # Add a default reranking score
101
- for doc in docs_question:
102
- doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
103
-
104
- # If rerank by question we select the top documents for each question
105
- if rerank_by_question:
106
- docs_question = docs_question[:k_by_question[i]]
107
-
108
- # Add sources used in the metadata
 
 
 
 
 
 
 
 
 
109
  for doc in docs_question:
110
- doc.metadata["sources_used"] = sources
111
 
112
- # Add to the list of docs
113
- docs.extend(docs_question)
 
 
 
 
 
 
 
 
 
 
114
 
115
  # Sorting the list in descending order by rerank_score
116
- # Then select the top k
117
  docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
118
- docs = docs[:k_final]
119
-
120
- new_state = {"documents":docs}
121
  return new_state
122
 
123
  return retrieve_documents
 
2
  import os
3
  from contextlib import contextmanager
4
 
5
+ from langchain_core.tools import tool
6
+ from langchain_core.runnables import chain
7
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
8
+ from langchain_core.runnables import RunnableLambda
9
+
10
  from ..reranker import rerank_docs
11
+ from ...knowledge.retriever import ClimateQARetriever
12
+ from ...knowledge.openalex import OpenAlexRetriever
13
+ from .keywords_extraction import make_keywords_extraction_chain
14
+ from ..utils import log_event
15
 
16
 
17
 
 
52
  sys.stderr = old_stderr
53
 
54
 
55
+ @tool
56
+ def query_retriever(question):
57
+ """Just a dummy tool to simulate the retriever query"""
58
+ return question
59
+
60
+
61
+
62
+
63
 
 
64
 
65
+
66
+ def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
67
+
68
+ # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
69
+ @chain
70
+ async def retrieve_documents(state,config):
71
+
72
+ keywords_extraction = make_keywords_extraction_chain(llm)
73
 
74
+ current_question = state["remaining_questions"][0]
75
+ remaining_questions = state["remaining_questions"][1:]
76
 
77
+ # ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")
 
 
 
 
 
78
 
79
+
80
+ # # There are several options to get the final top k
81
+ # # Option 1 - Get 100 documents by question and rerank by question
82
+ # # Option 2 - Get 100/n documents by question and rerank the total
83
+ # if rerank_by_question:
84
+ # k_by_question = divide_into_parts(k_final,len(questions))
85
 
86
+ # docs = state["documents"]
87
+ # if docs is None: docs = []
88
+
89
  docs = []
90
+ k_by_question = k_final // state["n_questions"]
91
 
92
+ sources = current_question["sources"]
93
+ question = current_question["question"]
94
+ index = current_question["index"]
95
+
96
+
97
+ await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
98
+
99
+
100
+ if index == "Vector":
 
 
 
101
 
102
  # Search the document store using the retriever
103
  # Configure high top k for further reranking step
104
  retriever = ClimateQARetriever(
105
  vectorstore=vectorstore,
106
  sources = sources,
107
+ min_size = 200,
108
+ k_summary = k_summary,
109
+ k_total = k_before_reranking,
110
+ threshold = 0.5,
111
  )
112
+ docs_question = await retriever.ainvoke(question,config)
113
+
114
+ elif index == "OpenAlex":
115
+
116
+ keywords = keywords_extraction.invoke(question)["keywords"]
117
+ openalex_query = " AND ".join(keywords)
118
+
119
+ print(f"... OpenAlex query: {openalex_query}")
120
+
121
+ retriever_openalex = OpenAlexRetriever(
122
+ min_year = state.get("min_year",1960),
123
+ max_year = state.get("max_year",None),
124
+ k = k_before_reranking
125
+ )
126
+ docs_question = await retriever_openalex.ainvoke(openalex_query,config)
127
+
128
+ else:
129
+ raise Exception(f"Index {index} not found in the routing index")
130
+
131
+ # Rerank
132
+ if reranker is not None:
133
+ with suppress_output():
134
+ docs_question = rerank_docs(reranker,docs_question,question)
135
+ else:
136
+ # Add a default reranking score
137
  for doc in docs_question:
138
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
139
 
140
+ # If rerank by question we select the top documents for each question
141
+ if rerank_by_question:
142
+ docs_question = docs_question[:k_by_question]
143
+
144
+ # Add sources used in the metadata
145
+ for doc in docs_question:
146
+ doc.metadata["sources_used"] = sources
147
+ doc.metadata["question_used"] = question
148
+ doc.metadata["index_used"] = index
149
+
150
+ # Add to the list of docs
151
+ docs.extend(docs_question)
152
 
153
  # Sorting the list in descending order by rerank_score
 
154
  docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
155
+ new_state = {"documents":docs,"remaining_questions":remaining_questions}
 
 
156
  return new_state
157
 
158
  return retrieve_documents
climateqa/engine/graph.py CHANGED
@@ -16,10 +16,9 @@ from .chains.answer_ai_impact import make_ai_impact_node
16
  from .chains.query_transformation import make_query_transform_node
17
  from .chains.translation import make_translation_node
18
  from .chains.intent_categorization import make_intent_categorization_node
19
- from .chains.retriever import make_retriever_node
20
  from .chains.answer_rag import make_rag_node
21
 
22
-
23
  class GraphState(TypedDict):
24
  """
25
  Represents the state of our graph.
@@ -28,21 +27,28 @@ class GraphState(TypedDict):
28
  language : str
29
  intent : str
30
  query: str
31
- questions : List[dict]
 
32
  answer: str
33
  audience: str = "experts"
34
- sources_input: List[str] = ["auto"]
 
 
 
35
  documents: List[Document]
36
 
37
  def search(state):
38
  return {}
39
 
 
 
 
40
  def route_intent(state):
41
  intent = state["intent"]
42
  if intent in ["chitchat","esg"]:
43
  return "answer_chitchat"
44
- elif intent == "ai_impact":
45
- return "answer_ai_impact"
46
  else:
47
  # Search route
48
  return "search"
@@ -74,17 +80,18 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
74
  translate_query = make_translation_node(llm)
75
  answer_chitchat = make_chitchat_node(llm)
76
  answer_ai_impact = make_ai_impact_node(llm)
77
- retrieve_documents = make_retriever_node(vectorstore,reranker)
78
  answer_rag = make_rag_node(llm,with_docs=True)
79
  answer_rag_no_docs = make_rag_node(llm,with_docs=False)
80
 
81
  # Define the nodes
82
  workflow.add_node("categorize_intent", categorize_intent)
83
  workflow.add_node("search", search)
 
84
  workflow.add_node("transform_query", transform_query)
85
  workflow.add_node("translate_query", translate_query)
86
  workflow.add_node("answer_chitchat", answer_chitchat)
87
- workflow.add_node("answer_ai_impact", answer_ai_impact)
88
  workflow.add_node("retrieve_documents",retrieve_documents)
89
  workflow.add_node("answer_rag",answer_rag)
90
  workflow.add_node("answer_rag_no_docs",answer_rag_no_docs)
@@ -96,7 +103,7 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
96
  workflow.add_conditional_edges(
97
  "categorize_intent",
98
  route_intent,
99
- make_id_dict(["answer_chitchat","answer_ai_impact","search"])
100
  )
101
 
102
  workflow.add_conditional_edges(
@@ -104,9 +111,14 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
104
  route_translation,
105
  make_id_dict(["translate_query","transform_query"])
106
  )
107
-
108
  workflow.add_conditional_edges(
109
  "retrieve_documents",
 
 
 
 
 
 
110
  lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
111
  make_id_dict(["answer_rag","answer_rag_no_docs"])
112
  )
@@ -114,11 +126,10 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
114
  # Define the edges
115
  workflow.add_edge("translate_query", "transform_query")
116
  workflow.add_edge("transform_query", "retrieve_documents")
117
- workflow.add_edge("retrieve_documents", "answer_rag")
118
  workflow.add_edge("answer_rag", END)
119
  workflow.add_edge("answer_rag_no_docs", END)
120
  workflow.add_edge("answer_chitchat", END)
121
- workflow.add_edge("answer_ai_impact", END)
122
 
123
  # Compile
124
  app = workflow.compile()
 
16
  from .chains.query_transformation import make_query_transform_node
17
  from .chains.translation import make_translation_node
18
  from .chains.intent_categorization import make_intent_categorization_node
19
+ from .chains.retrieve_documents import make_retriever_node
20
  from .chains.answer_rag import make_rag_node
21
 
 
22
  class GraphState(TypedDict):
23
  """
24
  Represents the state of our graph.
 
27
  language : str
28
  intent : str
29
  query: str
30
+ remaining_questions : List[dict]
31
+ n_questions : int
32
  answer: str
33
  audience: str = "experts"
34
+ sources_input: List[str] = ["IPCC","IPBES"]
35
+ sources_auto: bool = True
36
+ min_year: int = 1960
37
+ max_year: int = None
38
  documents: List[Document]
39
 
40
  def search(state):
41
  return {}
42
 
43
+ def answer_search(state):
44
+ return {}
45
+
46
  def route_intent(state):
47
  intent = state["intent"]
48
  if intent in ["chitchat","esg"]:
49
  return "answer_chitchat"
50
+ # elif intent == "ai_impact":
51
+ # return "answer_ai_impact"
52
  else:
53
  # Search route
54
  return "search"
 
80
  translate_query = make_translation_node(llm)
81
  answer_chitchat = make_chitchat_node(llm)
82
  answer_ai_impact = make_ai_impact_node(llm)
83
+ retrieve_documents = make_retriever_node(vectorstore,reranker,llm)
84
  answer_rag = make_rag_node(llm,with_docs=True)
85
  answer_rag_no_docs = make_rag_node(llm,with_docs=False)
86
 
87
  # Define the nodes
88
  workflow.add_node("categorize_intent", categorize_intent)
89
  workflow.add_node("search", search)
90
+ workflow.add_node("answer_search", answer_search)
91
  workflow.add_node("transform_query", transform_query)
92
  workflow.add_node("translate_query", translate_query)
93
  workflow.add_node("answer_chitchat", answer_chitchat)
94
+ # workflow.add_node("answer_ai_impact", answer_ai_impact)
95
  workflow.add_node("retrieve_documents",retrieve_documents)
96
  workflow.add_node("answer_rag",answer_rag)
97
  workflow.add_node("answer_rag_no_docs",answer_rag_no_docs)
 
103
  workflow.add_conditional_edges(
104
  "categorize_intent",
105
  route_intent,
106
+ make_id_dict(["answer_chitchat","search"])
107
  )
108
 
109
  workflow.add_conditional_edges(
 
111
  route_translation,
112
  make_id_dict(["translate_query","transform_query"])
113
  )
 
114
  workflow.add_conditional_edges(
115
  "retrieve_documents",
116
+ lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
117
+ make_id_dict(["retrieve_documents","answer_search"])
118
+ )
119
+
120
+ workflow.add_conditional_edges(
121
+ "answer_search",
122
  lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
123
  make_id_dict(["answer_rag","answer_rag_no_docs"])
124
  )
 
126
  # Define the edges
127
  workflow.add_edge("translate_query", "transform_query")
128
  workflow.add_edge("transform_query", "retrieve_documents")
 
129
  workflow.add_edge("answer_rag", END)
130
  workflow.add_edge("answer_rag_no_docs", END)
131
  workflow.add_edge("answer_chitchat", END)
132
+ # workflow.add_edge("answer_ai_impact", END)
133
 
134
  # Compile
135
  app = workflow.compile()
climateqa/engine/llm/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from climateqa.engine.llm.openai import get_llm as get_openai_llm
2
  from climateqa.engine.llm.azure import get_llm as get_azure_llm
 
3
 
4
 
5
  def get_llm(provider="openai",**kwargs):
@@ -8,6 +9,8 @@ def get_llm(provider="openai",**kwargs):
8
  return get_openai_llm(**kwargs)
9
  elif provider == "azure":
10
  return get_azure_llm(**kwargs)
 
 
11
  else:
12
  raise ValueError(f"Unknown provider: {provider}")
13
 
 
1
  from climateqa.engine.llm.openai import get_llm as get_openai_llm
2
  from climateqa.engine.llm.azure import get_llm as get_azure_llm
3
+ from climateqa.engine.llm.ollama import get_llm as get_ollama_llm
4
 
5
 
6
  def get_llm(provider="openai",**kwargs):
 
9
  return get_openai_llm(**kwargs)
10
  elif provider == "azure":
11
  return get_azure_llm(**kwargs)
12
+ elif provider == "ollama":
13
+ return get_ollama_llm(**kwargs)
14
  else:
15
  raise ValueError(f"Unknown provider: {provider}")
16
 
climateqa/engine/llm/ollama.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+
3
+ from langchain_community.llms import Ollama
4
+
5
+ def get_llm(model="llama3", **kwargs):
6
+ return Ollama(model=model, **kwargs)
climateqa/engine/utils.py CHANGED
@@ -1,8 +1,15 @@
1
  from operator import itemgetter
2
  from typing import Any, Dict, Iterable, Tuple
 
3
  from langchain_core.runnables import RunnablePassthrough
4
 
5
 
 
 
 
 
 
 
6
  def pass_values(x):
7
  if not isinstance(x, list):
8
  x = [x]
@@ -67,3 +74,13 @@ def flatten_dict(
67
  """
68
  flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
69
  return flat_dict
 
 
 
 
 
 
 
 
 
 
 
1
  from operator import itemgetter
2
  from typing import Any, Dict, Iterable, Tuple
3
+ import tiktoken
4
  from langchain_core.runnables import RunnablePassthrough
5
 
6
 
7
+ def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
8
+ encoding = tiktoken.get_encoding(encoding_name)
9
+ num_tokens = len(encoding.encode(string))
10
+ return num_tokens
11
+
12
+
13
  def pass_values(x):
14
  if not isinstance(x, list):
15
  x = [x]
 
74
  """
75
  flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
76
  return flat_dict
77
+
78
+
79
+
80
+ async def log_event(info,name,config):
81
+ """Helper function that will run a dummy chain with the given info
82
+ The astream_event function will catch this chain and stream the dict info to the logger
83
+ """
84
+
85
+ chain = RunnablePassthrough().with_config(run_name=name)
86
+ _ = await chain.ainvoke(info,config)
climateqa/knowledge/__init__.py ADDED
File without changes
climateqa/{papers → knowledge}/openalex.py RENAMED
@@ -3,18 +3,32 @@ import networkx as nx
3
  import matplotlib.pyplot as plt
4
  from pyvis.network import Network
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
7
  import pyalex
8
 
9
  pyalex.config.email = "[email protected]"
10
 
 
 
 
 
11
  class OpenAlex():
12
  def __init__(self):
13
  pass
14
 
15
 
16
-
17
- def search(self,keywords,n_results = 100,after = None,before = None):
18
 
19
  if isinstance(keywords,str):
20
  works = Works().search(keywords)
@@ -27,18 +41,21 @@ class OpenAlex():
27
  break
28
 
29
  df_works = pd.DataFrame(page)
30
- df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x))
 
 
31
  df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
32
  df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
33
- df_works["content"] = df_works["title"] + "\n" + df_works["abstract"]
34
-
 
 
 
 
 
 
35
  else:
36
- df_works = []
37
- for keyword in keywords:
38
- df_keyword = self.search(keyword,n_results = n_results,after = after,before = before)
39
- df_works.append(df_keyword)
40
- df_works = pd.concat(df_works,ignore_index=True,axis = 0)
41
- return df_works
42
 
43
 
44
  def rerank(self,query,df,reranker):
@@ -139,4 +156,36 @@ class OpenAlex():
139
  reconstructed[position] = token
140
 
141
  # Join the tokens to form the reconstructed sentence(s)
142
- return ' '.join(reconstructed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import matplotlib.pyplot as plt
4
  from pyvis.network import Network
5
 
6
+ from langchain_core.retrievers import BaseRetriever
7
+ from langchain_core.vectorstores import VectorStoreRetriever
8
+ from langchain_core.documents.base import Document
9
+ from langchain_core.vectorstores import VectorStore
10
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
11
+
12
+ from ..engine.utils import num_tokens_from_string
13
+
14
+ from typing import List
15
+ from pydantic import Field
16
+
17
  from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
18
  import pyalex
19
 
20
  pyalex.config.email = "[email protected]"
21
 
22
+
23
+ def replace_nan_with_empty_dict(x):
24
+ return x if pd.notna(x) else {}
25
+
26
  class OpenAlex():
27
  def __init__(self):
28
  pass
29
 
30
 
31
+ def search(self,keywords:str,n_results = 100,after = None,before = None):
 
32
 
33
  if isinstance(keywords,str):
34
  works = Works().search(keywords)
 
41
  break
42
 
43
  df_works = pd.DataFrame(page)
44
+ df_works = df_works.dropna(subset = ["title"])
45
+ df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
46
+ df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
47
  df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
48
  df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
49
+ df_works["url"] = df_works["id"]
50
+ df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
51
+ df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
52
+
53
+ df_works = df_works.drop(columns = ["abstract_inverted_index"])
54
+ # df_works["subtitle"] = df_works["title"] + " - " + df_works["primary_location"]["source"]["display_name"] + " - " + df_works["publication_year"]
55
+
56
+ return df_works
57
  else:
58
+ raise Exception("Keywords must be a string")
 
 
 
 
 
59
 
60
 
61
  def rerank(self,query,df,reranker):
 
156
  reconstructed[position] = token
157
 
158
  # Join the tokens to form the reconstructed sentence(s)
159
+ return ' '.join(reconstructed)
160
+
161
+
162
+
163
+ class OpenAlexRetriever(BaseRetriever):
164
+ min_year:int = 1960
165
+ max_year:int = None
166
+ k = 100
167
+
168
+ def _get_relevant_documents(
169
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
170
+ ) -> List[Document]:
171
+
172
+ openalex = OpenAlex()
173
+
174
+ # Search for documents
175
+ df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)
176
+
177
+ docs = []
178
+ for i,row in df_docs.iterrows():
179
+ num_tokens = row["num_tokens"]
180
+
181
+ if num_tokens < 50 or num_tokens > 1000:
182
+ continue
183
+
184
+ doc = Document(
185
+ page_content = row["content"],
186
+ metadata = row.to_dict()
187
+ )
188
+ docs.append(doc)
189
+ return docs
190
+
191
+
climateqa/{engine → knowledge}/retriever.py RENAMED
@@ -66,6 +66,7 @@ class ClimateQARetriever(BaseRetriever):
66
  # Add score to metadata
67
  results = []
68
  for i,(doc,score) in enumerate(docs):
 
69
  doc.metadata["similarity_score"] = score
70
  doc.metadata["content"] = doc.page_content
71
  doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
@@ -78,86 +79,3 @@ class ClimateQARetriever(BaseRetriever):
78
  return results
79
 
80
 
81
-
82
-
83
- # def filter_summaries(df,k_summary = 3,k_total = 10):
84
- # # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
85
-
86
- # # # Filter by source
87
- # # if source == "IPCC":
88
- # # df = df.loc[df["source"]=="IPCC"]
89
- # # elif source == "IPBES":
90
- # # df = df.loc[df["source"]=="IPBES"]
91
- # # else:
92
- # # pass
93
-
94
- # # Separate summaries and full reports
95
- # df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
96
- # df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
97
-
98
- # # Find passages from summaries dataset
99
- # passages_summaries = df_summaries.head(k_summary)
100
-
101
- # # Find passages from full reports dataset
102
- # passages_fullreports = df_full.head(k_total - len(passages_summaries))
103
-
104
- # # Concatenate passages
105
- # passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
106
- # return passages
107
-
108
-
109
-
110
-
111
- # def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
112
- # assert max_k > k_total
113
-
114
- # validated_sources = ["IPCC","IPBES"]
115
- # sources = [x for x in sources if x in validated_sources]
116
- # filters = {
117
- # "source": { "$in": sources },
118
- # }
119
- # print(filters)
120
-
121
- # # Retrieve documents
122
- # docs = retriever.retrieve(query,top_k = max_k,filters = filters)
123
-
124
- # # Filter by score
125
- # docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
126
-
127
- # if len(docs) == 0:
128
- # return []
129
- # res = pd.DataFrame(docs)
130
- # passages_df = filter_summaries(res,k_summary,k_total)
131
- # if as_dict:
132
- # contents = passages_df["content"].tolist()
133
- # meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
134
- # passages = []
135
- # for i in range(len(contents)):
136
- # passages.append({"content":contents[i],"meta":meta[i]})
137
- # return passages
138
- # else:
139
- # return passages_df
140
-
141
-
142
-
143
- # def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
144
-
145
-
146
- # print("hellooooo")
147
-
148
- # # Reformulate queries
149
- # reformulated_query,language = reformulate(query)
150
-
151
- # print(reformulated_query)
152
-
153
- # # Retrieve documents
154
- # passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
155
- # response = {
156
- # "query":query,
157
- # "reformulated_query":reformulated_query,
158
- # "language":language,
159
- # "sources":passages,
160
- # "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
161
- # }
162
- # return response
163
-
 
66
  # Add score to metadata
67
  results = []
68
  for i,(doc,score) in enumerate(docs):
69
+ doc.page_content = doc.page_content.replace("\r\n"," ")
70
  doc.metadata["similarity_score"] = score
71
  doc.metadata["content"] = doc.page_content
72
  doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
 
79
  return results
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/papers/__init__.py DELETED
@@ -1,43 +0,0 @@
1
- import pandas as pd
2
-
3
- from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
4
- import pyalex
5
-
6
- pyalex.config.email = "[email protected]"
7
-
8
- class OpenAlex():
9
- def __init__(self):
10
- pass
11
-
12
-
13
-
14
- def search(self,keywords,n_results = 100,after = None,before = None):
15
- works = Works().search(keywords).get()
16
-
17
- for page in works.paginate(per_page=n_results):
18
- break
19
-
20
- df_works = pd.DataFrame(page)
21
-
22
- return works
23
-
24
-
25
- def make_network(self):
26
- pass
27
-
28
-
29
- def get_abstract_from_inverted_index(self,index):
30
-
31
- # Determine the maximum index to know the length of the reconstructed array
32
- max_index = max([max(positions) for positions in index.values()])
33
-
34
- # Initialize a list with placeholders for all positions
35
- reconstructed = [''] * (max_index + 1)
36
-
37
- # Iterate through the inverted index and place each token at its respective position(s)
38
- for token, positions in index.items():
39
- for position in positions:
40
- reconstructed[position] = token
41
-
42
- # Join the tokens to form the reconstructed sentence(s)
43
- return ' '.join(reconstructed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sandbox/20240310 - CQA - Semantic Routing 1.ipynb CHANGED
The diff for this file is too large to render. See raw diff