import os
import streamlit as st
from PIL import Image, ImageOps
from langchain_openai import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from dotenv import load_dotenv
from langchain_community.embeddings.bedrock import BedrockEmbeddings
load_dotenv()
# Hyperparameters
PDF_CHUNK_SIZE = 1024
PDF_CHUNK_OVERLAP = 256
k = 9
# Load favicon image
def load_and_pad_image(image_path, size=(64, 64)):
img = Image.open(image_path)
return ImageOps.pad(img, size)
favicon_path = "medical.png"
favicon_image = load_and_pad_image(favicon_path)
# Streamlit Page Config
st.set_page_config(
page_title="Chatbot",
page_icon=favicon_image,
)
# Set up logo and title
col1, col2 = st.columns([1, 8])
with col1:
st.image(favicon_image)
with col2:
st.markdown(
"""
Chatbot
""", unsafe_allow_html=True
)
# Model and Embedding Selection
model_options = ["gpt-4o", "gpt-4o-mini"] #, "deepseek-chat"
selected_model = st.selectbox("Choose a GPT model", model_options)
embedding_model_options = ["OpenAI"] #, "Huggingface MedEmbed"
selected_embedding_model = st.selectbox("Choose an Embedding model", embedding_model_options)
# Load the model
def get_llm(selected_model):
api_key = os.getenv("DeepSeek_API_KEY") if selected_model == "deepseek-chat" else os.getenv("OPENAI_API_KEY")
return ChatOpenAI(
model=selected_model,
temperature=0,
max_tokens=None,
api_key=api_key,
)
# Cache the vector store loading
# @st.cache_resource
# def load_vector_store(selected_embedding_model):
# if selected_embedding_model == "OpenAI":
# embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=os.getenv("OPENAI_API_KEY"))
# return FAISS.load_local("faiss_index_medical_OpenAI", embeddings, allow_dangerous_deserialization=True)
# else:
# embeddings = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-large-v0.1")
# return FAISS.load_local("faiss_index_medical_MedEmbed", embeddings, allow_dangerous_deserialization=True)
@st.cache_resource
def load_vector_store(selected_embedding_model):
if selected_embedding_model == "OpenAI":
embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=os.getenv("OPENAI_API_KEY"))
return FAISS.load_local("faiss_table", embeddings, allow_dangerous_deserialization=True)
else:
embeddings = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-large-v0.1")
return FAISS.load_local("faiss_index_medical_MedEmbed", embeddings, allow_dangerous_deserialization=True)
# Load the selected vector store
vector_store = load_vector_store(selected_embedding_model)
llm = get_llm(selected_model)
# Main App Logic
def main():
st.session_state['knowledge_base'] = vector_store
st.header("Ask a Question")
question = st.text_input("Enter your question")
if st.button("Get Answer"):
knowledge_base = st.session_state['knowledge_base']
retriever = knowledge_base.as_retriever(search_kwargs={"k": k})
compressor = FlashrankRerank()
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
system_prompt = """
You are a friendly and knowledgeable assistant who is an expert in medical education who will only answer from the context provided. You need to understand the best context to answer the question.
"""
template = f"""
{system_prompt}
-------------------------------
Context: {{context}}
Question: {{question}}
Answer:
"""
prompt = PromptTemplate(
template=template,
input_variables=['context', 'question']
)
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=compression_retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt}
)
response = qa_chain.invoke({"query": question})
st.write(f"**Answer:** {response['result']}")
if __name__ == "__main__":
main()