Spaces:
Sleeping
Sleeping
import stat | |
import gradio as gr | |
from llama_index.core.postprocessor import SimilarityPostprocessor | |
from llama_index.core.postprocessor import SentenceTransformerRerank | |
from llama_index.core.postprocessor import MetadataReplacementPostProcessor | |
from llama_index.core import StorageContext | |
import chromadb | |
from llama_index.vector_stores.chroma import ChromaVectorStore | |
import zipfile | |
import requests | |
import torch | |
from llama_index.core import Settings | |
from llama_index.llms.huggingface import HuggingFaceLLM | |
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader | |
import sys | |
import logging | |
import os | |
enable_rerank = True | |
# sentence_window,naive,recursive_retrieval | |
retrieval_strategy = "sentence_window" | |
base_embedding_source = "hf" # local,openai,hf | |
# intfloat/multilingual-e5-small local:BAAI/bge-small-en-v1.5 text-embedding-3-small nvidia/NV-Embed-v2 Alibaba-NLP/gte-large-en-v1.5 | |
base_embedding_model = "Alibaba-NLP/gte-large-en-v1.5" | |
# meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-3B-Instruct meta-llama/Llama-2-7b-chat-hf google/gemma-2-9b CohereForAI/c4ai-command-r-plus CohereForAI/aya-23-8B | |
base_llm_model = "mistralai/Mistral-7B-Instruct-v0.3" | |
# AdaptLLM/finance-chat | |
base_llm_source = "hf" # cohere,hf,anthropic | |
base_similarity_top_k = 20 | |
# ChromaDB | |
env_extension = "_large" # _large _dev_window _large_window | |
db_collection = f"gte{env_extension}" # intfloat gte | |
read_db = True | |
active_chroma = True | |
root_path = "." | |
chroma_db_path = f"{root_path}/chroma_db" # ./chroma_db | |
# ./processed_files.json | |
processed_files_log = f"{root_path}/processed_files{env_extension}.json" | |
# check hyperparameter | |
if retrieval_strategy not in ["sentence_window", "naive"]: # recursive_retrieval | |
raise Exception(f"{retrieval_strategy} retrieval_strategy is not support") | |
os.environ["OPENAI_API_KEY"] = 'sk-xxxxxxxxxx' | |
hf_api_key = os.getenv("HF_API_KEY") | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
torch.cuda.empty_cache() | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
print(f"loading embedding ..{base_embedding_model}") | |
if base_embedding_source == 'hf': | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
Settings.embed_model = HuggingFaceEmbedding( | |
model_name=base_embedding_model, trust_remote_code=True) # , | |
else: | |
raise Exception("embedding model is invalid") | |
# setup prompts - specific to StableLM | |
if base_llm_source == 'hf': | |
from llama_index.core import PromptTemplate | |
# This will wrap the default prompts that are internal to llama-index | |
# taken from https://huggingface.co/Writer/camel-5b-hf | |
query_wrapper_prompt = PromptTemplate( | |
"Below is an instruction that describes a task. " | |
"you need to make sure that user's question and retrived context mention the same stock symbol if not please give no answer to user" | |
"Write a response that appropriately completes the request.\n\n" | |
"### Instruction:\n{query_str}\n\n### Response:" | |
) | |
if base_llm_source == 'hf': | |
llm = HuggingFaceLLM( | |
context_window=2048, | |
max_new_tokens=512, # 256 | |
generate_kwargs={"temperature": 0.1, "do_sample": False}, # 0.25 | |
query_wrapper_prompt=query_wrapper_prompt, | |
tokenizer_name=base_llm_model, | |
model_name=base_llm_model, | |
device_map="auto", | |
tokenizer_kwargs={"max_length": 2048}, | |
# uncomment this if using CUDA to reduce memory usage | |
model_kwargs={"torch_dtype": torch.float16} | |
) | |
Settings.chunk_size = 512 | |
Settings.llm = llm | |
"""#### Load documents, build the VectorStoreIndex""" | |
def download_and_extract_chroma_db(url, destination): | |
"""Download and extract ChromaDB from Hugging Face Datasets.""" | |
# Create destination folder if it doesn't exist | |
if not os.path.exists(destination): | |
os.makedirs(destination) | |
else: | |
# If the folder exists, remove it to ensure a fresh extract | |
print("Destination folder exists. Removing it...") | |
for root, dirs, files in os.walk(destination, topdown=False): | |
for file in files: | |
os.remove(os.path.join(root, file)) | |
for dir in dirs: | |
os.rmdir(os.path.join(root, dir)) | |
print("Destination folder cleared.") | |
db_zip_path = os.path.join(destination, "chroma_db.zip") | |
if not os.path.exists(db_zip_path): | |
# Download the ChromaDB zip file | |
print("Downloading ChromaDB from Hugging Face Datasets...") | |
headers = { | |
"Authorization": f"Bearer {hf_api_key}" | |
} | |
response = requests.get(url, headers=headers, stream=True) | |
response.raise_for_status() | |
with open(db_zip_path, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print("Download completed.") | |
else: | |
print("Zip file already exists, skipping download.") | |
# Extract the zip file | |
print("Extracting ChromaDB...") | |
with zipfile.ZipFile(db_zip_path, 'r') as zip_ref: | |
zip_ref.extractall(destination) | |
print("Extraction completed. Zip file retained.") | |
# URL to your dataset hosted on Hugging Face | |
chroma_db_url = "https://huggingface.co/datasets/iamboolean/set50-db/resolve/main/chroma_db.zip" | |
# Local destination for the ChromaDB | |
chroma_db_path_extract = "./" # You can change this to your desired path | |
# Download and extract the ChromaDB | |
download_and_extract_chroma_db(chroma_db_url, chroma_db_path_extract) | |
# Define ChromaDB client (persistent mode)er | |
db = chromadb.PersistentClient(path=chroma_db_path) | |
print(f"db path:{chroma_db_path}") | |
chroma_collection = db.get_or_create_collection(db_collection) | |
print(f"db collection:{db_collection}") | |
# Set up ChromaVectorStore and embeddings | |
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
document_count = chroma_collection.count() | |
print(f"Total documents in the collection: {document_count}") | |
index = VectorStoreIndex.from_vector_store( | |
vector_store=vector_store, | |
# embed_model=embed_model, | |
) | |
"""#### Query Index""" | |
rerank = SentenceTransformerRerank( | |
model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=10 | |
) | |
node_postprocessors = [] | |
# node_postprocessors.append(SimilarityPostprocessor(similarity_cutoff=0.6)) | |
if retrieval_strategy == 'sentence_window': | |
node_postprocessors.append( | |
MetadataReplacementPostProcessor(target_metadata_key="window")) | |
if enable_rerank: | |
node_postprocessors.append(rerank) | |
query_engine = index.as_query_engine( | |
similarity_top_k=base_similarity_top_k, | |
# the target key defaults to `window` to match the node_parser's default | |
node_postprocessors=node_postprocessors, | |
) | |
def metadata_formatter(metadata): | |
company_symbol = metadata['file_name'].split( | |
'-')[0] # Split at '-' and take the first part | |
# Split at '-' and then '.' to extract the year | |
year = metadata['file_name'].split('-')[1].split('.')[0] | |
page_number = metadata['page_label'] | |
return f"Company File: {metadata['file_name'].split('-')[0]}, Year: {metadata['file_name'].split('-')[1].split('.')[0]}, Page Number: {metadata['page_label']}" | |
def query_journal(question): | |
response = query_engine.query(question) # Query the index | |
matched_nodes = response.source_nodes # Extract matched nodes | |
# Prepare the matched nodes details | |
retrieved_context = "\n".join([ | |
# f"Node ID: {node.node_id}\n" | |
# f"Matched Content: {node.node.text}\n" | |
# f"Metadata: {node.node.metadata if node.node.metadata else 'None'}" | |
f"Metadata: {metadata_formatter(node.node.metadata) if node.node.metadata else 'None'}" | |
for node in matched_nodes | |
]) | |
generated_answer = str(response) | |
# Return both retrieved context and detailed matched nodes | |
return retrieved_context, generated_answer | |
# Define the Gradio interface | |
with gr.Blocks() as app: | |
# Title | |
gr.Markdown( | |
""" | |
<div style="text-align: center;"> | |
<h1>SET50RAG: Retrieval-Augmented Generation for Thai Public Companies Question Answering</h1> | |
</div> | |
""" | |
) | |
# Description | |
gr.Markdown( | |
""" | |
The **SET50RAG** tool provides an interactive way to analyze and extract insights from **243 annual reports** of Thai public companies spanning **5 years**. | |
By leveraging advanced **Retrieval-Augmented Generation**, including **GTE-Large embedding models**, **Sentence Window with Reranking**, and powerful **Large Language Models (LLMs)** like **Mistral-7B**, the system efficiently retrieves and answers complex financial queries. | |
This scalable and cost-effective solution reduces reliance on parametric knowledge, ensuring contextually accurate and relevant responses. | |
""" | |
) | |
# How to Use Section | |
gr.Markdown( | |
""" | |
### How to Use | |
1. Type your question in the box or select an example question below. | |
2. Click **Submit** to retrieve the context and get an AI-generated answer. | |
3. Review the retrieved context and the generated answer to gain insights. | |
--- | |
""" | |
) | |
# Example Questions Section | |
gr.Markdown( | |
""" | |
### Example Questions | |
- What is the revenue of PTTOR in 2022? | |
- what is effect of COVID-19 on BDMS show me in Timeline format from 2019 to 2023? | |
- How does CPALL plan for electric vehicles? | |
""" | |
) | |
# Interactive Section (RAG Box) | |
with gr.Row(): | |
with gr.Column(): | |
user_question = gr.Textbox( | |
label="Ask a Question", | |
placeholder="Type your question here, e.g., 'What is the revenue of PTTOR in 2022?'", | |
) | |
example_question_button = gr.Button("Use Example Question") | |
with gr.Column(): | |
generated_answer = gr.Textbox( | |
label="Generated Answer", | |
placeholder="The AI-generated answer will appear here.", | |
interactive=False, | |
) | |
retrieved_context = gr.Textbox( | |
label="Retrieved Context", | |
placeholder="Relevant context will appear here.", | |
interactive=False, | |
) | |
# Button for user interaction | |
submit_button = gr.Button("Submit") | |
# Example question logic | |
def use_example_question(): | |
return "What is the revenue of PTTOR in 2022?" | |
example_question_button.click( | |
use_example_question, inputs=[], outputs=[user_question] | |
) | |
# Interaction logic for submitting user queries | |
submit_button.click( | |
query_journal, inputs=[user_question], outputs=[ | |
retrieved_context, generated_answer] | |
) | |
# Footer | |
gr.Markdown( | |
""" | |
--- | |
### Limitations and Bias: | |
- Optimized for Thai financial reports from SET50 companies. Results may vary for other domains. | |
- Retrieval and accuracy depend on data quality and embedding models. | |
""" | |
) | |
# Launch the app | |
# app.launch() | |
app.launch(server_name="0.0.0.0") # , server_port=7860 | |