# __import__('pysqlite3')
# import sys
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
# modify comment
import streamlit as st
from langchain.llms import Ollama
import os
import chromadb
from constants import CHROMA_SETTINGS
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from chromadb.config import Settings
from langchain.vectorstores import Chroma
from langchain.llms import Ollama
from langchain.callbacks.base import BaseCallbackHandler
# Custom streamlit handler to display LLM outputs in stream mode
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text=initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text+=token+""
self.container.markdown(self.text)
# streamlit UI configuration
def setup_page():
st.set_page_config(layout="wide")
st.markdown("
Your Personal MBA
" , unsafe_allow_html=True)
url = 'https://personalmba.com/'
col1, col2, col3= st.columns(3)
with col2:
st.markdown("""
""" % url, unsafe_allow_html=True)
st.divider()
# get necessary environment variables for later use
def get_environment_variables():
model = os.environ.get("MODEL", "mistral")
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME", "all-MiniLM-L6-v2")
persist_directory = os.environ.get("PERSIST_DIRECTORY", "db")
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS', 4))
return model, embeddings_model_name, persist_directory, target_source_chunks
# create knowledge base retriever
def create_knowledge_base(embeddings_model_name, persist_directory, target_source_chunks):
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
return retriever
# handle query when user hit 'enter' on a question
def handle_query(query, model, retriever):
with st.chat_message('assistant'):
with st.spinner("Generating answer..."):
message_placeholder = st.empty()
stream_handler = StreamHandler(message_placeholder)
llm = Ollama(model=model, callbacks=[stream_handler])
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)
res = qa(query)
answer = res['result']
message_placeholder.markdown(answer)
return answer
# dictionary to store the previous messages, create a 'memory' for the LLM
def initialize_session():
if 'messages' not in st.session_state:
st.session_state.messages = []
# display the messages
def display_messages():
for message in st.session_state.messages:
with st.chat_message(message['role']):
st.markdown(message['content'])
# example questions when user first load up the app. Will disappear after user send the first query
def show_examples():
examples = st.empty()
with examples.container():
with st.chat_message('assistant'):
st.markdown('Example questions:')
st.markdown(' - How do I know that I am making the right decisions?')
st.markdown(' - What are the key ideas in Chapter 6 "The Human Mind"?')
st.markdown(' - What are common traits shared by the most sucessful individuals in the world?')
st.markdown(' - I want to be a millionaire, build me a 5 year roadmap based on the top 0.01 percent of the human population.')
st.markdown('So, how may I help you today?')
return examples
def main():
setup_page()
initialize_session()
display_messages()
examples = show_examples()
model, embeddings_model_name, persist_directory, target_source_chunks = get_environment_variables()
retriever = create_knowledge_base(embeddings_model_name, persist_directory, target_source_chunks)
query = st.chat_input(placeholder='Ask a question...') # starting with empty query
if query: # if user input a query and hit 'Enter'
examples.empty()
st.session_state.messages.append({ # add the query into session state/ dictionary
'role': 'user',
'content': query
})
with st.chat_message('user'):
st.markdown(query)
answer = handle_query(query, model, retriever)
st.session_state.messages.append({ # add the answer into session state/ dictionary
'role': 'assistant',
'content': answer
})
if __name__ == "__main__":
main()