salihmarangoz's picture
added examples
933f49e
#!/usr/bin/env python3
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModel, AutoTokenizer
import torch
import faiss
import glob
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained("google/siglip-base-patch16-256-multilingual").to(device)
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-256-multilingual")
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-256-multilingual")
num_dimensions = model.vision_model.config.hidden_size # 768
num_k = 30
text_examples = [
"Frog waiting on a rock",
"Bird with open mouth",
"Bridge and a ship",
"Bike for two people",
"Biene auf der Blume",
"Hesap makinesi"
]
def preprocess_images(pathname="images/*", index_file="index.faiss"):
print("Preprocessing images...")
index = faiss.IndexFlatIP(num_dimensions) # Build the index using Inner Product (IP) similarity.
image_filenames = []
image_features = []
for image_filename in glob.glob(pathname):
try:
image_raw = Image.open(image_filename)
image_rgb = image_raw.convert('RGB')
image_filenames.append(image_filename)
inputs = processor(images=image_rgb, return_tensors="pt").to(device)
with torch.no_grad():
image_embedding = model.get_image_features(**inputs).to("cpu")
image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True)
image_embedding_n = image_embedding_n.numpy()
image_features.append(image_embedding_n)
except Exception as e:
print(f"Error processing {image_filename}".format(image_filename))
print(e)
exit(1)
print("Indexing images...")
image_features = np.concatenate(image_features, axis=0)
index.add(image_features)
print("Saving index...")
faiss.write_index(index, index_file)
with open("image_filenames.txt", "w") as f:
for image_filename in image_filenames:
f.write(image_filename + "\n")
print("Preprocessing complete.")
return index, image_filenames
def load_processed_images(index_file="index.faiss", image_filenames_file="image_filenames.txt"):
print("Loading index...")
index = faiss.read_index(index_file)
with open(image_filenames_file) as f:
image_filenames = f.readlines()
image_filenames = [x.strip() for x in image_filenames]
return index, image_filenames
@torch.no_grad()
def search_using_text(text):
inputs = tokenizer(text, padding="max_length", return_tensors="pt").to(device)
text_features = model.get_text_features(**inputs).to("cpu")
text_features_n = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
text_features_n = text_features_n.numpy()
D, I = index.search(text_features_n, num_k)
scale = model.logit_scale.exp().cpu().numpy()
bias = model.logit_bias.cpu().numpy()
result = []
for dist, idx in zip(D[0], I[0]):
score_logit = dist * scale + bias
score_probability = torch.sigmoid(torch.tensor(score_logit)).item()
found_image = Image.open(image_filenames[idx])
found_image.load()
result.append((found_image, "{:.2f}%".format(score_probability*100)))
return result
@torch.no_grad()
def search_using_image(image):
image = Image.fromarray(image)
image_rgb = image.convert('RGB')
inputs = processor(images=image_rgb, return_tensors="pt").to(device)
image_embedding = model.get_image_features(**inputs).to("cpu")
image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True)
image_embedding_n = image_embedding_n.numpy()
D, I = index.search(image_embedding_n, num_k)
result = []
for dist, idx in zip(D[0], I[0]):
found_image = Image.open(image_filenames[idx])
found_image.load()
result.append(found_image)
return result
if __name__ == "__main__":
#index, image_filenames = preprocess_images() # uncomment this line to preprocess images
index, image_filenames = load_processed_images()
with gr.Blocks() as demo:
gr.Markdown("# Image Search Engine Demo")
with gr.Row(equal_height=False):
with gr.Column():
gr.Markdown("This app is powered by [SigLIP](https://huggingface.co/google/siglip-base-patch16-256-multilingual) with multilingual support and [GPR1200 Dataset](https://www.kaggle.com/datasets/mathurinache/gpr1200-dataset) image contents. Enter your query in the text box or upload an image to search for similar images.")
with gr.Tab("Text-Image Search"):
text_input = gr.Textbox(label="Type a word or a sentence")
search_using_text_btn = gr.Button("Search with text", scale=0)
gr.Examples(
examples = text_examples,
inputs = [text_input]
)
with gr.Tab("Image-Image Search"):
image_input = gr.Image()
search_using_image_btn = gr.Button("Search with image", scale=0)
gallery = gr.Gallery(label="Generated images", show_label=False,
elem_id="gallery", columns=3,
object_fit="contain", interactive=False, scale=2.75)
search_using_text_btn.click(search_using_text, inputs=text_input, outputs=gallery)
search_using_image_btn.click(search_using_image, inputs=image_input, outputs=gallery)
demo.launch(share=False)