import stat import gradio as gr from llama_index.core.postprocessor import SimilarityPostprocessor from llama_index.core.postprocessor import SentenceTransformerRerank from llama_index.core.postprocessor import MetadataReplacementPostProcessor from llama_index.core import StorageContext import chromadb from llama_index.vector_stores.chroma import ChromaVectorStore import zipfile import requests import torch from llama_index.core import Settings from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.core import VectorStoreIndex, SimpleDirectoryReader import sys import logging import os enable_rerank = True # sentence_window,naive,recursive_retrieval retrieval_strategy = "sentence_window" base_embedding_source = "hf" # local,openai,hf # intfloat/multilingual-e5-small local:BAAI/bge-small-en-v1.5 text-embedding-3-small nvidia/NV-Embed-v2 Alibaba-NLP/gte-large-en-v1.5 base_embedding_model = "Alibaba-NLP/gte-large-en-v1.5" # meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-3B-Instruct meta-llama/Llama-2-7b-chat-hf google/gemma-2-9b CohereForAI/c4ai-command-r-plus CohereForAI/aya-23-8B base_llm_model = "mistralai/Mistral-7B-Instruct-v0.3" # AdaptLLM/finance-chat base_llm_source = "hf" # cohere,hf,anthropic base_similarity_top_k = 20 # ChromaDB env_extension = "_large" # _large _dev_window _large_window db_collection = f"gte{env_extension}" # intfloat gte read_db = True active_chroma = True root_path = "." chroma_db_path = f"{root_path}/chroma_db" # ./chroma_db # ./processed_files.json processed_files_log = f"{root_path}/processed_files{env_extension}.json" # check hyperparameter if retrieval_strategy not in ["sentence_window", "naive"]: # recursive_retrieval raise Exception(f"{retrieval_strategy} retrieval_strategy is not support") os.environ["OPENAI_API_KEY"] = 'sk-xxxxxxxxxx' hf_api_key = os.getenv("HF_API_KEY") logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) torch.cuda.empty_cache() os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' print(f"loading embedding ..{base_embedding_model}") if base_embedding_source == 'hf': from llama_index.embeddings.huggingface import HuggingFaceEmbedding Settings.embed_model = HuggingFaceEmbedding( model_name=base_embedding_model, trust_remote_code=True) # , else: raise Exception("embedding model is invalid") # setup prompts - specific to StableLM if base_llm_source == 'hf': from llama_index.core import PromptTemplate # This will wrap the default prompts that are internal to llama-index # taken from https://huggingface.co/Writer/camel-5b-hf query_wrapper_prompt = PromptTemplate( "Below is an instruction that describes a task. " "you need to make sure that user's question and retrived context mention the same stock symbol if not please give no answer to user" "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{query_str}\n\n### Response:" ) if base_llm_source == 'hf': llm = HuggingFaceLLM( context_window=2048, max_new_tokens=512, # 256 generate_kwargs={"temperature": 0.1, "do_sample": False}, # 0.25 query_wrapper_prompt=query_wrapper_prompt, tokenizer_name=base_llm_model, model_name=base_llm_model, device_map="auto", tokenizer_kwargs={"max_length": 2048}, # uncomment this if using CUDA to reduce memory usage model_kwargs={"torch_dtype": torch.float16} ) Settings.chunk_size = 512 Settings.llm = llm """#### Load documents, build the VectorStoreIndex""" def download_and_extract_chroma_db(url, destination): """Download and extract ChromaDB from Hugging Face Datasets.""" # Create destination folder if it doesn't exist if not os.path.exists(destination): os.makedirs(destination) else: # If the folder exists, remove it to ensure a fresh extract print("Destination folder exists. Removing it...") for root, dirs, files in os.walk(destination, topdown=False): for file in files: os.remove(os.path.join(root, file)) for dir in dirs: os.rmdir(os.path.join(root, dir)) print("Destination folder cleared.") db_zip_path = os.path.join(destination, "chroma_db.zip") if not os.path.exists(db_zip_path): # Download the ChromaDB zip file print("Downloading ChromaDB from Hugging Face Datasets...") headers = { "Authorization": f"Bearer {hf_api_key}" } response = requests.get(url, headers=headers, stream=True) response.raise_for_status() with open(db_zip_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print("Download completed.") else: print("Zip file already exists, skipping download.") # Extract the zip file print("Extracting ChromaDB...") with zipfile.ZipFile(db_zip_path, 'r') as zip_ref: zip_ref.extractall(destination) print("Extraction completed. Zip file retained.") # URL to your dataset hosted on Hugging Face chroma_db_url = "https://huggingface.co/datasets/iamboolean/set50-db/resolve/main/chroma_db.zip" # Local destination for the ChromaDB chroma_db_path_extract = "./" # You can change this to your desired path # Download and extract the ChromaDB download_and_extract_chroma_db(chroma_db_url, chroma_db_path_extract) # Define ChromaDB client (persistent mode)er db = chromadb.PersistentClient(path=chroma_db_path) print(f"db path:{chroma_db_path}") chroma_collection = db.get_or_create_collection(db_collection) print(f"db collection:{db_collection}") # Set up ChromaVectorStore and embeddings vector_store = ChromaVectorStore(chroma_collection=chroma_collection) storage_context = StorageContext.from_defaults(vector_store=vector_store) document_count = chroma_collection.count() print(f"Total documents in the collection: {document_count}") index = VectorStoreIndex.from_vector_store( vector_store=vector_store, # embed_model=embed_model, ) """#### Query Index""" rerank = SentenceTransformerRerank( model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=10 ) node_postprocessors = [] # node_postprocessors.append(SimilarityPostprocessor(similarity_cutoff=0.6)) if retrieval_strategy == 'sentence_window': node_postprocessors.append( MetadataReplacementPostProcessor(target_metadata_key="window")) if enable_rerank: node_postprocessors.append(rerank) query_engine = index.as_query_engine( similarity_top_k=base_similarity_top_k, # the target key defaults to `window` to match the node_parser's default node_postprocessors=node_postprocessors, ) def metadata_formatter(metadata): company_symbol = metadata['file_name'].split( '-')[0] # Split at '-' and take the first part # Split at '-' and then '.' to extract the year year = metadata['file_name'].split('-')[1].split('.')[0] page_number = metadata['page_label'] return f"Company File: {metadata['file_name'].split('-')[0]}, Year: {metadata['file_name'].split('-')[1].split('.')[0]}, Page Number: {metadata['page_label']}" def query_journal(question): response = query_engine.query(question) # Query the index matched_nodes = response.source_nodes # Extract matched nodes # Prepare the matched nodes details retrieved_context = "\n".join([ # f"Node ID: {node.node_id}\n" # f"Matched Content: {node.node.text}\n" # f"Metadata: {node.node.metadata if node.node.metadata else 'None'}" f"Metadata: {metadata_formatter(node.node.metadata) if node.node.metadata else 'None'}" for node in matched_nodes ]) generated_answer = str(response) # Return both retrieved context and detailed matched nodes return retrieved_context, generated_answer # Define the Gradio interface with gr.Blocks() as app: # Title gr.Markdown( """

SET50RAG: Retrieval-Augmented Generation for Thai Public Companies Question Answering

""" ) # Description gr.Markdown( """ The **SET50RAG** tool provides an interactive way to analyze and extract insights from **243 annual reports** of Thai public companies spanning **5 years**. By leveraging advanced **Retrieval-Augmented Generation**, including **GTE-Large embedding models**, **Sentence Window with Reranking**, and powerful **Large Language Models (LLMs)** like **Mistral-7B**, the system efficiently retrieves and answers complex financial queries. This scalable and cost-effective solution reduces reliance on parametric knowledge, ensuring contextually accurate and relevant responses. """ ) # How to Use Section gr.Markdown( """ ### How to Use 1. Type your question in the box or select an example question below. 2. Click **Submit** to retrieve the context and get an AI-generated answer. 3. Review the retrieved context and the generated answer to gain insights. --- """ ) # Example Questions Section gr.Markdown( """ ### Example Questions - What is the revenue of PTTOR in 2022? - what is effect of COVID-19 on BDMS show me in Timeline format from 2019 to 2023? - How does CPALL plan for electric vehicles? """ ) # Interactive Section (RAG Box) with gr.Row(): with gr.Column(): user_question = gr.Textbox( label="Ask a Question", placeholder="Type your question here, e.g., 'What is the revenue of PTTOR in 2022?'", ) example_question_button = gr.Button("Use Example Question") with gr.Column(): generated_answer = gr.Textbox( label="Generated Answer", placeholder="The AI-generated answer will appear here.", interactive=False, ) retrieved_context = gr.Textbox( label="Retrieved Context", placeholder="Relevant context will appear here.", interactive=False, ) # Button for user interaction submit_button = gr.Button("Submit") # Example question logic def use_example_question(): return "What is the revenue of PTTOR in 2022?" example_question_button.click( use_example_question, inputs=[], outputs=[user_question] ) # Interaction logic for submitting user queries submit_button.click( query_journal, inputs=[user_question], outputs=[ retrieved_context, generated_answer] ) # Footer gr.Markdown( """ --- ### Limitations and Bias: - Optimized for Thai financial reports from SET50 companies. Results may vary for other domains. - Retrieval and accuracy depend on data quality and embedding models. """ ) # Launch the app # app.launch() app.launch(server_name="0.0.0.0") # , server_port=7860