neggles commited on
Commit
810354a
·
1 Parent(s): c34a266

add single file script

Browse files
Files changed (1) hide show
  1. scripts/wdtagger3-onnx.py +475 -0
scripts/wdtagger3-onnx.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Generator, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import onnxruntime as rt
11
+ from huggingface_hub import hf_hub_download
12
+ from huggingface_hub.utils import HfHubHTTPError
13
+ from pandas import DataFrame, read_csv
14
+ from PIL import Image
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from tqdm import tqdm
17
+
18
+ # allowed extensions
19
+ IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
20
+ # image input shape
21
+ IMAGE_SIZE = 448
22
+
23
+ MODEL_VARIANTS: dict[str, str] = {
24
+ "swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
25
+ "convnext": "SmilingWolf/wd-convnext-tagger-v3",
26
+ "vit": "SmilingWolf/wd-vit-tagger-v3",
27
+ }
28
+
29
+
30
+ @dataclass
31
+ class LabelData:
32
+ names: list[str]
33
+ rating: list[np.int64]
34
+ general: list[np.int64]
35
+ character: list[np.int64]
36
+
37
+
38
+ @dataclass
39
+ class ImageLabels:
40
+ caption: str
41
+ booru: str
42
+ rating: str
43
+ general: dict[str, float]
44
+ character: dict[str, float]
45
+ ratings: dict[str, float]
46
+
47
+
48
+ logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
49
+ logger = logging.getLogger()
50
+ logger.setLevel(logging.INFO)
51
+
52
+
53
+ ## Model loading functions
54
+ def download_onnx(
55
+ repo_id: str,
56
+ filename: str = "model.onnx",
57
+ revision: Optional[str] = None,
58
+ token: Optional[str] = None,
59
+ ) -> Path:
60
+ if not filename.endswith(".onnx"):
61
+ filename += ".onnx"
62
+
63
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token)
64
+ return Path(model_path).resolve()
65
+
66
+
67
+ def create_session(
68
+ repo_id: str,
69
+ revision: Optional[str] = None,
70
+ token: Optional[str] = None,
71
+ ) -> rt.InferenceSession:
72
+ model_path = download_onnx(repo_id, revision=revision, token=token)
73
+ if not model_path.is_file():
74
+ model_path = model_path.joinpath("model.onnx")
75
+ if not model_path.is_file():
76
+ raise FileNotFoundError(f"Model not found: {model_path}")
77
+
78
+ model = rt.InferenceSession(
79
+ str(model_path),
80
+ providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"],
81
+ )
82
+ return model
83
+
84
+
85
+ ## Label loading function
86
+ def load_labels_hf(
87
+ repo_id: str,
88
+ revision: Optional[str] = None,
89
+ token: Optional[str] = None,
90
+ ) -> LabelData:
91
+ try:
92
+ csv_path = hf_hub_download(
93
+ repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
94
+ )
95
+ csv_path = Path(csv_path).resolve()
96
+ except HfHubHTTPError as e:
97
+ raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
98
+
99
+ df: DataFrame = read_csv(csv_path, usecols=["name", "category"])
100
+ tag_data = LabelData(
101
+ names=df["name"].tolist(),
102
+ rating=list(np.where(df["category"] == 9)[0]),
103
+ general=list(np.where(df["category"] == 0)[0]),
104
+ character=list(np.where(df["category"] == 4)[0]),
105
+ )
106
+
107
+ return tag_data
108
+
109
+
110
+ ## Image preprocessing functions
111
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
112
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
113
+ if image.mode not in ["RGB", "RGBA"]:
114
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
115
+ # convert RGBA to RGB with white background
116
+ if image.mode == "RGBA":
117
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
118
+ canvas.alpha_composite(image)
119
+ image = canvas.convert("RGB")
120
+ return image
121
+
122
+
123
+ def pil_pad_square(
124
+ image: Image.Image,
125
+ fill: tuple[int, int, int] = (255, 255, 255),
126
+ ) -> Image.Image:
127
+ w, h = image.size
128
+ # get the largest dimension so we can pad to a square
129
+ px = max(image.size)
130
+ # pad to square with white background
131
+ canvas = Image.new("RGB", (px, px), fill)
132
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
133
+ return canvas
134
+
135
+
136
+ def preprocess_image(
137
+ image: Image.Image,
138
+ size_px: int | tuple[int, int],
139
+ upscale: bool = True,
140
+ ) -> Image.Image:
141
+ """
142
+ Preprocess an image to be square and centered on a white background.
143
+ """
144
+ if isinstance(size_px, int):
145
+ size_px = (size_px, size_px)
146
+
147
+ # ensure RGB and pad to square
148
+ image = pil_ensure_rgb(image)
149
+ image = pil_pad_square(image)
150
+
151
+ # resize to target size
152
+ if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
153
+ if upscale is False:
154
+ raise ValueError("Image is smaller than target size, and upscaling is disabled")
155
+ image = image.resize(size_px, Image.LANCZOS)
156
+ if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
157
+ image.thumbnail(size_px, Image.BICUBIC)
158
+
159
+ return image
160
+
161
+
162
+ ## Dataset for DataLoader
163
+ class ImageDataset(Dataset):
164
+ def __init__(self, image_paths: list[Path], size_px: int = IMAGE_SIZE, upscale: bool = True):
165
+ self.size_px = size_px
166
+ self.upscale = upscale
167
+ self.images = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS]
168
+
169
+ def __len__(self):
170
+ return len(self.images)
171
+
172
+ def __getitem__(self, idx):
173
+ image_path: Path = self.images[idx]
174
+ try:
175
+ image = Image.open(image_path)
176
+ image = preprocess_image(image, self.size_px, self.upscale)
177
+ # turn into BGR24 numpy array of N,H,W,C since thats what these want
178
+ image = image.convert("RGB").convert("BGR;24")
179
+ image = np.array(image).astype(np.float32)
180
+ except Exception as e:
181
+ logging.exception(f"Could not load image from {image_path}, error: {e}")
182
+ return None
183
+
184
+ return {"image": image, "path": np.array(str(image_path).encode("utf-8"), dtype=np.bytes_)}
185
+
186
+
187
+ def collate_fn_remove_corrupted(batch):
188
+ """Collate function that allows to remove corrupted examples in the
189
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
190
+ The 'None's in the batch are removed.
191
+ """
192
+ # Filter out all the Nones (corrupted examples)
193
+ batch = [x for x in batch if x is not None]
194
+ if len(batch) == 0:
195
+ return None
196
+ return {k: np.array([x[k] for x in batch if x is not None]) for k in batch[0]}
197
+
198
+
199
+ ## Main function
200
+ class ImageLabeler:
201
+ def __init__(
202
+ self,
203
+ repo_id: Optional[PathLike] = None,
204
+ general_threshold: float = 0.35,
205
+ character_threshold: float = 0.35,
206
+ banned_tags: list[str] = [],
207
+ ):
208
+ self.repo_id = repo_id
209
+
210
+ # create some object attributes for convenience
211
+ self.general_threshold = general_threshold
212
+ self.character_threshold = character_threshold
213
+ self.banned_tags = banned_tags if banned_tags is not None else []
214
+
215
+ # actually load the model
216
+ logging.info(f"Loading model from path: {self.repo_id}")
217
+ self.model = create_session(self.repo_id)
218
+
219
+ # Get input dimensions
220
+ _, self.height, self.width, _ = self.model.get_inputs()[0].shape
221
+ logging.info(f"Model loaded, input dimensions {self.height}x{self.width}")
222
+
223
+ # load labels
224
+ self.labels = load_labels_hf(self.repo_id)
225
+ self.labels.general = [i for i in self.labels.general if i not in banned_tags]
226
+ self.labels.character = [i for i in self.labels.character if i not in banned_tags]
227
+ logging.info(f"Loaded labels from {self.repo_id}")
228
+
229
+ @property
230
+ def input_size(self) -> Tuple[int, int]:
231
+ return (self.height, self.width)
232
+
233
+ @property
234
+ def input_name(self) -> str:
235
+ return self.model.get_inputs()[0].name if self.model is not None else None
236
+
237
+ @property
238
+ def output_name(self) -> str:
239
+ return self.model.get_outputs()[0].name if self.model is not None else None
240
+
241
+ def label_images(self, images: np.ndarray) -> ImageLabels:
242
+ # Run the ONNX model
243
+ probs: np.ndarray = self.model.run([self.output_name], {self.input_name: images})[0]
244
+
245
+ # Convert to labels
246
+ results = []
247
+ for sample in list(probs):
248
+ labels = list(zip(self.labels.names, sample.astype(float)))
249
+
250
+ # First 4 labels are actually ratings: pick one with argmax
251
+ rating_labels = dict([labels[i] for i in self.labels.rating])
252
+ rating = max(rating_labels, key=rating_labels.get)
253
+
254
+ # General labels, pick any where prediction confidence > threshold
255
+ gen_labels = [labels[i] for i in self.labels.general]
256
+ gen_labels = dict([x for x in gen_labels if x[1] > self.general_threshold])
257
+ gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
258
+
259
+ # Character labels, pick any where prediction confidence > threshold
260
+ char_labels = [labels[i] for i in self.labels.character]
261
+ char_labels = dict([x for x in char_labels if x[1] > self.character_threshold])
262
+ char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
263
+
264
+ # Combine general and character labels, sort by confidence
265
+ combined_names = [x for x in gen_labels]
266
+ combined_names.extend([x for x in char_labels])
267
+
268
+ # Convert to a string suitable for use as a training caption
269
+ caption = ", ".join(combined_names)
270
+ booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
271
+
272
+ # return output
273
+ results.append(
274
+ ImageLabels(
275
+ caption=caption,
276
+ booru=booru,
277
+ rating=rating,
278
+ general=gen_labels,
279
+ character=char_labels,
280
+ ratings=rating_labels,
281
+ )
282
+ )
283
+
284
+ return results
285
+
286
+ def __call__(self, images: list[Image.Image]) -> Generator[ImageLabels, None, None]:
287
+ for x in images:
288
+ yield self.label_images(x)
289
+
290
+
291
+ def main(args):
292
+ images_dir: Path = Path(args.images_dir).resolve()
293
+ if not images_dir.is_dir():
294
+ raise FileNotFoundError(f"Directory not found: {images_dir}")
295
+
296
+ variant: str = args.variant
297
+ recursive: bool = args.recursive or False
298
+ banned_tags: set[str] = set(args.banned_tags.split(","))
299
+ caption_extension: str = str(args.caption_extension).lower()
300
+ print_freqs: bool = args.print_freqs or False
301
+ num_workers: int = args.num_workers
302
+ batch_size: int = args.batch_size
303
+
304
+ remove_underscore: bool = args.remove_underscore or False
305
+ general_threshold: float = args.general_threshold or args.thresh
306
+ character_threshold: float = args.character_threshold or args.thresh
307
+ debug: bool = args.debug or False
308
+
309
+ # turn base model into a repo id and model path
310
+ repo_id: str = MODEL_VARIANTS.get(variant, None)
311
+ if repo_id is None:
312
+ raise ValueError(f"Unknown base model '{variant}'")
313
+
314
+ # instantiate the dataset
315
+ print(f"Loading images from {images_dir}...", end=" ")
316
+ if recursive is True:
317
+ image_paths = [p for p in images_dir.rglob("**/*") if p.suffix.lower() in IMAGE_EXTENSIONS]
318
+ else:
319
+ image_paths = [p for p in images_dir.glob("*") if p.suffix.lower() in IMAGE_EXTENSIONS]
320
+
321
+ n_images = len(image_paths)
322
+ print(f"found {n_images} images to process, creating DataLoader...")
323
+ # sort by filename if we have a small number of images
324
+ if n_images < 10000:
325
+ image_paths = sorted(image_paths, key=lambda x: x.stem)
326
+ dataset = ImageDataset(image_paths)
327
+
328
+ # Create the data loader
329
+ dataloader = DataLoader(
330
+ dataset,
331
+ batch_size=batch_size,
332
+ shuffle=False,
333
+ num_workers=num_workers,
334
+ collate_fn=collate_fn_remove_corrupted,
335
+ drop_last=False,
336
+ prefetch_factor=3,
337
+ )
338
+
339
+ # Create the image labeler
340
+ labeler: ImageLabeler = ImageLabeler(
341
+ repo_id=repo_id,
342
+ character_threshold=character_threshold,
343
+ general_threshold=general_threshold,
344
+ banned_tags=banned_tags,
345
+ )
346
+
347
+ # object to save tag frequencies
348
+ tag_freqs = {}
349
+
350
+ # iterate
351
+ for batch in tqdm(dataloader, ncols=100, unit="image", unit_scale=batch_size):
352
+ images = batch["image"]
353
+ paths = batch["path"]
354
+
355
+ # label the images
356
+ batch_labels = labeler.label_images(images)
357
+
358
+ # save the labels
359
+ for image_labels, image_path in zip(batch_labels, paths):
360
+ if isinstance(image_path, (np.bytes_, bytes)):
361
+ image_path = Path(image_path.decode("utf-8"))
362
+
363
+ # save the labels
364
+ caption = image_labels.caption
365
+ if remove_underscore is True:
366
+ caption = caption.replace("_", " ")
367
+ Path(image_path).with_suffix(caption_extension).write_text(caption + "\n", encoding="utf-8")
368
+
369
+ # save the tag frequencies
370
+ if print_freqs is True:
371
+ for tag in caption.split(", "):
372
+ if tag in banned_tags:
373
+ continue
374
+ if tag not in tag_freqs:
375
+ tag_freqs[tag] = 0
376
+ tag_freqs[tag] += 1
377
+
378
+ # debug
379
+ if debug is True:
380
+ print(
381
+ f"{image_path}:"
382
+ + f"\n Character tags: {image_labels.character}"
383
+ + f"\n General tags: {image_labels.general}"
384
+ )
385
+
386
+ if print_freqs:
387
+ sorted_tags = sorted(tag_freqs.items(), key=lambda x: x[1], reverse=True)
388
+ print("\nTag frequencies:")
389
+ for tag, freq in sorted_tags:
390
+ print(f"{tag}: {freq}")
391
+
392
+ print("done!")
393
+
394
+
395
+ if __name__ == "__main__":
396
+ parser = argparse.ArgumentParser()
397
+ parser.add_argument(
398
+ "images_dir",
399
+ type=str,
400
+ help="directory to tag image files in",
401
+ )
402
+ parser.add_argument(
403
+ "--variant",
404
+ type=str,
405
+ default="swinv2",
406
+ help="name of base model to use (one of 'swinv2', 'convnext', 'vit')",
407
+ )
408
+ parser.add_argument(
409
+ "--num_workers",
410
+ type=int,
411
+ default=4,
412
+ help="number of threads to use in Torch DataLoader (4 should be plenty)",
413
+ )
414
+ parser.add_argument(
415
+ "--batch_size",
416
+ type=int,
417
+ default=1,
418
+ help="batch size for Torch DataLoader (use 1 for cpu, 4-32 for gpu)",
419
+ )
420
+ parser.add_argument(
421
+ "--caption_extension",
422
+ type=str,
423
+ default=".txt",
424
+ help="extension of caption files to write (e.g. '.txt', '.caption')",
425
+ )
426
+ parser.add_argument(
427
+ "--thresh",
428
+ type=float,
429
+ default=0.35,
430
+ help="confidence threshold for adding tags",
431
+ )
432
+ parser.add_argument(
433
+ "--general_threshold",
434
+ type=float,
435
+ default=None,
436
+ help="confidence threshold for general tags - defaults to --thresh",
437
+ )
438
+ parser.add_argument(
439
+ "--character_threshold",
440
+ type=float,
441
+ default=None,
442
+ help="confidence threshold for character tags - defaults to --thresh",
443
+ )
444
+ parser.add_argument(
445
+ "--recursive",
446
+ action="store_true",
447
+ help="whether to recurse into subdirectories of images_dir",
448
+ )
449
+ parser.add_argument(
450
+ "--remove_underscore",
451
+ action="store_true",
452
+ help="whether to remove underscores from tags (e.g. 'long_hair' -> 'long hair')",
453
+ )
454
+ parser.add_argument(
455
+ "--debug",
456
+ action="store_true",
457
+ help="enable debug logging mode",
458
+ )
459
+ parser.add_argument(
460
+ "--banned_tags",
461
+ type=str,
462
+ default="",
463
+ help="tags to filter out (comma-separated)",
464
+ )
465
+ parser.add_argument(
466
+ "--print_freqs",
467
+ action="store_true",
468
+ help="Print overall tag frequencies at the end",
469
+ )
470
+
471
+ args = parser.parse_args()
472
+ if args.images_dir is None:
473
+ args.images_dir = Path.cwd().joinpath("temp/test")
474
+
475
+ main(args)