Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import argparse | |
import logging | |
from dataclasses import dataclass | |
from os import PathLike | |
from pathlib import Path | |
from typing import Generator, Optional, Tuple | |
import numpy as np | |
import onnxruntime as rt | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub.utils import HfHubHTTPError | |
from pandas import DataFrame, read_csv | |
from PIL import Image | |
from torch.utils.data import DataLoader, Dataset | |
from tqdm import tqdm | |
# allowed extensions | |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] | |
# image input shape | |
IMAGE_SIZE = 448 | |
MODEL_VARIANTS: dict[str, str] = { | |
"swinv2": "SmilingWolf/wd-swinv2-tagger-v3", | |
"convnext": "SmilingWolf/wd-convnext-tagger-v3", | |
"vit": "SmilingWolf/wd-vit-tagger-v3", | |
} | |
class LabelData: | |
names: list[str] | |
rating: list[np.int64] | |
general: list[np.int64] | |
character: list[np.int64] | |
class ImageLabels: | |
caption: str | |
booru: str | |
rating: str | |
general: dict[str, float] | |
character: dict[str, float] | |
ratings: dict[str, float] | |
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
## Model loading functions | |
def download_onnx( | |
repo_id: str, | |
filename: str = "model.onnx", | |
revision: Optional[str] = None, | |
token: Optional[str] = None, | |
) -> Path: | |
if not filename.endswith(".onnx"): | |
filename += ".onnx" | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) | |
return Path(model_path).resolve() | |
def create_session( | |
repo_id: str, | |
revision: Optional[str] = None, | |
token: Optional[str] = None, | |
) -> rt.InferenceSession: | |
model_path = download_onnx(repo_id, revision=revision, token=token) | |
if not model_path.is_file(): | |
model_path = model_path.joinpath("model.onnx") | |
if not model_path.is_file(): | |
raise FileNotFoundError(f"Model not found: {model_path}") | |
model = rt.InferenceSession( | |
str(model_path), | |
providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"], | |
) | |
return model | |
## Label loading function | |
def load_labels_hf( | |
repo_id: str, | |
revision: Optional[str] = None, | |
token: Optional[str] = None, | |
) -> LabelData: | |
try: | |
csv_path = hf_hub_download( | |
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token | |
) | |
csv_path = Path(csv_path).resolve() | |
except HfHubHTTPError as e: | |
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e | |
df: DataFrame = read_csv(csv_path, usecols=["name", "category"]) | |
tag_data = LabelData( | |
names=df["name"].tolist(), | |
rating=list(np.where(df["category"] == 9)[0]), | |
general=list(np.where(df["category"] == 0)[0]), | |
character=list(np.where(df["category"] == 4)[0]), | |
) | |
return tag_data | |
## Image preprocessing functions | |
def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
# convert to RGB/RGBA if not already (deals with palette images etc.) | |
if image.mode not in ["RGB", "RGBA"]: | |
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") | |
# convert RGBA to RGB with white background | |
if image.mode == "RGBA": | |
canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
canvas.alpha_composite(image) | |
image = canvas.convert("RGB") | |
return image | |
def pil_pad_square( | |
image: Image.Image, | |
fill: tuple[int, int, int] = (255, 255, 255), | |
) -> Image.Image: | |
w, h = image.size | |
# get the largest dimension so we can pad to a square | |
px = max(image.size) | |
# pad to square with white background | |
canvas = Image.new("RGB", (px, px), fill) | |
canvas.paste(image, ((px - w) // 2, (px - h) // 2)) | |
return canvas | |
def preprocess_image( | |
image: Image.Image, | |
size_px: int | tuple[int, int], | |
upscale: bool = True, | |
) -> Image.Image: | |
""" | |
Preprocess an image to be square and centered on a white background. | |
""" | |
if isinstance(size_px, int): | |
size_px = (size_px, size_px) | |
# ensure RGB and pad to square | |
image = pil_ensure_rgb(image) | |
image = pil_pad_square(image) | |
# resize to target size | |
if image.size[0] < size_px[0] or image.size[1] < size_px[1]: | |
if upscale is False: | |
raise ValueError("Image is smaller than target size, and upscaling is disabled") | |
image = image.resize(size_px, Image.LANCZOS) | |
if image.size[0] > size_px[0] or image.size[1] > size_px[1]: | |
image.thumbnail(size_px, Image.BICUBIC) | |
return image | |
## Dataset for DataLoader | |
class ImageDataset(Dataset): | |
def __init__(self, image_paths: list[Path], size_px: int = IMAGE_SIZE, upscale: bool = True): | |
self.size_px = size_px | |
self.upscale = upscale | |
self.images = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS] | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
image_path: Path = self.images[idx] | |
try: | |
image = Image.open(image_path) | |
image = preprocess_image(image, self.size_px, self.upscale) | |
# turn into BGR24 numpy array of N,H,W,C since thats what these want | |
image = image.convert("RGB").convert("BGR;24") | |
image = np.array(image).astype(np.float32) | |
except Exception as e: | |
logging.exception(f"Could not load image from {image_path}, error: {e}") | |
return None | |
return {"image": image, "path": np.array(str(image_path).encode("utf-8"), dtype=np.bytes_)} | |
def collate_fn_remove_corrupted(batch): | |
"""Collate function that allows to remove corrupted examples in the | |
dataloader. It expects that the dataloader returns 'None' when that occurs. | |
The 'None's in the batch are removed. | |
""" | |
# Filter out all the Nones (corrupted examples) | |
batch = [x for x in batch if x is not None] | |
if len(batch) == 0: | |
return None | |
return {k: np.array([x[k] for x in batch if x is not None]) for k in batch[0]} | |
## Main function | |
class ImageLabeler: | |
def __init__( | |
self, | |
repo_id: Optional[PathLike] = None, | |
general_threshold: float = 0.35, | |
character_threshold: float = 0.35, | |
banned_tags: list[str] = [], | |
): | |
self.repo_id = repo_id | |
# create some object attributes for convenience | |
self.general_threshold = general_threshold | |
self.character_threshold = character_threshold | |
self.banned_tags = banned_tags if banned_tags is not None else [] | |
# actually load the model | |
logging.info(f"Loading model from path: {self.repo_id}") | |
self.model = create_session(self.repo_id) | |
# Get input dimensions | |
_, self.height, self.width, _ = self.model.get_inputs()[0].shape | |
logging.info(f"Model loaded, input dimensions {self.height}x{self.width}") | |
# load labels | |
self.labels = load_labels_hf(self.repo_id) | |
self.labels.general = [i for i in self.labels.general if i not in banned_tags] | |
self.labels.character = [i for i in self.labels.character if i not in banned_tags] | |
logging.info(f"Loaded labels from {self.repo_id}") | |
def input_size(self) -> Tuple[int, int]: | |
return (self.height, self.width) | |
def input_name(self) -> str: | |
return self.model.get_inputs()[0].name if self.model is not None else None | |
def output_name(self) -> str: | |
return self.model.get_outputs()[0].name if self.model is not None else None | |
def label_images(self, images: np.ndarray) -> ImageLabels: | |
# Run the ONNX model | |
probs: np.ndarray = self.model.run([self.output_name], {self.input_name: images})[0] | |
# Convert to labels | |
results = [] | |
for sample in list(probs): | |
labels = list(zip(self.labels.names, sample.astype(float))) | |
# First 4 labels are actually ratings: pick one with argmax | |
rating_labels = dict([labels[i] for i in self.labels.rating]) | |
rating = max(rating_labels, key=rating_labels.get) | |
# General labels, pick any where prediction confidence > threshold | |
gen_labels = [labels[i] for i in self.labels.general] | |
gen_labels = dict([x for x in gen_labels if x[1] > self.general_threshold]) | |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) | |
# Character labels, pick any where prediction confidence > threshold | |
char_labels = [labels[i] for i in self.labels.character] | |
char_labels = dict([x for x in char_labels if x[1] > self.character_threshold]) | |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) | |
# Combine general and character labels, sort by confidence | |
combined_names = [x for x in gen_labels] | |
combined_names.extend([x for x in char_labels]) | |
# Convert to a string suitable for use as a training caption | |
caption = ", ".join(combined_names) | |
booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") | |
# return output | |
results.append( | |
ImageLabels( | |
caption=caption, | |
booru=booru, | |
rating=rating, | |
general=gen_labels, | |
character=char_labels, | |
ratings=rating_labels, | |
) | |
) | |
return results | |
def __call__(self, images: list[Image.Image]) -> Generator[ImageLabels, None, None]: | |
for x in images: | |
yield self.label_images(x) | |
def main(args): | |
images_dir: Path = Path(args.images_dir).resolve() | |
if not images_dir.is_dir(): | |
raise FileNotFoundError(f"Directory not found: {images_dir}") | |
variant: str = args.variant | |
recursive: bool = args.recursive or False | |
banned_tags: set[str] = set(args.banned_tags.split(",")) | |
caption_extension: str = str(args.caption_extension).lower() | |
print_freqs: bool = args.print_freqs or False | |
num_workers: int = args.num_workers | |
batch_size: int = args.batch_size | |
remove_underscore: bool = args.remove_underscore or False | |
general_threshold: float = args.general_threshold or args.thresh | |
character_threshold: float = args.character_threshold or args.thresh | |
debug: bool = args.debug or False | |
# turn base model into a repo id and model path | |
repo_id: str = MODEL_VARIANTS.get(variant, None) | |
if repo_id is None: | |
raise ValueError(f"Unknown base model '{variant}'") | |
# instantiate the dataset | |
print(f"Loading images from {images_dir}...", end=" ") | |
if recursive is True: | |
image_paths = [p for p in images_dir.rglob("**/*") if p.suffix.lower() in IMAGE_EXTENSIONS] | |
else: | |
image_paths = [p for p in images_dir.glob("*") if p.suffix.lower() in IMAGE_EXTENSIONS] | |
n_images = len(image_paths) | |
print(f"found {n_images} images to process, creating DataLoader...") | |
# sort by filename if we have a small number of images | |
if n_images < 10000: | |
image_paths = sorted(image_paths, key=lambda x: x.stem) | |
dataset = ImageDataset(image_paths) | |
# Create the data loader | |
dataloader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
collate_fn=collate_fn_remove_corrupted, | |
drop_last=False, | |
prefetch_factor=3, | |
) | |
# Create the image labeler | |
labeler: ImageLabeler = ImageLabeler( | |
repo_id=repo_id, | |
character_threshold=character_threshold, | |
general_threshold=general_threshold, | |
banned_tags=banned_tags, | |
) | |
# object to save tag frequencies | |
tag_freqs = {} | |
# iterate | |
for batch in tqdm(dataloader, ncols=100, unit="image", unit_scale=batch_size): | |
images = batch["image"] | |
paths = batch["path"] | |
# label the images | |
batch_labels = labeler.label_images(images) | |
# save the labels | |
for image_labels, image_path in zip(batch_labels, paths): | |
if isinstance(image_path, (np.bytes_, bytes)): | |
image_path = Path(image_path.decode("utf-8")) | |
# save the labels | |
caption = image_labels.caption | |
if remove_underscore is True: | |
caption = caption.replace("_", " ") | |
Path(image_path).with_suffix(caption_extension).write_text(caption + "\n", encoding="utf-8") | |
# save the tag frequencies | |
if print_freqs is True: | |
for tag in caption.split(", "): | |
if tag in banned_tags: | |
continue | |
if tag not in tag_freqs: | |
tag_freqs[tag] = 0 | |
tag_freqs[tag] += 1 | |
# debug | |
if debug is True: | |
print( | |
f"{image_path}:" | |
+ f"\n Character tags: {image_labels.character}" | |
+ f"\n General tags: {image_labels.general}" | |
) | |
if print_freqs: | |
sorted_tags = sorted(tag_freqs.items(), key=lambda x: x[1], reverse=True) | |
print("\nTag frequencies:") | |
for tag, freq in sorted_tags: | |
print(f"{tag}: {freq}") | |
print("done!") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"images_dir", | |
type=str, | |
help="directory to tag image files in", | |
) | |
parser.add_argument( | |
"--variant", | |
type=str, | |
default="swinv2", | |
help="name of base model to use (one of 'swinv2', 'convnext', 'vit')", | |
) | |
parser.add_argument( | |
"--num_workers", | |
type=int, | |
default=4, | |
help="number of threads to use in Torch DataLoader (4 should be plenty)", | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
default=1, | |
help="batch size for Torch DataLoader (use 1 for cpu, 4-32 for gpu)", | |
) | |
parser.add_argument( | |
"--caption_extension", | |
type=str, | |
default=".txt", | |
help="extension of caption files to write (e.g. '.txt', '.caption')", | |
) | |
parser.add_argument( | |
"--thresh", | |
type=float, | |
default=0.35, | |
help="confidence threshold for adding tags", | |
) | |
parser.add_argument( | |
"--general_threshold", | |
type=float, | |
default=None, | |
help="confidence threshold for general tags - defaults to --thresh", | |
) | |
parser.add_argument( | |
"--character_threshold", | |
type=float, | |
default=None, | |
help="confidence threshold for character tags - defaults to --thresh", | |
) | |
parser.add_argument( | |
"--recursive", | |
action="store_true", | |
help="whether to recurse into subdirectories of images_dir", | |
) | |
parser.add_argument( | |
"--remove_underscore", | |
action="store_true", | |
help="whether to remove underscores from tags (e.g. 'long_hair' -> 'long hair')", | |
) | |
parser.add_argument( | |
"--debug", | |
action="store_true", | |
help="enable debug logging mode", | |
) | |
parser.add_argument( | |
"--banned_tags", | |
type=str, | |
default="", | |
help="tags to filter out (comma-separated)", | |
) | |
parser.add_argument( | |
"--print_freqs", | |
action="store_true", | |
help="Print overall tag frequencies at the end", | |
) | |
args = parser.parse_args() | |
if args.images_dir is None: | |
args.images_dir = Path.cwd().joinpath("temp/test") | |
main(args) | |