Spaces:
Build error
Build error
from modules.vectorstore.vectorstore import VectorStore | |
from modules.dataloader.helpers import get_urls_from_file | |
from modules.dataloader.webpage_crawler import WebpageCrawler | |
from modules.dataloader.data_loader import DataLoader | |
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader | |
import logging | |
import os | |
import time | |
import asyncio | |
class VectorStoreManager: | |
def __init__(self, config, logger=None): | |
self.config = config | |
self.document_names = None | |
# Set up logging to both console and a file | |
self.logger = logger or self._setup_logging() | |
self.webpage_crawler = WebpageCrawler() | |
self.vector_db = VectorStore(self.config) | |
self.logger.info("VectorDB instance instantiated") | |
def _setup_logging(self): | |
logger = logging.getLogger(__name__) | |
if not logger.hasHandlers(): | |
logger.setLevel(logging.INFO) | |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
# Console Handler | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.INFO) | |
console_handler.setFormatter(formatter) | |
logger.addHandler(console_handler) | |
# Ensure log directory exists | |
log_directory = self.config["log_dir"] | |
os.makedirs(log_directory, exist_ok=True) | |
# File Handler | |
log_file_path = os.path.join(log_directory, "vector_db.log") | |
file_handler = logging.FileHandler(log_file_path, mode="w") | |
file_handler.setLevel(logging.INFO) | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
return logger | |
def load_files(self): | |
files = os.listdir(self.config["vectorstore"]["data_path"]) | |
files = [ | |
os.path.join(self.config["vectorstore"]["data_path"], file) | |
for file in files | |
] | |
urls = get_urls_from_file(self.config["vectorstore"]["url_file_path"]) | |
if self.config["vectorstore"]["expand_urls"]: | |
all_urls = [] | |
for url in urls: | |
loop = asyncio.get_event_loop() | |
all_urls.extend( | |
loop.run_until_complete( | |
self.webpage_crawler.get_all_pages( | |
url, url | |
) # only get child urls, if you want to get all urls, replace the second argument with the base url | |
) | |
) | |
urls = all_urls | |
return files, urls | |
def create_embedding_model(self): | |
self.logger.info("Creating embedding function") | |
embedding_model_loader = EmbeddingModelLoader(self.config) | |
embedding_model = embedding_model_loader.load_embedding_model() | |
return embedding_model | |
def initialize_database( | |
self, | |
document_chunks: list, | |
document_names: list, | |
documents: list, | |
document_metadata: list, | |
): | |
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]: | |
self.embedding_model = self.create_embedding_model() | |
else: | |
self.embedding_model = None | |
self.logger.info("Initializing vector_db") | |
self.logger.info( | |
"\tUsing {} as db_option".format(self.config["vectorstore"]["db_option"]) | |
) | |
self.vector_db._create_database( | |
document_chunks, | |
document_names, | |
documents, | |
document_metadata, | |
self.embedding_model, | |
) | |
def create_database(self): | |
start_time = time.time() # Start time for creating database | |
data_loader = DataLoader(self.config, self.logger) | |
self.logger.info("Loading data") | |
files, urls = self.load_files() | |
files, webpages = self.webpage_crawler.clean_url_list(urls) | |
self.logger.info(f"Number of files: {len(files)}") | |
self.logger.info(f"Number of webpages: {len(webpages)}") | |
if f"{self.config['vectorstore']['url_file_path']}" in files: | |
files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup | |
( | |
document_chunks, | |
document_names, | |
documents, | |
document_metadata, | |
) = data_loader.get_chunks(files, webpages) | |
num_documents = len(document_chunks) | |
self.logger.info(f"Number of documents in the DB: {num_documents}") | |
metadata_keys = list(document_metadata[0].keys()) if document_metadata else [] | |
self.logger.info(f"Metadata keys: {metadata_keys}") | |
self.logger.info("Completed loading data") | |
self.initialize_database( | |
document_chunks, document_names, documents, document_metadata | |
) | |
end_time = time.time() # End time for creating database | |
self.logger.info("Created database") | |
self.logger.info( | |
f"Time taken to create database: {end_time - start_time} seconds" | |
) | |
def load_database(self): | |
start_time = time.time() # Start time for loading database | |
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]: | |
self.embedding_model = self.create_embedding_model() | |
else: | |
self.embedding_model = None | |
try: | |
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model) | |
except Exception as e: | |
raise ValueError( | |
f"Error loading database, check if it exists. if not run python -m modules.vectorstore.store_manager / Resteart the HF Space: {e}" | |
) | |
# print(f"Creating database") | |
# self.create_database() | |
# self.loaded_vector_db = self.vector_db._load_database(self.embedding_model) | |
end_time = time.time() # End time for loading database | |
self.logger.info( | |
f"Time taken to load database {self.config['vectorstore']['db_option']}: {end_time - start_time} seconds" | |
) | |
self.logger.info("Loaded database") | |
return self.loaded_vector_db | |
def load_from_HF(self, HF_PATH): | |
start_time = time.time() # Start time for loading database | |
self.vector_db._load_from_HF(HF_PATH) | |
end_time = time.time() | |
self.logger.info( | |
f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds" | |
) | |
self.logger.info("Downloaded database") | |
def __len__(self): | |
return len(self.vector_db) | |
if __name__ == "__main__": | |
import yaml | |
import argparse | |
# Add argument parsing for config files | |
parser = argparse.ArgumentParser(description="Load configuration files.") | |
parser.add_argument( | |
"--config_file", type=str, help="Path to the main config file", required=True | |
) | |
parser.add_argument( | |
"--project_config_file", | |
type=str, | |
help="Path to the project config file", | |
required=True, | |
) | |
args = parser.parse_args() | |
with open(args.config_file, "r") as f: | |
config = yaml.safe_load(f) | |
with open(args.project_config_file, "r") as f: | |
project_config = yaml.safe_load(f) | |
# combine the two configs | |
config.update(project_config) | |
print(config) | |
print(f"Trying to create database with config: {config}") | |
vector_db = VectorStoreManager(config) | |
if config["vectorstore"]["load_from_HF"]: | |
if ( | |
config["vectorstore"]["db_option"] | |
in config["retriever"]["retriever_hf_paths"] | |
): | |
vector_db.load_from_HF( | |
HF_PATH=config["retriever"]["retriever_hf_paths"][ | |
config["vectorstore"]["db_option"] | |
] | |
) | |
else: | |
# print(f"HF_PATH not available for {config['vectorstore']['db_option']}") | |
# print("Creating database") | |
# vector_db.create_database() | |
raise ValueError( | |
f"HF_PATH not available for {config['vectorstore']['db_option']}" | |
) | |
else: | |
vector_db.create_database() | |
print("Created database") | |
print("Trying to load the database") | |
vector_db = VectorStoreManager(config) | |
vector_db.load_database() | |
print("Loaded database") | |
print(f"View the logs at {config['log_dir']}/vector_db.log") | |