tagger / genTag.py
MonkeyJuice's picture
change lib
7d3b3b8
raw
history blame
4.1 kB
#!/usr/bin/env python
from __future__ import annotations
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
def load_labels(dataframe) -> list[str]:
name_series = dataframe["name"]
tag_names = name_series.tolist()
rating_indexes = list(np.where(dataframe["category"] == 9)[0])
general_indexes = list(np.where(dataframe["category"] == 0)[0])
character_indexes = list(np.where(dataframe["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes
class Predictor:
def __init__(self):
self.model_target_size = None
self.load_model(EVA02_LARGE_MODEL_DSV3_REPO)
def download_model(self, model_repo):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,
)
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
)
return csv_path, model_path
def load_model(self, model_repo):
csv_path, model_path = self.download_model(model_repo)
tags_df = pd.read_csv(csv_path)
sep_tags = load_labels(tags_df)
self.tag_names = sep_tags[0]
self.rating_indexes = sep_tags[1]
self.general_indexes = sep_tags[2]
self.character_indexes = sep_tags[3]
model = rt.InferenceSession(model_path)
_, height, width, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.model = model
def prepare_image(self, image):
target_size = self.model_target_size
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
# Pad image to square
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
# Resize
if max_dim != target_size:
padded_image = padded_image.resize(
(target_size, target_size),
Image.BICUBIC,
)
# Convert to numpy array
image_array = np.asarray(padded_image, dtype=np.float32)
# Convert PIL-native RGB to BGR
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
def predict(self, image, general_thresh):
image = self.prepare_image(image)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image})[0]
labels = list(zip(self.tag_names, preds[0].astype(float)))
# First 4 labels are actually ratings: pick one with argmax
ratings_names = [labels[i] for i in self.rating_indexes]
ratings_names = dict(ratings_names)
ratings_names = sorted(
ratings_names.items(),
key=lambda x: x[1],
reverse=True,
)
# Then we have general tags: pick any where prediction confidence > threshold
general_names = [labels[i] for i in self.general_indexes]
general_res = [x for x in general_names if x[1] > general_thresh]
general_res = dict(general_res)
ratings = "rating:" + ratings_names[0][0]
if ratings_names[0][0] == "general":
ratings = "rating:safe"
general_res[ratings] = ratings_names[0][1]
general_res = sorted(
general_res.items(),
key=lambda x: x[1],
reverse=True,
)
return dict(general_res)
predictor = Predictor()
def genTag(image: PIL.Image.Image, score_threshold: float):
return predictor.predict(image, score_threshold)