|
from modules.vectorstore.vectorstore import VectorStore |
|
from modules.dataloader.helpers import get_urls_from_file |
|
from modules.dataloader.webpage_crawler import WebpageCrawler |
|
from modules.dataloader.data_loader import DataLoader |
|
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader |
|
import logging |
|
import os |
|
import time |
|
import asyncio |
|
|
|
|
|
class VectorStoreManager: |
|
def __init__(self, config, logger=None): |
|
self.config = config |
|
self.document_names = None |
|
|
|
|
|
self.logger = logger or self._setup_logging() |
|
self.webpage_crawler = WebpageCrawler() |
|
self.vector_db = VectorStore(self.config) |
|
|
|
self.logger.info("VectorDB instance instantiated") |
|
|
|
def _setup_logging(self): |
|
logger = logging.getLogger(__name__) |
|
if not logger.hasHandlers(): |
|
logger.setLevel(logging.INFO) |
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.INFO) |
|
console_handler.setFormatter(formatter) |
|
logger.addHandler(console_handler) |
|
|
|
|
|
log_directory = self.config["log_dir"] |
|
os.makedirs(log_directory, exist_ok=True) |
|
|
|
|
|
log_file_path = os.path.join(log_directory, "vector_db.log") |
|
file_handler = logging.FileHandler(log_file_path, mode="w") |
|
file_handler.setLevel(logging.INFO) |
|
file_handler.setFormatter(formatter) |
|
logger.addHandler(file_handler) |
|
|
|
return logger |
|
|
|
def load_files(self): |
|
files = os.listdir(self.config["vectorstore"]["data_path"]) |
|
files = [ |
|
os.path.join(self.config["vectorstore"]["data_path"], file) |
|
for file in files |
|
] |
|
urls = get_urls_from_file(self.config["vectorstore"]["url_file_path"]) |
|
if self.config["vectorstore"]["expand_urls"]: |
|
all_urls = [] |
|
for url in urls: |
|
loop = asyncio.get_event_loop() |
|
all_urls.extend( |
|
loop.run_until_complete( |
|
self.webpage_crawler.get_all_pages( |
|
url, url |
|
) |
|
) |
|
) |
|
urls = all_urls |
|
return files, urls |
|
|
|
def create_embedding_model(self): |
|
self.logger.info("Creating embedding function") |
|
embedding_model_loader = EmbeddingModelLoader(self.config) |
|
embedding_model = embedding_model_loader.load_embedding_model() |
|
return embedding_model |
|
|
|
def initialize_database( |
|
self, |
|
document_chunks: list, |
|
document_names: list, |
|
documents: list, |
|
document_metadata: list, |
|
): |
|
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]: |
|
self.embedding_model = self.create_embedding_model() |
|
else: |
|
self.embedding_model = None |
|
|
|
self.logger.info("Initializing vector_db") |
|
self.logger.info( |
|
"\tUsing {} as db_option".format(self.config["vectorstore"]["db_option"]) |
|
) |
|
self.vector_db._create_database( |
|
document_chunks, |
|
document_names, |
|
documents, |
|
document_metadata, |
|
self.embedding_model, |
|
) |
|
|
|
def create_database(self): |
|
start_time = time.time() |
|
data_loader = DataLoader(self.config, self.logger) |
|
self.logger.info("Loading data") |
|
files, urls = self.load_files() |
|
files, webpages = self.webpage_crawler.clean_url_list(urls) |
|
self.logger.info(f"Number of files: {len(files)}") |
|
self.logger.info(f"Number of webpages: {len(webpages)}") |
|
if f"{self.config['vectorstore']['url_file_path']}" in files: |
|
files.remove(f"{self.config['vectorstores']['url_file_path']}") |
|
( |
|
document_chunks, |
|
document_names, |
|
documents, |
|
document_metadata, |
|
) = data_loader.get_chunks(files, webpages) |
|
num_documents = len(document_chunks) |
|
self.logger.info(f"Number of documents in the DB: {num_documents}") |
|
metadata_keys = list(document_metadata[0].keys()) if document_metadata else [] |
|
self.logger.info(f"Metadata keys: {metadata_keys}") |
|
self.logger.info("Completed loading data") |
|
self.initialize_database( |
|
document_chunks, document_names, documents, document_metadata |
|
) |
|
end_time = time.time() |
|
self.logger.info("Created database") |
|
self.logger.info( |
|
f"Time taken to create database: {end_time - start_time} seconds" |
|
) |
|
|
|
def load_database(self): |
|
start_time = time.time() |
|
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]: |
|
self.embedding_model = self.create_embedding_model() |
|
else: |
|
self.embedding_model = None |
|
try: |
|
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model) |
|
except Exception as e: |
|
raise ValueError( |
|
f"Error loading database, check if it exists. if not run python -m modules.vectorstore.store_manager / Resteart the HF Space: {e}" |
|
) |
|
|
|
|
|
|
|
end_time = time.time() |
|
self.logger.info( |
|
f"Time taken to load database {self.config['vectorstore']['db_option']}: {end_time - start_time} seconds" |
|
) |
|
self.logger.info("Loaded database") |
|
return self.loaded_vector_db |
|
|
|
def load_from_HF(self, HF_PATH): |
|
start_time = time.time() |
|
self.vector_db._load_from_HF(HF_PATH) |
|
end_time = time.time() |
|
self.logger.info( |
|
f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds" |
|
) |
|
self.logger.info("Downloaded database") |
|
|
|
def __len__(self): |
|
return len(self.vector_db) |
|
|
|
|
|
if __name__ == "__main__": |
|
import yaml |
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Load configuration files.") |
|
parser.add_argument( |
|
"--config_file", type=str, help="Path to the main config file", required=True |
|
) |
|
parser.add_argument( |
|
"--project_config_file", |
|
type=str, |
|
help="Path to the project config file", |
|
required=True, |
|
) |
|
args = parser.parse_args() |
|
|
|
with open(args.config_file, "r") as f: |
|
config = yaml.safe_load(f) |
|
with open(args.project_config_file, "r") as f: |
|
project_config = yaml.safe_load(f) |
|
|
|
|
|
config.update(project_config) |
|
print(config) |
|
print(f"Trying to create database with config: {config}") |
|
vector_db = VectorStoreManager(config) |
|
if config["vectorstore"]["load_from_HF"]: |
|
if ( |
|
config["vectorstore"]["db_option"] |
|
in config["retriever"]["retriever_hf_paths"] |
|
): |
|
vector_db.load_from_HF( |
|
HF_PATH=config["retriever"]["retriever_hf_paths"][ |
|
config["vectorstore"]["db_option"] |
|
] |
|
) |
|
else: |
|
|
|
|
|
|
|
raise ValueError( |
|
f"HF_PATH not available for {config['vectorstore']['db_option']}" |
|
) |
|
else: |
|
vector_db.create_database() |
|
print("Created database") |
|
|
|
print("Trying to load the database") |
|
vector_db = VectorStoreManager(config) |
|
vector_db.load_database() |
|
print("Loaded database") |
|
|
|
print(f"View the logs at {config['log_dir']}/vector_db.log") |
|
|