narugo's picture
dev(narugo): save project
3c7d8d9
raw
history blame
3.97 kB
import json
import os
from functools import lru_cache
from typing import List, Dict
import faiss
import gradio as gr
import numpy as np
from PIL import Image
from cheesechaser.datapool import DanbooruWebpDataPool
from hfutils.operate import get_hf_fs, get_hf_client
from hfutils.utils import TemporaryDirectory
from imgutils.tagging import wd14
_REPO_ID = 'deepghs/index_experiments'
hf_fs = get_hf_fs()
hf_client = get_hf_client()
_DEFAULT_MODEL_NAME = 'SwinV2_v3_danbooru_7001436_4GB'
_ALL_MODEL_NAMES = [
os.path.dirname(os.path.relpath(path, _REPO_ID))
for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index')
]
def _get_from_ids(ids: List[int]) -> Dict[int, Image.Image]:
with TemporaryDirectory() as td:
datapool = DanbooruWebpDataPool()
datapool.batch_download_to_directory(
resource_ids=ids,
dst_dir=td,
)
retval = {}
for file in os.listdir(td):
id_ = int(os.path.splitext(file)[0])
image = Image.open(os.path.join(td, file))
image.load()
retval[id_] = image
return retval
def _x(x):
if isinstance(x, (int, np.integer)):
return int(x)
elif isinstance(x, (str, np.str_)):
return int(str(x).split('_')[-1])
else:
raise ValueError(f'Invalid ID: {x!r}, type: {type(x)!r}')
@lru_cache(maxsize=3)
def _get_index_info(repo_id: str, model_name: str):
image_ids = np.load(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/ids.npy',
))
knn_index = faiss.read_index(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/knn.index',
))
config = json.loads(open(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/infos.json',
)).read())["index_param"]
faiss.ParameterSpace().set_index_parameters(knn_index, config)
return image_ids, knn_index
def search(model_name: str, img_input, n_neighbours: int):
images_ids, knn_index = _get_index_info(_REPO_ID, model_name)
embeddings = wd14.get_wd14_tags(
img_input,
model_name="SwinV2_v3",
fmt="embedding",
)
embeddings = np.expand_dims(embeddings, 0)
faiss.normalize_L2(embeddings)
dists, indexes = knn_index.search(embeddings, k=n_neighbours)
neighbours_ids = images_ids[indexes][0]
neighbours_ids = [_x(x) for x in neighbours_ids]
captions = []
images = []
ids_to_images = _get_from_ids(neighbours_ids)
for image_id, dist in zip(neighbours_ids, dists[0]):
if image_id in ids_to_images:
images.append(ids_to_images[image_id])
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(images, captions))
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Input")
with gr.Column():
with gr.Row():
n_model = gr.Dropdown(
choices=_ALL_MODEL_NAMES,
value=_DEFAULT_MODEL_NAME,
label='Index to Use',
)
with gr.Row():
n_neighbours = gr.Slider(
minimum=1,
maximum=50,
value=20,
step=1,
label="# of images",
)
find_btn = gr.Button("Find similar images")
with gr.Row():
similar_images = gr.Gallery(label="Similar images", columns=[5])
find_btn.click(
fn=search,
inputs=[
n_model,
img_input,
n_neighbours,
],
outputs=[similar_images],
)
demo.queue().launch()