import numpy as np import torch from transformers import AutoProcessor from PIL import Image from colpali_engine.models.paligemma_colbert_architecture import ColPali import google.generativeai as genai # Load models and processor def load_models(): retrieval_model = ColPali.from_pretrained("inumulaisk/isk_rag_repo/retrieval_model") paligemma_processor = AutoProcessor.from_pretrained("inumulaisk/isk_rag_repo/retrieval_model") document_embeddings = np.load("document_embeddings/embeddings.npy", allow_pickle=True) document_images = [] # Load your document images here return retrieval_model, paligemma_processor, document_embeddings, document_images # Function to retrieve top document def retrieve_top_document(query, document_embeddings, document_images, retrieval_model, paligemma_processor): placeholder_image = Image.new("RGB", (448, 448), (255, 255, 255)) query_batch = process_queries(paligemma_processor, [query], placeholder_image) query_batch = {key: value.to('cuda') for key, value in query_batch.items()} with torch.no_grad(): query_embeddings_tensor = retrieval_model(**query_batch) query_embeddings = list(torch.unbind(query_embeddings_tensor.to("cpu"))) evaluator = CustomEvaluator(is_multi_vector=True) similarity_scores = evaluator.evaluate(query_embeddings, document_embeddings) best_index = int(similarity_scores.argmax(axis=1).item()) return document_images[best_index], best_index # Return image and index # Function to generate an answer def answer_query(query, prompt): retrieval_model, paligemma_processor, document_embeddings, document_images = load_models() best_image, best_index = retrieve_top_document(query, document_embeddings, document_images, retrieval_model, paligemma_processor) # Generate an answer using the retrieved document response = genai.generate_content([prompt, best_image]) return response.text, best_image, best_index # Return text, image, and index if __name__ == "__main__": query = "Your query here" # Replace with actual query prompt = "Your prompt here" # Replace with actual prompt answer_text, answer_image, retrieved_index = answer_query(query, prompt) print("Answer:", answer_text) answer_image.show() # Display the image