timeki commited on
Commit
aa904c1
2 Parent(s): 6b43c86 25e32e6

Merge branch 'bugfix/add_dummy_searchs' into feature/graph_recommandation

Browse files
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
4
- from climateqa.papers.openalex import OpenAlex
5
  from sentence_transformers import CrossEncoder
6
 
7
  # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
@@ -31,7 +31,7 @@ from collections import defaultdict
31
  # ClimateQ&A imports
32
  from climateqa.engine.llm import get_llm
33
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
34
- from climateqa.engine.retriever import ClimateQARetriever
35
  from climateqa.engine.reranker import get_reranker
36
  from climateqa.engine.embeddings import get_embeddings_function
37
  from climateqa.engine.chains.prompts import audience_prompts
 
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
4
+ from climateqa.knowledge.openalex import OpenAlex
5
  from sentence_transformers import CrossEncoder
6
 
7
  # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
 
31
  # ClimateQ&A imports
32
  from climateqa.engine.llm import get_llm
33
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
34
+ from climateqa.knowledge.retriever import ClimateQARetriever
35
  from climateqa.engine.reranker import get_reranker
36
  from climateqa.engine.embeddings import get_embeddings_function
37
  from climateqa.engine.chains.prompts import audience_prompts
climateqa/engine/chains/keywords_extraction.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+ from typing import List
4
+ from typing import Literal
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.utils.function_calling import convert_to_openai_function
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+
9
+
10
+ class KeywordExtraction(BaseModel):
11
+ """
12
+ Analyzing the user query to extract keywords to feed a search engine
13
+ """
14
+
15
+ keywords: List[str] = Field(
16
+ description="""
17
+ Extract the keywords from the user query to feed a search engine as a list
18
+ Avoid adding super specific keywords to prefer general keywords
19
+ Maximum 3 keywords
20
+
21
+ Examples:
22
+ - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
23
+ - "How will El Nino be impacted by climate change" -> ["el nino","climate change"]
24
+ - "Is climate change a hoax" -> ["climate change","hoax"]
25
+ """
26
+ )
27
+
28
+
29
+ def make_keywords_extraction_chain(llm):
30
+
31
+ openai_functions = [convert_to_openai_function(KeywordExtraction)]
32
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"KeywordExtraction"})
33
+
34
+ prompt = ChatPromptTemplate.from_messages([
35
+ ("system", "You are a helpful assistant"),
36
+ ("user", "input: {input}")
37
+ ])
38
+
39
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
40
+ return chain
climateqa/engine/chains/query_transformation.py CHANGED
@@ -8,6 +8,13 @@ from langchain_core.utils.function_calling import convert_to_openai_function
8
  from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
 
10
 
 
 
 
 
 
 
 
11
  # Prompt from the original paper https://arxiv.org/pdf/2305.14283
12
  # Query Rewriting for Retrieval-Augmented Large Language Models
13
  class QueryDecomposition(BaseModel):
@@ -20,8 +27,8 @@ class QueryDecomposition(BaseModel):
20
  description="""
21
  Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need.
22
  Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
23
- - If it's already a standalone question, you don't need to provide more questions, just reformulate it if relevant as a better question for a search engine
24
- - If you need to decompose the question, output a list of maximum 3 questions
25
  """
26
  )
27
 
@@ -125,12 +132,20 @@ def make_query_rewriter_chain(llm):
125
  return chain
126
 
127
 
128
- def make_query_transform_node(llm):
129
 
130
  decomposition_chain = make_query_decomposition_chain(llm)
131
  rewriter_chain = make_query_rewriter_chain(llm)
132
 
133
  def transform_query(state):
 
 
 
 
 
 
 
 
134
 
135
  new_state = {}
136
 
@@ -150,7 +165,33 @@ def make_query_transform_node(llm):
150
 
151
  question_state.update(analysis_output)
152
  questions.append(question_state)
153
- new_state["questions"] = questions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  return new_state
156
 
 
8
  from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
 
10
 
11
+ ROUTING_INDEX = {
12
+ "Vector":["IPCC","IPBES","IPOS"],
13
+ "OpenAlex":["OpenAlex"],
14
+ }
15
+
16
+ POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values]
17
+
18
  # Prompt from the original paper https://arxiv.org/pdf/2305.14283
19
  # Query Rewriting for Retrieval-Augmented Large Language Models
20
  class QueryDecomposition(BaseModel):
 
27
  description="""
28
  Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need.
29
  Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
30
+ - If it's already a standalone and explicit question, just return the reformulated question for the search engine
31
+ - If you need to decompose the question, output a list of maximum 2 to 3 questions
32
  """
33
  )
34
 
 
132
  return chain
133
 
134
 
135
+ def make_query_transform_node(llm,k_final=15):
136
 
137
  decomposition_chain = make_query_decomposition_chain(llm)
138
  rewriter_chain = make_query_rewriter_chain(llm)
139
 
140
  def transform_query(state):
141
+
142
+ if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
143
+ auto_mode = False
144
+ else:
145
+ auto_mode = True
146
+
147
+ sources_input = state.get("sources_input")
148
+ if sources_input is None: sources_input = ROUTING_INDEX["Vector"]
149
 
150
  new_state = {}
151
 
 
165
 
166
  question_state.update(analysis_output)
167
  questions.append(question_state)
168
+
169
+ # Explode the questions into multiple questions with different sources
170
+ new_questions = []
171
+ for q in questions:
172
+ question,sources = q["question"],q["sources"]
173
+
174
+ # If not auto mode we take the configuration
175
+ if not auto_mode:
176
+ sources = sources_input
177
+
178
+ for index,index_sources in ROUTING_INDEX.items():
179
+ selected_sources = list(set(sources).intersection(index_sources))
180
+ if len(selected_sources) > 0:
181
+ new_questions.append({"question":question,"sources":selected_sources,"index":index})
182
+
183
+ # # Add the number of questions to search
184
+ # k_by_question = k_final // len(new_questions)
185
+ # for q in new_questions:
186
+ # q["k"] = k_by_question
187
+
188
+ # new_state["questions"] = new_questions
189
+ # new_state["remaining_questions"] = new_questions
190
+
191
+ new_state = {
192
+ "remaining_questions":new_questions,
193
+ "n_questions":len(new_questions),
194
+ }
195
 
196
  return new_state
197
 
climateqa/engine/chains/retrieve_documents.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from contextlib import contextmanager
4
+
5
+ from langchain_core.tools import tool
6
+ from langchain_core.runnables import chain
7
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
8
+ from langchain_core.runnables import RunnableLambda
9
+
10
+ from ..reranker import rerank_docs
11
+ from ...knowledge.retriever import ClimateQARetriever
12
+ from ...knowledge.openalex import OpenAlexRetriever
13
+ from .keywords_extraction import make_keywords_extraction_chain
14
+ from ..utils import log_event
15
+
16
+
17
+
18
+ def divide_into_parts(target, parts):
19
+ # Base value for each part
20
+ base = target // parts
21
+ # Remainder to distribute
22
+ remainder = target % parts
23
+ # List to hold the result
24
+ result = []
25
+
26
+ for i in range(parts):
27
+ if i < remainder:
28
+ # These parts get base value + 1
29
+ result.append(base + 1)
30
+ else:
31
+ # The rest get the base value
32
+ result.append(base)
33
+
34
+ return result
35
+
36
+
37
+ @contextmanager
38
+ def suppress_output():
39
+ # Open a null device
40
+ with open(os.devnull, 'w') as devnull:
41
+ # Store the original stdout and stderr
42
+ old_stdout = sys.stdout
43
+ old_stderr = sys.stderr
44
+ # Redirect stdout and stderr to the null device
45
+ sys.stdout = devnull
46
+ sys.stderr = devnull
47
+ try:
48
+ yield
49
+ finally:
50
+ # Restore stdout and stderr
51
+ sys.stdout = old_stdout
52
+ sys.stderr = old_stderr
53
+
54
+
55
+ @tool
56
+ def query_retriever(question):
57
+ """Just a dummy tool to simulate the retriever query"""
58
+ return question
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
67
+
68
+ # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
69
+ @chain
70
+ async def retrieve_documents(state,config):
71
+
72
+ keywords_extraction = make_keywords_extraction_chain(llm)
73
+
74
+ current_question = state["remaining_questions"][0]
75
+ remaining_questions = state["remaining_questions"][1:]
76
+
77
+ # ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")
78
+
79
+
80
+ # # There are several options to get the final top k
81
+ # # Option 1 - Get 100 documents by question and rerank by question
82
+ # # Option 2 - Get 100/n documents by question and rerank the total
83
+ # if rerank_by_question:
84
+ # k_by_question = divide_into_parts(k_final,len(questions))
85
+
86
+ # docs = state["documents"]
87
+ # if docs is None: docs = []
88
+
89
+ docs = []
90
+ k_by_question = k_final // state["n_questions"]
91
+
92
+ sources = current_question["sources"]
93
+ question = current_question["question"]
94
+ index = current_question["index"]
95
+
96
+
97
+ await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
98
+
99
+
100
+ if index == "Vector":
101
+
102
+ # Search the document store using the retriever
103
+ # Configure high top k for further reranking step
104
+ retriever = ClimateQARetriever(
105
+ vectorstore=vectorstore,
106
+ sources = sources,
107
+ min_size = 200,
108
+ k_summary = k_summary,
109
+ k_total = k_before_reranking,
110
+ threshold = 0.5,
111
+ )
112
+ docs_question = await retriever.ainvoke(question,config)
113
+
114
+ elif index == "OpenAlex":
115
+
116
+ keywords = keywords_extraction.invoke(question)["keywords"]
117
+ openalex_query = " AND ".join(keywords)
118
+
119
+ print(f"... OpenAlex query: {openalex_query}")
120
+
121
+ retriever_openalex = OpenAlexRetriever(
122
+ min_year = state.get("min_year",1960),
123
+ max_year = state.get("max_year",None),
124
+ k = k_before_reranking
125
+ )
126
+ docs_question = await retriever_openalex.ainvoke(openalex_query,config)
127
+
128
+ else:
129
+ raise Exception(f"Index {index} not found in the routing index")
130
+
131
+ # Rerank
132
+ if reranker is not None:
133
+ with suppress_output():
134
+ docs_question = rerank_docs(reranker,docs_question,question)
135
+ else:
136
+ # Add a default reranking score
137
+ for doc in docs_question:
138
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
139
+
140
+ # If rerank by question we select the top documents for each question
141
+ if rerank_by_question:
142
+ docs_question = docs_question[:k_by_question]
143
+
144
+ # Add sources used in the metadata
145
+ for doc in docs_question:
146
+ doc.metadata["sources_used"] = sources
147
+ doc.metadata["question_used"] = question
148
+ doc.metadata["index_used"] = index
149
+
150
+ # Add to the list of docs
151
+ docs.extend(docs_question)
152
+
153
+ # Sorting the list in descending order by rerank_score
154
+ docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
155
+ new_state = {"documents":docs,"remaining_questions":remaining_questions}
156
+ return new_state
157
+
158
+ return retrieve_documents
159
+
climateqa/engine/graph.py CHANGED
@@ -16,10 +16,9 @@ from .chains.answer_ai_impact import make_ai_impact_node
16
  from .chains.query_transformation import make_query_transform_node
17
  from .chains.translation import make_translation_node
18
  from .chains.intent_categorization import make_intent_categorization_node
19
- from .chains.retriever import make_retriever_node
20
  from .chains.answer_rag import make_rag_node
21
 
22
-
23
  class GraphState(TypedDict):
24
  """
25
  Represents the state of our graph.
@@ -29,23 +28,30 @@ class GraphState(TypedDict):
29
  intent : str
30
  search_graphs_chitchat : bool
31
  query: str
32
- questions : List[dict]
 
33
  answer: str
34
  audience: str = "experts"
35
- sources_input: List[str] = ["auto"]
 
 
 
36
  documents: List[Document]
37
  recommended_content : List[Document]
38
  # graphs_returned: Dict[str,str]
39
 
40
- def search(state):
41
- return {}
 
 
 
42
 
43
  def route_intent(state):
44
  intent = state["intent"]
45
  if intent in ["chitchat","esg"]:
46
  return "answer_chitchat"
47
- elif intent == "ai_impact":
48
- return "answer_ai_impact"
49
  else:
50
  # Search route
51
  return "search"
@@ -95,6 +101,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
95
  workflow.add_node("set_defaults", set_defaults)
96
  workflow.add_node("categorize_intent", categorize_intent)
97
  workflow.add_node("search", search)
 
98
  workflow.add_node("transform_query", transform_query)
99
  workflow.add_node("translate_query", translate_query)
100
  # workflow.add_node("transform_query_ai", transform_query)
@@ -118,7 +125,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
118
  workflow.add_conditional_edges(
119
  "categorize_intent",
120
  route_intent,
121
- make_id_dict(["answer_chitchat","answer_ai_impact","search"])
122
  )
123
 
124
  workflow.add_conditional_edges(
@@ -132,9 +139,14 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
132
  route_translation,
133
  make_id_dict(["translate_query","transform_query"])
134
  )
135
-
136
  workflow.add_conditional_edges(
137
  "retrieve_documents",
 
 
 
 
 
 
138
  lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
139
  make_id_dict(["answer_rag","answer_rag_no_docs"])
140
  )
 
16
  from .chains.query_transformation import make_query_transform_node
17
  from .chains.translation import make_translation_node
18
  from .chains.intent_categorization import make_intent_categorization_node
19
+ from .chains.retrieve_documents import make_retriever_node
20
  from .chains.answer_rag import make_rag_node
21
 
 
22
  class GraphState(TypedDict):
23
  """
24
  Represents the state of our graph.
 
28
  intent : str
29
  search_graphs_chitchat : bool
30
  query: str
31
+ remaining_questions : List[dict]
32
+ n_questions : int
33
  answer: str
34
  audience: str = "experts"
35
+ sources_input: List[str] = ["IPCC","IPBES"]
36
+ sources_auto: bool = True
37
+ min_year: int = 1960
38
+ max_year: int = None
39
  documents: List[Document]
40
  recommended_content : List[Document]
41
  # graphs_returned: Dict[str,str]
42
 
43
+ def search(state): #TODO
44
+ return state
45
+
46
+ def answer_search(state):#TODO
47
+ return state
48
 
49
  def route_intent(state):
50
  intent = state["intent"]
51
  if intent in ["chitchat","esg"]:
52
  return "answer_chitchat"
53
+ # elif intent == "ai_impact":
54
+ # return "answer_ai_impact"
55
  else:
56
  # Search route
57
  return "search"
 
101
  workflow.add_node("set_defaults", set_defaults)
102
  workflow.add_node("categorize_intent", categorize_intent)
103
  workflow.add_node("search", search)
104
+ workflow.add_node("answer_search", answer_search)
105
  workflow.add_node("transform_query", transform_query)
106
  workflow.add_node("translate_query", translate_query)
107
  # workflow.add_node("transform_query_ai", transform_query)
 
125
  workflow.add_conditional_edges(
126
  "categorize_intent",
127
  route_intent,
128
+ make_id_dict(["answer_chitchat","search"])
129
  )
130
 
131
  workflow.add_conditional_edges(
 
139
  route_translation,
140
  make_id_dict(["translate_query","transform_query"])
141
  )
 
142
  workflow.add_conditional_edges(
143
  "retrieve_documents",
144
+ lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
145
+ make_id_dict(["retrieve_documents","answer_search"])
146
+ )
147
+
148
+ workflow.add_conditional_edges(
149
+ "answer_search",
150
  lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
151
  make_id_dict(["answer_rag","answer_rag_no_docs"])
152
  )
climateqa/engine/llm/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from climateqa.engine.llm.openai import get_llm as get_openai_llm
2
  from climateqa.engine.llm.azure import get_llm as get_azure_llm
 
3
 
4
 
5
  def get_llm(provider="openai",**kwargs):
@@ -8,6 +9,8 @@ def get_llm(provider="openai",**kwargs):
8
  return get_openai_llm(**kwargs)
9
  elif provider == "azure":
10
  return get_azure_llm(**kwargs)
 
 
11
  else:
12
  raise ValueError(f"Unknown provider: {provider}")
13
 
 
1
  from climateqa.engine.llm.openai import get_llm as get_openai_llm
2
  from climateqa.engine.llm.azure import get_llm as get_azure_llm
3
+ from climateqa.engine.llm.ollama import get_llm as get_ollama_llm
4
 
5
 
6
  def get_llm(provider="openai",**kwargs):
 
9
  return get_openai_llm(**kwargs)
10
  elif provider == "azure":
11
  return get_azure_llm(**kwargs)
12
+ elif provider == "ollama":
13
+ return get_ollama_llm(**kwargs)
14
  else:
15
  raise ValueError(f"Unknown provider: {provider}")
16
 
climateqa/engine/llm/ollama.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+
3
+ from langchain_community.llms import Ollama
4
+
5
+ def get_llm(model="llama3", **kwargs):
6
+ return Ollama(model=model, **kwargs)
climateqa/engine/utils.py CHANGED
@@ -1,8 +1,15 @@
1
  from operator import itemgetter
2
  from typing import Any, Dict, Iterable, Tuple
 
3
  from langchain_core.runnables import RunnablePassthrough
4
 
5
 
 
 
 
 
 
 
6
  def pass_values(x):
7
  if not isinstance(x, list):
8
  x = [x]
@@ -67,3 +74,13 @@ def flatten_dict(
67
  """
68
  flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
69
  return flat_dict
 
 
 
 
 
 
 
 
 
 
 
1
  from operator import itemgetter
2
  from typing import Any, Dict, Iterable, Tuple
3
+ import tiktoken
4
  from langchain_core.runnables import RunnablePassthrough
5
 
6
 
7
+ def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
8
+ encoding = tiktoken.get_encoding(encoding_name)
9
+ num_tokens = len(encoding.encode(string))
10
+ return num_tokens
11
+
12
+
13
  def pass_values(x):
14
  if not isinstance(x, list):
15
  x = [x]
 
74
  """
75
  flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
76
  return flat_dict
77
+
78
+
79
+
80
+ async def log_event(info,name,config):
81
+ """Helper function that will run a dummy chain with the given info
82
+ The astream_event function will catch this chain and stream the dict info to the logger
83
+ """
84
+
85
+ chain = RunnablePassthrough().with_config(run_name=name)
86
+ _ = await chain.ainvoke(info,config)
climateqa/knowledge/__init__.py ADDED
File without changes
climateqa/{papers → knowledge}/openalex.py RENAMED
@@ -3,18 +3,32 @@ import networkx as nx
3
  import matplotlib.pyplot as plt
4
  from pyvis.network import Network
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
7
  import pyalex
8
 
9
  pyalex.config.email = "[email protected]"
10
 
 
 
 
 
11
  class OpenAlex():
12
  def __init__(self):
13
  pass
14
 
15
 
16
-
17
- def search(self,keywords,n_results = 100,after = None,before = None):
18
 
19
  if isinstance(keywords,str):
20
  works = Works().search(keywords)
@@ -27,18 +41,21 @@ class OpenAlex():
27
  break
28
 
29
  df_works = pd.DataFrame(page)
30
- df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x))
 
 
31
  df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
32
  df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
33
- df_works["content"] = df_works["title"] + "\n" + df_works["abstract"]
34
-
 
 
 
 
 
 
35
  else:
36
- df_works = []
37
- for keyword in keywords:
38
- df_keyword = self.search(keyword,n_results = n_results,after = after,before = before)
39
- df_works.append(df_keyword)
40
- df_works = pd.concat(df_works,ignore_index=True,axis = 0)
41
- return df_works
42
 
43
 
44
  def rerank(self,query,df,reranker):
@@ -139,4 +156,36 @@ class OpenAlex():
139
  reconstructed[position] = token
140
 
141
  # Join the tokens to form the reconstructed sentence(s)
142
- return ' '.join(reconstructed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import matplotlib.pyplot as plt
4
  from pyvis.network import Network
5
 
6
+ from langchain_core.retrievers import BaseRetriever
7
+ from langchain_core.vectorstores import VectorStoreRetriever
8
+ from langchain_core.documents.base import Document
9
+ from langchain_core.vectorstores import VectorStore
10
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
11
+
12
+ from ..engine.utils import num_tokens_from_string
13
+
14
+ from typing import List
15
+ from pydantic import Field
16
+
17
  from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
18
  import pyalex
19
 
20
  pyalex.config.email = "[email protected]"
21
 
22
+
23
+ def replace_nan_with_empty_dict(x):
24
+ return x if pd.notna(x) else {}
25
+
26
  class OpenAlex():
27
  def __init__(self):
28
  pass
29
 
30
 
31
+ def search(self,keywords:str,n_results = 100,after = None,before = None):
 
32
 
33
  if isinstance(keywords,str):
34
  works = Works().search(keywords)
 
41
  break
42
 
43
  df_works = pd.DataFrame(page)
44
+ df_works = df_works.dropna(subset = ["title"])
45
+ df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
46
+ df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
47
  df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
48
  df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
49
+ df_works["url"] = df_works["id"]
50
+ df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
51
+ df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
52
+
53
+ df_works = df_works.drop(columns = ["abstract_inverted_index"])
54
+ # df_works["subtitle"] = df_works["title"] + " - " + df_works["primary_location"]["source"]["display_name"] + " - " + df_works["publication_year"]
55
+
56
+ return df_works
57
  else:
58
+ raise Exception("Keywords must be a string")
 
 
 
 
 
59
 
60
 
61
  def rerank(self,query,df,reranker):
 
156
  reconstructed[position] = token
157
 
158
  # Join the tokens to form the reconstructed sentence(s)
159
+ return ' '.join(reconstructed)
160
+
161
+
162
+
163
+ class OpenAlexRetriever(BaseRetriever):
164
+ min_year:int = 1960
165
+ max_year:int = None
166
+ k:int = 100
167
+
168
+ def _get_relevant_documents(
169
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
170
+ ) -> List[Document]:
171
+
172
+ openalex = OpenAlex()
173
+
174
+ # Search for documents
175
+ df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)
176
+
177
+ docs = []
178
+ for i,row in df_docs.iterrows():
179
+ num_tokens = row["num_tokens"]
180
+
181
+ if num_tokens < 50 or num_tokens > 1000:
182
+ continue
183
+
184
+ doc = Document(
185
+ page_content = row["content"],
186
+ metadata = row.to_dict()
187
+ )
188
+ docs.append(doc)
189
+ return docs
190
+
191
+
climateqa/{engine → knowledge}/retriever.py RENAMED
@@ -67,6 +67,7 @@ class ClimateQARetriever(BaseRetriever):
67
  # Add score to metadata
68
  results = []
69
  for i,(doc,score) in enumerate(docs):
 
70
  doc.metadata["similarity_score"] = score
71
  doc.metadata["content"] = doc.page_content
72
  doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
@@ -79,86 +80,3 @@ class ClimateQARetriever(BaseRetriever):
79
  return results
80
 
81
 
82
-
83
-
84
- # def filter_summaries(df,k_summary = 3,k_total = 10):
85
- # # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
86
-
87
- # # # Filter by source
88
- # # if source == "IPCC":
89
- # # df = df.loc[df["source"]=="IPCC"]
90
- # # elif source == "IPBES":
91
- # # df = df.loc[df["source"]=="IPBES"]
92
- # # else:
93
- # # pass
94
-
95
- # # Separate summaries and full reports
96
- # df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
97
- # df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
98
-
99
- # # Find passages from summaries dataset
100
- # passages_summaries = df_summaries.head(k_summary)
101
-
102
- # # Find passages from full reports dataset
103
- # passages_fullreports = df_full.head(k_total - len(passages_summaries))
104
-
105
- # # Concatenate passages
106
- # passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
107
- # return passages
108
-
109
-
110
-
111
-
112
- # def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
113
- # assert max_k > k_total
114
-
115
- # validated_sources = ["IPCC","IPBES"]
116
- # sources = [x for x in sources if x in validated_sources]
117
- # filters = {
118
- # "source": { "$in": sources },
119
- # }
120
- # print(filters)
121
-
122
- # # Retrieve documents
123
- # docs = retriever.retrieve(query,top_k = max_k,filters = filters)
124
-
125
- # # Filter by score
126
- # docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
127
-
128
- # if len(docs) == 0:
129
- # return []
130
- # res = pd.DataFrame(docs)
131
- # passages_df = filter_summaries(res,k_summary,k_total)
132
- # if as_dict:
133
- # contents = passages_df["content"].tolist()
134
- # meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
135
- # passages = []
136
- # for i in range(len(contents)):
137
- # passages.append({"content":contents[i],"meta":meta[i]})
138
- # return passages
139
- # else:
140
- # return passages_df
141
-
142
-
143
-
144
- # def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
145
-
146
-
147
- # print("hellooooo")
148
-
149
- # # Reformulate queries
150
- # reformulated_query,language = reformulate(query)
151
-
152
- # print(reformulated_query)
153
-
154
- # # Retrieve documents
155
- # passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
156
- # response = {
157
- # "query":query,
158
- # "reformulated_query":reformulated_query,
159
- # "language":language,
160
- # "sources":passages,
161
- # "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
162
- # }
163
- # return response
164
-
 
67
  # Add score to metadata
68
  results = []
69
  for i,(doc,score) in enumerate(docs):
70
+ doc.page_content = doc.page_content.replace("\r\n"," ")
71
  doc.metadata["similarity_score"] = score
72
  doc.metadata["content"] = doc.page_content
73
  doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
 
80
  return results
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/papers/__init__.py DELETED
@@ -1,43 +0,0 @@
1
- import pandas as pd
2
-
3
- from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
4
- import pyalex
5
-
6
- pyalex.config.email = "[email protected]"
7
-
8
- class OpenAlex():
9
- def __init__(self):
10
- pass
11
-
12
-
13
-
14
- def search(self,keywords,n_results = 100,after = None,before = None):
15
- works = Works().search(keywords).get()
16
-
17
- for page in works.paginate(per_page=n_results):
18
- break
19
-
20
- df_works = pd.DataFrame(page)
21
-
22
- return works
23
-
24
-
25
- def make_network(self):
26
- pass
27
-
28
-
29
- def get_abstract_from_inverted_index(self,index):
30
-
31
- # Determine the maximum index to know the length of the reconstructed array
32
- max_index = max([max(positions) for positions in index.values()])
33
-
34
- # Initialize a list with placeholders for all positions
35
- reconstructed = [''] * (max_index + 1)
36
-
37
- # Iterate through the inverted index and place each token at its respective position(s)
38
- for token, positions in index.items():
39
- for position in positions:
40
- reconstructed[position] = token
41
-
42
- # Join the tokens to form the reconstructed sentence(s)
43
- return ' '.join(reconstructed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio==4.19.1
2
  azure-storage-file-share==12.11.1
3
  azure-storage-blob
4
  python-dotenv==1.0.0
@@ -15,3 +15,5 @@ flashrank==0.2.5
15
  rerankers==0.3.0
16
  torch==2.3.0
17
  nvidia-cudnn-cu12==8.9.2.26
 
 
 
1
+ gradio==4.44
2
  azure-storage-file-share==12.11.1
3
  azure-storage-blob
4
  python-dotenv==1.0.0
 
15
  rerankers==0.3.0
16
  torch==2.3.0
17
  nvidia-cudnn-cu12==8.9.2.26
18
+ langchain-community==0.2
19
+ msal==1.31
sandbox/20240310 - CQA - Semantic Routing 1.ipynb CHANGED
The diff for this file is too large to render. See raw diff