File size: 2,844 Bytes
f474bbc
 
 
 
 
 
 
 
 
 
dfcc607
f474bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfcc607
f474bbc
 
 
 
 
 
 
dfcc607
f474bbc
 
 
 
 
 
 
dfcc607
 
 
 
 
 
 
f474bbc
 
 
 
 
dfcc607
 
 
f474bbc
 
 
dfcc607
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import json
from functools import lru_cache

import numpy as np
import pandas as pd
from PIL import Image
from autofaiss import build_index
from hfutils.operate import get_hf_fs
from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.metrics import ccip_batch_extract_features, ccip_batch_differences, ccip_default_threshold

SRC_REPO = 'deepghs/character_index'

hf_fs = get_hf_fs()


@lru_cache()
def _make_index():
    tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json')))
    embeddings = np.load(hf_hub_download(
        repo_id=SRC_REPO,
        repo_type='dataset',
        filename='index/embeddings.npy',
    ))
    index, index_infos = build_index(embeddings, save_on_disk=False)
    return (index, index_infos), tag_infos


def gender_predict(p):
    if p['boy'] - p['girl'] >= 0.1:
        return 'male'
    elif p['girl'] - p['boy'] >= 0.1:
        return 'female'
    else:
        return 'not_sure'


def query_character(image: Image.Image, count: int = 5, order_by: str = 'same_ratio', threshold: float = 0.7):
    (index, index_infos), tag_infos = _make_index()
    query = ccip_batch_extract_features([image])
    assert query.shape == (1, 768)
    query = query / np.linalg.norm(query)
    all_dists, all_indices = index.search(query, k=count)
    dists, indices = all_dists[0], all_indices[0]

    images, records = {}, []
    for dist, idx in zip(dists, indices):
        info = tag_infos[idx]
        current_image = load_image(hf_hub_download(
            repo_id=SRC_REPO,
            repo_type='dataset',
            filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp'
        ))
        feats = np.load(hf_hub_download(
            repo_id=SRC_REPO,
            repo_type='dataset',
            filename=f'{info["hprefix"]}/{info["short_tag"]}/feat.npy'
        ))
        diffs = ccip_batch_differences([query[0], *feats])[0, 1:]
        images[info['tag']] = current_image
        records.append({
            'id': info['id'],
            'tag': info['tag'],
            'gender': gender_predict(info['gender']),
            'copyright': info['copyright'],
            'index_score': dist,
            'mean_diff': diffs.mean(),
            'same_ratio': (diffs < ccip_default_threshold()).mean(),
        })

    df_records = pd.DataFrame(records)
    df_records = df_records.sort_values(
        by=[order_by, 'index_score'] if order_by != 'index_score' else ['index_score'],
        ascending=[False, False] if order_by != 'index_score' else [False],
    )
    df_records = df_records[df_records[order_by] >= threshold]
    ret_images = []
    for row_item in df_records.to_dict('records'):
        ret_images.append((images[row_item['tag']], f'{row_item["tag"]} ({row_item[order_by]:.3f})'))
    return ret_images, df_records