from modules.chat.helpers import get_prompt from modules.chat.chat_model_loader import ChatModelLoader from modules.vectorstore.store_manager import VectorStoreManager from modules.retriever.retriever import Retriever from modules.chat.langchain.langchain_rag import ( Langchain_RAG_V2, QuestionGenerator, ) class LLMTutor: def __init__(self, config, user, logger=None): """ Initialize the LLMTutor class. Args: config (dict): Configuration dictionary. user (str): User identifier. logger (Logger, optional): Logger instance. Defaults to None. """ self.config = config self.llm = self.load_llm() self.user = user self.logger = logger self.vector_db = VectorStoreManager(config, logger=self.logger).load_database() self.qa_prompt = get_prompt(config, "qa") # Initialize qa_prompt self.rephrase_prompt = get_prompt( config, "rephrase" ) # Initialize rephrase_prompt # TODO: Removed this functionality for now, don't know if we need it # if self.config["vectorstore"]["embedd_files"]: # self.vector_db.create_database() # self.vector_db.save_database() def update_llm(self, old_config, new_config): """ Update the LLM and VectorStoreManager based on new configuration. Args: new_config (dict): New configuration dictionary. """ changes = self.get_config_changes(old_config, new_config) if "llm_params.llm_loader" in changes: self.llm = self.load_llm() # Reinitialize LLM if chat_model changes if "vectorstore.db_option" in changes: self.vector_db = VectorStoreManager( self.config, logger=self.logger ).load_database() # Reinitialize VectorStoreManager if vectorstore changes # TODO: Removed this functionality for now, don't know if we need it # if self.config["vectorstore"]["embedd_files"]: # self.vector_db.create_database() # self.vector_db.save_database() if "llm_params.llm_style" in changes: self.qa_prompt = get_prompt( self.config, "qa" ) # Update qa_prompt if ELI5 changes def get_config_changes(self, old_config, new_config): """ Get the changes between the old and new configuration. Args: old_config (dict): Old configuration dictionary. new_config (dict): New configuration dictionary. Returns: dict: Dictionary containing the changes. """ changes = {} def compare_dicts(old, new, parent_key=""): for key in new: full_key = f"{parent_key}.{key}" if parent_key else key if isinstance(new[key], dict) and isinstance(old.get(key), dict): compare_dicts(old.get(key, {}), new[key], full_key) elif old.get(key) != new[key]: changes[full_key] = (old.get(key), new[key]) # Include keys that are in old but not in new for key in old: if key not in new: full_key = f"{parent_key}.{key}" if parent_key else key changes[full_key] = (old[key], None) compare_dicts(old_config, new_config) return changes def retrieval_qa_chain( self, llm, qa_prompt, rephrase_prompt, db, memory=None, callbacks=None ): """ Create a Retrieval QA Chain. Args: llm (LLM): The language model instance. qa_prompt (str): The QA prompt string. rephrase_prompt (str): The rephrase prompt string. db (VectorStore): The vector store instance. memory (Memory, optional): Memory instance. Defaults to None. Returns: Chain: The retrieval QA chain instance. """ retriever = Retriever(self.config)._return_retriever(db) if self.config["llm_params"]["llm_arch"] == "langchain": self.qa_chain = Langchain_RAG_V2( llm=llm, memory=memory, retriever=retriever, qa_prompt=qa_prompt, rephrase_prompt=rephrase_prompt, config=self.config, callbacks=callbacks, ) self.question_generator = QuestionGenerator() else: raise ValueError( f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}" ) return self.qa_chain def load_llm(self): """ Load the language model. Returns: LLM: The loaded language model instance. """ chat_model_loader = ChatModelLoader(self.config) llm = chat_model_loader.load_chat_model() return llm def qa_bot(self, memory=None, callbacks=None): """ Create a QA bot instance. Args: memory (Memory, optional): Memory instance. Defaults to None. qa_prompt (str, optional): QA prompt string. Defaults to None. rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None. Returns: Chain: The QA bot chain instance. """ # sanity check to see if there are any documents in the database if len(self.vector_db) == 0: raise ValueError( "No documents in the database. Populate the database first." ) qa = self.retrieval_qa_chain( self.llm, self.qa_prompt, self.rephrase_prompt, self.vector_db, memory, callbacks=callbacks, ) return qa