import json import os from functools import lru_cache from typing import List, Dict import faiss import gradio as gr import numpy as np from PIL import Image from cheesechaser.datapool import DanbooruWebpDataPool from hfutils.operate import get_hf_fs, get_hf_client from hfutils.utils import TemporaryDirectory from imgutils.tagging import wd14 _REPO_ID = 'deepghs/index_experiments' hf_fs = get_hf_fs() hf_client = get_hf_client() _DEFAULT_MODEL_NAME = 'SwinV2_v3_danbooru_7001436_4GB' _ALL_MODEL_NAMES = [ os.path.dirname(os.path.relpath(path, _REPO_ID)) for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') ] def _get_from_ids(ids: List[int]) -> Dict[int, Image.Image]: with TemporaryDirectory() as td: datapool = DanbooruWebpDataPool() datapool.batch_download_to_directory( resource_ids=ids, dst_dir=td, ) retval = {} for file in os.listdir(td): id_ = int(os.path.splitext(file)[0]) image = Image.open(os.path.join(td, file)) image.load() retval[id_] = image return retval def _x(x): if isinstance(x, (int, np.integer)): return int(x) elif isinstance(x, (str, np.str_)): return int(str(x).split('_')[-1]) else: raise ValueError(f'Invalid ID: {x!r}, type: {type(x)!r}') @lru_cache(maxsize=3) def _get_index_info(repo_id: str, model_name: str): image_ids = np.load(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/ids.npy', )) knn_index = faiss.read_index(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/knn.index', )) config = json.loads(open(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/infos.json', )).read())["index_param"] faiss.ParameterSpace().set_index_parameters(knn_index, config) return image_ids, knn_index def search(model_name: str, img_input, n_neighbours: int): images_ids, knn_index = _get_index_info(_REPO_ID, model_name) embeddings = wd14.get_wd14_tags( img_input, model_name="SwinV2_v3", fmt="embedding", ) embeddings = np.expand_dims(embeddings, 0) faiss.normalize_L2(embeddings) dists, indexes = knn_index.search(embeddings, k=n_neighbours) neighbours_ids = images_ids[indexes][0] neighbours_ids = [_x(x) for x in neighbours_ids] captions = [] images = [] ids_to_images = _get_from_ids(neighbours_ids) for image_id, dist in zip(neighbours_ids, dists[0]): if image_id in ids_to_images: images.append(ids_to_images[image_id]) captions.append(f"{image_id}/{dist:.2f}") return list(zip(images, captions)) if __name__ == "__main__": with gr.Blocks() as demo: with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Input") with gr.Column(): with gr.Row(): n_model = gr.Dropdown( choices=_ALL_MODEL_NAMES, value=_DEFAULT_MODEL_NAME, label='Index to Use', ) with gr.Row(): n_neighbours = gr.Slider( minimum=1, maximum=50, value=20, step=1, label="# of images", ) find_btn = gr.Button("Find similar images") with gr.Row(): similar_images = gr.Gallery(label="Similar images", columns=[5]) find_btn.click( fn=search, inputs=[ n_model, img_input, n_neighbours, ], outputs=[similar_images], ) demo.queue().launch()