ask-a-bot / app.py
cpereira's picture
Added streamlit app and requirements
3a2813b
import tempfile
import fitz
import streamlit as st
from langchain.chains import ConversationalRetrievalChain
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import CTransformers
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from PIL import Image
from streamlit_chat import message
st.set_page_config(
page_title="AskBot",
page_icon=":robot_face:",
layout="wide"
)
st.sidebar.title("""
Ask A Bot :robot_face: \n Talk with your PDFs
""")
st.sidebar.write("""
###### A Q&A chatbot for you to talk with your PDFs.
###### Upload the PDF you want to talk to and start asking questions. The display will show the page where the answer was found.
###### When you upload a new PDF, the chat history is reset for you to start fresh.
###### The chatbot is based on Langchain and the Llama language model, which is a large language model trained on the Common Crawl dataset. Obtained from [here](https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML).
###### The performance of this bot is limited due to its size. For better performance, a larger LLM should be used.
###### :warning: Sometimes the Streamlit app will not re-run and refresh the PDF. If this happens, refresh the page.
###### Developed by [Carlos Pereira](https://linkedin.com/in/carlos-miguel-pereira/).
""")
if 'pdf_page' not in st.session_state:
st.session_state['pdf_page'] = 0
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
def update_state():
"""
Reset state when a new PDF is uploaded
"""
st.session_state.pdf_page = 0
st.session_state.chat_history = []
st.session_state['generated'] = []
st.session_state['past'] = []
@st.cache_resource(show_spinner=False)
def load_llm():
"""
Load Llama LLM
"""
llm_model = CTransformers(
model="llama-2-13b-chat.ggmlv3.q3_K_L.bin",
model_type="llama",
max_new_tokens=150,
temperature=0.2
)
return llm_model
@st.cache_resource(show_spinner=False)
def gen_embeddings():
"""
Generate embeddings
"""
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2",
model_kwargs={'device': 'cpu'})
return embeddings
def load_pdf(file):
"""
Load PDF and process for Search
"""
# create tempfile to load pdf to PyPDFLoader
temp_file = tempfile.NamedTemporaryFile()
temp_file.write(file.getbuffer())
loader = PyPDFLoader(temp_file.name)
documents = loader.load()
pdf_file = fitz.open(temp_file.name)
temp_file.close()
# split doc into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
#get embedding model
embeddings = gen_embeddings()
pdf_search = Chroma.from_documents(texts, embeddings)
return pdf_search, pdf_file
def generate_chain(pdf_vector, llm):
"""
Generate Retrieval chain
"""
chain = ConversationalRetrievalChain.from_llm(llm,
chain_type="stuff",
retriever=pdf_vector.as_retriever(search_kwargs={"k": 1}),
return_source_documents=True)
return chain
def get_answer(chain, query, chat_history):
"""
Get an answer from the chain
"""
result = chain({"question": query, 'chat_history': chat_history}, return_only_outputs=True)
answer = result["answer"]
# if you want history uncomment the line below
# st.session_state.chat_history += [(query, answer)]
st.session_state.pdf_page = list(result['source_documents'][0])[1][1]['page']
return answer
def render_page_file(file, page):
"""
Render page from PDF file
"""
try:
page = file[page]
except: # todo: fix this exception handling
page = file[0]
st.session_state.pdf_page = 0
# Render the PDF page as an image
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples)
return image
uploaded_file = st.file_uploader("Upload your PDF", type=["pdf"],
accept_multiple_files=False,
on_change=update_state)
def app():
"""
Main app
"""
if uploaded_file:
# Load LLM
with st.spinner('Loading LLM...'):
llm = load_llm()
# Load and process the uploaded PDF file
with st.spinner('Loading PDF...'):
pdf_vector, pdf_file = load_pdf(uploaded_file)
with st.spinner('Generating chain...'):
chain = generate_chain(pdf_vector, llm)
col1, col2 = st.columns(2)
with col1:
# Question and answering
with st.form(key='question_form', clear_on_submit=True):
question = st.text_input('Enter your question:', value="", key='text_value')
submit_question = st.form_submit_button(label="Enter")
if submit_question:
with st.spinner('Getting answer...'):
answer = get_answer(chain, question,
st.session_state.chat_history)
st.session_state.past.append(question)
st.session_state.generated.append(answer)
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])-1, -1, -1):
message(st.session_state["generated"][i], is_user=False,
avatar_style="bottts", key=str(i))
message(st.session_state['past'][i], is_user=True,
avatar_style="adventurer", key=str(i) + '_user')
with col2:
# Render PDF page
if pdf_file:
st.image(render_page_file(pdf_file, st.session_state.pdf_page))
if __name__ == "__main__":
app()