|
import math |
|
import os |
|
import re |
|
from pathlib import Path |
|
from statistics import median |
|
import json |
|
import pandas as pd |
|
import streamlit as st |
|
from bs4 import BeautifulSoup |
|
from langchain.callbacks import get_openai_callback |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.docstore.document import Document |
|
from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader |
|
from langchain.retrievers.multi_query import MultiQueryRetriever |
|
from langchain_openai import ChatOpenAI |
|
from ragatouille import RAGPretrainedModel |
|
|
|
st.set_page_config(layout="wide") |
|
os.environ["OPENAI_API_KEY"] = "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" |
|
|
|
LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store") |
|
|
|
deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip() |
|
|
|
get_references = lambda relevant_docs: " ".join( |
|
[f"[{ref}]" for ref in sorted([ref.metadata["chunk_id"] for ref in relevant_docs])] |
|
) |
|
session_state_2_llm_chat_history = lambda session_state: [ |
|
ss[:2] for ss in session_state |
|
] |
|
|
|
|
|
def get_conversation_history(): |
|
return json.dumps( |
|
{ |
|
"document_urls": ( |
|
st.session_state.source_doc_urls |
|
if "source_doc_urls" in st.session_state |
|
else [] |
|
), |
|
"document_snippets": ( |
|
st.session_state.headers.to_list() |
|
if "headers" in st.session_state |
|
else [] |
|
), |
|
"conversation": [ |
|
{"human": message[0], "ai": message[1], "references": message[2]} |
|
for message in st.session_state.messages |
|
], |
|
"costing": ( |
|
st.session_state.costing if "costing" in st.session_state else [] |
|
), |
|
"total_cost": ( |
|
{ |
|
k: sum(d[k] for d in st.session_state.costing) |
|
for k in st.session_state.costing[0] |
|
} |
|
if "costing" in st.session_state and len(st.session_state.costing) > 0 |
|
else {} |
|
), |
|
} |
|
) |
|
|
|
|
|
ai_message_format = lambda message, references: f"{message}\n\n---\n\n{references}" |
|
|
|
|
|
def embeddings_on_local_vectordb(texts): |
|
colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9") |
|
colbert.index( |
|
collection=[chunk.page_content for chunk in texts], |
|
split_documents=False, |
|
document_metadatas=[chunk.metadata for chunk in texts], |
|
index_name="vector_store", |
|
) |
|
retriever = colbert.as_langchain_retriever(k=5) |
|
retriever = MultiQueryRetriever.from_llm( |
|
retriever=retriever, llm=ChatOpenAI(temperature=0) |
|
) |
|
return retriever |
|
|
|
|
|
def query_llm(retriever, query): |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0), |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type="stuff", |
|
) |
|
relevant_docs = retriever.get_relevant_documents(query) |
|
with get_openai_callback() as cb: |
|
result = qa_chain( |
|
{ |
|
"question": query, |
|
"chat_history": session_state_2_llm_chat_history( |
|
st.session_state.messages |
|
), |
|
} |
|
) |
|
stats = cb |
|
result = result["answer"] |
|
references = get_references(relevant_docs) |
|
st.session_state.messages.append((query, result, references)) |
|
return result, references, stats |
|
|
|
|
|
def input_fields(): |
|
st.session_state.source_doc_urls = [ |
|
url.strip() |
|
for url in st.sidebar.text_area( |
|
"Source Document URLs\n(New line separated)", height=50 |
|
).split("\n") |
|
] |
|
|
|
|
|
def process_documents(): |
|
try: |
|
snippets = [] |
|
for url in st.session_state.source_doc_urls: |
|
if url.endswith(".pdf"): |
|
snippets.extend(process_pdf(url)) |
|
else: |
|
snippets.extend(process_web(url)) |
|
st.session_state.retriever = embeddings_on_local_vectordb(snippets) |
|
st.session_state.headers = pd.Series( |
|
[snip.metadata["header"] for snip in snippets], name="references" |
|
) |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
|
|
|
|
def process_pdf(url): |
|
data = PDFMinerPDFasHTMLLoader(url).load()[0] |
|
content = BeautifulSoup(data.page_content, "html.parser").find_all("div") |
|
snippets = get_pdf_snippets(content) |
|
filtered_snippets = filter_pdf_snippets(snippets, new_line_threshold_ratio=0.4) |
|
median_font_size = math.ceil( |
|
median([font_size for _, font_size in filtered_snippets]) |
|
) |
|
semantic_snippets = get_pdf_semantic_snippets(filtered_snippets, median_font_size) |
|
document_snippets = [ |
|
Document( |
|
page_content=deep_strip(snip[1]["header_text"]) + " " + deep_strip(snip[0]), |
|
metadata={ |
|
"header": " ".join(snip[1]["header_text"].split()[:10]), |
|
"source_url": url, |
|
"source_type": "pdf", |
|
"chunk_id": i, |
|
}, |
|
) |
|
for i, snip in enumerate(semantic_snippets) |
|
] |
|
return document_snippets |
|
|
|
|
|
def get_pdf_snippets(content): |
|
current_font_size = None |
|
current_text = "" |
|
snippets = [] |
|
for cntnt in content: |
|
span = cntnt.find("span") |
|
if not span: |
|
continue |
|
style = span.get("style") |
|
if not style: |
|
continue |
|
font_size = re.findall("font-size:(\d+)px", style) |
|
if not font_size: |
|
continue |
|
font_size = int(font_size[0]) |
|
|
|
if not current_font_size: |
|
current_font_size = font_size |
|
if font_size == current_font_size: |
|
current_text += cntnt.text |
|
else: |
|
snippets.append((current_text, current_font_size)) |
|
current_font_size = font_size |
|
current_text = cntnt.text |
|
snippets.append((current_text, current_font_size)) |
|
return snippets |
|
|
|
|
|
def filter_pdf_snippets(content_list, new_line_threshold_ratio): |
|
filtered_list = [] |
|
for e, (content, font_size) in enumerate(content_list): |
|
newline_count = content.count("\n") |
|
total_chars = len(content) |
|
ratio = newline_count / total_chars |
|
if ratio <= new_line_threshold_ratio: |
|
filtered_list.append((content, font_size)) |
|
return filtered_list |
|
|
|
|
|
def get_pdf_semantic_snippets(filtered_snippets, median_font_size): |
|
semantic_snippets = [] |
|
current_header = None |
|
current_content = [] |
|
header_font_size = None |
|
content_font_sizes = [] |
|
|
|
for content, font_size in filtered_snippets: |
|
if font_size > median_font_size: |
|
if current_header is not None: |
|
metadata = { |
|
"header_font_size": header_font_size, |
|
"content_font_size": ( |
|
median(content_font_sizes) if content_font_sizes else None |
|
), |
|
"header_text": current_header, |
|
} |
|
semantic_snippets.append((current_content, metadata)) |
|
current_content = [] |
|
content_font_sizes = [] |
|
|
|
current_header = content |
|
header_font_size = font_size |
|
else: |
|
content_font_sizes.append(font_size) |
|
if current_content: |
|
current_content += " " + content |
|
else: |
|
current_content = content |
|
|
|
if current_header is not None: |
|
metadata = { |
|
"header_font_size": header_font_size, |
|
"content_font_size": ( |
|
median(content_font_sizes) if content_font_sizes else None |
|
), |
|
"header_text": current_header, |
|
} |
|
semantic_snippets.append((current_content, metadata)) |
|
return semantic_snippets |
|
|
|
|
|
def process_web(url): |
|
data = WebBaseLoader(url).load()[0] |
|
document_snippets = [ |
|
Document( |
|
page_content=deep_strip(data.page_content), |
|
metadata={ |
|
"header": data.metadata["title"], |
|
"source_url": url, |
|
"source_type": "web", |
|
}, |
|
) |
|
] |
|
return document_snippets |
|
|
|
|
|
def boot(): |
|
st.title("Agent Xi - An ArXiv Chatbot") |
|
st.sidebar.title("Input Documents") |
|
input_fields() |
|
st.sidebar.button("Submit Documents", on_click=process_documents) |
|
if "headers" in st.session_state: |
|
st.sidebar.write("### References") |
|
st.sidebar.write(st.session_state.headers) |
|
if "costing" not in st.session_state: |
|
st.session_state.costing = [] |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
for message in st.session_state.messages: |
|
st.chat_message("human").write(message[0]) |
|
st.chat_message("ai").write(ai_message_format(message[1], message[2])) |
|
if query := st.chat_input(): |
|
st.chat_message("human").write(query) |
|
response, references, stats = query_llm(st.session_state.retriever, query) |
|
st.chat_message("ai").write(ai_message_format(response, references)) |
|
|
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
stats_df = pd.DataFrame(st.session_state.costing) |
|
stats_df.loc["total"] = stats_df.sum() |
|
st.sidebar.write(stats_df) |
|
st.sidebar.download_button( |
|
"Download Conversation", |
|
get_conversation_history(), |
|
"conversation.json", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
boot() |
|
|