from search import SemanticSearch, GoogleSearch, Document from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import BitsAndBytesConfig from transformers.utils import is_flash_attn_2_available import yaml import torch import nltk def load_configs(config_file: str) -> dict: with open(config_file, "r") as f: configs = yaml.safe_load(f) return configs class RAGModel: def __init__(self, configs) -> None: self.configs = configs self.device = configs["model"]["device"] model_url = configs["model"]["genration_model"] # quantization_config = BitsAndBytesConfig( # load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 # ) self.model = AutoModelForCausalLM.from_pretrained( model_url, torch_dtype=torch.float16, # quantization_config=quantization_config, low_cpu_mem_usage=False, attn_implementation="sdpa", ).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained( model_url, ) def create_prompt(self, query, topk_items: list[str]): context = "\n-".join(c for c in topk_items) base_prompt = f"""You are an alternate to goole search. Your job is to answer the user query in as detailed manner as possible. you have access to the internet and other relevent data related to the user's question. Give time for yourself to read the context and user query and extract relevent data and then answer the query. make sure your answers is as detailed as posssbile. Do not return thinking process, just return the answer. Give the output structured as a Wikipedia article. Now use the following context items to answer the user query context: {context} user query : {query} """ dialog_template = [{"role": "user", "content": base_prompt}] prompt = self.tokenizer.apply_chat_template( conversation=dialog_template, tokenize=False, add_feneration_prompt=True ) return prompt def answer_query(self, query: str, topk_items: list[str]): prompt = self.create_prompt(query, topk_items) input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) output = self.model.generate(**input_ids, temperature=0.7, max_new_tokens=512, do_sample=True) text = self.tokenizer.decode(output[0]) text = text.replace(prompt, "").replace("", "").replace("", "") return text if __name__ == "__main__": configs = load_configs(config_file="rag.configs.yml") query = "Explain F1 racing for a beginer" g = GoogleSearch(query) data = g.all_page_data d = Document(data, 512) doc_chunks = d.doc() s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps") topk, u = s.semantic_search(query=query, k=32) r = RAGModel(configs) output = r.answer_query(query=query, topk_items=topk) print(output)