Zeta / app.py
Ritvik19's picture
Upload 2 files
7112b8c verified
raw
history blame
14.8 kB
import streamlit as st
import os
import pandas as pd
from command_center import CommandCenter
from process_documents import process_documents, num_tokens
from embed_documents import create_retriever
import json
from langchain.callbacks import get_openai_callback
from langchain_openai import ChatOpenAI
import base64
from chat_chains import (
parse_model_response,
qa_chain,
format_docs,
parse_context_and_question,
ai_response_format,
)
from autoqa_chains import auto_qa_chain, auto_qa_output_parser, followup_qa_chain
from chain_of_density import chain_of_density_chain
from insights_bullet_chain import insights_bullet_chain
from insights_mind_map_chain import insights_mind_map_chain
from synopsis_chain import synopsis_chain
from custom_exceptions import InvalidArgumentError, InvalidCommandError
from openai_configuration import openai_parser
from summary_chain import summary_chain
st.set_page_config(layout="wide")
welcome_message = """
Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible.
Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you.
Here's a quick guide to getting started with me:
| Command | Description |
|---------|-------------|
| `/configure --key <api key> --model <model>` | Configure the OpenAI API key and model for our conversation. |
| `/add-papers <list of urls>` | Upload and process documents for our conversation. |
| `/library` | View an index of processed documents to easily navigate your research. |
| `/view-snip <snippet id>` | View the content of a specific snnippet. |
| `/session-expense` | Calculate the cost of our conversation, ensuring transparency in resource usage. |
| `/export` | Download conversation data for your records or further analysis. |
| `/auto-insight <list of snippet ids>` | Automatically generate questions and answers for the paper. |
| `/condense-summary <list of snippet ids>` | Generate increasingly concise, entity-dense summaries of the paper. |
| `/insight-bullets <list of snippet ids>` | Extract and summarize key insights, methods, results, and conclusions. |
| `/insight-mind-map <list of snippet ids>` | Create a structured outline of the key insights in Markdown format. |
| `/paper-synopsis <list of snippet ids>` | Generate a synopsis of the paper. |
| `/deep-dive [<list of snippet ids>] <query>` | Query me with a specific context. |
| `/summarise-section [<list of snippet ids>] <section name>` | Summarize a specific section of the paper. |
<br>
Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together!
Use `/help-me` at any point of time to view this guide again.
"""
def process_documents_wrapper(inputs):
if inputs == []:
raise InvalidArgumentError("Please provide document urls")
snippets, documents = process_documents(inputs)
st.session_state.retriever = create_retriever(snippets)
st.session_state.source_doc_urls = inputs
st.session_state.index = [
[
snip.metadata["chunk_id"],
snip.metadata["header"],
num_tokens(snip.page_content),
]
for snip in snippets
]
response = f"Uploaded and processed documents {inputs}"
st.session_state.messages.append((f"/add-papers {inputs}", response, "identity"))
st.session_state.documents = documents
return (response, "identity")
def index_documents_wrapper(inputs=None):
response = pd.DataFrame(
st.session_state.index, columns=["id", "reference", "tokens"]
)
st.session_state.messages.append(("/library", response, "dataframe"))
return (response, "dataframe")
def view_document_wrapper(inputs):
response = st.session_state.documents[inputs].page_content
st.session_state.messages.append((f"/view-snip {inputs}", response, "identity"))
return (response, "identity")
def calculate_cost_wrapper(inputs=None):
try:
stats_df = pd.DataFrame(st.session_state.costing)
stats_df.loc["total"] = stats_df.sum()
response = stats_df
except ValueError:
response = "No cost incurred yet"
st.session_state.messages.append(("/session-expense", response, "dataframe"))
return (response, "dataframe")
def download_conversation_wrapper(inputs=None):
conversation_data = json.dumps(
{
"document_urls": (
st.session_state.source_doc_urls
if "source_doc_urls" in st.session_state
else []
),
"document_snippets": (
st.session_state.index if "index" in st.session_state else []
),
"conversation": [
{"human": message[0], "ai": jsonify_functions[message[2]](message[1])}
for message in st.session_state.messages
],
"costing": (
st.session_state.costing if "costing" in st.session_state else []
),
"total_cost": (
{
k: sum(d[k] for d in st.session_state.costing)
for k in st.session_state.costing[0]
}
if "costing" in st.session_state and len(st.session_state.costing) > 0
else {}
),
}
)
conversation_data = base64.b64encode(conversation_data.encode()).decode()
st.session_state.messages.append(
("/export", "Conversation data downloaded", "identity")
)
return (
f'<a href="data:text/csv;base64,{conversation_data}" download="conversation_data.json">Download Conversation</a>',
"identity",
)
def query_llm(inputs, relevant_docs):
with get_openai_callback() as cb:
response = (
qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0))
.invoke({"context": format_docs(relevant_docs), "question": inputs})
.content
)
stats = cb
response = parse_model_response(response)
answer = response["answer"]
citations = response["citations"]
citations.append(
{
"source_id": " ".join(
[
f"[{ref}]"
for ref in sorted(
[str(ref.metadata["chunk_id"]) for ref in relevant_docs],
)
]
),
"quote": "other sources",
}
)
st.session_state.messages.append(
(inputs, {"answer": answer, "citations": citations}, "reponse_with_citations")
)
st.session_state.costing.append(
{
"prompt tokens": stats.prompt_tokens,
"completion tokens": stats.completion_tokens,
"cost": stats.total_cost,
}
)
return ({"answer": answer, "citations": citations}, "reponse_with_citations")
def rag_llm_wrapper(inputs):
retriever = st.session_state.retriever
relevant_docs = retriever.get_relevant_documents(inputs)
return query_llm(inputs, relevant_docs)
def query_llm_wrapper(inputs):
context, question = parse_context_and_question(inputs)
relevant_docs = [st.session_state.documents[c] for c in context]
return query_llm(question, relevant_docs)
def summarise_wrapper(inputs):
context, query = parse_context_and_question(inputs)
document = [st.session_state.documents[c] for c in context]
llm = ChatOpenAI(model=st.session_state.model, temperature=0)
with get_openai_callback() as cb:
summary = summary_chain(llm).invoke({"section_name": query, "paper": document})
stats = cb
st.session_state.messages.append(
(f"/summarise-section {query}", summary, "identity")
)
st.session_state.costing.append(
{
"prompt tokens": stats.prompt_tokens,
"completion tokens": stats.completion_tokens,
"cost": stats.total_cost,
}
)
return (summary, "identity")
def chain_of_density_wrapper(inputs):
if inputs == []:
raise InvalidArgumentError("Please provide snippet ids")
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
llm = ChatOpenAI(model=st.session_state.model, temperature=0)
with get_openai_callback() as cb:
summary = chain_of_density_chain(llm).invoke({"paper": document})
stats = cb
st.session_state.messages.append(("/condense-summary", summary, "identity"))
st.session_state.costing.append(
{
"prompt tokens": stats.prompt_tokens,
"completion tokens": stats.completion_tokens,
"cost": stats.total_cost,
}
)
return (summary, "identity")
def synopsis_wrapper(inputs):
if inputs == []:
raise InvalidArgumentError("Please provide snippet ids")
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
llm = ChatOpenAI(model=st.session_state.model, temperature=0)
with get_openai_callback() as cb:
summary = synopsis_chain(llm).invoke({"paper": document})
stats = cb
st.session_state.messages.append(("/paper-synopsis", summary, "identity"))
st.session_state.costing.append(
{
"prompt tokens": stats.prompt_tokens,
"completion tokens": stats.completion_tokens,
"cost": stats.total_cost,
}
)
return (summary, "identity")
def insights_bullet_wrapper(inputs):
if inputs == []:
raise InvalidArgumentError("Please provide snippet ids")
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
llm = ChatOpenAI(model=st.session_state.model, temperature=0)
with get_openai_callback() as cb:
insights = insights_bullet_chain(llm).invoke({"paper": document})
stats = cb
st.session_state.messages.append(("/insight-bullets", insights, "identity"))
st.session_state.costing.append(
{
"prompt tokens": stats.prompt_tokens,
"completion tokens": stats.completion_tokens,
"cost": stats.total_cost,
}
)
return (insights, "identity")
def insights_mind_map_wrapper(inputs):
if inputs == []:
raise InvalidArgumentError("Please provide snippet ids")
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
llm = ChatOpenAI(model=st.session_state.model, temperature=0)
with get_openai_callback() as cb:
insights = insights_mind_map_chain(llm).invoke({"paper": document})
stats = cb
st.session_state.messages.append(("/insight-mind-map", insights, "identity"))
st.session_state.costing.append(
{
"prompt tokens": stats.prompt_tokens,
"completion tokens": stats.completion_tokens,
"cost": stats.total_cost,
}
)
return (insights, "identity")
def auto_qa_chain_wrapper(inputs):
if inputs == []:
raise InvalidArgumentError("Please provide snippet ids")
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
llm = ChatOpenAI(model=st.session_state.model, temperature=0)
with get_openai_callback() as cb:
auto_qa_response = auto_qa_output_parser.invoke(
auto_qa_chain(llm).invoke({"paper": document})
)["questions"]
formated_response = "\n\n".join(
f"#### {qa['question']}\n\n{qa['answer']}" for qa in auto_qa_response
)
stats = cb
st.session_state.messages.append(
(f"/auto-insight {inputs}", formated_response, "identity")
)
st.session_state.costing.append(
{
"prompt tokens": stats.prompt_tokens,
"completion tokens": stats.completion_tokens,
"cost": stats.total_cost,
}
)
return (
formated_response,
"identity",
)
def boot(command_center, formating_functions):
st.write("# Agent Zeta")
if "costing" not in st.session_state:
st.session_state.costing = []
if "messages" not in st.session_state:
st.session_state.messages = []
st.chat_message("ai").write(welcome_message, unsafe_allow_html=True)
for message in st.session_state.messages:
st.chat_message("human").write(message[0])
st.chat_message("ai").write(
formating_functions[message[2]](message[1]), unsafe_allow_html=True
)
if query := st.chat_input():
try:
st.chat_message("human").write(query)
response, format_fn_name = command_center.execute_command(query)
st.chat_message("ai").write(
formating_functions[format_fn_name](response), unsafe_allow_html=True
)
except (InvalidArgumentError, InvalidCommandError) as e:
st.error(e)
def configure_openai_wrapper(inputs):
args = openai_parser.parse_args(inputs.split())
os.environ["OPENAI_API_KEY"] = args.key
st.session_state.model = args.model
st.session_state.messages.append(("/configure", str(args), "identity"))
return (str(args), "identity")
if __name__ == "__main__":
all_commands = [
("/configure", str, configure_openai_wrapper),
("/add-papers", list, process_documents_wrapper),
("/library", None, index_documents_wrapper),
("/view-snip", str, view_document_wrapper),
("/session-expense", None, calculate_cost_wrapper),
("/export", None, download_conversation_wrapper),
("/help-me", None, lambda x: (welcome_message, "identity")),
("/auto-insight", list, auto_qa_chain_wrapper),
("/deep-dive", str, query_llm_wrapper),
("/condense-summary", list, chain_of_density_wrapper),
("/insight-bullets", list, insights_bullet_wrapper),
("/insight-mind-map", list, insights_mind_map_wrapper),
("/paper-synopsis", list, synopsis_wrapper),
("/summarise-section", str, summarise_wrapper),
]
command_center = CommandCenter(
default_input_type=str,
default_function=rag_llm_wrapper,
all_commands=all_commands,
)
formating_functions = {
"identity": lambda x: x,
"dataframe": lambda x: x,
"reponse_with_citations": lambda x: ai_response_format(
x["answer"], x["citations"]
),
}
jsonify_functions = {
"identity": lambda x: x,
"dataframe": lambda x: (
x.to_dict(orient="records")
if isinstance(x, pd.DataFrame) or isinstance(x, pd.Series)
else x
),
"reponse_with_citations": lambda x: x,
}
boot(command_center, formating_functions)