|
from climateqa.engine.embeddings import get_embeddings_function |
|
embeddings_function = get_embeddings_function() |
|
|
|
from climateqa.knowledge.openalex import OpenAlex |
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
oa = OpenAlex() |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import time |
|
import re |
|
import json |
|
|
|
|
|
|
|
from io import BytesIO |
|
import base64 |
|
|
|
from datetime import datetime |
|
from azure.storage.fileshare import ShareServiceClient |
|
|
|
from utils import create_user_id |
|
|
|
from langchain_chroma import Chroma |
|
from collections import defaultdict |
|
|
|
|
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.vectorstore import get_pinecone_vectorstore |
|
from climateqa.knowledge.retriever import ClimateQARetriever |
|
from climateqa.engine.reranker import get_reranker |
|
from climateqa.engine.embeddings import get_embeddings_function |
|
from climateqa.engine.chains.prompts import audience_prompts |
|
from climateqa.sample_questions import QUESTIONS |
|
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES |
|
from climateqa.utils import get_image_from_azure_blob_storage |
|
from climateqa.engine.keywords import make_keywords_chain |
|
|
|
from climateqa.engine.graph import make_graph_agent,display_graph |
|
from climateqa.engine.embeddings import get_embeddings_function |
|
|
|
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs |
|
|
|
|
|
try: |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
except Exception as e: |
|
pass |
|
|
|
|
|
|
|
theme = gr.themes.Base( |
|
primary_hue="blue", |
|
secondary_hue="red", |
|
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
|
|
|
|
init_prompt = "" |
|
|
|
system_template = { |
|
"role": "system", |
|
"content": init_prompt, |
|
} |
|
|
|
account_key = os.environ["BLOB_ACCOUNT_KEY"] |
|
if len(account_key) == 86: |
|
account_key += "==" |
|
|
|
credential = { |
|
"account_key": account_key, |
|
"account_name": os.environ["BLOB_ACCOUNT_NAME"], |
|
} |
|
|
|
account_url = os.environ["BLOB_ACCOUNT_URL"] |
|
file_share_name = "climateqa" |
|
service = ShareServiceClient(account_url=account_url, credential=credential) |
|
share_client = service.get_share_client(file_share_name) |
|
|
|
user_id = create_user_id() |
|
|
|
|
|
embeddings_function = get_embeddings_function() |
|
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) |
|
reranker = get_reranker("nano") |
|
|
|
|
|
vectorstore = get_pinecone_vectorstore(embeddings_function) |
|
vectorstore_graphs = Chroma(persist_directory="/home/tim/ai4s/climate_qa/dora/climate-question-answering-graphs/climate-question-answering-graphs/vectorstore_owid", embedding_function=embeddings_function) |
|
|
|
|
|
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker) |
|
|
|
async def chat(query,history,audience,sources,reports,current_graphs): |
|
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: |
|
(messages in gradio format, messages in langchain format, source documents)""" |
|
|
|
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
print(f">> NEW QUESTION ({date_now}) : {query}") |
|
|
|
if audience == "Children": |
|
audience_prompt = audience_prompts["children"] |
|
elif audience == "General public": |
|
audience_prompt = audience_prompts["general"] |
|
elif audience == "Experts": |
|
audience_prompt = audience_prompts["experts"] |
|
else: |
|
audience_prompt = audience_prompts["experts"] |
|
|
|
|
|
if sources is None or len(sources) == 0: |
|
sources = ["IPCC", "IPBES", "IPOS"] |
|
|
|
if reports is None or len(reports) == 0: |
|
reports = [] |
|
|
|
inputs = {"user_input": query,"audience": audience_prompt,"sources":sources} |
|
print(f"\n\nInputs:\n {inputs}\n\n") |
|
result = agent.astream_events(inputs,version = "v1") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs = [] |
|
docs_used = True |
|
docs_html = "" |
|
current_graphs = [] |
|
output_query = "" |
|
output_language = "" |
|
output_keywords = "" |
|
gallery = [] |
|
updates = [] |
|
start_streaming = False |
|
|
|
steps_display = { |
|
"categorize_intent":("🔄️ Analyzing user message",True), |
|
"transform_query":("🔄️ Thinking step by step to answer the question",True), |
|
"retrieve_documents":("🔄️ Searching in the knowledge base",False), |
|
} |
|
|
|
try: |
|
async for event in result: |
|
|
|
if event["event"] == "on_chat_model_stream" and event["metadata"]["langgraph_node"] in ["answer_rag", "answer_rag_no_docs", "answer_chitchat", "answer_ai_impact"]: |
|
if start_streaming == False: |
|
start_streaming = True |
|
history[-1] = (query,"") |
|
|
|
new_token = event["data"]["chunk"].content |
|
|
|
previous_answer = history[-1][1] |
|
previous_answer = previous_answer if previous_answer is not None else "" |
|
answer_yet = previous_answer + new_token |
|
answer_yet = parse_output_llm_with_sources(answer_yet) |
|
history[-1] = (query,answer_yet) |
|
|
|
if docs_used is True and event["metadata"]["langgraph_node"] in ["answer_rag_no_docs", "answer_chitchat", "answer_ai_impact"]: |
|
docs_used = False |
|
|
|
elif docs_used is True and event["name"] == "retrieve_documents" and event["event"] == "on_chain_end": |
|
try: |
|
docs = event["data"]["output"]["documents"] |
|
docs_html = [] |
|
for i, d in enumerate(docs, 1): |
|
docs_html.append(make_html_source(d, i)) |
|
docs_html = "".join(docs_html) |
|
|
|
except Exception as e: |
|
print(f"Error getting documents: {e}") |
|
print(event) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end": |
|
try: |
|
recommended_content = event["data"]["output"]["recommended_content"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_graphs = [] |
|
seen_embeddings = set() |
|
|
|
for x in recommended_content: |
|
embedding = x.metadata["returned_content"] |
|
|
|
|
|
if embedding not in seen_embeddings: |
|
unique_graphs.append({ |
|
"embedding": embedding, |
|
"metadata": { |
|
"source": x.metadata["source"], |
|
"category": x.metadata["category"] |
|
} |
|
}) |
|
|
|
seen_embeddings.add(embedding) |
|
|
|
|
|
categories = {} |
|
for graph in unique_graphs: |
|
category = graph['metadata']['category'] |
|
if category not in categories: |
|
categories[category] = [] |
|
categories[category].append(graph['embedding']) |
|
|
|
|
|
for category, embeddings in categories.items(): |
|
|
|
|
|
for embedding in embeddings: |
|
current_graphs.append([embedding, category]) |
|
|
|
|
|
except Exception as e: |
|
print(f"Error getting graphs: {e}") |
|
|
|
for event_name,(event_description,display_output) in steps_display.items(): |
|
if event["name"] == event_name: |
|
if event["event"] == "on_chain_start": |
|
|
|
|
|
answer_yet = event_description |
|
|
|
history[-1] = (query,answer_yet) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history = [tuple(x) for x in history] |
|
yield history,docs_html,output_query,output_language,gallery,current_graphs |
|
|
|
|
|
except Exception as e: |
|
raise gr.Error(f"{e}") |
|
|
|
|
|
try: |
|
|
|
if os.getenv("GRADIO_ENV") != "local": |
|
timestamp = str(datetime.now().timestamp()) |
|
file = timestamp + ".json" |
|
prompt = history[-1][0] |
|
logs = { |
|
"user_id": str(user_id), |
|
"prompt": prompt, |
|
"query": prompt, |
|
"question":output_query, |
|
"sources":sources, |
|
"docs":serialize_docs(docs), |
|
"answer": history[-1][1], |
|
"time": timestamp, |
|
} |
|
log_on_azure(file, logs, share_client) |
|
except Exception as e: |
|
print(f"Error logging on Azure Blob Storage: {e}") |
|
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") |
|
|
|
image_dict = {} |
|
for i,doc in enumerate(docs): |
|
|
|
if doc.metadata["chunk_type"] == "image": |
|
try: |
|
key = f"Image {i+1}" |
|
image_path = doc.metadata["image_path"].split("documents/")[1] |
|
img = get_image_from_azure_blob_storage(image_path) |
|
|
|
|
|
buffered = BytesIO() |
|
img.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
markdown_image = f"![Alt text](data:image/png;base64,{img_str})" |
|
image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]} |
|
except Exception as e: |
|
print(f"Skipped adding image {i} because of {e}") |
|
|
|
if len(image_dict) > 0: |
|
|
|
gallery = [x["img"] for x in list(image_dict.values())] |
|
img = list(image_dict.values())[0] |
|
img_md = img["md"] |
|
img_caption = img["caption"] |
|
img_code = img["figure_code"] |
|
if img_code != "N/A": |
|
img_name = f"{img['key']} - {img['figure_code']}" |
|
else: |
|
img_name = f"{img['key']}" |
|
|
|
answer_yet = history[-1][1] + f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>" |
|
history[-1] = (history[-1][0],answer_yet) |
|
history = [tuple(x) for x in history] |
|
|
|
print(f"\n\nImages:\n{gallery}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
yield history,docs_html,output_query,output_language,gallery,current_graphs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_feedback(feed: str, user_id): |
|
if len(feed) > 1: |
|
timestamp = str(datetime.now().timestamp()) |
|
file = user_id + timestamp + ".json" |
|
logs = { |
|
"user_id": user_id, |
|
"feedback": feed, |
|
"time": timestamp, |
|
} |
|
log_on_azure(file, logs, share_client) |
|
return "Feedback submitted, thank you!" |
|
|
|
|
|
|
|
|
|
def log_on_azure(file, logs, share_client): |
|
logs = json.dumps(logs) |
|
file_client = share_client.get_file_client(file) |
|
file_client.upload_file(logs) |
|
|
|
|
|
def generate_keywords(query): |
|
chain = make_keywords_chain(llm) |
|
keywords = chain.invoke(query) |
|
keywords = " AND ".join(keywords["keywords"]) |
|
return keywords |
|
|
|
|
|
|
|
papers_cols_widths = { |
|
"doc":50, |
|
"id":100, |
|
"title":300, |
|
"doi":100, |
|
"publication_year":100, |
|
"abstract":500, |
|
"rerank_score":100, |
|
"is_oa":50, |
|
} |
|
|
|
papers_cols = list(papers_cols_widths.keys()) |
|
papers_cols_widths = list(papers_cols_widths.values()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
init_prompt = """ |
|
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**. |
|
|
|
❓ How to use |
|
- **Language**: You can ask me your questions in any language. |
|
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer. |
|
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both. |
|
|
|
⚠️ Limitations |
|
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* |
|
|
|
What do you want to learn ? |
|
""" |
|
|
|
|
|
def vote(data: gr.LikeData): |
|
if data.liked: |
|
print(data.value) |
|
else: |
|
print(data) |
|
|
|
def save_graph(saved_graphs_state, embedding, category): |
|
print(f"\nCategory:\n{saved_graphs_state}\n") |
|
if category not in saved_graphs_state: |
|
saved_graphs_state[category] = [] |
|
if embedding not in saved_graphs_state[category]: |
|
saved_graphs_state[category].append(embedding) |
|
return saved_graphs_state, gr.Button("Graph Saved") |
|
|
|
|
|
with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo: |
|
user_id_state = gr.State([user_id]) |
|
|
|
chat_completed_state = gr.State(0) |
|
current_graphs = gr.State([]) |
|
saved_graphs = gr.State({}) |
|
|
|
with gr.Tab("ClimateQ&A"): |
|
|
|
with gr.Row(elem_id="chatbot-row"): |
|
with gr.Column(scale=2): |
|
state = gr.State([system_template]) |
|
chatbot = gr.Chatbot( |
|
value=[(None,init_prompt)], |
|
show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel", |
|
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(elem_id = "input-message"): |
|
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox") |
|
|
|
|
|
|
|
with gr.Column(scale=1, variant="panel",elem_id = "right-panel"): |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0): |
|
|
|
examples_hidden = gr.Textbox(visible = False) |
|
first_key = list(QUESTIONS.keys())[0] |
|
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples") |
|
|
|
samples = [] |
|
for i,key in enumerate(QUESTIONS.keys()): |
|
|
|
examples_visible = True if i == 0 else False |
|
|
|
with gr.Row(visible = examples_visible) as group_examples: |
|
|
|
examples_questions = gr.Examples( |
|
QUESTIONS[key], |
|
[examples_hidden], |
|
examples_per_page=8, |
|
run_on_click=False, |
|
elem_id=f"examples{i}", |
|
api_name=f"examples{i}", |
|
|
|
|
|
) |
|
|
|
samples.append(group_examples) |
|
|
|
|
|
with gr.Tab("Sources",elem_id = "tab-citations",id = 1): |
|
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") |
|
docs_textbox = gr.State("") |
|
|
|
|
|
with gr.Tab("Configuration",elem_id = "tab-config",id = 2): |
|
|
|
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!") |
|
|
|
|
|
dropdown_sources = gr.CheckboxGroup( |
|
["IPCC", "IPBES","IPOS"], |
|
label="Select source", |
|
value=["IPCC", "IPBES","IPOS"], |
|
interactive=True, |
|
) |
|
|
|
dropdown_reports = gr.Dropdown( |
|
POSSIBLE_REPORTS, |
|
label="Or select specific reports", |
|
multiselect=True, |
|
value=None, |
|
interactive=True, |
|
) |
|
|
|
dropdown_audience = gr.Dropdown( |
|
["Children","General public","Experts"], |
|
label="Select audience", |
|
value="Experts", |
|
interactive=True, |
|
) |
|
|
|
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False) |
|
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"): |
|
gallery_component = gr.Gallery() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("About",elem_classes = "max-height other-tabs"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)") |
|
|
|
|
|
def start_chat(query,history): |
|
history = history + [(query,None)] |
|
history = [tuple(x) for x in history] |
|
return (gr.update(interactive = False),gr.update(selected=1),history) |
|
|
|
def finish_chat(): |
|
return (gr.update(interactive = True,value = ""),gr.update(selected=3)) |
|
|
|
def change_completion_status(current_state): |
|
current_state = 1 - current_state |
|
return current_state |
|
|
|
|
|
(textbox |
|
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox") |
|
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, current_graphs],concurrency_limit = 8,api_name = "chat_textbox") |
|
.then(finish_chat, None, [textbox,tabs],api_name = "finish_chat_textbox") |
|
.then(change_completion_status, [chat_completed_state], [chat_completed_state]) |
|
) |
|
|
|
(examples_hidden |
|
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples") |
|
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports,current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, current_graphs],concurrency_limit = 8,api_name = "chat_examples") |
|
.then(finish_chat, None, [textbox,tabs],api_name = "finish_chat_examples") |
|
.then(change_completion_status, [chat_completed_state], [chat_completed_state]) |
|
) |
|
|
|
|
|
def change_sample_questions(key): |
|
index = list(QUESTIONS.keys()).index(key) |
|
visible_bools = [False] * len(samples) |
|
visible_bools[index] = True |
|
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))] |
|
|
|
|
|
|
|
dropdown_samples.change(change_sample_questions,dropdown_samples,samples) |
|
|
|
|
|
|
|
|
|
demo.queue() |
|
|
|
demo.launch(debug=True) |