from climateqa.engine.embeddings import get_embeddings_function embeddings_function = get_embeddings_function() import gradio as gr import pandas as pd import numpy as np import os import time import re import json from gradio_modal import Modal from io import BytesIO import base64 from datetime import datetime from azure.storage.fileshare import ShareServiceClient from utils import create_user_id # ClimateQ&A imports from climateqa.engine.llm import get_llm from climateqa.engine.rag import make_rag_chain from climateqa.engine.vectorstore import get_pinecone_vectorstore from climateqa.engine.retriever import ClimateQARetriever from climateqa.engine.embeddings import get_embeddings_function from climateqa.engine.prompts import audience_prompts from climateqa.sample_questions import QUESTIONS from climateqa.constants import POSSIBLE_REPORTS from climateqa.utils import get_image_from_azure_blob_storage # Load environment variables in local mode try: from dotenv import load_dotenv load_dotenv() except Exception as e: pass # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", secondary_hue="red", font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) init_prompt = "" system_template = { "role": "system", "content": init_prompt, } account_key = os.environ["BLOB_ACCOUNT_KEY"] if len(account_key) == 86: account_key += "==" credential = { "account_key": account_key, "account_name": os.environ["BLOB_ACCOUNT_NAME"], } account_url = os.environ["BLOB_ACCOUNT_URL"] file_share_name = "climateqa" service = ShareServiceClient(account_url=account_url, credential=credential) share_client = service.get_share_client(file_share_name) user_id = create_user_id() def parse_output_llm_with_sources(output): # Split the content into a list of text and "[Doc X]" references content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output) parts = [] for part in content_parts: if part.startswith("Doc"): subparts = part.split(",") subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts] subparts = [f"{subpart}" for subpart in subparts] parts.append("".join(subparts)) else: parts.append(part) content_parts = "".join(parts) return content_parts # Create vectorstore and retriever vectorstore = get_pinecone_vectorstore(embeddings_function) llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) def make_pairs(lst): """from a list of even lenght, make tupple pairs""" return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)] def serialize_docs(docs): new_docs = [] for doc in docs: new_doc = {} new_doc["page_content"] = doc.page_content new_doc["metadata"] = doc.metadata new_docs.append(new_doc) return new_docs async def chat(query,history,audience,sources,reports): """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: (messages in gradio format, messages in langchain format, source documents)""" print(f">> NEW QUESTION : {query}") if audience == "Children": audience_prompt = audience_prompts["children"] elif audience == "General public": audience_prompt = audience_prompts["general"] elif audience == "Experts": audience_prompt = audience_prompts["experts"] else: audience_prompt = audience_prompts["experts"] # Prepare default values if len(sources) == 0: sources = ["IPCC"] if len(reports) == 0: reports = [] retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5) rag_chain = make_rag_chain(retriever,llm) inputs = {"query": query,"audience": audience_prompt} result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]}) # result = rag_chain.stream(inputs) path_reformulation = "/logs/reformulation/final_output" path_retriever = "/logs/find_documents/final_output" path_answer = "/logs/answer/streamed_output_str/-" docs_html = "" output_query = "" output_language = "" gallery = [] try: async for op in result: op = op.ops[0] # print("ITERATION",op) if op['path'] == path_reformulation: # reforulated question try: output_language = op['value']["language"] # str output_query = op["value"]["question"] except Exception as e: raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)") elif op['path'] == path_retriever: # documents try: docs = op['value']['docs'] # List[Document] docs_html = [] for i, d in enumerate(docs, 1): docs_html.append(make_html_source(d, i)) docs_html = "".join(docs_html) except TypeError: print("No documents found") print("op: ",op) continue elif op['path'] == path_answer: # final answer new_token = op['value'] # str # time.sleep(0.01) previous_answer = history[-1][1] previous_answer = previous_answer if previous_answer is not None else "" answer_yet = previous_answer + new_token answer_yet = parse_output_llm_with_sources(answer_yet) history[-1] = (query,answer_yet) # elif op['path'] == final_output_path_id: # final_output = op['value'] # if "answer" in final_output: # final_output = final_output["answer"] # print(final_output) # answer = history[-1][1] + final_output # answer = parse_output_llm_with_sources(answer) # history[-1] = (query,answer) else: continue history = [tuple(x) for x in history] yield history,docs_html,output_query,output_language,gallery except Exception as e: raise gr.Error(f"{e}") try: # Log answer on Azure Blob Storage if os.getenv("GRADIO_ENV") != "local": timestamp = str(datetime.now().timestamp()) file = timestamp + ".json" prompt = history[-1][0] logs = { "user_id": str(user_id), "prompt": prompt, "query": prompt, "question":output_query, "sources":sources, "docs":serialize_docs(docs), "answer": history[-1][1], "time": timestamp, } log_on_azure(file, logs, share_client) except Exception as e: print(f"Error logging on Azure Blob Storage: {e}") raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") image_dict = {} for i,doc in enumerate(docs): if doc.metadata["chunk_type"] == "image": try: key = f"Image {i+1}" image_path = doc.metadata["image_path"].split("documents/")[1] img = get_image_from_azure_blob_storage(image_path) # Convert the image to a byte buffer buffered = BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() # Embedding the base64 string in Markdown markdown_image = f"![Alt text](data:image/png;base64,{img_str})" image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]} except Exception as e: print(f"Skipped adding image {i} because of {e}") if len(image_dict) > 0: gallery = [x["img"] for x in list(image_dict.values())] img = list(image_dict.values())[0] img_md = img["md"] img_caption = img["caption"] img_code = img["figure_code"] if img_code != "N/A": img_name = f"{img['key']} - {img['figure_code']}" else: img_name = f"{img['key']}" answer_yet = history[-1][1] + f"\n\n{img_md}\n
" history[-1] = (history[-1][0],answer_yet) history = [tuple(x) for x in history] # gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])] # if len(gallery) > 0: # gallery = list(set("|".join(gallery).split("|"))) # gallery = [get_image_from_azure_blob_storage(x) for x in gallery] yield history,docs_html,output_query,output_language,gallery # memory.save_context(inputs, {"answer": gradio_format[-1][1]}) # yield gradio_format, memory.load_memory_variables({})["history"], source_string # async def chat_with_timeout(query, history, audience, sources, reports, timeout_seconds=2): # async def timeout_gen(async_gen, timeout): # try: # while True: # try: # yield await asyncio.wait_for(async_gen.__anext__(), timeout) # except StopAsyncIteration: # break # except asyncio.TimeoutError: # raise gr.Error("Operation timed out. Please try again.") # return timeout_gen(chat(query, history, audience, sources, reports), timeout_seconds) # # A wrapper function that includes a timeout # async def chat_with_timeout(query, history, audience, sources, reports, timeout_seconds=2): # try: # # Use asyncio.wait_for to apply a timeout to the chat function # return await asyncio.wait_for(chat(query, history, audience, sources, reports), timeout_seconds) # except asyncio.TimeoutError: # # Handle the timeout error as desired # raise gr.Error("Operation timed out. Please try again.") def make_html_source(source,i): meta = source.metadata # content = source.page_content.split(":",1)[1].strip() content = source.page_content.strip() toc_levels = [] for j in range(2): level = meta[f"toc_level{j}"] if level != "N/A": toc_levels.append(level) else: break toc_levels = " > ".join(toc_levels) if len(toc_levels) > 0: name = f"{toc_levels}{content}
{content}
AI-generated description