|
import tempfile |
|
import fitz |
|
import streamlit as st |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.llms import CTransformers |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from PIL import Image |
|
from streamlit_chat import message |
|
|
|
st.set_page_config( |
|
page_title="AskBot", |
|
page_icon=":robot_face:", |
|
layout="wide" |
|
) |
|
|
|
st.sidebar.title(""" |
|
Ask A Bot :robot_face: \n Talk with your PDFs |
|
""") |
|
|
|
st.sidebar.write(""" |
|
###### A Q&A chatbot for you to talk with your PDFs. |
|
###### Upload the PDF you want to talk to and start asking questions. The display will show the page where the answer was found. |
|
###### When you upload a new PDF, the chat history is reset for you to start fresh. |
|
###### The chatbot is based on Langchain and the Llama language model, which is a large language model trained on the Common Crawl dataset. Obtained from [here](https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML). |
|
###### The performance of this bot is limited due to its size. For better performance, a larger LLM should be used. |
|
###### :warning: Sometimes the Streamlit app will not re-run and refresh the PDF. If this happens, refresh the page. |
|
###### Developed by [Carlos Pereira](https://linkedin.com/in/carlos-miguel-pereira/). |
|
""") |
|
|
|
if 'pdf_page' not in st.session_state: |
|
st.session_state['pdf_page'] = 0 |
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state['chat_history'] = [] |
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = [] |
|
|
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = [] |
|
|
|
def update_state(): |
|
""" |
|
Reset state when a new PDF is uploaded |
|
""" |
|
st.session_state.pdf_page = 0 |
|
st.session_state.chat_history = [] |
|
st.session_state['generated'] = [] |
|
st.session_state['past'] = [] |
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_llm(): |
|
""" |
|
Load Llama LLM |
|
""" |
|
llm_model = CTransformers( |
|
model="llama-2-13b-chat.ggmlv3.q3_K_L.bin", |
|
model_type="llama", |
|
max_new_tokens=150, |
|
temperature=0.2 |
|
) |
|
return llm_model |
|
|
|
@st.cache_resource(show_spinner=False) |
|
def gen_embeddings(): |
|
""" |
|
Generate embeddings |
|
""" |
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", |
|
model_kwargs={'device': 'cpu'}) |
|
|
|
return embeddings |
|
|
|
def load_pdf(file): |
|
""" |
|
Load PDF and process for Search |
|
""" |
|
|
|
temp_file = tempfile.NamedTemporaryFile() |
|
temp_file.write(file.getbuffer()) |
|
loader = PyPDFLoader(temp_file.name) |
|
documents = loader.load() |
|
pdf_file = fitz.open(temp_file.name) |
|
temp_file.close() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
texts = text_splitter.split_documents(documents) |
|
|
|
|
|
embeddings = gen_embeddings() |
|
pdf_search = Chroma.from_documents(texts, embeddings) |
|
|
|
return pdf_search, pdf_file |
|
|
|
def generate_chain(pdf_vector, llm): |
|
""" |
|
Generate Retrieval chain |
|
""" |
|
chain = ConversationalRetrievalChain.from_llm(llm, |
|
chain_type="stuff", |
|
retriever=pdf_vector.as_retriever(search_kwargs={"k": 1}), |
|
return_source_documents=True) |
|
|
|
return chain |
|
|
|
def get_answer(chain, query, chat_history): |
|
""" |
|
Get an answer from the chain |
|
""" |
|
result = chain({"question": query, 'chat_history': chat_history}, return_only_outputs=True) |
|
answer = result["answer"] |
|
|
|
|
|
st.session_state.pdf_page = list(result['source_documents'][0])[1][1]['page'] |
|
|
|
return answer |
|
|
|
def render_page_file(file, page): |
|
""" |
|
Render page from PDF file |
|
""" |
|
try: |
|
page = file[page] |
|
except: |
|
page = file[0] |
|
st.session_state.pdf_page = 0 |
|
|
|
|
|
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) |
|
image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples) |
|
|
|
return image |
|
|
|
uploaded_file = st.file_uploader("Upload your PDF", type=["pdf"], |
|
accept_multiple_files=False, |
|
on_change=update_state) |
|
|
|
def app(): |
|
""" |
|
Main app |
|
""" |
|
if uploaded_file: |
|
|
|
with st.spinner('Loading LLM...'): |
|
llm = load_llm() |
|
|
|
with st.spinner('Loading PDF...'): |
|
pdf_vector, pdf_file = load_pdf(uploaded_file) |
|
with st.spinner('Generating chain...'): |
|
chain = generate_chain(pdf_vector, llm) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
|
|
with st.form(key='question_form', clear_on_submit=True): |
|
question = st.text_input('Enter your question:', value="", key='text_value') |
|
submit_question = st.form_submit_button(label="Enter") |
|
|
|
if submit_question: |
|
with st.spinner('Getting answer...'): |
|
answer = get_answer(chain, question, |
|
st.session_state.chat_history) |
|
st.session_state.past.append(question) |
|
st.session_state.generated.append(answer) |
|
|
|
if st.session_state['generated']: |
|
for i in range(len(st.session_state['generated'])-1, -1, -1): |
|
message(st.session_state["generated"][i], is_user=False, |
|
avatar_style="bottts", key=str(i)) |
|
message(st.session_state['past'][i], is_user=True, |
|
avatar_style="adventurer", key=str(i) + '_user') |
|
|
|
with col2: |
|
|
|
if pdf_file: |
|
st.image(render_page_file(pdf_file, st.session_state.pdf_page)) |
|
|
|
if __name__ == "__main__": |
|
app() |
|
|