|
import math |
|
import os |
|
import re |
|
from pathlib import Path |
|
from statistics import median |
|
|
|
import streamlit as st |
|
from bs4 import BeautifulSoup |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.docstore.document import Document |
|
from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain_openai import ChatOpenAI, OpenAI |
|
from langchain.vectorstores import Chroma |
|
from langchain.retrievers.multi_query import MultiQueryRetriever |
|
from ragatouille import RAGPretrainedModel |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
os.environ["OPENAI_API_KEY"] = "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" |
|
|
|
LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store") |
|
|
|
deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip() |
|
|
|
|
|
def embeddings_on_local_vectordb(texts): |
|
colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9") |
|
colbert.index( |
|
collection=[chunk.page_content for chunk in texts], |
|
split_documents=False, |
|
document_metadatas=[chunk.metadata for chunk in texts], |
|
index_name="vector_store", |
|
) |
|
retriever = colbert.as_langchain_retriever(k=5) |
|
retriever = MultiQueryRetriever.from_llm( |
|
retriever=retriever, llm=ChatOpenAI(temperature=0) |
|
) |
|
return retriever |
|
|
|
|
|
def query_llm(retriever, query): |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0), |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type="stuff", |
|
) |
|
relevant_docs = retriever.get_relevant_documents(query) |
|
result = qa_chain({"question": query, "chat_history": st.session_state.messages}) |
|
result = result["answer"] |
|
st.session_state.messages.append((query, result)) |
|
return relevant_docs, result |
|
|
|
|
|
def input_fields(): |
|
st.session_state.source_doc_urls = [ |
|
url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",") |
|
] |
|
|
|
|
|
def process_documents(): |
|
try: |
|
snippets = [] |
|
for url in st.session_state.source_doc_urls: |
|
if url.endswith(".pdf"): |
|
snippets.extend(process_pdf(url)) |
|
else: |
|
snippets.extend(process_web(url)) |
|
st.session_state.retriever = embeddings_on_local_vectordb(snippets) |
|
st.session_state.headers = [ |
|
" ".join(snip.metadata["header"].split()[:10]) for snip in snippets |
|
] |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
|
|
|
|
def process_pdf(url): |
|
data = PDFMinerPDFasHTMLLoader(url).load()[0] |
|
content = BeautifulSoup(data.page_content, "html.parser").find_all("div") |
|
snippets = get_pdf_snippets(content) |
|
filtered_snippets = filter_pdf_snippets(snippets, new_line_threshold_ratio=0.4) |
|
median_font_size = math.ceil( |
|
median([font_size for _, font_size in filtered_snippets]) |
|
) |
|
semantic_snippets = get_pdf_semantic_snippets(filtered_snippets, median_font_size) |
|
document_snippets = [ |
|
Document( |
|
page_content=deep_strip(snip[1]["header_text"]) + " " + deep_strip(snip[0]), |
|
metadata={ |
|
"header": deep_strip(snip[1]["header_text"]), |
|
"source_url": url, |
|
"source_type": "pdf", |
|
}, |
|
) |
|
for snip in semantic_snippets |
|
] |
|
return document_snippets |
|
|
|
|
|
def get_pdf_snippets(content): |
|
current_font_size = None |
|
current_text = "" |
|
snippets = [] |
|
for cntnt in content: |
|
span = cntnt.find("span") |
|
if not span: |
|
continue |
|
style = span.get("style") |
|
if not style: |
|
continue |
|
font_size = re.findall("font-size:(\d+)px", style) |
|
if not font_size: |
|
continue |
|
font_size = int(font_size[0]) |
|
|
|
if not current_font_size: |
|
current_font_size = font_size |
|
if font_size == current_font_size: |
|
current_text += cntnt.text |
|
else: |
|
snippets.append((current_text, current_font_size)) |
|
current_font_size = font_size |
|
current_text = cntnt.text |
|
snippets.append((current_text, current_font_size)) |
|
return snippets |
|
|
|
|
|
def filter_pdf_snippets(content_list, new_line_threshold_ratio): |
|
filtered_list = [] |
|
for e, (content, font_size) in enumerate(content_list): |
|
newline_count = content.count("\n") |
|
total_chars = len(content) |
|
ratio = newline_count / total_chars |
|
if ratio <= new_line_threshold_ratio: |
|
filtered_list.append((content, font_size)) |
|
return filtered_list |
|
|
|
|
|
def get_pdf_semantic_snippets(filtered_snippets, median_font_size): |
|
semantic_snippets = [] |
|
current_header = None |
|
current_content = [] |
|
header_font_size = None |
|
content_font_sizes = [] |
|
|
|
for content, font_size in filtered_snippets: |
|
if font_size > median_font_size: |
|
if current_header is not None: |
|
metadata = { |
|
"header_font_size": header_font_size, |
|
"content_font_size": ( |
|
median(content_font_sizes) if content_font_sizes else None |
|
), |
|
"header_text": current_header, |
|
} |
|
semantic_snippets.append((current_content, metadata)) |
|
current_content = [] |
|
content_font_sizes = [] |
|
|
|
current_header = content |
|
header_font_size = font_size |
|
else: |
|
content_font_sizes.append(font_size) |
|
if current_content: |
|
current_content += " " + content |
|
else: |
|
current_content = content |
|
|
|
if current_header is not None: |
|
metadata = { |
|
"header_font_size": header_font_size, |
|
"content_font_size": ( |
|
median(content_font_sizes) if content_font_sizes else None |
|
), |
|
"header_text": current_header, |
|
} |
|
semantic_snippets.append((current_content, metadata)) |
|
return semantic_snippets |
|
|
|
|
|
def process_web(url): |
|
data = WebBaseLoader(url).load()[0] |
|
document_snippets = [ |
|
Document( |
|
page_content=deep_strip(data.page_content), |
|
metadata={ |
|
"header": data.metadata["title"], |
|
"source_url": url, |
|
"source_type": "web", |
|
}, |
|
) |
|
] |
|
return document_snippets |
|
|
|
|
|
def boot(): |
|
st.title("Xi Chatbot") |
|
input_fields() |
|
col1, col2 = st.columns([4, 1]) |
|
st.sidebar.button("Submit Documents", on_click=process_documents) |
|
if "headers" in st.session_state: |
|
for header in st.session_state.headers: |
|
col2.info(header) |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
for message in st.session_state.messages: |
|
col1.chat_message("human").write(message[0]) |
|
col1.chat_message("ai").write(message[1]) |
|
if query := col1.chat_input(): |
|
col1.chat_message("human").write(query) |
|
references, response = query_llm(st.session_state.retriever, query) |
|
for snip in references: |
|
st.sidebar.success( |
|
f'Section {" ".join(snip.metadata["header"].split()[:10])}' |
|
) |
|
col1.chat_message("ai").write(response) |
|
|
|
|
|
if __name__ == "__main__": |
|
boot() |
|
|