|
import json |
|
from functools import lru_cache |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
from autofaiss import build_index |
|
from hfutils.operate import get_hf_fs |
|
from huggingface_hub import hf_hub_download |
|
from imgutils.data import load_image |
|
from imgutils.metrics import ccip_batch_extract_features |
|
|
|
SRC_REPO = 'deepghs/character_index' |
|
|
|
hf_fs = get_hf_fs() |
|
|
|
|
|
@lru_cache() |
|
def _make_index(): |
|
tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json'))) |
|
embeddings = np.load(hf_hub_download( |
|
repo_id=SRC_REPO, |
|
repo_type='dataset', |
|
filename='index/embeddings.npy', |
|
)) |
|
index, index_infos = build_index(embeddings, save_on_disk=False) |
|
return (index, index_infos), tag_infos |
|
|
|
|
|
def gender_predict(p): |
|
if p['boy'] - p['girl'] >= 0.1: |
|
return 'male' |
|
elif p['girl'] - p['boy'] >= 0.1: |
|
return 'female' |
|
else: |
|
return 'not_sure' |
|
|
|
|
|
def query_character(image: Image.Image, count: int = 5): |
|
(index, index_infos), tag_infos = _make_index() |
|
query = ccip_batch_extract_features([image]) |
|
assert query.shape == (1, 768) |
|
query = query / np.linalg.norm(query) |
|
all_dists, all_indices = index.search(query, k=count) |
|
dists, indices = all_dists[0], all_indices[0] |
|
|
|
images, records = [], [] |
|
for dist, idx in zip(dists, indices): |
|
info = tag_infos[idx] |
|
current_image = load_image(hf_hub_download( |
|
repo_id=SRC_REPO, |
|
repo_type='dataset', |
|
filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp' |
|
)) |
|
images.append((current_image, f'{info["tag"]} ({dist:.3f})')) |
|
records.append({ |
|
'id': info['id'], |
|
'tag': info['tag'], |
|
'gender': gender_predict(info['gender']), |
|
'copyright': info['copyright'], |
|
'score': dist, |
|
}) |
|
|
|
df_records = pd.DataFrame(records) |
|
return images, df_records |
|
|