Joshua Sundance Bailey commited on
Commit
ef382e2
·
unverified ·
2 Parent(s): daaed02 979e3bd

Merge pull request #44 from joshuasundance-swca/azure

Browse files
Files changed (2) hide show
  1. langchain-streamlit-demo/app.py +82 -13
  2. requirements.txt +3 -3
langchain-streamlit-demo/app.py CHANGED
@@ -12,7 +12,12 @@ from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_
12
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
13
  from langchain.chains import RetrievalQA
14
  from langchain.chains.llm import LLMChain
15
- from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
 
 
 
 
 
16
  from langchain.document_loaders import PyPDFLoader
17
  from langchain.embeddings import OpenAIEmbeddings
18
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
@@ -90,6 +95,7 @@ MODEL_DICT = {
90
  "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
91
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
92
  "codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
 
93
  }
94
  SUPPORTED_MODELS = list(MODEL_DICT.keys())
95
 
@@ -107,6 +113,17 @@ MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
107
  MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
108
  DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
109
  DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
 
 
 
 
 
 
 
 
 
 
 
110
  PROVIDER_KEY_DICT = {
111
  "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
112
  "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
@@ -173,11 +190,16 @@ with sidebar:
173
 
174
  st.session_state.provider = MODEL_DICT[model]
175
 
176
- provider_api_key = PROVIDER_KEY_DICT.get(
177
- st.session_state.provider,
178
- ) or st.text_input(
179
- f"{st.session_state.provider} API key",
180
- type="password",
 
 
 
 
 
181
  )
182
 
183
  if st.button("Clear message history"):
@@ -266,8 +288,8 @@ with sidebar:
266
  else:
267
  st.error("Please enter a valid OpenAI API key.", icon="❌")
268
 
269
- # --- Advanced Options ---
270
- with st.expander("Advanced Options", expanded=False):
271
  st.markdown("## Feedback Scale")
272
  use_faces = st.toggle(label="`Thumbs` ⇄ `Faces`", value=False)
273
  feedback_option = "faces" if use_faces else "thumbs"
@@ -298,14 +320,16 @@ with sidebar:
298
  help="Higher values give longer results.",
299
  )
300
 
301
- # --- API Keys ---
302
- LANGSMITH_API_KEY = PROVIDER_KEY_DICT.get("LANGSMITH") or st.text_input(
 
303
  "LangSmith API Key (optional)",
304
  type="password",
 
305
  )
306
- LANGSMITH_PROJECT = DEFAULT_LANGSMITH_PROJECT or st.text_input(
307
  "LangSmith Project Name",
308
- value="langchain-streamlit-demo",
309
  )
310
  if st.session_state.client is None and LANGSMITH_API_KEY:
311
  st.session_state.client = Client(
@@ -317,6 +341,40 @@ with sidebar:
317
  client=st.session_state.client,
318
  )
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  # --- LLM Instantiation ---
322
  if provider_api_key:
@@ -344,7 +402,18 @@ if provider_api_key:
344
  streaming=True,
345
  max_tokens=max_tokens,
346
  )
347
-
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  # --- Chat History ---
350
  if len(STMEMORY.messages) == 0:
 
12
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
13
  from langchain.chains import RetrievalQA
14
  from langchain.chains.llm import LLMChain
15
+ from langchain.chat_models import (
16
+ AzureChatOpenAI,
17
+ ChatAnthropic,
18
+ ChatAnyscale,
19
+ ChatOpenAI,
20
+ )
21
  from langchain.document_loaders import PyPDFLoader
22
  from langchain.embeddings import OpenAIEmbeddings
23
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
 
95
  "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
96
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
97
  "codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
98
+ "Azure OpenAI": "Azure OpenAI",
99
  }
100
  SUPPORTED_MODELS = list(MODEL_DICT.keys())
101
 
 
113
  MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
114
  DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
115
  DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
116
+
117
+ AZURE_VARS = [
118
+ "AZURE_OPENAI_BASE_URL",
119
+ "AZURE_OPENAI_API_VERSION",
120
+ "AZURE_OPENAI_DEPLOYMENT_NAME",
121
+ "AZURE_OPENAI_API_KEY",
122
+ "AZURE_OPENAI_MODEL_VERSION",
123
+ ]
124
+
125
+ AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
126
+
127
  PROVIDER_KEY_DICT = {
128
  "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
129
  "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
 
190
 
191
  st.session_state.provider = MODEL_DICT[model]
192
 
193
+ provider_api_key = (
194
+ PROVIDER_KEY_DICT.get(
195
+ st.session_state.provider,
196
+ )
197
+ or st.text_input(
198
+ f"{st.session_state.provider} API key",
199
+ type="password",
200
+ )
201
+ if st.session_state.provider != "Azure OpenAI"
202
+ else ""
203
  )
204
 
205
  if st.button("Clear message history"):
 
288
  else:
289
  st.error("Please enter a valid OpenAI API key.", icon="❌")
290
 
291
+ # --- Advanced Settings ---
292
+ with st.expander("Advanced Settings", expanded=False):
293
  st.markdown("## Feedback Scale")
294
  use_faces = st.toggle(label="`Thumbs` ⇄ `Faces`", value=False)
295
  feedback_option = "faces" if use_faces else "thumbs"
 
320
  help="Higher values give longer results.",
321
  )
322
 
323
+ # --- LangSmith Options ---
324
+ with st.expander("LangSmith Options", expanded=False):
325
+ LANGSMITH_API_KEY = st.text_input(
326
  "LangSmith API Key (optional)",
327
  type="password",
328
+ value=PROVIDER_KEY_DICT.get("LANGSMITH"),
329
  )
330
+ LANGSMITH_PROJECT = st.text_input(
331
  "LangSmith Project Name",
332
+ value=DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo",
333
  )
334
  if st.session_state.client is None and LANGSMITH_API_KEY:
335
  st.session_state.client = Client(
 
341
  client=st.session_state.client,
342
  )
343
 
344
+ # --- Azure Options ---
345
+ with st.expander("Azure Options", expanded=False):
346
+ AZURE_OPENAI_BASE_URL = st.text_input(
347
+ "AZURE_OPENAI_BASE_URL",
348
+ value=AZURE_DICT["AZURE_OPENAI_BASE_URL"],
349
+ )
350
+ AZURE_OPENAI_API_VERSION = st.text_input(
351
+ "AZURE_OPENAI_API_VERSION",
352
+ value=AZURE_DICT["AZURE_OPENAI_API_VERSION"],
353
+ )
354
+ AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
355
+ "AZURE_OPENAI_DEPLOYMENT_NAME",
356
+ value=AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"],
357
+ )
358
+ AZURE_OPENAI_API_KEY = st.text_input(
359
+ "AZURE_OPENAI_API_KEY",
360
+ value=AZURE_DICT["AZURE_OPENAI_API_KEY"],
361
+ type="password",
362
+ )
363
+ AZURE_OPENAI_MODEL_VERSION = st.text_input(
364
+ "AZURE_OPENAI_MODEL_VERSION",
365
+ value=AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"],
366
+ )
367
+
368
+ AZURE_AVAILABLE = all(
369
+ [
370
+ AZURE_OPENAI_BASE_URL,
371
+ AZURE_OPENAI_API_VERSION,
372
+ AZURE_OPENAI_DEPLOYMENT_NAME,
373
+ AZURE_OPENAI_API_KEY,
374
+ AZURE_OPENAI_MODEL_VERSION,
375
+ ],
376
+ )
377
+
378
 
379
  # --- LLM Instantiation ---
380
  if provider_api_key:
 
402
  streaming=True,
403
  max_tokens=max_tokens,
404
  )
405
+ elif AZURE_AVAILABLE and st.session_state.provider == "Azure OpenAI":
406
+ st.session_state.llm = AzureChatOpenAI(
407
+ openai_api_base=AZURE_OPENAI_BASE_URL,
408
+ openai_api_version=AZURE_OPENAI_API_VERSION,
409
+ deployment_name=AZURE_OPENAI_DEPLOYMENT_NAME,
410
+ openai_api_key=AZURE_OPENAI_API_KEY,
411
+ openai_api_type="azure",
412
+ model_version=AZURE_OPENAI_MODEL_VERSION,
413
+ temperature=temperature,
414
+ streaming=True,
415
+ max_tokens=max_tokens,
416
+ )
417
 
418
  # --- Chat History ---
419
  if len(STMEMORY.messages) == 0:
requirements.txt CHANGED
@@ -1,12 +1,12 @@
1
  anthropic==0.3.11
2
  faiss-cpu==1.7.4
3
- langchain==0.0.305
4
- langsmith==0.0.41
5
  numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
6
  openai==0.28.1
7
  pypdf==3.16.2
8
  rank_bm25==0.2.2
9
- streamlit==1.27.1
10
  streamlit-feedback==0.1.2
11
  tiktoken==0.5.1
12
  tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability
 
1
  anthropic==0.3.11
2
  faiss-cpu==1.7.4
3
+ langchain==0.0.308
4
+ langsmith==0.0.43
5
  numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
6
  openai==0.28.1
7
  pypdf==3.16.2
8
  rank_bm25==0.2.2
9
+ streamlit==1.27.2
10
  streamlit-feedback==0.1.2
11
  tiktoken==0.5.1
12
  tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability