Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import gradio as gr | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModel, AutoTokenizer | |
import torch | |
import faiss | |
import glob | |
import numpy as np | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model = AutoModel.from_pretrained("google/siglip-base-patch16-256-multilingual").to(device) | |
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-256-multilingual") | |
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-256-multilingual") | |
num_dimensions = model.vision_model.config.hidden_size # 768 | |
num_k = 30 | |
text_examples = [ | |
"Frog waiting on a rock", | |
"Bird with open mouth", | |
"Bridge and a ship", | |
"Bike for two people", | |
"Biene auf der Blume", | |
"Hesap makinesi" | |
] | |
def preprocess_images(pathname="images/*", index_file="index.faiss"): | |
print("Preprocessing images...") | |
index = faiss.IndexFlatIP(num_dimensions) # Build the index using Inner Product (IP) similarity. | |
image_filenames = [] | |
image_features = [] | |
for image_filename in glob.glob(pathname): | |
try: | |
image_raw = Image.open(image_filename) | |
image_rgb = image_raw.convert('RGB') | |
image_filenames.append(image_filename) | |
inputs = processor(images=image_rgb, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
image_embedding = model.get_image_features(**inputs).to("cpu") | |
image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True) | |
image_embedding_n = image_embedding_n.numpy() | |
image_features.append(image_embedding_n) | |
except Exception as e: | |
print(f"Error processing {image_filename}".format(image_filename)) | |
print(e) | |
exit(1) | |
print("Indexing images...") | |
image_features = np.concatenate(image_features, axis=0) | |
index.add(image_features) | |
print("Saving index...") | |
faiss.write_index(index, index_file) | |
with open("image_filenames.txt", "w") as f: | |
for image_filename in image_filenames: | |
f.write(image_filename + "\n") | |
print("Preprocessing complete.") | |
return index, image_filenames | |
def load_processed_images(index_file="index.faiss", image_filenames_file="image_filenames.txt"): | |
print("Loading index...") | |
index = faiss.read_index(index_file) | |
with open(image_filenames_file) as f: | |
image_filenames = f.readlines() | |
image_filenames = [x.strip() for x in image_filenames] | |
return index, image_filenames | |
def search_using_text(text): | |
inputs = tokenizer(text, padding="max_length", return_tensors="pt").to(device) | |
text_features = model.get_text_features(**inputs).to("cpu") | |
text_features_n = text_features / text_features.norm(p=2, dim=-1, keepdim=True) | |
text_features_n = text_features_n.numpy() | |
D, I = index.search(text_features_n, num_k) | |
scale = model.logit_scale.exp().cpu().numpy() | |
bias = model.logit_bias.cpu().numpy() | |
result = [] | |
for dist, idx in zip(D[0], I[0]): | |
score_logit = dist * scale + bias | |
score_probability = torch.sigmoid(torch.tensor(score_logit)).item() | |
found_image = Image.open(image_filenames[idx]) | |
found_image.load() | |
result.append((found_image, "{:.2f}%".format(score_probability*100))) | |
return result | |
def search_using_image(image): | |
image = Image.fromarray(image) | |
image_rgb = image.convert('RGB') | |
inputs = processor(images=image_rgb, return_tensors="pt").to(device) | |
image_embedding = model.get_image_features(**inputs).to("cpu") | |
image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True) | |
image_embedding_n = image_embedding_n.numpy() | |
D, I = index.search(image_embedding_n, num_k) | |
result = [] | |
for dist, idx in zip(D[0], I[0]): | |
found_image = Image.open(image_filenames[idx]) | |
found_image.load() | |
result.append(found_image) | |
return result | |
if __name__ == "__main__": | |
#index, image_filenames = preprocess_images() # uncomment this line to preprocess images | |
index, image_filenames = load_processed_images() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Search Engine Demo") | |
with gr.Row(equal_height=False): | |
with gr.Column(): | |
gr.Markdown("This app is powered by [SigLIP](https://huggingface.co/google/siglip-base-patch16-256-multilingual) with multilingual support and [GPR1200 Dataset](https://www.kaggle.com/datasets/mathurinache/gpr1200-dataset) image contents. Enter your query in the text box or upload an image to search for similar images.") | |
with gr.Tab("Text-Image Search"): | |
text_input = gr.Textbox(label="Type a word or a sentence") | |
search_using_text_btn = gr.Button("Search with text", scale=0) | |
gr.Examples( | |
examples = text_examples, | |
inputs = [text_input] | |
) | |
with gr.Tab("Image-Image Search"): | |
image_input = gr.Image() | |
search_using_image_btn = gr.Button("Search with image", scale=0) | |
gallery = gr.Gallery(label="Generated images", show_label=False, | |
elem_id="gallery", columns=3, | |
object_fit="contain", interactive=False, scale=2.75) | |
search_using_text_btn.click(search_using_text, inputs=text_input, outputs=gallery) | |
search_using_image_btn.click(search_using_image, inputs=image_input, outputs=gallery) | |
demo.launch(share=False) |