import math import os import re from pathlib import Path from statistics import median import streamlit as st from bs4 import BeautifulSoup from langchain.chains import ConversationalRetrievalChain from langchain.docstore.document import Document from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader from langchain.embeddings.openai import OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAI from langchain.vectorstores import Chroma from langchain.retrievers.multi_query import MultiQueryRetriever from ragatouille import RAGPretrainedModel st.set_page_config(layout="wide") os.environ["OPENAI_API_KEY"] = "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" LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store") deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip() def embeddings_on_local_vectordb(texts): colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9") colbert.index( collection=[chunk.page_content for chunk in texts], split_documents=False, document_metadatas=[chunk.metadata for chunk in texts], index_name="vector_store", ) retriever = colbert.as_langchain_retriever(k=5) retriever = MultiQueryRetriever.from_llm( retriever=retriever, llm=ChatOpenAI(temperature=0) ) return retriever def query_llm(retriever, query): qa_chain = ConversationalRetrievalChain.from_llm( llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0), retriever=retriever, return_source_documents=True, chain_type="stuff", ) relevant_docs = retriever.get_relevant_documents(query) result = qa_chain({"question": query, "chat_history": st.session_state.messages}) result = result["answer"] st.session_state.messages.append((query, result)) return relevant_docs, result def input_fields(): st.session_state.source_doc_urls = [ url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",") ] def process_documents(): try: snippets = [] for url in st.session_state.source_doc_urls: if url.endswith(".pdf"): snippets.extend(process_pdf(url)) else: snippets.extend(process_web(url)) st.session_state.retriever = embeddings_on_local_vectordb(snippets) st.session_state.headers = [ " ".join(snip.metadata["header"].split()[:10]) for snip in snippets ] except Exception as e: st.error(f"An error occurred: {e}") def process_pdf(url): data = PDFMinerPDFasHTMLLoader(url).load()[0] content = BeautifulSoup(data.page_content, "html.parser").find_all("div") snippets = get_pdf_snippets(content) filtered_snippets = filter_pdf_snippets(snippets, new_line_threshold_ratio=0.4) median_font_size = math.ceil( median([font_size for _, font_size in filtered_snippets]) ) semantic_snippets = get_pdf_semantic_snippets(filtered_snippets, median_font_size) document_snippets = [ Document( page_content=deep_strip(snip[1]["header_text"]) + " " + deep_strip(snip[0]), metadata={ "header": deep_strip(snip[1]["header_text"]), "source_url": url, "source_type": "pdf", }, ) for snip in semantic_snippets ] return document_snippets def get_pdf_snippets(content): current_font_size = None current_text = "" snippets = [] for cntnt in content: span = cntnt.find("span") if not span: continue style = span.get("style") if not style: continue font_size = re.findall("font-size:(\d+)px", style) if not font_size: continue font_size = int(font_size[0]) if not current_font_size: current_font_size = font_size if font_size == current_font_size: current_text += cntnt.text else: snippets.append((current_text, current_font_size)) current_font_size = font_size current_text = cntnt.text snippets.append((current_text, current_font_size)) return snippets def filter_pdf_snippets(content_list, new_line_threshold_ratio): filtered_list = [] for e, (content, font_size) in enumerate(content_list): newline_count = content.count("\n") total_chars = len(content) ratio = newline_count / total_chars if ratio <= new_line_threshold_ratio: filtered_list.append((content, font_size)) return filtered_list def get_pdf_semantic_snippets(filtered_snippets, median_font_size): semantic_snippets = [] current_header = None current_content = [] header_font_size = None content_font_sizes = [] for content, font_size in filtered_snippets: if font_size > median_font_size: if current_header is not None: metadata = { "header_font_size": header_font_size, "content_font_size": ( median(content_font_sizes) if content_font_sizes else None ), "header_text": current_header, } semantic_snippets.append((current_content, metadata)) current_content = [] content_font_sizes = [] current_header = content header_font_size = font_size else: content_font_sizes.append(font_size) if current_content: current_content += " " + content else: current_content = content if current_header is not None: metadata = { "header_font_size": header_font_size, "content_font_size": ( median(content_font_sizes) if content_font_sizes else None ), "header_text": current_header, } semantic_snippets.append((current_content, metadata)) return semantic_snippets def process_web(url): data = WebBaseLoader(url).load()[0] document_snippets = [ Document( page_content=deep_strip(data.page_content), metadata={ "header": data.metadata["title"], "source_url": url, "source_type": "web", }, ) ] return document_snippets def boot(): st.title("Xi Chatbot") input_fields() col1, col2 = st.columns([4, 1]) st.sidebar.button("Submit Documents", on_click=process_documents) if "headers" in st.session_state: for header in st.session_state.headers: col2.info(header) if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: col1.chat_message("human").write(message[0]) col1.chat_message("ai").write(message[1]) if query := col1.chat_input(): col1.chat_message("human").write(query) references, response = query_llm(st.session_state.retriever, query) for snip in references: st.sidebar.success( f'Section {" ".join(snip.metadata["header"].split()[:10])}' ) col1.chat_message("ai").write(response) if __name__ == "__main__": boot()