import sys |
import os |
from contextlib import contextmanager |
from langchain.schema import Document |
from langgraph.graph import END, StateGraph |
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod |
from typing_extensions import TypedDict |
from typing import List, Dict |
from IPython.display import display, HTML, Image |
from .chains.answer_chitchat import make_chitchat_node |
from .chains.answer_ai_impact import make_ai_impact_node |
from .chains.query_transformation import make_query_transform_node |
from .chains.translation import make_translation_node |
from .chains.intent_categorization import make_intent_categorization_node |
from .chains.retrieve_documents import make_retriever_node |
from .chains.answer_rag import make_rag_node |
from .chains.graph_retriever import make_graph_retriever_node |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node |
from .chains.set_defaults import set_defaults |
class GraphState(TypedDict): |
""" |
Represents the state of our graph. |
""" |
user_input : str |
language : str |
intent : str |
search_graphs_chitchat : bool |
query: str |
remaining_questions : List[dict] |
n_questions : int |
answer: str |
audience: str = "experts" |
sources_input: List[str] = ["IPCC","IPBES"] |
relevant_content_sources: List[str] = ["IPCC figures"] |
sources_auto: bool = True |
min_year: int = 1960 |
max_year: int = None |
documents: List[Document] |
related_contents : Dict[str,Document] |
recommended_content : List[Document] |
def search(state): |
return state |
def answer_search(state): |
return state |
def route_intent(state): |
intent = state["intent"] |
if intent in ["chitchat","esg"]: |
return "answer_chitchat" |
else: |
return "search" |
def chitchat_route_intent(state): |
intent = state["search_graphs_chitchat"] |
if intent is True: |
return "retrieve_graphs_chitchat" |
elif intent is False: |
return END |
def route_translation(state): |
if state["language"].lower() == "english": |
return "transform_query" |
else: |
return "translate_query" |
def route_based_on_relevant_docs(state,threshold_docs=0.2): |
docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs] |
if len(docs) > 0: |
return "answer_rag" |
else: |
return "answer_rag_no_docs" |
def make_id_dict(values): |
return {k:k for k in values} |
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, threshold_docs=0.2): |
workflow = StateGraph(GraphState) |
categorize_intent = make_intent_categorization_node(llm) |
transform_query = make_query_transform_node(llm) |
translate_query = make_translation_node(llm) |
answer_chitchat = make_chitchat_node(llm) |
answer_ai_impact = make_ai_impact_node(llm) |
retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker, llm) |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker) |
answer_rag = make_rag_node(llm, with_docs=True) |
answer_rag_no_docs = make_rag_node(llm, with_docs=False) |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm) |
workflow.add_node("categorize_intent", categorize_intent) |
workflow.add_node("search", search) |
workflow.add_node("answer_search", answer_search) |
workflow.add_node("transform_query", transform_query) |
workflow.add_node("translate_query", translate_query) |
workflow.add_node("answer_chitchat", answer_chitchat) |
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent) |
workflow.add_node("retrieve_graphs", retrieve_graphs) |
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs) |
workflow.add_node("retrieve_documents", retrieve_documents) |
workflow.add_node("answer_rag", answer_rag) |
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs) |
workflow.set_entry_point("categorize_intent") |
workflow.add_conditional_edges( |
"categorize_intent", |
route_intent, |
make_id_dict(["answer_chitchat","search"]) |
) |
workflow.add_conditional_edges( |
"chitchat_categorize_intent", |
chitchat_route_intent, |
make_id_dict(["retrieve_graphs_chitchat", END]) |
) |
workflow.add_conditional_edges( |
"search", |
route_translation, |
make_id_dict(["translate_query","transform_query"]) |
) |
workflow.add_conditional_edges( |
"retrieve_documents", |
lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search", |
make_id_dict(["retrieve_documents","answer_search"]) |
) |
workflow.add_conditional_edges( |
"answer_search", |
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs), |
make_id_dict(["answer_rag","answer_rag_no_docs"]) |
) |
workflow.add_conditional_edges( |
"transform_query", |
lambda state : "retrieve_graphs" if "OurWorldInData" in state["relevant_content_sources"] else END, |
make_id_dict(["retrieve_graphs", END]) |
) |
workflow.add_edge("translate_query", "transform_query") |
workflow.add_edge("transform_query", "retrieve_documents") |
workflow.add_edge("retrieve_graphs", END) |
workflow.add_edge("answer_rag", END) |
workflow.add_edge("answer_rag_no_docs", END) |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent") |
app = workflow.compile() |
return app |
def display_graph(app): |
display( |
Image( |
app.get_graph(xray = True).draw_mermaid_png( |
draw_method=MermaidDrawMethod.API, |
) |
) |
) |