Joshua Sundance Bailey commited on
Commit
85014c3
Β·
unverified Β·
2 Parent(s): ad0c8c1 68a3064

Merge pull request #17 from joshuasundance-swca/dev

Browse files
docker-compose.yml CHANGED
@@ -4,6 +4,8 @@ services:
4
  langchain-streamlit-demo:
5
  image: langchain-streamlit-demo:latest
6
  build: .
 
 
7
  ports:
8
  - "${APP_PORT:-7860}:${APP_PORT:-7860}"
9
  command: [
 
4
  langchain-streamlit-demo:
5
  image: langchain-streamlit-demo:latest
6
  build: .
7
+ env_file:
8
+ - .env
9
  ports:
10
  - "${APP_PORT:-7860}:${APP_PORT:-7860}"
11
  command: [
langchain-streamlit-demo/app.py CHANGED
@@ -1,43 +1,78 @@
1
  import os
2
  from datetime import datetime
 
3
  from typing import Union
4
 
5
  import anthropic
6
  import openai
7
  import streamlit as st
8
  from langchain import LLMChain
 
9
  from langchain.callbacks.base import BaseCallbackHandler
10
- from langchain.callbacks.tracers.langchain import wait_for_all_tracers
11
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
 
12
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
13
- from langchain.chat_models.base import BaseChatModel
 
14
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
15
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
16
- from langchain.schema.runnable import RunnableConfig
 
 
17
  from langsmith.client import Client
18
  from streamlit_feedback import streamlit_feedback
19
 
 
20
  st.set_page_config(
21
  page_title="langchain-streamlit-demo",
22
  page_icon="🦜",
23
  )
24
 
25
- st.sidebar.markdown("# Menu")
26
 
 
 
 
 
27
 
28
- _STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
29
- _MEMORY = ConversationBufferMemory(
30
- chat_memory=_STMEMORY,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return_messages=True,
32
  memory_key="chat_history",
33
  )
34
 
35
- _DEFAULT_SYSTEM_PROMPT = os.environ.get(
36
- "DEFAULT_SYSTEM_PROMPT",
37
- "You are a helpful chatbot.",
38
- )
39
 
40
- _MODEL_DICT = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  "gpt-3.5-turbo": "OpenAI",
42
  "gpt-4": "OpenAI",
43
  "claude-instant-v1": "Anthropic",
@@ -46,248 +81,329 @@ _MODEL_DICT = {
46
  "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
47
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
48
  }
49
- _SUPPORTED_MODELS = list(_MODEL_DICT.keys())
50
- _DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo")
51
-
52
- _DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", 0.7))
53
- _MIN_TEMPERATURE = float(os.environ.get("MIN_TEMPERATURE", 0.0))
54
- _MAX_TEMPERATURE = float(os.environ.get("MAX_TEMPERATURE", 1.0))
55
-
56
- _DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
57
- _MIN_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
58
- _MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
59
-
60
-
61
- def get_llm(
62
- model: str,
63
- provider_api_key: str,
64
- temperature: float,
65
- max_tokens: int = _DEFAULT_MAX_TOKENS,
66
- ) -> BaseChatModel:
67
- if _MODEL_DICT[model] == "OpenAI":
68
- return ChatOpenAI(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  model=model,
70
  openai_api_key=provider_api_key,
71
  temperature=temperature,
72
  streaming=True,
73
  max_tokens=max_tokens,
74
  )
75
- elif _MODEL_DICT[model] == "Anthropic":
76
- return ChatAnthropic(
77
  model_name=model,
78
  anthropic_api_key=provider_api_key,
79
  temperature=temperature,
80
  streaming=True,
81
  max_tokens_to_sample=max_tokens,
82
  )
83
- elif _MODEL_DICT[model] == "Anyscale Endpoints":
84
- return ChatAnyscale(
85
  model=model,
86
  anyscale_api_key=provider_api_key,
87
  temperature=temperature,
88
  streaming=True,
89
  max_tokens=max_tokens,
90
  )
91
- else:
92
- raise NotImplementedError(f"Unknown model {model}")
93
-
94
-
95
- def get_llm_chain(
96
- model: str,
97
- provider_api_key: str,
98
- system_prompt: str = _DEFAULT_SYSTEM_PROMPT,
99
- temperature: float = _DEFAULT_TEMPERATURE,
100
- max_tokens: int = _DEFAULT_MAX_TOKENS,
101
- ) -> LLMChain:
102
- """Return a basic LLMChain with memory."""
103
- prompt = ChatPromptTemplate.from_messages(
104
- [
105
- (
106
- "system",
107
- system_prompt + "\nIt's currently {time}.",
108
- ),
109
- MessagesPlaceholder(variable_name="chat_history"),
110
- ("human", "{input}"),
111
- ],
112
- ).partial(time=lambda: str(datetime.now()))
113
- llm = get_llm(model, provider_api_key, temperature, max_tokens)
114
- return LLMChain(prompt=prompt, llm=llm, memory=_MEMORY)
115
-
116
-
117
- class StreamHandler(BaseCallbackHandler):
118
- def __init__(self, container, initial_text=""):
119
- self.container = container
120
- self.text = initial_text
121
-
122
- def on_llm_new_token(self, token: str, **kwargs) -> None:
123
- self.text += token
124
- self.container.markdown(self.text)
125
-
126
-
127
- def feedback_component(client):
128
- scores = {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0}
129
- if feedback := streamlit_feedback(
130
- feedback_type="faces",
131
- optional_text_label="[Optional] Please provide an explanation",
132
- key=f"feedback_{st.session_state.run_id}",
133
- ):
134
- score = scores[feedback["score"]]
135
- feedback = client.create_feedback(
136
- st.session_state.run_id,
137
- feedback["type"],
138
- score=score,
139
- comment=feedback.get("text", None),
140
- )
141
- st.session_state.feedback = {"feedback_id": str(feedback.id), "score": score}
142
- st.toast("Feedback recorded!", icon="πŸ“")
143
 
144
 
145
- # Initialize State
146
- if "trace_link" not in st.session_state:
147
- st.session_state.trace_link = None
148
- if "run_id" not in st.session_state:
149
- st.session_state.run_id = None
150
- if len(_STMEMORY.messages) == 0:
151
- _STMEMORY.add_ai_message("Hello! I'm a helpful AI chatbot. Ask me a question!")
152
 
153
- for msg in _STMEMORY.messages:
154
  st.chat_message(
155
  msg.type,
156
  avatar="🦜" if msg.type in ("ai", "assistant") else None,
157
  ).write(msg.content)
158
 
159
- model = st.sidebar.selectbox(
160
- label="Chat Model",
161
- options=_SUPPORTED_MODELS,
162
- index=_SUPPORTED_MODELS.index(_DEFAULT_MODEL),
163
- )
164
- provider = _MODEL_DICT[model]
165
-
166
-
167
- def api_key_from_env(_provider: str) -> Union[str, None]:
168
- if _provider == "OpenAI":
169
- return os.environ.get("OPENAI_API_KEY")
170
- elif _provider == "Anthropic":
171
- return os.environ.get("ANTHROPIC_API_KEY")
172
- elif _provider == "Anyscale Endpoints":
173
- return os.environ.get("ANYSCALE_API_KEY")
174
- elif _provider == "LANGSMITH":
175
- return os.environ.get("LANGCHAIN_API_KEY")
176
- else:
177
- return None
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- provider_api_key = api_key_from_env(provider) or st.sidebar.text_input(
181
- f"{provider} API key",
182
- type="password",
183
- )
184
- langsmith_api_key = api_key_from_env("LANGSMITH") or st.sidebar.text_input(
185
- "LangSmith API Key (optional)",
186
- type="password",
187
- )
188
- if langsmith_api_key:
189
- langsmith_project = os.environ.get("LANGCHAIN_PROJECT") or st.sidebar.text_input(
190
- "LangSmith Project Name",
191
- value="langchain-streamlit-demo",
192
- )
193
- os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
194
- os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
195
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
196
- os.environ["LANGCHAIN_PROJECT"] = langsmith_project
197
 
198
- client = Client(api_key=langsmith_api_key)
199
- else:
200
- langsmith_project = None
201
- client = None
202
-
203
- system_prompt = (
204
- st.sidebar.text_area(
205
- "Custom Instructions",
206
- _DEFAULT_SYSTEM_PROMPT,
207
- help="Custom instructions to provide the language model to determine style, personality, etc.",
208
- )
209
- .strip()
210
- .replace("{", "{{")
211
- .replace("}", "}}")
212
- )
213
 
214
- if st.sidebar.button("Clear message history"):
215
- print("Clearing message history")
216
- _STMEMORY.clear()
217
- st.session_state.trace_link = None
218
- st.session_state.run_id = None
219
-
220
- temperature = st.sidebar.slider(
221
- "Temperature",
222
- min_value=_MIN_TEMPERATURE,
223
- max_value=_MAX_TEMPERATURE,
224
- value=_DEFAULT_TEMPERATURE,
225
- help="Higher values give more random results.",
226
- )
227
 
228
- max_tokens = st.sidebar.slider(
229
- "Max Tokens",
230
- min_value=_MIN_TOKENS,
231
- max_value=_MAX_TOKENS,
232
- value=_DEFAULT_MAX_TOKENS,
233
- help="Higher values give longer results.",
234
- )
235
- chain = None
236
- if provider_api_key:
237
- chain = get_llm_chain(
238
- model,
239
- provider_api_key,
240
- system_prompt,
241
- temperature,
242
- max_tokens,
243
- )
244
 
245
- run_collector = RunCollectorCallbackHandler()
 
 
 
 
 
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- def _reset_feedback():
249
- st.session_state.feedback_update = None
250
- st.session_state.feedback = None
 
 
 
 
251
 
 
 
 
 
 
252
 
253
- if chain:
254
- prompt = st.chat_input(placeholder="Ask me a question!")
255
- if prompt:
256
- st.chat_message("user").write(prompt)
257
- _reset_feedback()
258
 
259
- with st.chat_message("assistant", avatar="🦜"):
260
- message_placeholder = st.empty()
261
- stream_handler = StreamHandler(message_placeholder)
262
- runnable_config = RunnableConfig(
263
- callbacks=[run_collector, stream_handler],
264
- tags=["Streamlit Chat"],
265
  )
266
- try:
267
- full_response = chain.invoke(
268
- {"input": prompt},
269
- config=runnable_config,
270
- )["text"]
271
- except (openai.error.AuthenticationError, anthropic.AuthenticationError):
272
- st.error(f"Please enter a valid {provider} API key.", icon="❌")
273
- st.stop()
274
- message_placeholder.markdown(full_response)
275
-
276
- if client:
277
- run = run_collector.traced_runs[0]
278
- run_collector.traced_runs = []
279
- st.session_state.run_id = run.id
280
- wait_for_all_tracers()
281
- url = client.read_run(run.id).url
282
- st.session_state.trace_link = url
283
- if client and st.session_state.get("run_id"):
284
- feedback_component(client)
 
 
285
 
286
  else:
287
  st.error(f"Please enter a valid {provider} API key.", icon="❌")
288
-
289
- if client and st.session_state.get("trace_link"):
290
- st.sidebar.markdown(
291
- f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
292
- unsafe_allow_html=True,
293
- )
 
1
  import os
2
  from datetime import datetime
3
+ from tempfile import NamedTemporaryFile
4
  from typing import Union
5
 
6
  import anthropic
7
  import openai
8
  import streamlit as st
9
  from langchain import LLMChain
10
+ from langchain.callbacks import StreamlitCallbackHandler
11
  from langchain.callbacks.base import BaseCallbackHandler
12
+ from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
13
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
14
+ from langchain.chains import RetrievalQA
15
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
16
+ from langchain.document_loaders import PyPDFLoader
17
+ from langchain.embeddings import OpenAIEmbeddings
18
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
19
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
20
+ from langchain.schema.retriever import BaseRetriever
21
+ from langchain.text_splitter import CharacterTextSplitter
22
+ from langchain.vectorstores import FAISS
23
  from langsmith.client import Client
24
  from streamlit_feedback import streamlit_feedback
25
 
26
+ # --- Initialization ---
27
  st.set_page_config(
28
  page_title="langchain-streamlit-demo",
29
  page_icon="🦜",
30
  )
31
 
 
32
 
33
+ def st_init_null(*variable_names) -> None:
34
+ for variable_name in variable_names:
35
+ if variable_name not in st.session_state:
36
+ st.session_state[variable_name] = None
37
 
38
+
39
+ st_init_null(
40
+ "chain",
41
+ "client",
42
+ "doc_chain",
43
+ "llm",
44
+ "ls_tracer",
45
+ "retriever",
46
+ "run",
47
+ "run_id",
48
+ "trace_link",
49
+ )
50
+
51
+ # --- Memory ---
52
+ STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
53
+ MEMORY = ConversationBufferMemory(
54
+ chat_memory=STMEMORY,
55
  return_messages=True,
56
  memory_key="chat_history",
57
  )
58
 
 
 
 
 
59
 
60
+ # --- Callbacks ---
61
+ class StreamHandler(BaseCallbackHandler):
62
+ def __init__(self, container, initial_text=""):
63
+ self.container = container
64
+ self.text = initial_text
65
+
66
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
67
+ self.text += token
68
+ self.container.markdown(self.text)
69
+
70
+
71
+ RUN_COLLECTOR = RunCollectorCallbackHandler()
72
+
73
+
74
+ # --- Model Selection Helpers ---
75
+ MODEL_DICT = {
76
  "gpt-3.5-turbo": "OpenAI",
77
  "gpt-4": "OpenAI",
78
  "claude-instant-v1": "Anthropic",
 
81
  "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
82
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
83
  }
84
+ SUPPORTED_MODELS = list(MODEL_DICT.keys())
85
+
86
+
87
+ # --- Constants from Environment Variables ---
88
+ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo")
89
+ DEFAULT_SYSTEM_PROMPT = os.environ.get(
90
+ "DEFAULT_SYSTEM_PROMPT",
91
+ "You are a helpful chatbot.",
92
+ )
93
+ MIN_TEMP = float(os.environ.get("MIN_TEMPERATURE", 0.0))
94
+ MAX_TEMP = float(os.environ.get("MAX_TEMPERATURE", 1.0))
95
+ DEFAULT_TEMP = float(os.environ.get("DEFAULT_TEMPERATURE", 0.7))
96
+ MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
97
+ MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
98
+ DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
99
+ DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
100
+ PROVIDER_KEY_DICT = {
101
+ "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
102
+ "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
103
+ "Anyscale Endpoints": os.environ.get("ANYSCALE_API_KEY", ""),
104
+ "LANGSMITH": os.environ.get("LANGCHAIN_API_KEY", ""),
105
+ }
106
+ OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
107
+
108
+
109
+ @st.cache_data
110
+ def get_retriever(uploaded_file_bytes: bytes) -> BaseRetriever:
111
+ with NamedTemporaryFile() as temp_file:
112
+ temp_file.write(uploaded_file_bytes)
113
+ temp_file.seek(0)
114
+
115
+ loader = PyPDFLoader(temp_file.name)
116
+ documents = loader.load()
117
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
118
+ texts = text_splitter.split_documents(documents)
119
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
120
+ db = FAISS.from_documents(texts, embeddings)
121
+ return db.as_retriever()
122
+
123
+
124
+ # --- Sidebar ---
125
+ sidebar = st.sidebar
126
+ with sidebar:
127
+ st.markdown("# Menu")
128
+
129
+ model = st.selectbox(
130
+ label="Chat Model",
131
+ options=SUPPORTED_MODELS,
132
+ index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
133
+ )
134
+
135
+ provider = MODEL_DICT[model]
136
+
137
+ provider_api_key = PROVIDER_KEY_DICT.get(provider) or st.text_input(
138
+ f"{provider} API key",
139
+ type="password",
140
+ )
141
+
142
+ uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
143
+
144
+ openai_api_key = (
145
+ provider_api_key
146
+ if provider == "OpenAI"
147
+ else OPENAI_API_KEY
148
+ or st.sidebar.text_input("OpenAI API Key: ", type="password")
149
+ )
150
+
151
+ if uploaded_file:
152
+ if openai_api_key:
153
+ st.session_state.retriever = get_retriever(
154
+ uploaded_file_bytes=uploaded_file.getvalue(),
155
+ )
156
+ else:
157
+ st.error("Please enter a valid OpenAI API key.", icon="❌")
158
+
159
+ document_chat = st.checkbox(
160
+ "Document Chat",
161
+ value=False,
162
+ help="Uploaded document will provide context for the chat.",
163
+ )
164
+
165
+ if st.button("Clear message history"):
166
+ STMEMORY.clear()
167
+ st.session_state.trace_link = None
168
+ st.session_state.run_id = None
169
+
170
+ # --- Advanced Options ---
171
+ with st.expander("Advanced Options", expanded=False):
172
+ st.markdown("## Feedback Scale")
173
+ use_faces = st.toggle(label="`Thumbs` ⇄ `Faces`", value=False)
174
+ feedback_option = "faces" if use_faces else "thumbs"
175
+
176
+ system_prompt = (
177
+ st.text_area(
178
+ "Custom Instructions",
179
+ DEFAULT_SYSTEM_PROMPT,
180
+ help="Custom instructions to provide the language model to determine style, personality, etc.",
181
+ )
182
+ .strip()
183
+ .replace("{", "{{")
184
+ .replace("}", "}}")
185
+ )
186
+ temperature = st.slider(
187
+ "Temperature",
188
+ min_value=MIN_TEMP,
189
+ max_value=MAX_TEMP,
190
+ value=DEFAULT_TEMP,
191
+ help="Higher values give more random results.",
192
+ )
193
+
194
+ max_tokens = st.slider(
195
+ "Max Tokens",
196
+ min_value=MIN_MAX_TOKENS,
197
+ max_value=MAX_MAX_TOKENS,
198
+ value=DEFAULT_MAX_TOKENS,
199
+ help="Higher values give longer results.",
200
+ )
201
+
202
+ # --- API Keys ---
203
+ LANGSMITH_API_KEY = PROVIDER_KEY_DICT.get("LANGSMITH") or st.text_input(
204
+ "LangSmith API Key (optional)",
205
+ type="password",
206
+ )
207
+ LANGSMITH_PROJECT = DEFAULT_LANGSMITH_PROJECT or st.text_input(
208
+ "LangSmith Project Name",
209
+ value="langchain-streamlit-demo",
210
+ )
211
+ if st.session_state.client is None and LANGSMITH_API_KEY:
212
+ st.session_state.client = Client(
213
+ api_url="https://api.smith.langchain.com",
214
+ api_key=LANGSMITH_API_KEY,
215
+ )
216
+ st.session_state.ls_tracer = LangChainTracer(
217
+ project_name=LANGSMITH_PROJECT,
218
+ client=st.session_state.client,
219
+ )
220
+
221
+
222
+ # --- LLM Instantiation ---
223
+ if provider_api_key:
224
+ if provider == "OpenAI":
225
+ st.session_state.llm = ChatOpenAI(
226
  model=model,
227
  openai_api_key=provider_api_key,
228
  temperature=temperature,
229
  streaming=True,
230
  max_tokens=max_tokens,
231
  )
232
+ elif provider == "Anthropic":
233
+ st.session_state.llm = ChatAnthropic(
234
  model_name=model,
235
  anthropic_api_key=provider_api_key,
236
  temperature=temperature,
237
  streaming=True,
238
  max_tokens_to_sample=max_tokens,
239
  )
240
+ elif provider == "Anyscale Endpoints":
241
+ st.session_state.llm = ChatAnyscale(
242
  model=model,
243
  anyscale_api_key=provider_api_key,
244
  temperature=temperature,
245
  streaming=True,
246
  max_tokens=max_tokens,
247
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
+ # --- Chat History ---
251
+ if len(STMEMORY.messages) == 0:
252
+ STMEMORY.add_ai_message("Hello! I'm a helpful AI chatbot. Ask me a question!")
 
 
 
 
253
 
254
+ for msg in STMEMORY.messages:
255
  st.chat_message(
256
  msg.type,
257
  avatar="🦜" if msg.type in ("ai", "assistant") else None,
258
  ).write(msg.content)
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
+ # --- Current Chat ---
262
+ if st.session_state.llm:
263
+ # --- Document Chat ---
264
+ if st.session_state.retriever:
265
+ # st.session_state.doc_chain = ConversationalRetrievalChain.from_llm(
266
+ # st.session_state.llm,
267
+ # st.session_state.retriever,
268
+ # memory=MEMORY,
269
+ # )
270
+
271
+ st.session_state.doc_chain = RetrievalQA.from_chain_type(
272
+ llm=st.session_state.llm,
273
+ chain_type="stuff",
274
+ retriever=st.session_state.retriever,
275
+ memory=MEMORY,
276
+ )
277
 
278
+ else:
279
+ # --- Regular Chat ---
280
+ chat_prompt = ChatPromptTemplate.from_messages(
281
+ [
282
+ (
283
+ "system",
284
+ system_prompt + "\nIt's currently {time}.",
285
+ ),
286
+ MessagesPlaceholder(variable_name="chat_history"),
287
+ ("human", "{query}"),
288
+ ],
289
+ ).partial(time=lambda: str(datetime.now()))
290
+ st.session_state.chain = LLMChain(
291
+ prompt=chat_prompt,
292
+ llm=st.session_state.llm,
293
+ memory=MEMORY,
294
+ )
295
 
296
+ # --- Chat Input ---
297
+ prompt = st.chat_input(placeholder="Ask me a question!")
298
+ if prompt:
299
+ st.chat_message("user").write(prompt)
300
+ feedback_update = None
301
+ feedback = None
 
 
 
 
 
 
 
 
 
302
 
303
+ # --- Chat Output ---
304
+ with st.chat_message("assistant", avatar="🦜"):
305
+ callbacks = [RUN_COLLECTOR]
 
 
 
 
 
 
 
 
 
 
306
 
307
+ if st.session_state.ls_tracer:
308
+ callbacks.append(st.session_state.ls_tracer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
+ use_document_chat = all(
311
+ [
312
+ document_chat,
313
+ st.session_state.doc_chain,
314
+ st.session_state.retriever,
315
+ ],
316
+ )
317
 
318
+ try:
319
+ if use_document_chat:
320
+ st_handler = StreamlitCallbackHandler(st.container())
321
+ callbacks.append(st_handler)
322
+ full_response = st.session_state.doc_chain(
323
+ {"query": prompt},
324
+ callbacks=callbacks,
325
+ tags=["Streamlit Chat"],
326
+ return_only_outputs=True,
327
+ )[st.session_state.doc_chain.output_key]
328
+ st_handler._complete_current_thought()
329
+ st.markdown(full_response)
330
+ else:
331
+ message_placeholder = st.empty()
332
+ stream_handler = StreamHandler(message_placeholder)
333
+ callbacks.append(stream_handler)
334
+ full_response = st.session_state.chain(
335
+ {"query": prompt},
336
+ callbacks=callbacks,
337
+ tags=["Streamlit Chat"],
338
+ return_only_outputs=True,
339
+ )[st.session_state.chain.output_key]
340
+ message_placeholder.markdown(full_response)
341
+ except (openai.error.AuthenticationError, anthropic.AuthenticationError):
342
+ st.error(
343
+ f"Please enter a valid {provider} API key.",
344
+ icon="❌",
345
+ )
346
+ full_response = None
347
+ if full_response:
348
+ # --- Tracing ---
349
+ if st.session_state.client:
350
+ st.session_state.run = RUN_COLLECTOR.traced_runs[0]
351
+ st.session_state.run_id = st.session_state.run.id
352
+ RUN_COLLECTOR.traced_runs = []
353
+ wait_for_all_tracers()
354
+ st.session_state.trace_link = st.session_state.client.read_run(
355
+ st.session_state.run_id,
356
+ ).url
357
+ if st.session_state.trace_link:
358
+ with sidebar:
359
+ st.markdown(
360
+ f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
361
+ unsafe_allow_html=True,
362
+ )
363
 
364
+ # --- Feedback ---
365
+ if st.session_state.client and st.session_state.run_id:
366
+ feedback = streamlit_feedback(
367
+ feedback_type=feedback_option,
368
+ optional_text_label="[Optional] Please provide an explanation",
369
+ key=f"feedback_{st.session_state.run_id}",
370
+ )
371
 
372
+ # Define score mappings for both "thumbs" and "faces" feedback systems
373
+ score_mappings: dict[str, dict[str, Union[int, float]]] = {
374
+ "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
375
+ "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
376
+ }
377
 
378
+ # Get the score mapping based on the selected feedback option
379
+ scores = score_mappings[feedback_option]
 
 
 
380
 
381
+ if feedback:
382
+ # Get the score from the selected feedback option's score mapping
383
+ score = scores.get(
384
+ feedback["score"],
 
 
385
  )
386
+
387
+ if score is not None:
388
+ # Formulate feedback type string incorporating the feedback option
389
+ # and score value
390
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
391
+
392
+ # Record the feedback with the formulated feedback type string
393
+ # and optional comment
394
+ feedback_record = st.session_state.client.create_feedback(
395
+ st.session_state.run_id,
396
+ feedback_type_str,
397
+ score=score,
398
+ comment=feedback.get("text"),
399
+ )
400
+ # feedback = {
401
+ # "feedback_id": str(feedback_record.id),
402
+ # "score": score,
403
+ # }
404
+ st.toast("Feedback recorded!", icon="πŸ“")
405
+ else:
406
+ st.warning("Invalid feedback score.")
407
 
408
  else:
409
  st.error(f"Please enter a valid {provider} API key.", icon="❌")
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
  anthropic==0.3.11
 
2
  langchain==0.0.293
3
  langsmith==0.0.38
4
  openai==0.28.0
 
5
  streamlit==1.26.0
6
  streamlit-feedback==0.1.2
7
  tiktoken==0.5.1
 
1
  anthropic==0.3.11
2
+ faiss-cpu==1.7.4
3
  langchain==0.0.293
4
  langsmith==0.0.38
5
  openai==0.28.0
6
+ pypdf==3.16.1
7
  streamlit==1.26.0
8
  streamlit-feedback==0.1.2
9
  tiktoken==0.5.1