captain-awesome's picture
Update app.py
e8b8b3e verified
from langchain_community.llms import CTransformers
from ctransformers import AutoModelForCausalLM
from langchain.agents import Tool
from langchain.agents import AgentType, initialize_agent
from langchain.chains import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import tempfile
import os
import streamlit as st
import timeit
from langchain.callbacks.tracers import ConsoleCallbackHandler
# tt
def main():
FILE_LOADER_MAPPING = {
"pdf": (PyPDFLoader, {})
# Add more mappings for other file extensions and loaders as needed
}
st.title("Document Comparison with Q&A using Agents")
# Upload files
uploaded_files = st.file_uploader("Upload your documents", type=["pdf"], accept_multiple_files=True)
loaded_documents = []
if uploaded_files:
# Create a temporary directory
with tempfile.TemporaryDirectory() as td:
# Move the uploaded files to the temporary directory and process them
for uploaded_file in uploaded_files:
st.write(f"Uploaded: {uploaded_file.name}")
ext = os.path.splitext(uploaded_file.name)[-1][1:].lower()
st.write(f"Uploaded: {ext}")
# Check if the extension is in FILE_LOADER_MAPPING
if ext in FILE_LOADER_MAPPING:
loader_class, loader_args = FILE_LOADER_MAPPING[ext]
# st.write(f"loader_class: {loader_class}")
# Save the uploaded file to the temporary directory
file_path = os.path.join(td, uploaded_file.name)
with open(file_path, 'wb') as temp_file:
temp_file.write(uploaded_file.read())
# Use Langchain loader to process the file
loader = loader_class(file_path, **loader_args)
loaded_documents.extend(loader.load())
else:
st.warning(f"Unsupported file extension: {ext}, the app currently only supports pdf")
st.write("Ask question to get comparison from the documents:")
query = st.text_input("Ask a question:")
if st.button("Get Answer"):
if query:
# Load model, set prompts, create vector database, and retrieve answer
try:
start = timeit.default_timer()
# config = {
# 'max_new_tokens': 1024,
# 'repetition_penalty': 1.1,
# 'temperature': 0.1,
# 'top_k': 50,
# 'top_p': 0.9,
# 'stream': True,
# 'threads': int(os.cpu_count() / 2)
# }
llm = CTransformers(
# model = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
model= "TheBloke/Llama-2-7B-Chat-GGUF",
model_file = "llama-2-7b-chat.Q3_K_S.gguf",
model_type="llama",
max_new_tokens = 300,
temperature = 0.3,
lib="avx2", # for CPU
)
# llm = AutoModelForCausalLM.from_pretrained("second-state/stablelm-2-zephyr-1.6b-GGUF", model_type="stablelm-2-zephyr-1_6b-Q4_0.gguf")
print("LLM Initialized...")
model_name = "BAAI/bge-large-en"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunked_documents = text_splitter.split_documents(loaded_documents)
retriever = FAISS.from_documents(chunked_documents, embeddings).as_retriever()
# Wrap retrievers in a Tool
tools = []
tools.append(
Tool(
name="Comparison tool",
description="useful when you want to answer questions about the uploaded documents",
func=RetrievalQA.from_chain_type(llm=llm, retriever=retriever),
)
)
agent = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True
)
# response = agent.run(query)
end = timeit.default_timer()
st.write("Elapsed time:")
st.write(end - start)
st.write("Bot Response:")
# st.write(agent.invoke(query, config={"callbacks":[ConsoleCallbackHandler()]}))
st.write(agent.run({"input": query}))
# st.write(response)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
else:
st.warning("Please enter a question.")
if __name__ == "__main__":
main()