File size: 1,951 Bytes
f474bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
from functools import lru_cache

import numpy as np
import pandas as pd
from PIL import Image
from autofaiss import build_index
from hfutils.operate import get_hf_fs
from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.metrics import ccip_batch_extract_features

SRC_REPO = 'deepghs/character_index'

hf_fs = get_hf_fs()


@lru_cache()
def _make_index():
    tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json')))
    embeddings = np.load(hf_hub_download(
        repo_id=SRC_REPO,
        repo_type='dataset',
        filename='index/embeddings.npy',
    ))
    index, index_infos = build_index(embeddings, save_on_disk=False)
    return (index, index_infos), tag_infos


def gender_predict(p):
    if p['boy'] - p['girl'] >= 0.1:
        return 'male'
    elif p['girl'] - p['boy'] >= 0.1:
        return 'female'
    else:
        return 'not_sure'


def query_character(image: Image.Image, count: int = 5):
    (index, index_infos), tag_infos = _make_index()
    query = ccip_batch_extract_features([image])
    assert query.shape == (1, 768)
    query = query / np.linalg.norm(query)
    all_dists, all_indices = index.search(query, k=count)
    dists, indices = all_dists[0], all_indices[0]

    images, records = [], []
    for dist, idx in zip(dists, indices):
        info = tag_infos[idx]
        current_image = load_image(hf_hub_download(
            repo_id=SRC_REPO,
            repo_type='dataset',
            filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp'
        ))
        images.append((current_image, f'{info["tag"]} ({dist:.3f})'))
        records.append({
            'id': info['id'],
            'tag': info['tag'],
            'gender': gender_predict(info['gender']),
            'copyright': info['copyright'],
            'score': dist,
        })

    df_records = pd.DataFrame(records)
    return images, df_records