ZeeAI1 commited on
Commit
08557bb
·
verified ·
1 Parent(s): eb5fde5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -43
app.py CHANGED
@@ -1,49 +1,168 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- from sentence_transformers import SentenceTransformer
3
- from datasets import load_dataset
4
- import faiss
5
- import numpy as np
6
  import streamlit as st
 
 
 
 
 
 
7
 
8
- # Load a public legal guidance dataset
9
- dataset = load_dataset("lex_glue", "ecthr_a")
10
- texts = dataset['train']['text'][:100] # Limiting to 100 samples for efficiency
11
-
12
- # Initialize Sentence-BERT for document encoding and T5 for summarization
13
- sbert_model = SentenceTransformer("all-mpnet-base-v2")
14
- t5_tokenizer = AutoTokenizer.from_pretrained("t5-small")
15
- t5_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
16
-
17
- # Encode the legal guidance texts and build FAISS index
18
- case_embeddings = sbert_model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
19
- index = faiss.IndexFlatL2(case_embeddings.shape[1])
20
- index.add(np.array(case_embeddings.cpu()))
21
-
22
- # Function to retrieve similar cases
23
- def retrieve_cases(query, top_k=3):
24
- query_embedding = sbert_model.encode(query, convert_to_tensor=True)
25
- _, indices = index.search(np.array([query_embedding.cpu()]), top_k)
26
- return [(texts[i], i) for i in indices[0]]
27
-
28
- # Function to summarize a given text
29
- def summarize_text(text):
30
- inputs = t5_tokenizer("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
31
- outputs = t5_model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
32
- return t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
33
-
34
- # Streamlit UI for LawyerGuide App
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def main():
36
- st.title("LawyerGuide App: Legal Guidance for False Accusations")
37
- query = st.text_input("Describe your situation or legal concern:")
38
- top_k = st.slider("Number of similar cases to retrieve:", 1, 5, 3)
39
-
40
- if st.button("Get Guidance"):
41
- results = retrieve_cases(query, top_k=top_k)
42
- for i, (case_text, index) in enumerate(results):
43
- st.subheader(f"Guidance {i+1}")
44
- st.write("Relevant Text:", case_text)
45
- summary = summarize_text(case_text)
46
- st.write("Summary of Legal Guidance:", summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
  main()
 
1
+ import os
 
 
 
 
2
  import streamlit as st
3
+ import pdfplumber
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import FAISS
8
+ from transformers import pipeline, M2M100ForConditionalGeneration, AutoTokenizer
9
 
10
+ # Set up the page configuration
11
+ st.set_page_config(page_title="RAG-based PDF Chat", layout="centered", page_icon="📄")
12
+
13
+ # Load the summarization pipeline model
14
+ @st.cache_resource
15
+ def load_summarization_pipeline():
16
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
17
+ return summarizer
18
+
19
+ summarizer = load_summarization_pipeline()
20
+
21
+ # Load the translation model
22
+ @st.cache_resource
23
+ def load_translation_model():
24
+ model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100")
25
+ tokenizer = AutoTokenizer.from_pretrained("alirezamsh/small100")
26
+ return model, tokenizer
27
+
28
+ translation_model, translation_tokenizer = load_translation_model()
29
+
30
+ # Define available languages for translation
31
+ LANGUAGES = {
32
+ "English": "en",
33
+ "French": "fr",
34
+ "Spanish": "es",
35
+ "Chinese": "zh",
36
+ "Hindi": "hi",
37
+ "Urdu": "ur",
38
+ }
39
+
40
+ # Split text into manageable chunks
41
+ @st.cache_data
42
+ def get_text_chunks(text):
43
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
44
+ chunks = text_splitter.split_text(text)
45
+ return chunks
46
+
47
+ # Initialize embedding function
48
+ embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
49
+
50
+ # Create a FAISS vector store with embeddings
51
+ @st.cache_resource
52
+ def load_or_create_vector_store(text_chunks):
53
+ if not text_chunks:
54
+ st.error("No valid text chunks found to create a vector store. Please check your PDF files.")
55
+ return None
56
+ vector_store = FAISS.from_texts(text_chunks, embedding=embedding_function)
57
+ return vector_store
58
+
59
+ # Helper function to process a single PDF
60
+ def process_single_pdf(file_path):
61
+ text = ""
62
+ try:
63
+ with pdfplumber.open(file_path) as pdf:
64
+ for page in pdf.pages:
65
+ page_text = page.extract_text()
66
+ if page_text:
67
+ text += page_text
68
+ except Exception as e:
69
+ st.error(f"Failed to read PDF: {file_path} - {e}")
70
+ return text
71
+
72
+ # Load PDFs with progress display
73
+ def load_pdfs_with_progress(folder_path):
74
+ all_text = ""
75
+ pdf_files = [os.path.join(folder_path, filename) for filename in os.listdir(folder_path) if filename.endswith('.pdf')]
76
+ num_files = len(pdf_files)
77
+
78
+ if num_files == 0:
79
+ st.error("No PDF files found in the specified folder.")
80
+ st.session_state['vector_store'] = None
81
+ st.session_state['loading'] = False
82
+ return
83
+
84
+ st.markdown("### Loading data...")
85
+ progress_bar = st.progress(0)
86
+ status_text = st.empty()
87
+
88
+ processed_count = 0
89
+
90
+ for file_path in pdf_files:
91
+ result = process_single_pdf(file_path)
92
+ all_text += result
93
+ processed_count += 1
94
+ progress_percentage = int((processed_count / num_files) * 100)
95
+ progress_bar.progress(processed_count / num_files)
96
+ status_text.text(f"Loading documents: {progress_percentage}% completed")
97
+
98
+ progress_bar.empty()
99
+ status_text.text("Document loading completed!")
100
+
101
+ if all_text:
102
+ text_chunks = get_text_chunks(all_text)
103
+ vector_store = load_or_create_vector_store(text_chunks)
104
+ st.session_state['vector_store'] = vector_store
105
+ else:
106
+ st.session_state['vector_store'] = None
107
+
108
+ st.session_state['loading'] = False
109
+
110
+ # Generate summary based on retrieved text
111
+ def generate_summary_with_huggingface(query, retrieved_text):
112
+ summarization_input = f"{query} Related information:{retrieved_text}"
113
+ max_input_length = 1024
114
+ summarization_input = summarization_input[:max_input_length]
115
+ summary = summarizer(summarization_input, max_length=500, min_length=50, do_sample=False)
116
+ return summary[0]["summary_text"]
117
+
118
+ # Generate response for user query
119
+ def user_input(user_question):
120
+ vector_store = st.session_state.get('vector_store')
121
+ if vector_store is None:
122
+ return "The app is still loading documents or no documents were successfully loaded."
123
+ docs = vector_store.similarity_search(user_question)
124
+ context_text = " ".join([doc.page_content for doc in docs])
125
+ return generate_summary_with_huggingface(user_question, context_text)
126
+
127
+ # Translate text to selected language
128
+ def translate_text(text, target_lang):
129
+ translation_tokenizer.tgt_lang = target_lang
130
+ encoded_text = translation_tokenizer(text, return_tensors="pt")
131
+ generated_tokens = translation_model.generate(**encoded_text)
132
+ translated_text = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
133
+ return translated_text
134
+
135
+ # Main function to run the Streamlit app
136
  def main():
137
+ st.markdown(
138
+ """
139
+ <h1 style="font-size:30px; text-align: center;">
140
+ 📄 JusticeCompass: Your AI-Powered Legal Navigator for Swift, Accurate Guidance.
141
+ </h1>
142
+ """,
143
+ unsafe_allow_html=True
144
+ )
145
+
146
+ if 'loading' not in st.session_state or st.session_state['loading']:
147
+ st.session_state['loading'] = True
148
+ load_pdfs_with_progress('documents1')
149
+
150
+ user_question = st.text_input("Ask a Question:", placeholder="Type your question here...")
151
+
152
+ # Display language selection dropdown
153
+ selected_language = st.selectbox("Select output language:", list(LANGUAGES.keys()))
154
+
155
+ if st.session_state.get('loading', True):
156
+ st.info("The app is loading documents in the background. You can type your question now and submit once loading is complete.")
157
+
158
+ # Only display "Get Response" button after user enters a question
159
+ if user_question:
160
+ if st.button("Get Response"):
161
+ with st.spinner("Generating response..."):
162
+ answer = user_input(user_question)
163
+ target_lang_code = LANGUAGES[selected_language]
164
+ translated_answer = translate_text(answer, target_lang_code)
165
+ st.markdown(f"**🤖 AI ({selected_language}):** {translated_answer}")
166
 
167
  if __name__ == "__main__":
168
  main()