Joshua Sundance Bailey commited on
Commit
c603886
·
1 Parent(s): 0ce4fb3

add azure embeddings; cleanup

Browse files
kubernetes/resources.yaml CHANGED
@@ -39,6 +39,11 @@ spec:
39
  secretKeyRef:
40
  name: langchain-streamlit-demo-secret
41
  key: AZURE_OPENAI_DEPLOYMENT_NAME
 
 
 
 
 
42
  - name: AZURE_OPENAI_API_KEY
43
  valueFrom:
44
  secretKeyRef:
 
39
  secretKeyRef:
40
  name: langchain-streamlit-demo-secret
41
  key: AZURE_OPENAI_DEPLOYMENT_NAME
42
+ - name: AZURE_OPENAI_EMB_DEPLOYMENT_NAME
43
+ valueFrom:
44
+ secretKeyRef:
45
+ name: langchain-streamlit-demo-secret
46
+ key: AZURE_OPENAI_EMB_DEPLOYMENT_NAME
47
  - name: AZURE_OPENAI_API_KEY
48
  valueFrom:
49
  secretKeyRef:
langchain-streamlit-demo/app.py CHANGED
@@ -1,5 +1,5 @@
1
  from datetime import datetime
2
- from typing import Tuple, List, Dict, Any, Union
3
 
4
  import anthropic
5
  import langsmith.utils
@@ -56,6 +56,43 @@ MEMORY = ConversationBufferMemory(
56
  )
57
  RUN_COLLECTOR = RunCollectorCallbackHandler()
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @st.cache_data
61
  def get_texts_and_retriever_cacheable_wrapper(
@@ -64,6 +101,8 @@ def get_texts_and_retriever_cacheable_wrapper(
64
  chunk_size: int = default_values.DEFAULT_CHUNK_SIZE,
65
  chunk_overlap: int = default_values.DEFAULT_CHUNK_OVERLAP,
66
  k: int = default_values.DEFAULT_RETRIEVER_K,
 
 
67
  ) -> Tuple[List[Document], BaseRetriever]:
68
  return get_texts_and_retriever(
69
  uploaded_file_bytes=uploaded_file_bytes,
@@ -71,6 +110,8 @@ def get_texts_and_retriever_cacheable_wrapper(
71
  chunk_size=chunk_size,
72
  chunk_overlap=chunk_overlap,
73
  k=k,
 
 
74
  )
75
 
76
 
@@ -173,9 +214,17 @@ with sidebar:
173
  help=chain_type_help,
174
  disabled=not document_chat,
175
  )
 
 
 
 
 
 
 
 
176
 
177
  if uploaded_file:
178
- if openai_api_key:
179
  (
180
  st.session_state.texts,
181
  st.session_state.retriever,
@@ -185,6 +234,8 @@ with sidebar:
185
  chunk_size=chunk_size,
186
  chunk_overlap=chunk_overlap,
187
  k=k,
 
 
188
  )
189
  else:
190
  st.error("Please enter a valid OpenAI API key.", icon="❌")
@@ -223,11 +274,6 @@ with sidebar:
223
  )
224
 
225
  # --- LangSmith Options ---
226
- LANGSMITH_API_KEY = default_values.PROVIDER_KEY_DICT.get("LANGSMITH")
227
- LANGSMITH_PROJECT = (
228
- default_values.DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo"
229
- )
230
-
231
  if default_values.SHOW_LANGSMITH_OPTIONS:
232
  with st.expander("LangSmith Options", expanded=False):
233
  LANGSMITH_API_KEY = st.text_input(
@@ -252,14 +298,6 @@ with sidebar:
252
  )
253
 
254
  # --- Azure Options ---
255
- AZURE_OPENAI_BASE_URL = default_values.AZURE_DICT["AZURE_OPENAI_BASE_URL"]
256
- AZURE_OPENAI_API_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_API_VERSION"]
257
- AZURE_OPENAI_DEPLOYMENT_NAME = default_values.AZURE_DICT[
258
- "AZURE_OPENAI_DEPLOYMENT_NAME"
259
- ]
260
- AZURE_OPENAI_API_KEY = default_values.AZURE_DICT["AZURE_OPENAI_API_KEY"]
261
- AZURE_OPENAI_MODEL_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"]
262
-
263
  if default_values.SHOW_AZURE_OPTIONS:
264
  with st.expander("Azure Options", expanded=False):
265
  AZURE_OPENAI_BASE_URL = st.text_input(
@@ -288,16 +326,6 @@ with sidebar:
288
  value=AZURE_OPENAI_MODEL_VERSION,
289
  )
290
 
291
- AZURE_AVAILABLE = all(
292
- [
293
- AZURE_OPENAI_BASE_URL,
294
- AZURE_OPENAI_API_VERSION,
295
- AZURE_OPENAI_DEPLOYMENT_NAME,
296
- AZURE_OPENAI_API_KEY,
297
- AZURE_OPENAI_MODEL_VERSION,
298
- ],
299
- )
300
-
301
 
302
  # --- LLM Instantiation ---
303
  st.session_state.llm = get_llm(
 
1
  from datetime import datetime
2
+ from typing import Tuple, List, Dict, Any, Union, Optional
3
 
4
  import anthropic
5
  import langsmith.utils
 
56
  )
57
  RUN_COLLECTOR = RunCollectorCallbackHandler()
58
 
59
+ LANGSMITH_API_KEY = default_values.PROVIDER_KEY_DICT.get("LANGSMITH")
60
+ LANGSMITH_PROJECT = (
61
+ default_values.DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo"
62
+ )
63
+ AZURE_OPENAI_BASE_URL = default_values.AZURE_DICT["AZURE_OPENAI_BASE_URL"]
64
+ AZURE_OPENAI_API_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_API_VERSION"]
65
+ AZURE_OPENAI_DEPLOYMENT_NAME = default_values.AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"]
66
+ AZURE_OPENAI_EMB_DEPLOYMENT_NAME = default_values.AZURE_DICT[
67
+ "AZURE_OPENAI_EMB_DEPLOYMENT_NAME"
68
+ ]
69
+ AZURE_OPENAI_API_KEY = default_values.AZURE_DICT["AZURE_OPENAI_API_KEY"]
70
+ AZURE_OPENAI_MODEL_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"]
71
+
72
+ AZURE_AVAILABLE = all(
73
+ [
74
+ AZURE_OPENAI_BASE_URL,
75
+ AZURE_OPENAI_API_VERSION,
76
+ AZURE_OPENAI_DEPLOYMENT_NAME,
77
+ AZURE_OPENAI_API_KEY,
78
+ AZURE_OPENAI_MODEL_VERSION,
79
+ ],
80
+ )
81
+
82
+ AZURE_EMB_AVAILABLE = AZURE_AVAILABLE and AZURE_OPENAI_EMB_DEPLOYMENT_NAME
83
+
84
+ AZURE_KWARGS = (
85
+ None
86
+ if not AZURE_EMB_AVAILABLE
87
+ else {
88
+ "openai_api_base": AZURE_OPENAI_BASE_URL,
89
+ "openai_api_version": AZURE_OPENAI_API_VERSION,
90
+ "deployment": AZURE_OPENAI_EMB_DEPLOYMENT_NAME,
91
+ "openai_api_key": AZURE_OPENAI_API_KEY,
92
+ "openai_api_type": "azure",
93
+ }
94
+ )
95
+
96
 
97
  @st.cache_data
98
  def get_texts_and_retriever_cacheable_wrapper(
 
101
  chunk_size: int = default_values.DEFAULT_CHUNK_SIZE,
102
  chunk_overlap: int = default_values.DEFAULT_CHUNK_OVERLAP,
103
  k: int = default_values.DEFAULT_RETRIEVER_K,
104
+ azure_kwargs: Optional[Dict[str, str]] = None,
105
+ use_azure: bool = False,
106
  ) -> Tuple[List[Document], BaseRetriever]:
107
  return get_texts_and_retriever(
108
  uploaded_file_bytes=uploaded_file_bytes,
 
110
  chunk_size=chunk_size,
111
  chunk_overlap=chunk_overlap,
112
  k=k,
113
+ azure_kwargs=azure_kwargs,
114
+ use_azure=use_azure,
115
  )
116
 
117
 
 
214
  help=chain_type_help,
215
  disabled=not document_chat,
216
  )
217
+ use_azure = False
218
+
219
+ if AZURE_EMB_AVAILABLE:
220
+ use_azure = st.toggle(
221
+ label="Use Azure OpenAI",
222
+ value=AZURE_EMB_AVAILABLE,
223
+ help="Use Azure for embeddings instead of using OpenAI directly.",
224
+ )
225
 
226
  if uploaded_file:
227
+ if AZURE_EMB_AVAILABLE or openai_api_key:
228
  (
229
  st.session_state.texts,
230
  st.session_state.retriever,
 
234
  chunk_size=chunk_size,
235
  chunk_overlap=chunk_overlap,
236
  k=k,
237
+ azure_kwargs=AZURE_KWARGS,
238
+ use_azure=use_azure,
239
  )
240
  else:
241
  st.error("Please enter a valid OpenAI API key.", icon="❌")
 
274
  )
275
 
276
  # --- LangSmith Options ---
 
 
 
 
 
277
  if default_values.SHOW_LANGSMITH_OPTIONS:
278
  with st.expander("LangSmith Options", expanded=False):
279
  LANGSMITH_API_KEY = st.text_input(
 
298
  )
299
 
300
  # --- Azure Options ---
 
 
 
 
 
 
 
 
301
  if default_values.SHOW_AZURE_OPTIONS:
302
  with st.expander("Azure Options", expanded=False):
303
  AZURE_OPENAI_BASE_URL = st.text_input(
 
326
  value=AZURE_OPENAI_MODEL_VERSION,
327
  )
328
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  # --- LLM Instantiation ---
331
  st.session_state.llm = get_llm(
langchain-streamlit-demo/defaults.py CHANGED
@@ -37,6 +37,7 @@ AZURE_VARS = [
37
  "AZURE_OPENAI_BASE_URL",
38
  "AZURE_OPENAI_API_VERSION",
39
  "AZURE_OPENAI_DEPLOYMENT_NAME",
 
40
  "AZURE_OPENAI_API_KEY",
41
  "AZURE_OPENAI_MODEL_VERSION",
42
  ]
 
37
  "AZURE_OPENAI_BASE_URL",
38
  "AZURE_OPENAI_API_VERSION",
39
  "AZURE_OPENAI_DEPLOYMENT_NAME",
40
+ "AZURE_OPENAI_EMB_DEPLOYMENT_NAME",
41
  "AZURE_OPENAI_API_KEY",
42
  "AZURE_OPENAI_MODEL_VERSION",
43
  ]
langchain-streamlit-demo/llm_resources.py CHANGED
@@ -1,5 +1,5 @@
1
  from tempfile import NamedTemporaryFile
2
- from typing import Tuple, List
3
 
4
  from langchain.callbacks.base import BaseCallbackHandler
5
  from langchain.chains import RetrievalQA, LLMChain
@@ -117,6 +117,8 @@ def get_texts_and_retriever(
117
  chunk_size: int = DEFAULT_CHUNK_SIZE,
118
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
119
  k: int = DEFAULT_RETRIEVER_K,
 
 
120
  ) -> Tuple[List[Document], BaseRetriever]:
121
  with NamedTemporaryFile() as temp_file:
122
  temp_file.write(uploaded_file_bytes)
@@ -129,7 +131,10 @@ def get_texts_and_retriever(
129
  chunk_overlap=chunk_overlap,
130
  )
131
  texts = text_splitter.split_documents(documents)
132
- embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
 
 
 
133
 
134
  bm25_retriever = BM25Retriever.from_documents(texts)
135
  bm25_retriever.k = k
 
1
  from tempfile import NamedTemporaryFile
2
+ from typing import Tuple, List, Optional, Dict
3
 
4
  from langchain.callbacks.base import BaseCallbackHandler
5
  from langchain.chains import RetrievalQA, LLMChain
 
117
  chunk_size: int = DEFAULT_CHUNK_SIZE,
118
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
119
  k: int = DEFAULT_RETRIEVER_K,
120
+ azure_kwargs: Optional[Dict[str, str]] = None,
121
+ use_azure: bool = False,
122
  ) -> Tuple[List[Document], BaseRetriever]:
123
  with NamedTemporaryFile() as temp_file:
124
  temp_file.write(uploaded_file_bytes)
 
131
  chunk_overlap=chunk_overlap,
132
  )
133
  texts = text_splitter.split_documents(documents)
134
+ embeddings_kwargs = {"openai_api_key": openai_api_key}
135
+ if use_azure and azure_kwargs:
136
+ embeddings_kwargs.update(azure_kwargs)
137
+ embeddings = OpenAIEmbeddings(**embeddings_kwargs)
138
 
139
  bm25_retriever = BM25Retriever.from_documents(texts)
140
  bm25_retriever.k = k