Ritvik19 commited on
Commit
2159374
·
verified ·
1 Parent(s): edfc4f1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +168 -49
  2. requirements.txt +9 -5
app.py CHANGED
@@ -1,61 +1,50 @@
 
1
  import os
 
2
  from pathlib import Path
 
3
 
 
 
4
  from langchain.chains import ConversationalRetrievalChain
5
- from langchain.vectorstores import Chroma
6
- from langchain.llms.openai import OpenAIChat, OpenAI
7
- from langchain.document_loaders import PyPDFLoader, WebBaseLoader
8
- from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.embeddings.openai import OpenAIEmbeddings
10
- from langchain.retrievers import ContextualCompressionRetriever
11
- from langchain.retrievers.document_compressors import LLMChainExtractor
12
- from langchain_experimental.text_splitter import SemanticChunker
 
13
 
14
- import streamlit as st
15
 
 
 
16
 
17
  LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store")
18
 
19
-
20
- def load_documents():
21
- loaders = [
22
- PyPDFLoader(source_doc_url)
23
- if source_doc_url.endswith(".pdf")
24
- else WebBaseLoader(source_doc_url)
25
- for source_doc_url in st.session_state.source_doc_urls
26
- ]
27
- documents = []
28
- for loader in loaders:
29
- documents.extend(loader.load())
30
- return documents
31
-
32
-
33
- def split_documents(documents):
34
- text_splitter = SemanticChunker(OpenAIEmbeddings(temperature=0))
35
- texts = text_splitter.split_documents(documents)
36
- return texts
37
 
38
 
39
  def embeddings_on_local_vectordb(texts):
40
- vectordb = Chroma.from_documents(
41
- texts,
42
- embedding=OpenAIEmbeddings(temperature=0),
43
- persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
 
 
44
  )
45
- vectordb.persist()
46
- retriever = ContextualCompressionRetriever(
47
- base_compressor=LLMChainExtractor.from_llm(OpenAI(temperature=0)),
48
- base_retriever=vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr"),
49
  )
50
  return retriever
51
 
52
 
53
  def query_llm(retriever, query):
54
  qa_chain = ConversationalRetrievalChain.from_llm(
55
- llm=OpenAIChat(temperature=0),
56
  retriever=retriever,
57
  return_source_documents=True,
58
- chain_type="refine",
59
  )
60
  relevant_docs = retriever.get_relevant_documents(query)
61
  result = qa_chain({"question": query, "chat_history": st.session_state.messages})
@@ -72,30 +61,160 @@ def input_fields():
72
 
73
  def process_documents():
74
  try:
75
- documents = load_documents()
76
- texts = split_documents(documents)
77
- st.session_state.retriever = embeddings_on_local_vectordb(texts)
 
 
 
 
 
 
 
78
  except Exception as e:
79
  st.error(f"An error occurred: {e}")
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def boot():
83
- st.title("Enigma Chatbot")
84
  input_fields()
 
85
  st.sidebar.button("Submit Documents", on_click=process_documents)
86
- st.sidebar.write("---")
87
- st.sidebar.write("References made during the chat will appear here")
 
88
  if "messages" not in st.session_state:
89
  st.session_state.messages = []
90
  for message in st.session_state.messages:
91
- st.chat_message("human").write(message[0])
92
- st.chat_message("ai").write(message[1])
93
- if query := st.chat_input():
94
- st.chat_message("human").write(query)
95
  references, response = query_llm(st.session_state.retriever, query)
96
- for doc in references:
97
- st.sidebar.info(f"Page {doc.metadata['page']}\n\n{doc.page_content}")
98
- st.chat_message("ai").write(response)
 
 
99
 
100
 
101
  if __name__ == "__main__":
 
1
+ import math
2
  import os
3
+ import re
4
  from pathlib import Path
5
+ from statistics import median
6
 
7
+ import streamlit as st
8
+ from bs4 import BeautifulSoup
9
  from langchain.chains import ConversationalRetrievalChain
10
+ from langchain.docstore.document import Document
11
+ from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader
 
 
12
  from langchain.embeddings.openai import OpenAIEmbeddings
13
+ from langchain_openai import ChatOpenAI, OpenAI
14
+ from langchain.vectorstores import Chroma
15
+ from langchain.retrievers.multi_query import MultiQueryRetriever
16
+ from ragatouille import RAGPretrainedModel
17
 
 
18
 
19
+ st.set_page_config(layout="wide")
20
+ os.environ["OPENAI_API_KEY"] = "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"
21
 
22
  LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store")
23
 
24
+ deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def embeddings_on_local_vectordb(texts):
28
+ colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9")
29
+ colbert.index(
30
+ collection=[chunk.page_content for chunk in texts],
31
+ split_documents=False,
32
+ document_metadatas=[chunk.metadata for chunk in texts],
33
+ index_name="vector_store",
34
  )
35
+ retriever = colbert.as_langchain_retriever(k=5)
36
+ retriever = MultiQueryRetriever.from_llm(
37
+ retriever=retriever, llm=ChatOpenAI(temperature=0)
 
38
  )
39
  return retriever
40
 
41
 
42
  def query_llm(retriever, query):
43
  qa_chain = ConversationalRetrievalChain.from_llm(
44
+ llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0),
45
  retriever=retriever,
46
  return_source_documents=True,
47
+ chain_type="stuff",
48
  )
49
  relevant_docs = retriever.get_relevant_documents(query)
50
  result = qa_chain({"question": query, "chat_history": st.session_state.messages})
 
61
 
62
  def process_documents():
63
  try:
64
+ snippets = []
65
+ for url in st.session_state.source_doc_urls:
66
+ if url.endswith(".pdf"):
67
+ snippets.extend(process_pdf(url))
68
+ else:
69
+ snippets.extend(process_web(url))
70
+ st.session_state.retriever = embeddings_on_local_vectordb(snippets)
71
+ st.session_state.headers = [
72
+ " ".join(snip.metadata["header"].split()[:10]) for snip in snippets
73
+ ]
74
  except Exception as e:
75
  st.error(f"An error occurred: {e}")
76
 
77
 
78
+ def process_pdf(url):
79
+ data = PDFMinerPDFasHTMLLoader(url).load()[0]
80
+ content = BeautifulSoup(data.page_content, "html.parser").find_all("div")
81
+ snippets = get_pdf_snippets(content)
82
+ filtered_snippets = filter_pdf_snippets(snippets, new_line_threshold_ratio=0.4)
83
+ median_font_size = math.ceil(
84
+ median([font_size for _, font_size in filtered_snippets])
85
+ )
86
+ semantic_snippets = get_pdf_semantic_snippets(filtered_snippets, median_font_size)
87
+ document_snippets = [
88
+ Document(
89
+ page_content=deep_strip(snip[1]["header_text"]) + " " + deep_strip(snip[0]),
90
+ metadata={
91
+ "header": deep_strip(snip[1]["header_text"]),
92
+ "source_url": url,
93
+ "source_type": "pdf",
94
+ },
95
+ )
96
+ for snip in semantic_snippets
97
+ ]
98
+ return document_snippets
99
+
100
+
101
+ def get_pdf_snippets(content):
102
+ current_font_size = None
103
+ current_text = ""
104
+ snippets = []
105
+ for cntnt in content:
106
+ span = cntnt.find("span")
107
+ if not span:
108
+ continue
109
+ style = span.get("style")
110
+ if not style:
111
+ continue
112
+ font_size = re.findall("font-size:(\d+)px", style)
113
+ if not font_size:
114
+ continue
115
+ font_size = int(font_size[0])
116
+
117
+ if not current_font_size:
118
+ current_font_size = font_size
119
+ if font_size == current_font_size:
120
+ current_text += cntnt.text
121
+ else:
122
+ snippets.append((current_text, current_font_size))
123
+ current_font_size = font_size
124
+ current_text = cntnt.text
125
+ snippets.append((current_text, current_font_size))
126
+ return snippets
127
+
128
+
129
+ def filter_pdf_snippets(content_list, new_line_threshold_ratio):
130
+ filtered_list = []
131
+ for e, (content, font_size) in enumerate(content_list):
132
+ newline_count = content.count("\n")
133
+ total_chars = len(content)
134
+ ratio = newline_count / total_chars
135
+ if ratio <= new_line_threshold_ratio:
136
+ filtered_list.append((content, font_size))
137
+ return filtered_list
138
+
139
+
140
+ def get_pdf_semantic_snippets(filtered_snippets, median_font_size):
141
+ semantic_snippets = []
142
+ current_header = None
143
+ current_content = []
144
+ header_font_size = None
145
+ content_font_sizes = []
146
+
147
+ for content, font_size in filtered_snippets:
148
+ if font_size > median_font_size:
149
+ if current_header is not None:
150
+ metadata = {
151
+ "header_font_size": header_font_size,
152
+ "content_font_size": (
153
+ median(content_font_sizes) if content_font_sizes else None
154
+ ),
155
+ "header_text": current_header,
156
+ }
157
+ semantic_snippets.append((current_content, metadata))
158
+ current_content = []
159
+ content_font_sizes = []
160
+
161
+ current_header = content
162
+ header_font_size = font_size
163
+ else:
164
+ content_font_sizes.append(font_size)
165
+ if current_content:
166
+ current_content += " " + content
167
+ else:
168
+ current_content = content
169
+
170
+ if current_header is not None:
171
+ metadata = {
172
+ "header_font_size": header_font_size,
173
+ "content_font_size": (
174
+ median(content_font_sizes) if content_font_sizes else None
175
+ ),
176
+ "header_text": current_header,
177
+ }
178
+ semantic_snippets.append((current_content, metadata))
179
+ return semantic_snippets
180
+
181
+
182
+ def process_web(url):
183
+ data = WebBaseLoader(url).load()[0]
184
+ document_snippets = [
185
+ Document(
186
+ page_content=deep_strip(data.page_content),
187
+ metadata={
188
+ "header": data.metadata["title"],
189
+ "source_url": url,
190
+ "source_type": "web",
191
+ },
192
+ )
193
+ ]
194
+ return document_snippets
195
+
196
+
197
  def boot():
198
+ st.title("Xi Chatbot")
199
  input_fields()
200
+ col1, col2 = st.columns([4, 1])
201
  st.sidebar.button("Submit Documents", on_click=process_documents)
202
+ if "headers" in st.session_state:
203
+ for header in st.session_state.headers:
204
+ col2.info(header)
205
  if "messages" not in st.session_state:
206
  st.session_state.messages = []
207
  for message in st.session_state.messages:
208
+ col1.chat_message("human").write(message[0])
209
+ col1.chat_message("ai").write(message[1])
210
+ if query := col1.chat_input():
211
+ col1.chat_message("human").write(query)
212
  references, response = query_llm(st.session_state.retriever, query)
213
+ for snip in references:
214
+ st.sidebar.success(
215
+ f'Section {" ".join(snip.metadata["header"].split()[:10])}'
216
+ )
217
+ col1.chat_message("ai").write(response)
218
 
219
 
220
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,6 +1,10 @@
1
- openai==0.28
2
- langchain==0.1.1
3
- pypdf==4.0.0
4
- chromadb==0.4.22
5
  langchain-experimental==0.0.49
6
- tiktoken==0.5.2
 
 
 
 
 
1
+ openai==1.12.0
2
+ langchain==0.1.9
3
+ langchain-community==0.0.24
4
+ langchain-core==0.1.27
5
  langchain-experimental==0.0.49
6
+ langchain-openai==0.0.8
7
+ chromadb==0.4.22
8
+ tiktoken==0.5.2
9
+ pdfminer.six==20231228
10
+ beautifulsoup4==4.12.3