import torch import gradio as gr from PIL import Image from transformers import AutoProcessor, SiglipModel import faiss import numpy as np from huggingface_hub import hf_hub_download from datasets import load_dataset import pandas as pd import requests from io import BytesIO import spaces # download model and dataset hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k_latest.index", local_dir="./") hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k_latest.csv", local_dir="./") # read index, dataset and load siglip model and processor index = faiss.read_index("./siglip_10k_latest.index") df = pd.read_csv("./wikiart_10k_latest.csv") device = torch.device('cuda' if torch.cuda.is_available() else "cpu") processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") model = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device) def read_image_from_url(url): response = requests.get(url) img = Image.open(BytesIO(response.content)).convert("RGB") return img #@spaces.GPU def extract_features_siglip(image): with torch.no_grad(): inputs = processor(images=image, return_tensors="pt").to(device) image_features = model.get_image_features(**inputs) return image_features def infer(input_image): input_features = extract_features_siglip(input_image["composite"].convert("RGB")) input_features = input_features.detach().cpu().numpy() input_features = np.float32(input_features) faiss.normalize_L2(input_features) distances, indices = index.search(input_features, 3) gallery_output = [] for i,v in enumerate(indices[0]): sim = -distances[0][i] image_url = df.iloc[v]["Link"] img_retrieved = read_image_from_url(image_url) gallery_output.append(img_retrieved) return gallery_output description="This is an application where you can draw an image and find the closest artwork among 10k art from wikiart dataset. This is built on 🤗 transformers integration of SIGLIP model by Google, and FAISS for indexing." sketchpad = gr.ImageEditor(type="pil") gr.Interface(infer, sketchpad, "gallery", description=description, title="Draw to Search Art 🖼️").launch()