import json from functools import lru_cache import numpy as np import pandas as pd from PIL import Image from autofaiss import build_index from hfutils.operate import get_hf_fs from huggingface_hub import hf_hub_download from imgutils.data import load_image from imgutils.metrics import ccip_batch_extract_features SRC_REPO = 'deepghs/character_index' hf_fs = get_hf_fs() @lru_cache() def _make_index(): tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json'))) embeddings = np.load(hf_hub_download( repo_id=SRC_REPO, repo_type='dataset', filename='index/embeddings.npy', )) index, index_infos = build_index(embeddings, save_on_disk=False) return (index, index_infos), tag_infos def gender_predict(p): if p['boy'] - p['girl'] >= 0.1: return 'male' elif p['girl'] - p['boy'] >= 0.1: return 'female' else: return 'not_sure' def query_character(image: Image.Image, count: int = 5): (index, index_infos), tag_infos = _make_index() query = ccip_batch_extract_features([image]) assert query.shape == (1, 768) query = query / np.linalg.norm(query) all_dists, all_indices = index.search(query, k=count) dists, indices = all_dists[0], all_indices[0] images, records = [], [] for dist, idx in zip(dists, indices): info = tag_infos[idx] current_image = load_image(hf_hub_download( repo_id=SRC_REPO, repo_type='dataset', filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp' )) images.append((current_image, f'{info["tag"]} ({dist:.3f})')) records.append({ 'id': info['id'], 'tag': info['tag'], 'gender': gender_predict(info['gender']), 'copyright': info['copyright'], 'score': dist, }) df_records = pd.DataFrame(records) return images, df_records