|
|
|
from langchain_community.llms import CTransformers |
|
from ctransformers import AutoModelForCausalLM |
|
from langchain.agents import Tool |
|
from langchain.agents import AgentType, initialize_agent |
|
from langchain.chains import RetrievalQA |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
import tempfile |
|
import os |
|
import streamlit as st |
|
import timeit |
|
from langchain.callbacks.tracers import ConsoleCallbackHandler |
|
|
|
|
|
def main(): |
|
|
|
FILE_LOADER_MAPPING = { |
|
"pdf": (PyPDFLoader, {}) |
|
|
|
} |
|
|
|
st.title("Document Comparison with Q&A using Agents") |
|
|
|
|
|
|
|
|
|
uploaded_files = st.file_uploader("Upload your documents", type=["pdf"], accept_multiple_files=True) |
|
loaded_documents = [] |
|
|
|
if uploaded_files: |
|
|
|
with tempfile.TemporaryDirectory() as td: |
|
|
|
for uploaded_file in uploaded_files: |
|
st.write(f"Uploaded: {uploaded_file.name}") |
|
ext = os.path.splitext(uploaded_file.name)[-1][1:].lower() |
|
st.write(f"Uploaded: {ext}") |
|
|
|
|
|
if ext in FILE_LOADER_MAPPING: |
|
loader_class, loader_args = FILE_LOADER_MAPPING[ext] |
|
|
|
|
|
|
|
file_path = os.path.join(td, uploaded_file.name) |
|
with open(file_path, 'wb') as temp_file: |
|
temp_file.write(uploaded_file.read()) |
|
|
|
|
|
loader = loader_class(file_path, **loader_args) |
|
loaded_documents.extend(loader.load()) |
|
else: |
|
st.warning(f"Unsupported file extension: {ext}, the app currently only supports pdf") |
|
|
|
st.write("Ask question to get comparison from the documents:") |
|
query = st.text_input("Ask a question:") |
|
|
|
|
|
|
|
|
|
if st.button("Get Answer"): |
|
if query: |
|
|
|
try: |
|
start = timeit.default_timer() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = CTransformers( |
|
|
|
model= "TheBloke/Llama-2-7B-Chat-GGUF", |
|
model_file = "llama-2-7b-chat.Q3_K_S.gguf", |
|
model_type="llama", |
|
max_new_tokens = 300, |
|
temperature = 0.3, |
|
lib="avx2", |
|
) |
|
|
|
|
|
print("LLM Initialized...") |
|
|
|
|
|
|
|
model_name = "BAAI/bge-large-en" |
|
model_kwargs = {'device': 'cpu'} |
|
encode_kwargs = {'normalize_embeddings': False} |
|
embeddings = HuggingFaceBgeEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs |
|
) |
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
chunked_documents = text_splitter.split_documents(loaded_documents) |
|
retriever = FAISS.from_documents(chunked_documents, embeddings).as_retriever() |
|
|
|
|
|
tools = [] |
|
tools.append( |
|
Tool( |
|
name="Comparison tool", |
|
description="useful when you want to answer questions about the uploaded documents", |
|
func=RetrievalQA.from_chain_type(llm=llm, retriever=retriever), |
|
) |
|
) |
|
|
|
agent = initialize_agent( |
|
tools=tools, |
|
llm=llm, |
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, |
|
verbose=True |
|
) |
|
|
|
|
|
|
|
end = timeit.default_timer() |
|
st.write("Elapsed time:") |
|
st.write(end - start) |
|
|
|
st.write("Bot Response:") |
|
|
|
st.write(agent.run({"input": query})) |
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
else: |
|
st.warning("Please enter a question.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|