Joshua Sundance Bailey commited on
Commit
500ff1a
·
unverified ·
2 Parent(s): 85014c3 54cd375

Merge pull request #18 from joshuasundance-swca/dev

Browse files
Files changed (1) hide show
  1. langchain-streamlit-demo/app.py +85 -25
langchain-streamlit-demo/app.py CHANGED
@@ -40,6 +40,7 @@ st_init_null(
40
  "chain",
41
  "client",
42
  "doc_chain",
 
43
  "llm",
44
  "ls_tracer",
45
  "retriever",
@@ -105,16 +106,31 @@ PROVIDER_KEY_DICT = {
105
  }
106
  OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
107
 
 
 
 
 
 
 
 
 
108
 
109
  @st.cache_data
110
- def get_retriever(uploaded_file_bytes: bytes) -> BaseRetriever:
 
 
 
 
111
  with NamedTemporaryFile() as temp_file:
112
  temp_file.write(uploaded_file_bytes)
113
  temp_file.seek(0)
114
 
115
  loader = PyPDFLoader(temp_file.name)
116
  documents = loader.load()
117
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 
 
 
118
  texts = text_splitter.split_documents(documents)
119
  embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
120
  db = FAISS.from_documents(texts, embeddings)
@@ -139,33 +155,77 @@ with sidebar:
139
  type="password",
140
  )
141
 
142
- uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
 
 
 
 
 
 
 
143
 
144
- openai_api_key = (
145
- provider_api_key
146
- if provider == "OpenAI"
147
- else OPENAI_API_KEY
148
- or st.sidebar.text_input("OpenAI API Key: ", type="password")
149
- )
150
 
151
- if uploaded_file:
152
- if openai_api_key:
153
- st.session_state.retriever = get_retriever(
154
- uploaded_file_bytes=uploaded_file.getvalue(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
- else:
157
- st.error("Please enter a valid OpenAI API key.", icon="❌")
158
 
159
- document_chat = st.checkbox(
160
- "Document Chat",
161
- value=False,
162
- help="Uploaded document will provide context for the chat.",
163
- )
164
 
165
- if st.button("Clear message history"):
166
- STMEMORY.clear()
167
- st.session_state.trace_link = None
168
- st.session_state.run_id = None
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # --- Advanced Options ---
171
  with st.expander("Advanced Options", expanded=False):
@@ -270,7 +330,7 @@ if st.session_state.llm:
270
 
271
  st.session_state.doc_chain = RetrievalQA.from_chain_type(
272
  llm=st.session_state.llm,
273
- chain_type="stuff",
274
  retriever=st.session_state.retriever,
275
  memory=MEMORY,
276
  )
 
40
  "chain",
41
  "client",
42
  "doc_chain",
43
+ "document_chat_chain_type",
44
  "llm",
45
  "ls_tracer",
46
  "retriever",
 
106
  }
107
  OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
108
 
109
+ MIN_CHUNK_SIZE = 1
110
+ MAX_CHUNK_SIZE = 10000
111
+ DEFAULT_CHUNK_SIZE = 1000
112
+
113
+ MIN_CHUNK_OVERLAP = 0
114
+ MAX_CHUNK_OVERLAP = 10000
115
+ DEFAULT_CHUNK_OVERLAP = 0
116
+
117
 
118
  @st.cache_data
119
+ def get_retriever(
120
+ uploaded_file_bytes: bytes,
121
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
122
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
123
+ ) -> BaseRetriever:
124
  with NamedTemporaryFile() as temp_file:
125
  temp_file.write(uploaded_file_bytes)
126
  temp_file.seek(0)
127
 
128
  loader = PyPDFLoader(temp_file.name)
129
  documents = loader.load()
130
+ text_splitter = CharacterTextSplitter(
131
+ chunk_size=chunk_size,
132
+ chunk_overlap=chunk_overlap,
133
+ )
134
  texts = text_splitter.split_documents(documents)
135
  embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
136
  db = FAISS.from_documents(texts, embeddings)
 
155
  type="password",
156
  )
157
 
158
+ if st.button("Clear message history"):
159
+ STMEMORY.clear()
160
+ st.session_state.trace_link = None
161
+ st.session_state.run_id = None
162
+
163
+ # --- Document Chat Options ---
164
+ with st.expander("Document Chat", expanded=False):
165
+ uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
166
 
167
+ openai_api_key = (
168
+ provider_api_key
169
+ if provider == "OpenAI"
170
+ else OPENAI_API_KEY
171
+ or st.sidebar.text_input("OpenAI API Key: ", type="password")
172
+ )
173
 
174
+ document_chat = st.checkbox(
175
+ "Document Chat",
176
+ value=False,
177
+ help="Uploaded document will provide context for the chat.",
178
+ )
179
+
180
+ chunk_size = st.slider(
181
+ label="chunk_size",
182
+ help="Size of each chunk of text",
183
+ min_value=MIN_CHUNK_SIZE,
184
+ max_value=MAX_CHUNK_SIZE,
185
+ value=DEFAULT_CHUNK_SIZE,
186
+ )
187
+ chunk_overlap = st.slider(
188
+ label="chunk_overlap",
189
+ help="Number of characters to overlap between chunks",
190
+ min_value=MIN_CHUNK_OVERLAP,
191
+ max_value=MAX_CHUNK_OVERLAP,
192
+ value=DEFAULT_CHUNK_OVERLAP,
193
+ )
194
+
195
+ chain_type_help_root = (
196
+ "https://python.langchain.com/docs/modules/chains/document/"
197
+ )
198
+ chain_type_help_dict = {
199
+ chain_type_name: f"{chain_type_help_root}/{chain_type_name}"
200
+ for chain_type_name in (
201
+ "stuff",
202
+ "refine",
203
+ "map_reduce",
204
+ "map_rerank",
205
  )
206
+ }
 
207
 
208
+ chain_type_help = "\n".join(
209
+ f"- [{k}]({v})" for k, v in chain_type_help_dict.items()
210
+ )
 
 
211
 
212
+ document_chat_chain_type = st.selectbox(
213
+ label="Document Chat Chain Type",
214
+ options=["stuff", "refine", "map_reduce", "map_rerank"],
215
+ index=0,
216
+ help=chain_type_help,
217
+ disabled=not document_chat,
218
+ )
219
+
220
+ if uploaded_file:
221
+ if openai_api_key:
222
+ st.session_state.retriever = get_retriever(
223
+ uploaded_file_bytes=uploaded_file.getvalue(),
224
+ chunk_size=chunk_size,
225
+ chunk_overlap=chunk_overlap,
226
+ )
227
+ else:
228
+ st.error("Please enter a valid OpenAI API key.", icon="❌")
229
 
230
  # --- Advanced Options ---
231
  with st.expander("Advanced Options", expanded=False):
 
330
 
331
  st.session_state.doc_chain = RetrievalQA.from_chain_type(
332
  llm=st.session_state.llm,
333
+ chain_type=document_chat_chain_type,
334
  retriever=st.session_state.retriever,
335
  memory=MEMORY,
336
  )