Spaces:
Running
Running
import json | |
import os | |
from functools import lru_cache | |
from typing import List, Dict | |
import faiss | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from cheesechaser.datapool import DanbooruWebpDataPool | |
from hfutils.operate import get_hf_fs, get_hf_client | |
from hfutils.utils import TemporaryDirectory | |
from imgutils.tagging import wd14 | |
_REPO_ID = 'deepghs/index_experiments' | |
hf_fs = get_hf_fs() | |
hf_client = get_hf_client() | |
_DEFAULT_MODEL_NAME = 'SwinV2_v3_danbooru_7001436_4GB' | |
_ALL_MODEL_NAMES = [ | |
os.path.dirname(os.path.relpath(path, _REPO_ID)) | |
for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') | |
] | |
def _get_from_ids(ids: List[int]) -> Dict[int, Image.Image]: | |
with TemporaryDirectory() as td: | |
datapool = DanbooruWebpDataPool() | |
datapool.batch_download_to_directory( | |
resource_ids=ids, | |
dst_dir=td, | |
) | |
retval = {} | |
for file in os.listdir(td): | |
id_ = int(os.path.splitext(file)[0]) | |
image = Image.open(os.path.join(td, file)) | |
image.load() | |
retval[id_] = image | |
return retval | |
def _x(x): | |
if isinstance(x, (int, np.integer)): | |
return int(x) | |
elif isinstance(x, (str, np.str_)): | |
return int(str(x).split('_')[-1]) | |
else: | |
raise ValueError(f'Invalid ID: {x!r}, type: {type(x)!r}') | |
def _get_index_info(repo_id: str, model_name: str): | |
image_ids = np.load(hf_client.hf_hub_download( | |
repo_id=repo_id, | |
repo_type='model', | |
filename=f'{model_name}/ids.npy', | |
)) | |
knn_index = faiss.read_index(hf_client.hf_hub_download( | |
repo_id=repo_id, | |
repo_type='model', | |
filename=f'{model_name}/knn.index', | |
)) | |
config = json.loads(open(hf_client.hf_hub_download( | |
repo_id=repo_id, | |
repo_type='model', | |
filename=f'{model_name}/infos.json', | |
)).read())["index_param"] | |
faiss.ParameterSpace().set_index_parameters(knn_index, config) | |
return image_ids, knn_index | |
def search(model_name: str, img_input, n_neighbours: int): | |
images_ids, knn_index = _get_index_info(_REPO_ID, model_name) | |
embeddings = wd14.get_wd14_tags( | |
img_input, | |
model_name="SwinV2_v3", | |
fmt="embedding", | |
) | |
embeddings = np.expand_dims(embeddings, 0) | |
faiss.normalize_L2(embeddings) | |
dists, indexes = knn_index.search(embeddings, k=n_neighbours) | |
neighbours_ids = images_ids[indexes][0] | |
neighbours_ids = [_x(x) for x in neighbours_ids] | |
captions = [] | |
images = [] | |
ids_to_images = _get_from_ids(neighbours_ids) | |
for image_id, dist in zip(neighbours_ids, dists[0]): | |
if image_id in ids_to_images: | |
images.append(ids_to_images[image_id]) | |
captions.append(f"{image_id}/{dist:.2f}") | |
return list(zip(images, captions)) | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(type="pil", label="Input") | |
with gr.Column(): | |
with gr.Row(): | |
n_model = gr.Dropdown( | |
choices=_ALL_MODEL_NAMES, | |
value=_DEFAULT_MODEL_NAME, | |
label='Index to Use', | |
) | |
with gr.Row(): | |
n_neighbours = gr.Slider( | |
minimum=1, | |
maximum=50, | |
value=20, | |
step=1, | |
label="# of images", | |
) | |
find_btn = gr.Button("Find similar images") | |
with gr.Row(): | |
similar_images = gr.Gallery(label="Similar images", columns=[5]) | |
find_btn.click( | |
fn=search, | |
inputs=[ | |
n_model, | |
img_input, | |
n_neighbours, | |
], | |
outputs=[similar_images], | |
) | |
demo.queue().launch() | |