corumbus commited on
Commit
d1c8f90
·
1 Parent(s): e63a4b0
Files changed (3) hide show
  1. Dockerfile +20 -0
  2. app.py +312 -0
  3. requirements.txt +13 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+
17
+ EXPOSE 7860
18
+
19
+ # Run the Gradio app
20
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import stat
2
+ import gradio as gr
3
+ from llama_index.core.postprocessor import SimilarityPostprocessor
4
+ from llama_index.core.postprocessor import SentenceTransformerRerank
5
+ from llama_index.core.postprocessor import MetadataReplacementPostProcessor
6
+ from llama_index.core import StorageContext
7
+ import chromadb
8
+ from llama_index.vector_stores.chroma import ChromaVectorStore
9
+ import zipfile
10
+ import requests
11
+ import torch
12
+ from llama_index.core import Settings
13
+ from llama_index.llms.huggingface import HuggingFaceLLM
14
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
15
+ import sys
16
+ import logging
17
+ import os
18
+
19
+
20
+ enable_rerank = True
21
+ # sentence_window,naive,recursive_retrieval
22
+ retrieval_strategy = "sentence_window"
23
+ base_embedding_source = "hf" # local,openai,hf
24
+ # intfloat/multilingual-e5-small local:BAAI/bge-small-en-v1.5 text-embedding-3-small nvidia/NV-Embed-v2 Alibaba-NLP/gte-large-en-v1.5
25
+ base_embedding_model = "Alibaba-NLP/gte-large-en-v1.5"
26
+ # meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-3B-Instruct meta-llama/Llama-2-7b-chat-hf google/gemma-2-9b CohereForAI/c4ai-command-r-plus CohereForAI/aya-23-8B
27
+ base_llm_model = "mistralai/Mistral-7B-Instruct-v0.3"
28
+ # AdaptLLM/finance-chat
29
+ base_llm_source = "hf" # cohere,hf,anthropic
30
+ base_similarity_top_k = 20
31
+
32
+
33
+ # ChromaDB
34
+ env_extension = "_large" # _large _dev_window _large_window
35
+ db_collection = f"gte{env_extension}" # intfloat gte
36
+ read_db = True
37
+ active_chroma = True
38
+ root_path = "."
39
+ chroma_db_path = f"{root_path}/chroma_db" # ./chroma_db
40
+ # ./processed_files.json
41
+ processed_files_log = f"{root_path}/processed_files{env_extension}.json"
42
+
43
+
44
+ # check hyperparameter
45
+ if retrieval_strategy not in ["sentence_window", "naive"]: # recursive_retrieval
46
+ raise Exception(f"{retrieval_strategy} retrieval_strategy is not support")
47
+
48
+
49
+ os.environ["OPENAI_API_KEY"] = 'sk-xxxxxxxxxx'
50
+ hf_api_key = os.getenv("HF_API_KEY")
51
+
52
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
53
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
54
+
55
+
56
+ torch.cuda.empty_cache()
57
+
58
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
59
+
60
+ print(f"loading embedding ..{base_embedding_model}")
61
+ if base_embedding_source == 'hf':
62
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
63
+ Settings.embed_model = HuggingFaceEmbedding(
64
+ model_name=base_embedding_model, trust_remote_code=True) # ,
65
+ else:
66
+ raise Exception("embedding model is invalid")
67
+
68
+ # setup prompts - specific to StableLM
69
+ if base_llm_source == 'hf':
70
+ from llama_index.core import PromptTemplate
71
+
72
+ # This will wrap the default prompts that are internal to llama-index
73
+ # taken from https://huggingface.co/Writer/camel-5b-hf
74
+ query_wrapper_prompt = PromptTemplate(
75
+ "Below is an instruction that describes a task. "
76
+ "you need to make sure that user's question and retrived context mention the same stock symbol if not please give no answer to user"
77
+ "Write a response that appropriately completes the request.\n\n"
78
+ "### Instruction:\n{query_str}\n\n### Response:"
79
+ )
80
+
81
+ if base_llm_source == 'hf':
82
+ llm = HuggingFaceLLM(
83
+ context_window=2048,
84
+ max_new_tokens=512, # 256
85
+ generate_kwargs={"temperature": 0.1, "do_sample": False}, # 0.25
86
+ query_wrapper_prompt=query_wrapper_prompt,
87
+ tokenizer_name=base_llm_model,
88
+ model_name=base_llm_model,
89
+ device_map="auto",
90
+ tokenizer_kwargs={"max_length": 2048},
91
+ # uncomment this if using CUDA to reduce memory usage
92
+ model_kwargs={"torch_dtype": torch.float16}
93
+ )
94
+
95
+ Settings.chunk_size = 512
96
+ Settings.llm = llm
97
+
98
+ """#### Load documents, build the VectorStoreIndex"""
99
+
100
+
101
+ def download_and_extract_chroma_db(url, destination):
102
+ """Download and extract ChromaDB from Hugging Face Datasets."""
103
+ # Create destination folder if it doesn't exist
104
+ if not os.path.exists(destination):
105
+ os.makedirs(destination)
106
+ else:
107
+ # If the folder exists, remove it to ensure a fresh extract
108
+ print("Destination folder exists. Removing it...")
109
+ for root, dirs, files in os.walk(destination, topdown=False):
110
+ for file in files:
111
+ os.remove(os.path.join(root, file))
112
+ for dir in dirs:
113
+ os.rmdir(os.path.join(root, dir))
114
+ print("Destination folder cleared.")
115
+
116
+ db_zip_path = os.path.join(destination, "chroma_db.zip")
117
+ if not os.path.exists(db_zip_path):
118
+ # Download the ChromaDB zip file
119
+ print("Downloading ChromaDB from Hugging Face Datasets...")
120
+ headers = {
121
+ "Authorization": f"Bearer {hf_api_key}"
122
+ }
123
+ response = requests.get(url, headers=headers, stream=True)
124
+ response.raise_for_status()
125
+ with open(db_zip_path, "wb") as f:
126
+ for chunk in response.iter_content(chunk_size=8192):
127
+ f.write(chunk)
128
+ print("Download completed.")
129
+ else:
130
+ print("Zip file already exists, skipping download.")
131
+
132
+ # Extract the zip file
133
+ print("Extracting ChromaDB...")
134
+ with zipfile.ZipFile(db_zip_path, 'r') as zip_ref:
135
+ zip_ref.extractall(destination)
136
+ print("Extraction completed. Zip file retained.")
137
+
138
+
139
+ # URL to your dataset hosted on Hugging Face
140
+ chroma_db_url = "https://huggingface.co/datasets/iamboolean/set50-db/resolve/main/chroma_db.zip"
141
+
142
+ # Local destination for the ChromaDB
143
+ chroma_db_path_extract = "./" # You can change this to your desired path
144
+
145
+ # Download and extract the ChromaDB
146
+ download_and_extract_chroma_db(chroma_db_url, chroma_db_path_extract)
147
+
148
+ # Define ChromaDB client (persistent mode)er
149
+ db = chromadb.PersistentClient(path=chroma_db_path)
150
+ print(f"db path:{chroma_db_path}")
151
+ chroma_collection = db.get_or_create_collection(db_collection)
152
+ print(f"db collection:{db_collection}")
153
+
154
+
155
+ # Set up ChromaVectorStore and embeddings
156
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
157
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
158
+
159
+ document_count = chroma_collection.count()
160
+ print(f"Total documents in the collection: {document_count}")
161
+
162
+ index = VectorStoreIndex.from_vector_store(
163
+ vector_store=vector_store,
164
+ # embed_model=embed_model,
165
+ )
166
+
167
+ """#### Query Index"""
168
+
169
+
170
+ rerank = SentenceTransformerRerank(
171
+ model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=10
172
+ )
173
+ node_postprocessors = []
174
+ # node_postprocessors.append(SimilarityPostprocessor(similarity_cutoff=0.6))
175
+
176
+ if retrieval_strategy == 'sentence_window':
177
+ node_postprocessors.append(
178
+ MetadataReplacementPostProcessor(target_metadata_key="window"))
179
+
180
+
181
+ if enable_rerank:
182
+ node_postprocessors.append(rerank)
183
+
184
+
185
+ query_engine = index.as_query_engine(
186
+ similarity_top_k=base_similarity_top_k,
187
+ # the target key defaults to `window` to match the node_parser's default
188
+ node_postprocessors=node_postprocessors,
189
+ )
190
+
191
+
192
+ def metadata_formatter(metadata):
193
+ company_symbol = metadata['file_name'].split(
194
+ '-')[0] # Split at '-' and take the first part
195
+ # Split at '-' and then '.' to extract the year
196
+ year = metadata['file_name'].split('-')[1].split('.')[0]
197
+ page_number = metadata['page_label']
198
+
199
+ return f"Company File: {metadata['file_name'].split('-')[0]}, Year: {metadata['file_name'].split('-')[1].split('.')[0]}, Page Number: {metadata['page_label']}"
200
+
201
+
202
+ def query_journal(question):
203
+
204
+ response = query_engine.query(question) # Query the index
205
+ matched_nodes = response.source_nodes # Extract matched nodes
206
+
207
+ # Prepare the matched nodes details
208
+ retrieved_context = "\n".join([
209
+ # f"Node ID: {node.node_id}\n"
210
+ # f"Matched Content: {node.node.text}\n"
211
+ # f"Metadata: {node.node.metadata if node.node.metadata else 'None'}"
212
+ f"Metadata: {metadata_formatter(node.node.metadata) if node.node.metadata else 'None'}"
213
+ for node in matched_nodes
214
+ ])
215
+
216
+ generated_answer = str(response)
217
+
218
+ # Return both retrieved context and detailed matched nodes
219
+ return retrieved_context, generated_answer
220
+
221
+
222
+ # Define the Gradio interface
223
+ with gr.Blocks() as app:
224
+ # Title
225
+ gr.Markdown(
226
+ """
227
+ <div style="text-align: center;">
228
+ <h1>SET50RAG: Retrieval-Augmented Generation for Thai Public Companies Question Answering</h1>
229
+ </div>
230
+ """
231
+ )
232
+
233
+ # Description
234
+ gr.Markdown(
235
+ """
236
+ The **SET50RAG** tool provides an interactive way to analyze and extract insights from **243 annual reports** of Thai public companies spanning **5 years**.
237
+ By leveraging advanced **Retrieval-Augmented Generation**, including **GTE-Large embedding models**, **Sentence Window with Reranking**, and powerful **Large Language Models (LLMs)** like **Mistral-7B**, the system efficiently retrieves and answers complex financial queries.
238
+ This scalable and cost-effective solution reduces reliance on parametric knowledge, ensuring contextually accurate and relevant responses.
239
+ """
240
+ )
241
+
242
+ # How to Use Section
243
+ gr.Markdown(
244
+ """
245
+ ### How to Use
246
+ 1. Type your question in the box or select an example question below.
247
+ 2. Click **Submit** to retrieve the context and get an AI-generated answer.
248
+ 3. Review the retrieved context and the generated answer to gain insights.
249
+ ---
250
+ """
251
+ )
252
+
253
+ # Example Questions Section
254
+ gr.Markdown(
255
+ """
256
+ ### Example Questions
257
+ - What is the revenue of PTTOR in 2022?
258
+ - what is effect of COVID-19 on BDMS show me in Timeline format from 2019 to 2023?
259
+ - How does CPALL plan for electric vehicles?
260
+ """
261
+ )
262
+
263
+ # Interactive Section (RAG Box)
264
+ with gr.Row():
265
+ with gr.Column():
266
+ user_question = gr.Textbox(
267
+ label="Ask a Question",
268
+ placeholder="Type your question here, e.g., 'What is the revenue of PTTOR in 2022?'",
269
+ )
270
+ example_question_button = gr.Button("Use Example Question")
271
+ with gr.Column():
272
+ generated_answer = gr.Textbox(
273
+ label="Generated Answer",
274
+ placeholder="The AI-generated answer will appear here.",
275
+ interactive=False,
276
+ )
277
+ retrieved_context = gr.Textbox(
278
+ label="Retrieved Context",
279
+ placeholder="Relevant context will appear here.",
280
+ interactive=False,
281
+ )
282
+
283
+ # Button for user interaction
284
+ submit_button = gr.Button("Submit")
285
+
286
+ # Example question logic
287
+ def use_example_question():
288
+ return "What is the revenue of PTTOR in 2022?"
289
+
290
+ example_question_button.click(
291
+ use_example_question, inputs=[], outputs=[user_question]
292
+ )
293
+
294
+ # Interaction logic for submitting user queries
295
+ submit_button.click(
296
+ query_journal, inputs=[user_question], outputs=[
297
+ retrieved_context, generated_answer]
298
+ )
299
+
300
+ # Footer
301
+ gr.Markdown(
302
+ """
303
+ ---
304
+ ### Limitations and Bias:
305
+ - Optimized for Thai financial reports from SET50 companies. Results may vary for other domains.
306
+ - Retrieval and accuracy depend on data quality and embedding models.
307
+ """
308
+ )
309
+
310
+ # Launch the app
311
+ # app.launch()
312
+ app.launch(server_name="0.0.0.0") # , server_port=7860
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ragas==0.1.22
2
+ gradio==4.44.1
3
+ llama-index
4
+ llama-index-llms-huggingface
5
+ llama_index-embeddings-huggingface
6
+ llama_index-llms-cohere
7
+ llama-index-embeddings-instructor
8
+ datasets
9
+ transformers
10
+ llama-index-embeddings-huggingface
11
+ chromadb
12
+ llama-index-vector-stores-chroma
13
+ sentencepiece