Merge pull request #44 from joshuasundance-swca/azure
Browse files- langchain-streamlit-demo/app.py +82 -13
- requirements.txt +3 -3
langchain-streamlit-demo/app.py
CHANGED
@@ -12,7 +12,12 @@ from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_
|
|
12 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
13 |
from langchain.chains import RetrievalQA
|
14 |
from langchain.chains.llm import LLMChain
|
15 |
-
from langchain.chat_models import
|
|
|
|
|
|
|
|
|
|
|
16 |
from langchain.document_loaders import PyPDFLoader
|
17 |
from langchain.embeddings import OpenAIEmbeddings
|
18 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
@@ -90,6 +95,7 @@ MODEL_DICT = {
|
|
90 |
"meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
|
91 |
"meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
|
92 |
"codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
|
|
|
93 |
}
|
94 |
SUPPORTED_MODELS = list(MODEL_DICT.keys())
|
95 |
|
@@ -107,6 +113,17 @@ MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
|
|
107 |
MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
|
108 |
DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
|
109 |
DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
PROVIDER_KEY_DICT = {
|
111 |
"OpenAI": os.environ.get("OPENAI_API_KEY", ""),
|
112 |
"Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
|
@@ -173,11 +190,16 @@ with sidebar:
|
|
173 |
|
174 |
st.session_state.provider = MODEL_DICT[model]
|
175 |
|
176 |
-
provider_api_key =
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
181 |
)
|
182 |
|
183 |
if st.button("Clear message history"):
|
@@ -266,8 +288,8 @@ with sidebar:
|
|
266 |
else:
|
267 |
st.error("Please enter a valid OpenAI API key.", icon="❌")
|
268 |
|
269 |
-
# --- Advanced
|
270 |
-
with st.expander("Advanced
|
271 |
st.markdown("## Feedback Scale")
|
272 |
use_faces = st.toggle(label="`Thumbs` ⇄ `Faces`", value=False)
|
273 |
feedback_option = "faces" if use_faces else "thumbs"
|
@@ -298,14 +320,16 @@ with sidebar:
|
|
298 |
help="Higher values give longer results.",
|
299 |
)
|
300 |
|
301 |
-
|
302 |
-
|
|
|
303 |
"LangSmith API Key (optional)",
|
304 |
type="password",
|
|
|
305 |
)
|
306 |
-
LANGSMITH_PROJECT =
|
307 |
"LangSmith Project Name",
|
308 |
-
value="langchain-streamlit-demo",
|
309 |
)
|
310 |
if st.session_state.client is None and LANGSMITH_API_KEY:
|
311 |
st.session_state.client = Client(
|
@@ -317,6 +341,40 @@ with sidebar:
|
|
317 |
client=st.session_state.client,
|
318 |
)
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
# --- LLM Instantiation ---
|
322 |
if provider_api_key:
|
@@ -344,7 +402,18 @@ if provider_api_key:
|
|
344 |
streaming=True,
|
345 |
max_tokens=max_tokens,
|
346 |
)
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
# --- Chat History ---
|
350 |
if len(STMEMORY.messages) == 0:
|
|
|
12 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
13 |
from langchain.chains import RetrievalQA
|
14 |
from langchain.chains.llm import LLMChain
|
15 |
+
from langchain.chat_models import (
|
16 |
+
AzureChatOpenAI,
|
17 |
+
ChatAnthropic,
|
18 |
+
ChatAnyscale,
|
19 |
+
ChatOpenAI,
|
20 |
+
)
|
21 |
from langchain.document_loaders import PyPDFLoader
|
22 |
from langchain.embeddings import OpenAIEmbeddings
|
23 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
|
|
95 |
"meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
|
96 |
"meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
|
97 |
"codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
|
98 |
+
"Azure OpenAI": "Azure OpenAI",
|
99 |
}
|
100 |
SUPPORTED_MODELS = list(MODEL_DICT.keys())
|
101 |
|
|
|
113 |
MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
|
114 |
DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
|
115 |
DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
|
116 |
+
|
117 |
+
AZURE_VARS = [
|
118 |
+
"AZURE_OPENAI_BASE_URL",
|
119 |
+
"AZURE_OPENAI_API_VERSION",
|
120 |
+
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
121 |
+
"AZURE_OPENAI_API_KEY",
|
122 |
+
"AZURE_OPENAI_MODEL_VERSION",
|
123 |
+
]
|
124 |
+
|
125 |
+
AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
|
126 |
+
|
127 |
PROVIDER_KEY_DICT = {
|
128 |
"OpenAI": os.environ.get("OPENAI_API_KEY", ""),
|
129 |
"Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
|
|
|
190 |
|
191 |
st.session_state.provider = MODEL_DICT[model]
|
192 |
|
193 |
+
provider_api_key = (
|
194 |
+
PROVIDER_KEY_DICT.get(
|
195 |
+
st.session_state.provider,
|
196 |
+
)
|
197 |
+
or st.text_input(
|
198 |
+
f"{st.session_state.provider} API key",
|
199 |
+
type="password",
|
200 |
+
)
|
201 |
+
if st.session_state.provider != "Azure OpenAI"
|
202 |
+
else ""
|
203 |
)
|
204 |
|
205 |
if st.button("Clear message history"):
|
|
|
288 |
else:
|
289 |
st.error("Please enter a valid OpenAI API key.", icon="❌")
|
290 |
|
291 |
+
# --- Advanced Settings ---
|
292 |
+
with st.expander("Advanced Settings", expanded=False):
|
293 |
st.markdown("## Feedback Scale")
|
294 |
use_faces = st.toggle(label="`Thumbs` ⇄ `Faces`", value=False)
|
295 |
feedback_option = "faces" if use_faces else "thumbs"
|
|
|
320 |
help="Higher values give longer results.",
|
321 |
)
|
322 |
|
323 |
+
# --- LangSmith Options ---
|
324 |
+
with st.expander("LangSmith Options", expanded=False):
|
325 |
+
LANGSMITH_API_KEY = st.text_input(
|
326 |
"LangSmith API Key (optional)",
|
327 |
type="password",
|
328 |
+
value=PROVIDER_KEY_DICT.get("LANGSMITH"),
|
329 |
)
|
330 |
+
LANGSMITH_PROJECT = st.text_input(
|
331 |
"LangSmith Project Name",
|
332 |
+
value=DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo",
|
333 |
)
|
334 |
if st.session_state.client is None and LANGSMITH_API_KEY:
|
335 |
st.session_state.client = Client(
|
|
|
341 |
client=st.session_state.client,
|
342 |
)
|
343 |
|
344 |
+
# --- Azure Options ---
|
345 |
+
with st.expander("Azure Options", expanded=False):
|
346 |
+
AZURE_OPENAI_BASE_URL = st.text_input(
|
347 |
+
"AZURE_OPENAI_BASE_URL",
|
348 |
+
value=AZURE_DICT["AZURE_OPENAI_BASE_URL"],
|
349 |
+
)
|
350 |
+
AZURE_OPENAI_API_VERSION = st.text_input(
|
351 |
+
"AZURE_OPENAI_API_VERSION",
|
352 |
+
value=AZURE_DICT["AZURE_OPENAI_API_VERSION"],
|
353 |
+
)
|
354 |
+
AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
|
355 |
+
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
356 |
+
value=AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"],
|
357 |
+
)
|
358 |
+
AZURE_OPENAI_API_KEY = st.text_input(
|
359 |
+
"AZURE_OPENAI_API_KEY",
|
360 |
+
value=AZURE_DICT["AZURE_OPENAI_API_KEY"],
|
361 |
+
type="password",
|
362 |
+
)
|
363 |
+
AZURE_OPENAI_MODEL_VERSION = st.text_input(
|
364 |
+
"AZURE_OPENAI_MODEL_VERSION",
|
365 |
+
value=AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"],
|
366 |
+
)
|
367 |
+
|
368 |
+
AZURE_AVAILABLE = all(
|
369 |
+
[
|
370 |
+
AZURE_OPENAI_BASE_URL,
|
371 |
+
AZURE_OPENAI_API_VERSION,
|
372 |
+
AZURE_OPENAI_DEPLOYMENT_NAME,
|
373 |
+
AZURE_OPENAI_API_KEY,
|
374 |
+
AZURE_OPENAI_MODEL_VERSION,
|
375 |
+
],
|
376 |
+
)
|
377 |
+
|
378 |
|
379 |
# --- LLM Instantiation ---
|
380 |
if provider_api_key:
|
|
|
402 |
streaming=True,
|
403 |
max_tokens=max_tokens,
|
404 |
)
|
405 |
+
elif AZURE_AVAILABLE and st.session_state.provider == "Azure OpenAI":
|
406 |
+
st.session_state.llm = AzureChatOpenAI(
|
407 |
+
openai_api_base=AZURE_OPENAI_BASE_URL,
|
408 |
+
openai_api_version=AZURE_OPENAI_API_VERSION,
|
409 |
+
deployment_name=AZURE_OPENAI_DEPLOYMENT_NAME,
|
410 |
+
openai_api_key=AZURE_OPENAI_API_KEY,
|
411 |
+
openai_api_type="azure",
|
412 |
+
model_version=AZURE_OPENAI_MODEL_VERSION,
|
413 |
+
temperature=temperature,
|
414 |
+
streaming=True,
|
415 |
+
max_tokens=max_tokens,
|
416 |
+
)
|
417 |
|
418 |
# --- Chat History ---
|
419 |
if len(STMEMORY.messages) == 0:
|
requirements.txt
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
anthropic==0.3.11
|
2 |
faiss-cpu==1.7.4
|
3 |
-
langchain==0.0.
|
4 |
-
langsmith==0.0.
|
5 |
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
|
6 |
openai==0.28.1
|
7 |
pypdf==3.16.2
|
8 |
rank_bm25==0.2.2
|
9 |
-
streamlit==1.27.
|
10 |
streamlit-feedback==0.1.2
|
11 |
tiktoken==0.5.1
|
12 |
tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability
|
|
|
1 |
anthropic==0.3.11
|
2 |
faiss-cpu==1.7.4
|
3 |
+
langchain==0.0.308
|
4 |
+
langsmith==0.0.43
|
5 |
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
|
6 |
openai==0.28.1
|
7 |
pypdf==3.16.2
|
8 |
rank_bm25==0.2.2
|
9 |
+
streamlit==1.27.2
|
10 |
streamlit-feedback==0.1.2
|
11 |
tiktoken==0.5.1
|
12 |
tornado>=6.3.3 # not directly required, pinned by Snyk to avoid a vulnerability
|