narugo's picture
dev(narugo): init commit
f474bbc
raw
history blame
1.95 kB
import json
from functools import lru_cache
import numpy as np
import pandas as pd
from PIL import Image
from autofaiss import build_index
from hfutils.operate import get_hf_fs
from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.metrics import ccip_batch_extract_features
SRC_REPO = 'deepghs/character_index'
hf_fs = get_hf_fs()
@lru_cache()
def _make_index():
tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json')))
embeddings = np.load(hf_hub_download(
repo_id=SRC_REPO,
repo_type='dataset',
filename='index/embeddings.npy',
))
index, index_infos = build_index(embeddings, save_on_disk=False)
return (index, index_infos), tag_infos
def gender_predict(p):
if p['boy'] - p['girl'] >= 0.1:
return 'male'
elif p['girl'] - p['boy'] >= 0.1:
return 'female'
else:
return 'not_sure'
def query_character(image: Image.Image, count: int = 5):
(index, index_infos), tag_infos = _make_index()
query = ccip_batch_extract_features([image])
assert query.shape == (1, 768)
query = query / np.linalg.norm(query)
all_dists, all_indices = index.search(query, k=count)
dists, indices = all_dists[0], all_indices[0]
images, records = [], []
for dist, idx in zip(dists, indices):
info = tag_infos[idx]
current_image = load_image(hf_hub_download(
repo_id=SRC_REPO,
repo_type='dataset',
filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp'
))
images.append((current_image, f'{info["tag"]} ({dist:.3f})'))
records.append({
'id': info['id'],
'tag': info['tag'],
'gender': gender_predict(info['gender']),
'copyright': info['copyright'],
'score': dist,
})
df_records = pd.DataFrame(records)
return images, df_records