#!/usr/bin/env python3 import gradio as gr from PIL import Image from transformers import AutoProcessor, AutoModel, AutoTokenizer import torch import faiss import glob import numpy as np device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = AutoModel.from_pretrained("google/siglip-base-patch16-256-multilingual").to(device) processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-256-multilingual") tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-256-multilingual") num_dimensions = model.vision_model.config.hidden_size # 768 num_k = 30 text_examples = [ "Frog waiting on a rock", "Bird with open mouth", "Bridge and a ship", "Bike for two people", "Biene auf der Blume", "Hesap makinesi" ] def preprocess_images(pathname="images/*", index_file="index.faiss"): print("Preprocessing images...") index = faiss.IndexFlatIP(num_dimensions) # Build the index using Inner Product (IP) similarity. image_filenames = [] image_features = [] for image_filename in glob.glob(pathname): try: image_raw = Image.open(image_filename) image_rgb = image_raw.convert('RGB') image_filenames.append(image_filename) inputs = processor(images=image_rgb, return_tensors="pt").to(device) with torch.no_grad(): image_embedding = model.get_image_features(**inputs).to("cpu") image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True) image_embedding_n = image_embedding_n.numpy() image_features.append(image_embedding_n) except Exception as e: print(f"Error processing {image_filename}".format(image_filename)) print(e) exit(1) print("Indexing images...") image_features = np.concatenate(image_features, axis=0) index.add(image_features) print("Saving index...") faiss.write_index(index, index_file) with open("image_filenames.txt", "w") as f: for image_filename in image_filenames: f.write(image_filename + "\n") print("Preprocessing complete.") return index, image_filenames def load_processed_images(index_file="index.faiss", image_filenames_file="image_filenames.txt"): print("Loading index...") index = faiss.read_index(index_file) with open(image_filenames_file) as f: image_filenames = f.readlines() image_filenames = [x.strip() for x in image_filenames] return index, image_filenames @torch.no_grad() def search_using_text(text): inputs = tokenizer(text, padding="max_length", return_tensors="pt").to(device) text_features = model.get_text_features(**inputs).to("cpu") text_features_n = text_features / text_features.norm(p=2, dim=-1, keepdim=True) text_features_n = text_features_n.numpy() D, I = index.search(text_features_n, num_k) scale = model.logit_scale.exp().cpu().numpy() bias = model.logit_bias.cpu().numpy() result = [] for dist, idx in zip(D[0], I[0]): score_logit = dist * scale + bias score_probability = torch.sigmoid(torch.tensor(score_logit)).item() found_image = Image.open(image_filenames[idx]) found_image.load() result.append((found_image, "{:.2f}%".format(score_probability*100))) return result @torch.no_grad() def search_using_image(image): image = Image.fromarray(image) image_rgb = image.convert('RGB') inputs = processor(images=image_rgb, return_tensors="pt").to(device) image_embedding = model.get_image_features(**inputs).to("cpu") image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True) image_embedding_n = image_embedding_n.numpy() D, I = index.search(image_embedding_n, num_k) result = [] for dist, idx in zip(D[0], I[0]): found_image = Image.open(image_filenames[idx]) found_image.load() result.append(found_image) return result if __name__ == "__main__": #index, image_filenames = preprocess_images() # uncomment this line to preprocess images index, image_filenames = load_processed_images() with gr.Blocks() as demo: gr.Markdown("# Image Search Engine Demo") with gr.Row(equal_height=False): with gr.Column(): gr.Markdown("This app is powered by [SigLIP](https://huggingface.co/google/siglip-base-patch16-256-multilingual) with multilingual support and [GPR1200 Dataset](https://www.kaggle.com/datasets/mathurinache/gpr1200-dataset) image contents. Enter your query in the text box or upload an image to search for similar images.") with gr.Tab("Text-Image Search"): text_input = gr.Textbox(label="Type a word or a sentence") search_using_text_btn = gr.Button("Search with text", scale=0) gr.Examples( examples = text_examples, inputs = [text_input] ) with gr.Tab("Image-Image Search"): image_input = gr.Image() search_using_image_btn = gr.Button("Search with image", scale=0) gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=3, object_fit="contain", interactive=False, scale=2.75) search_using_text_btn.click(search_using_text, inputs=text_input, outputs=gallery) search_using_image_btn.click(search_using_image, inputs=image_input, outputs=gallery) demo.launch(share=False)