import math import os import re from pathlib import Path from statistics import median import json import pandas as pd import streamlit as st from bs4 import BeautifulSoup from langchain.callbacks import get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain.docstore.document import Document from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader from langchain.retrievers.multi_query import MultiQueryRetriever from langchain_openai import ChatOpenAI from ragatouille import RAGPretrainedModel st.set_page_config(layout="wide") os.environ["OPENAI_API_KEY"] = "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" LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store") deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip() get_references = lambda relevant_docs: " ".join( [f"[{ref}]" for ref in sorted([ref.metadata["chunk_id"] for ref in relevant_docs])] ) session_state_2_llm_chat_history = lambda session_state: [ ss[:2] for ss in session_state ] def get_conversation_history(): return json.dumps( { "document_urls": ( st.session_state.source_doc_urls if "source_doc_urls" in st.session_state else [] ), "document_snippets": ( st.session_state.headers.to_list() if "headers" in st.session_state else [] ), "conversation": [ {"human": message[0], "ai": message[1], "references": message[2]} for message in st.session_state.messages ], "costing": ( st.session_state.costing if "costing" in st.session_state else [] ), "total_cost": ( { k: sum(d[k] for d in st.session_state.costing) for k in st.session_state.costing[0] } if "costing" in st.session_state and len(st.session_state.costing) > 0 else {} ), } ) ai_message_format = lambda message, references: f"{message}\n\n---\n\n{references}" def embeddings_on_local_vectordb(texts): colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9") colbert.index( collection=[chunk.page_content for chunk in texts], split_documents=False, document_metadatas=[chunk.metadata for chunk in texts], index_name="vector_store", ) retriever = colbert.as_langchain_retriever(k=5) retriever = MultiQueryRetriever.from_llm( retriever=retriever, llm=ChatOpenAI(temperature=0) ) return retriever def query_llm(retriever, query): qa_chain = ConversationalRetrievalChain.from_llm( llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0), retriever=retriever, return_source_documents=True, chain_type="stuff", ) relevant_docs = retriever.get_relevant_documents(query) with get_openai_callback() as cb: result = qa_chain( { "question": query, "chat_history": session_state_2_llm_chat_history( st.session_state.messages ), } ) stats = cb result = result["answer"] references = get_references(relevant_docs) st.session_state.messages.append((query, result, references)) return result, references, stats def input_fields(): st.session_state.source_doc_urls = [ url.strip() for url in st.sidebar.text_area( "Source Document URLs\n(New line separated)", height=50 ).split("\n") ] def process_documents(): try: snippets = [] for url in st.session_state.source_doc_urls: if url.endswith(".pdf"): snippets.extend(process_pdf(url)) else: snippets.extend(process_web(url)) st.session_state.retriever = embeddings_on_local_vectordb(snippets) st.session_state.headers = pd.Series( [snip.metadata["header"] for snip in snippets], name="references" ) except Exception as e: st.error(f"An error occurred: {e}") def process_pdf(url): data = PDFMinerPDFasHTMLLoader(url).load()[0] content = BeautifulSoup(data.page_content, "html.parser").find_all("div") snippets = get_pdf_snippets(content) filtered_snippets = filter_pdf_snippets(snippets, new_line_threshold_ratio=0.4) median_font_size = math.ceil( median([font_size for _, font_size in filtered_snippets]) ) semantic_snippets = get_pdf_semantic_snippets(filtered_snippets, median_font_size) document_snippets = [ Document( page_content=deep_strip(snip[1]["header_text"]) + " " + deep_strip(snip[0]), metadata={ "header": " ".join(snip[1]["header_text"].split()[:10]), "source_url": url, "source_type": "pdf", "chunk_id": i, }, ) for i, snip in enumerate(semantic_snippets) ] return document_snippets def get_pdf_snippets(content): current_font_size = None current_text = "" snippets = [] for cntnt in content: span = cntnt.find("span") if not span: continue style = span.get("style") if not style: continue font_size = re.findall("font-size:(\d+)px", style) if not font_size: continue font_size = int(font_size[0]) if not current_font_size: current_font_size = font_size if font_size == current_font_size: current_text += cntnt.text else: snippets.append((current_text, current_font_size)) current_font_size = font_size current_text = cntnt.text snippets.append((current_text, current_font_size)) return snippets def filter_pdf_snippets(content_list, new_line_threshold_ratio): filtered_list = [] for e, (content, font_size) in enumerate(content_list): newline_count = content.count("\n") total_chars = len(content) ratio = newline_count / total_chars if ratio <= new_line_threshold_ratio: filtered_list.append((content, font_size)) return filtered_list def get_pdf_semantic_snippets(filtered_snippets, median_font_size): semantic_snippets = [] current_header = None current_content = [] header_font_size = None content_font_sizes = [] for content, font_size in filtered_snippets: if font_size > median_font_size: if current_header is not None: metadata = { "header_font_size": header_font_size, "content_font_size": ( median(content_font_sizes) if content_font_sizes else None ), "header_text": current_header, } semantic_snippets.append((current_content, metadata)) current_content = [] content_font_sizes = [] current_header = content header_font_size = font_size else: content_font_sizes.append(font_size) if current_content: current_content += " " + content else: current_content = content if current_header is not None: metadata = { "header_font_size": header_font_size, "content_font_size": ( median(content_font_sizes) if content_font_sizes else None ), "header_text": current_header, } semantic_snippets.append((current_content, metadata)) return semantic_snippets def process_web(url): data = WebBaseLoader(url).load()[0] document_snippets = [ Document( page_content=deep_strip(data.page_content), metadata={ "header": data.metadata["title"], "source_url": url, "source_type": "web", }, ) ] return document_snippets def boot(): st.title("Agent Xi - An ArXiv Chatbot") st.sidebar.title("Input Documents") input_fields() st.sidebar.button("Submit Documents", on_click=process_documents) if "headers" in st.session_state: st.sidebar.write("### References") st.sidebar.write(st.session_state.headers) if "costing" not in st.session_state: st.session_state.costing = [] if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: st.chat_message("human").write(message[0]) st.chat_message("ai").write(ai_message_format(message[1], message[2])) if query := st.chat_input(): st.chat_message("human").write(query) response, references, stats = query_llm(st.session_state.retriever, query) st.chat_message("ai").write(ai_message_format(response, references)) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) stats_df = pd.DataFrame(st.session_state.costing) stats_df.loc["total"] = stats_df.sum() st.sidebar.write(stats_df) st.sidebar.download_button( "Download Conversation", get_conversation_history(), "conversation.json", ) if __name__ == "__main__": boot()