ZeeAI1's picture
Update app.py
28296a9 verified
import os
import streamlit as st
import pdfplumber
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from transformers import pipeline, M2M100ForConditionalGeneration, AutoTokenizer
# Set up the page configuration
st.set_page_config(page_title="RAG-based PDF Chat", layout="centered", page_icon="πŸ“„")
# Load the summarization pipeline model
@st.cache_resource
def load_summarization_pipeline():
return pipeline("summarization", model="facebook/bart-large-cnn")
summarizer = load_summarization_pipeline()
# Load the translation model
@st.cache_resource
def load_translation_model():
model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100")
tokenizer = AutoTokenizer.from_pretrained("alirezamsh/small100")
return model, tokenizer
translation_model, translation_tokenizer = load_translation_model()
# Define available languages for translation
LANGUAGES = {
"English": "en",
"French": "fr",
"Spanish": "es",
"Chinese": "zh",
"Hindi": "hi",
"Urdu": "ur",
}
# Split text into manageable chunks
@st.cache_data
def get_text_chunks(text):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
return text_splitter.split_text(text)
# Initialize embedding function
embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Create a FAISS vector store with embeddings
@st.cache_resource
def load_or_create_vector_store(text_chunks):
return FAISS.from_texts(text_chunks, embedding=embedding_function) if text_chunks else None
# Helper function to process a single PDF
def process_single_pdf(file_path):
text = ""
try:
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
if page_text:
text += page_text
except Exception as e:
st.error(f"Failed to read PDF: {file_path} - {e}")
return text
# Load PDFs with progress display
def load_pdfs_with_progress(folder_path):
if not os.path.exists(folder_path):
st.error(f"The folder '{folder_path}' does not exist. Please create it and add PDF files.")
return None
all_text = ""
pdf_files = [os.path.join(folder_path, filename) for filename in os.listdir(folder_path) if filename.endswith('.pdf')]
if not pdf_files:
st.error("No PDF files found in the specified folder.")
return None
st.markdown("### Loading data...")
progress_bar = st.progress(0)
for i, file_path in enumerate(pdf_files):
all_text += process_single_pdf(file_path)
progress_bar.progress((i + 1) / len(pdf_files))
progress_bar.empty()
return load_or_create_vector_store(get_text_chunks(all_text)) if all_text else None
# Generate summary based on retrieved text
def generate_summary(query, retrieved_text):
summarization_input = f"{query} Related information:{retrieved_text}"[:1024]
summary = summarizer(summarization_input, max_length=500, min_length=50, do_sample=False)
return summary[0]["summary_text"]
# Translate text to selected language
def translate_text(text, target_lang):
translation_tokenizer.tgt_lang = target_lang
encoded_text = translation_tokenizer(text, return_tensors="pt")
generated_tokens = translation_model.generate(**encoded_text)
return translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Main function to run the Streamlit app
def main():
st.markdown(
"""
<h1 style="font-size:30px; text-align: center;">
πŸ“„ JusticeCompass: Your AI-Powered Legal Navigator for Swift, Accurate Guidance.
</h1>
""",
unsafe_allow_html=True
)
if "vector_store" not in st.session_state:
st.session_state["vector_store"] = load_pdfs_with_progress('documents1')
if st.session_state["vector_store"] is None:
return
# Prompt input
user_question = st.text_input("Ask a Question:", placeholder="Type your question here...")
# Language selection dropdown
selected_language = st.selectbox("Select output language:", list(LANGUAGES.keys()))
if user_question and st.button("Get Response"):
with st.spinner("Generating response..."):
docs = st.session_state["vector_store"].similarity_search(user_question)
context_text = " ".join([doc.page_content for doc in docs])
answer = generate_summary(user_question, context_text)
translated_answer = translate_text(answer, LANGUAGES[selected_language])
st.markdown(f"**πŸ€– AI ({selected_language}):** {translated_answer}")
if __name__ == "__main__":
main()