Merge pull request #18 from joshuasundance-swca/dev
Browse files- langchain-streamlit-demo/app.py +85 -25
langchain-streamlit-demo/app.py
CHANGED
@@ -40,6 +40,7 @@ st_init_null(
|
|
40 |
"chain",
|
41 |
"client",
|
42 |
"doc_chain",
|
|
|
43 |
"llm",
|
44 |
"ls_tracer",
|
45 |
"retriever",
|
@@ -105,16 +106,31 @@ PROVIDER_KEY_DICT = {
|
|
105 |
}
|
106 |
OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
@st.cache_data
|
110 |
-
def get_retriever(
|
|
|
|
|
|
|
|
|
111 |
with NamedTemporaryFile() as temp_file:
|
112 |
temp_file.write(uploaded_file_bytes)
|
113 |
temp_file.seek(0)
|
114 |
|
115 |
loader = PyPDFLoader(temp_file.name)
|
116 |
documents = loader.load()
|
117 |
-
text_splitter = CharacterTextSplitter(
|
|
|
|
|
|
|
118 |
texts = text_splitter.split_documents(documents)
|
119 |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
120 |
db = FAISS.from_documents(texts, embeddings)
|
@@ -139,33 +155,77 @@ with sidebar:
|
|
139 |
type="password",
|
140 |
)
|
141 |
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
)
|
156 |
-
|
157 |
-
st.error("Please enter a valid OpenAI API key.", icon="❌")
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
help="Uploaded document will provide context for the chat.",
|
163 |
-
)
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
# --- Advanced Options ---
|
171 |
with st.expander("Advanced Options", expanded=False):
|
@@ -270,7 +330,7 @@ if st.session_state.llm:
|
|
270 |
|
271 |
st.session_state.doc_chain = RetrievalQA.from_chain_type(
|
272 |
llm=st.session_state.llm,
|
273 |
-
chain_type=
|
274 |
retriever=st.session_state.retriever,
|
275 |
memory=MEMORY,
|
276 |
)
|
|
|
40 |
"chain",
|
41 |
"client",
|
42 |
"doc_chain",
|
43 |
+
"document_chat_chain_type",
|
44 |
"llm",
|
45 |
"ls_tracer",
|
46 |
"retriever",
|
|
|
106 |
}
|
107 |
OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
|
108 |
|
109 |
+
MIN_CHUNK_SIZE = 1
|
110 |
+
MAX_CHUNK_SIZE = 10000
|
111 |
+
DEFAULT_CHUNK_SIZE = 1000
|
112 |
+
|
113 |
+
MIN_CHUNK_OVERLAP = 0
|
114 |
+
MAX_CHUNK_OVERLAP = 10000
|
115 |
+
DEFAULT_CHUNK_OVERLAP = 0
|
116 |
+
|
117 |
|
118 |
@st.cache_data
|
119 |
+
def get_retriever(
|
120 |
+
uploaded_file_bytes: bytes,
|
121 |
+
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
122 |
+
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
123 |
+
) -> BaseRetriever:
|
124 |
with NamedTemporaryFile() as temp_file:
|
125 |
temp_file.write(uploaded_file_bytes)
|
126 |
temp_file.seek(0)
|
127 |
|
128 |
loader = PyPDFLoader(temp_file.name)
|
129 |
documents = loader.load()
|
130 |
+
text_splitter = CharacterTextSplitter(
|
131 |
+
chunk_size=chunk_size,
|
132 |
+
chunk_overlap=chunk_overlap,
|
133 |
+
)
|
134 |
texts = text_splitter.split_documents(documents)
|
135 |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
136 |
db = FAISS.from_documents(texts, embeddings)
|
|
|
155 |
type="password",
|
156 |
)
|
157 |
|
158 |
+
if st.button("Clear message history"):
|
159 |
+
STMEMORY.clear()
|
160 |
+
st.session_state.trace_link = None
|
161 |
+
st.session_state.run_id = None
|
162 |
+
|
163 |
+
# --- Document Chat Options ---
|
164 |
+
with st.expander("Document Chat", expanded=False):
|
165 |
+
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
|
166 |
|
167 |
+
openai_api_key = (
|
168 |
+
provider_api_key
|
169 |
+
if provider == "OpenAI"
|
170 |
+
else OPENAI_API_KEY
|
171 |
+
or st.sidebar.text_input("OpenAI API Key: ", type="password")
|
172 |
+
)
|
173 |
|
174 |
+
document_chat = st.checkbox(
|
175 |
+
"Document Chat",
|
176 |
+
value=False,
|
177 |
+
help="Uploaded document will provide context for the chat.",
|
178 |
+
)
|
179 |
+
|
180 |
+
chunk_size = st.slider(
|
181 |
+
label="chunk_size",
|
182 |
+
help="Size of each chunk of text",
|
183 |
+
min_value=MIN_CHUNK_SIZE,
|
184 |
+
max_value=MAX_CHUNK_SIZE,
|
185 |
+
value=DEFAULT_CHUNK_SIZE,
|
186 |
+
)
|
187 |
+
chunk_overlap = st.slider(
|
188 |
+
label="chunk_overlap",
|
189 |
+
help="Number of characters to overlap between chunks",
|
190 |
+
min_value=MIN_CHUNK_OVERLAP,
|
191 |
+
max_value=MAX_CHUNK_OVERLAP,
|
192 |
+
value=DEFAULT_CHUNK_OVERLAP,
|
193 |
+
)
|
194 |
+
|
195 |
+
chain_type_help_root = (
|
196 |
+
"https://python.langchain.com/docs/modules/chains/document/"
|
197 |
+
)
|
198 |
+
chain_type_help_dict = {
|
199 |
+
chain_type_name: f"{chain_type_help_root}/{chain_type_name}"
|
200 |
+
for chain_type_name in (
|
201 |
+
"stuff",
|
202 |
+
"refine",
|
203 |
+
"map_reduce",
|
204 |
+
"map_rerank",
|
205 |
)
|
206 |
+
}
|
|
|
207 |
|
208 |
+
chain_type_help = "\n".join(
|
209 |
+
f"- [{k}]({v})" for k, v in chain_type_help_dict.items()
|
210 |
+
)
|
|
|
|
|
211 |
|
212 |
+
document_chat_chain_type = st.selectbox(
|
213 |
+
label="Document Chat Chain Type",
|
214 |
+
options=["stuff", "refine", "map_reduce", "map_rerank"],
|
215 |
+
index=0,
|
216 |
+
help=chain_type_help,
|
217 |
+
disabled=not document_chat,
|
218 |
+
)
|
219 |
+
|
220 |
+
if uploaded_file:
|
221 |
+
if openai_api_key:
|
222 |
+
st.session_state.retriever = get_retriever(
|
223 |
+
uploaded_file_bytes=uploaded_file.getvalue(),
|
224 |
+
chunk_size=chunk_size,
|
225 |
+
chunk_overlap=chunk_overlap,
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
st.error("Please enter a valid OpenAI API key.", icon="❌")
|
229 |
|
230 |
# --- Advanced Options ---
|
231 |
with st.expander("Advanced Options", expanded=False):
|
|
|
330 |
|
331 |
st.session_state.doc_chain = RetrievalQA.from_chain_type(
|
332 |
llm=st.session_state.llm,
|
333 |
+
chain_type=document_chat_chain_type,
|
334 |
retriever=st.session_state.retriever,
|
335 |
memory=MEMORY,
|
336 |
)
|