neggles commited on
Commit
2b6048b
·
1 Parent(s): 7534598
.editorconfig ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # http://editorconfig.org
2
+
3
+ root = true
4
+
5
+ [*]
6
+ indent_style = space
7
+ indent_size = 4
8
+ trim_trailing_whitespace = true
9
+ insert_final_newline = true
10
+ charset = utf-8
11
+ end_of_line = lf
12
+
13
+ [*.bat]
14
+ indent_style = tab
15
+ end_of_line = crlf
16
+
17
+ [*.{json,jsonc}]
18
+ indent_style = space
19
+ indent_size = 2
20
+
21
+ [.vscode/*.{json,jsonc}]
22
+ indent_style = space
23
+ indent_size = 4
24
+
25
+ [*.{yml,yaml,toml}]
26
+ indent_style = space
27
+ indent_size = 2
28
+
29
+ [*.md]
30
+ trim_trailing_whitespace = false
31
+
32
+ [Makefile]
33
+ indent_style = tab
34
+ indent_size = 8
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
3
+
4
+ ### Linux ###
5
+ *~
6
+
7
+ # temporary files which can be created if a process still has a handle open of a deleted file
8
+ .fuse_hidden*
9
+
10
+ # KDE directory preferences
11
+ .directory
12
+
13
+ # Linux trash folder which might appear on any partition or disk
14
+ .Trash-*
15
+
16
+ # .nfs files are created when an open file is removed but is still being accessed
17
+ .nfs*
18
+
19
+ ### macOS ###
20
+ # General
21
+ .DS_Store
22
+ .AppleDouble
23
+ .LSOverride
24
+
25
+ # Icon must end with two \r
26
+ Icon
27
+
28
+
29
+ # Thumbnails
30
+ ._*
31
+
32
+ # Files that might appear in the root of a volume
33
+ .DocumentRevisions-V100
34
+ .fseventsd
35
+ .Spotlight-V100
36
+ .TemporaryItems
37
+ .Trashes
38
+ .VolumeIcon.icns
39
+ .com.apple.timemachine.donotpresent
40
+
41
+ # Directories potentially created on remote AFP share
42
+ .AppleDB
43
+ .AppleDesktop
44
+ Network Trash Folder
45
+ Temporary Items
46
+ .apdisk
47
+
48
+ ### Python ###
49
+ # Byte-compiled / optimized / DLL files
50
+ __pycache__/
51
+ *.py[cod]
52
+ *$py.class
53
+
54
+ # C extensions
55
+ *.so
56
+
57
+ # Distribution / packaging
58
+ .Python
59
+ build/
60
+ develop-eggs/
61
+ dist/
62
+ downloads/
63
+ eggs/
64
+ .eggs/
65
+ lib/
66
+ lib64/
67
+ parts/
68
+ sdist/
69
+ var/
70
+ wheels/
71
+ share/python-wheels/
72
+ *.egg-info/
73
+ .installed.cfg
74
+ *.egg
75
+ MANIFEST
76
+
77
+ # PyInstaller
78
+ # Usually these files are written by a python script from a template
79
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
80
+ *.manifest
81
+ *.spec
82
+
83
+ # Installer logs
84
+ pip-log.txt
85
+ pip-delete-this-directory.txt
86
+
87
+ # Unit test / coverage reports
88
+ htmlcov/
89
+ .tox/
90
+ .nox/
91
+ .coverage
92
+ .coverage.*
93
+ .cache
94
+ nosetests.xml
95
+ coverage.xml
96
+ *.cover
97
+ *.py,cover
98
+ .hypothesis/
99
+ .pytest_cache/
100
+ cover/
101
+
102
+ # Translations
103
+ *.mo
104
+ *.pot
105
+
106
+ # Django stuff:
107
+ *.log
108
+ local_settings.py
109
+ db.sqlite3
110
+ db.sqlite3-journal
111
+
112
+ # Flask stuff:
113
+ instance/
114
+ .webassets-cache
115
+
116
+ # Scrapy stuff:
117
+ .scrapy
118
+
119
+ # Sphinx documentation
120
+ docs/_build/
121
+
122
+ # PyBuilder
123
+ .pybuilder/
124
+ target/
125
+
126
+ # Jupyter Notebook
127
+ .ipynb_checkpoints
128
+
129
+ # IPython
130
+ profile_default/
131
+ ipython_config.py
132
+
133
+ # pyenv
134
+ # For a library or package, you might want to ignore these files since the code is
135
+ # intended to run in multiple environments; otherwise, check them in:
136
+ # .python-version
137
+
138
+ # pipenv
139
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
140
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
141
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
142
+ # install all needed dependencies.
143
+ #Pipfile.lock
144
+
145
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
146
+ __pypackages__/
147
+
148
+ # Celery stuff
149
+ celerybeat-schedule
150
+ celerybeat.pid
151
+
152
+ # SageMath parsed files
153
+ *.sage.py
154
+
155
+ # Environments
156
+ .env
157
+ .venv
158
+ env/
159
+ venv/
160
+ ENV/
161
+ env.bak/
162
+ venv.bak/
163
+
164
+ # Spyder project settings
165
+ .spyderproject
166
+ .spyproject
167
+
168
+ # Rope project settings
169
+ .ropeproject
170
+
171
+ # mkdocs documentation
172
+ /site
173
+
174
+ # mypy
175
+ .mypy_cache/
176
+ .dmypy.json
177
+ dmypy.json
178
+
179
+ # Pyre type checker
180
+ .pyre/
181
+
182
+ # pytype static type analyzer
183
+ .pytype/
184
+
185
+ # Cython debug symbols
186
+ cython_debug/
187
+
188
+ ### VisualStudioCode ###
189
+ .vscode/*
190
+ !.vscode/settings.json
191
+ !.vscode/tasks.json
192
+ !.vscode/launch.json
193
+ !.vscode/extensions.json
194
+ *.code-workspace
195
+
196
+ # Local History for Visual Studio Code
197
+ .history/
198
+
199
+ ### VisualStudioCode Patch ###
200
+ # Ignore all local history of files
201
+ .history
202
+ .ionide
203
+
204
+ ### Windows ###
205
+ # Windows thumbnail cache files
206
+ Thumbs.db
207
+ Thumbs.db:encryptable
208
+ ehthumbs.db
209
+ ehthumbs_vista.db
210
+
211
+ # Dump file
212
+ *.stackdump
213
+
214
+ # Folder config file
215
+ [Dd]esktop.ini
216
+
217
+ # Recycle Bin used on file shares
218
+ $RECYCLE.BIN/
219
+
220
+ # Windows Installer files
221
+ *.cab
222
+ *.msi
223
+ *.msix
224
+ *.msm
225
+ *.msp
226
+
227
+ # Windows shortcuts
228
+ *.lnk
229
+
230
+ # End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
231
+
232
+ # temp and misc
233
+ /misc/
234
+ /temp/
235
+
236
+ # outputs and such
237
+ /logs/
238
+ /cache/
239
+
240
+ # direnv
241
+ .envrc
242
+ .envrc.*
243
+
244
+ # dotenv
245
+ .env
246
+ .env.*
247
+
248
+ # temp files
249
+ **/tmp_*.*
250
+ **/*.tmp.*
251
+
252
+ # but keep examples
253
+ !*.example
.pre-commit-config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ ci:
3
+ autofix_prs: true
4
+ autoupdate_branch: "main"
5
+ autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate"
6
+ autoupdate_schedule: weekly
7
+
8
+ repos:
9
+ - repo: https://github.com/astral-sh/ruff-pre-commit
10
+ rev: v0.2.0
11
+ hooks:
12
+ # Run the linter.
13
+ - id: ruff
14
+ types_or: [python, pyi, jupyter]
15
+ args: [--fix, --exit-non-zero-on-fix]
16
+ # Run the formatter.
17
+ - id: ruff-format
18
+ types_or: [python, pyi, jupyter]
19
+
20
+ - repo: https://github.com/pre-commit/pre-commit-hooks
21
+ rev: v4.5.0
22
+ hooks:
23
+ - id: trailing-whitespace
24
+ exclude_types:
25
+ - "markdown"
26
+ - id: end-of-file-fixer
27
+ - id: check-yaml
.vscode/settings.json ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.insertSpaces": true,
3
+ "editor.tabSize": 4,
4
+ "files.trimTrailingWhitespace": true,
5
+ "editor.rulers": [100, 120],
6
+
7
+ "files.associations": {
8
+ "*.yaml": "yaml"
9
+ },
10
+ "files.exclude": {
11
+ "**/.git": true,
12
+ "**/.svn": true,
13
+ "**/.hg": true,
14
+ "**/CVS": true,
15
+ "**/.DS_Store": true,
16
+ "**/Thumbs.db": true,
17
+ "**/.ruff_cache": true,
18
+ "**/__pycache__": true,
19
+ "**/*.egg-info": true
20
+ },
21
+
22
+ "[shellscript]": {
23
+ "files.eol": "\n",
24
+ "editor.tabSize": 4,
25
+ "editor.detectIndentation": false
26
+ },
27
+
28
+ "[python]": {
29
+ "editor.wordBasedSuggestions": "off",
30
+ "editor.formatOnSave": true,
31
+ "editor.defaultFormatter": "charliermarsh.ruff",
32
+ "editor.codeActionsOnSave": {
33
+ "source.organizeImports": "always"
34
+ }
35
+ },
36
+ "ruff.format.args": ["--line-length", "110"],
37
+
38
+ "[json]": {
39
+ "editor.defaultFormatter": "esbenp.prettier-vscode",
40
+ "editor.detectIndentation": false,
41
+ "editor.formatOnSaveMode": "file",
42
+ "editor.formatOnSave": true,
43
+ "editor.tabSize": 2
44
+ },
45
+ "[jsonc]": {
46
+ "editor.defaultFormatter": "esbenp.prettier-vscode",
47
+ "editor.detectIndentation": false,
48
+ "editor.formatOnSaveMode": "file",
49
+ "editor.formatOnSave": true,
50
+ "editor.tabSize": 2
51
+ },
52
+
53
+ "[toml]": {
54
+ "editor.tabSize": 2,
55
+ "editor.detectIndentation": false,
56
+ "editor.formatOnSave": true,
57
+ "editor.formatOnSaveMode": "file",
58
+ "editor.defaultFormatter": "tamasfe.even-better-toml",
59
+ "editor.rulers": [80, 100]
60
+ },
61
+ "evenBetterToml.formatter.columnWidth": 88,
62
+
63
+ "[yaml]": {
64
+ "editor.detectIndentation": false,
65
+ "editor.tabSize": 2,
66
+ "editor.formatOnSave": true,
67
+ "editor.formatOnSaveMode": "file",
68
+ "diffEditor.ignoreTrimWhitespace": false,
69
+ "editor.defaultFormatter": "redhat.vscode-yaml"
70
+ },
71
+ "yaml.format.bracketSpacing": true,
72
+ "yaml.format.proseWrap": "preserve",
73
+ "yaml.format.singleQuote": false,
74
+ "yaml.format.printWidth": 110,
75
+
76
+ "[hcl]": {
77
+ "editor.detectIndentation": false,
78
+ "editor.formatOnSave": true,
79
+ "editor.formatOnSaveMode": "file",
80
+ "editor.defaultFormatter": "fredwangwang.vscode-hcl-format"
81
+ },
82
+
83
+ "[markdown]": {
84
+ "files.trimTrailingWhitespace": false
85
+ },
86
+
87
+ "css.lint.validProperties": ["dock", "content-align", "content-justify"],
88
+ "[css]": {
89
+ "editor.formatOnSave": true
90
+ },
91
+
92
+ "remote.autoForwardPorts": false,
93
+ "remote.autoForwardPortsSource": "process"
94
+ }
LICENSE.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+ =====================
3
+
4
+ Copyright © 2024 Andi Powers-Holmes <[email protected]>
5
+
6
+ Permission is hereby granted, free of charge, to any person
7
+ obtaining a copy of this software and associated documentation
8
+ files (the “Software”), to deal in the Software without
9
+ restriction, including without limitation the rights to use,
10
+ copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the
12
+ Software is furnished to do so, subject to the following
13
+ conditions:
14
+
15
+ The above copyright notice and this permission notice shall be
16
+ included in all copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND,
19
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
20
+ OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
21
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
22
+ HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
23
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
24
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
25
+ OTHER DEALINGS IN THE SOFTWARE.
README.md CHANGED
@@ -1,12 +1,21 @@
1
  ---
2
- title: Pi Tagger
3
  emoji: 🌖
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: pi-chan tagger
3
  emoji: 🌖
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
+ short_description: A WD Tagger Space for pi-chan to use
11
+ preload_from_hub:
12
+ - SmilingWolf/wd-v1-4-moat-tagger-v2 model.onnx
13
+ - SmilingWolf/wd-v1-4-swinv2-tagger-v2 model.onnx
14
+ - SmilingWolf/wd-v1-4-convnext-tagger-v2 model.onnx
15
+ - SmilingWolf/wd-v1-4-convnextv2-tagger-v2 model.onnx
16
+ - SmilingWolf/wd-v1-4-vit-tagger-v2 model.onnx
17
  ---
18
 
19
+ # pi-chan tagger
20
+
21
+ WD Tagger space for a prompt inspector to use as a backend.
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import getenv
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import onnxruntime as rt
8
+ from PIL import Image
9
+
10
+ from tagger.common import LabelData, load_labels, preprocess_image
11
+ from tagger.model import create_session
12
+
13
+ HF_TOKEN = getenv("HF_TOKEN", None)
14
+ WORK_DIR = Path.cwd().resolve()
15
+
16
+ MODEL_VARIANTS: dict[str, str] = {
17
+ "MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2",
18
+ "SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
19
+ "ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2",
20
+ "ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
21
+ "ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2",
22
+ }
23
+
24
+ # allowed extensions
25
+ IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
26
+
27
+ # model input shape
28
+ IMAGE_SIZE = 448
29
+ example_images = sorted(
30
+ [
31
+ str(x.relative_to(WORK_DIR))
32
+ for x in WORK_DIR.joinpath("examples").iterdir()
33
+ if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
34
+ ]
35
+ )
36
+ loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k, _ in MODEL_VARIANTS.items()}
37
+
38
+
39
+ def load_model(variant: str) -> rt.InferenceSession:
40
+ global loaded_models
41
+
42
+ # resolve the repo name
43
+ model_repo = MODEL_VARIANTS.get(variant, None)
44
+ if model_repo is None:
45
+ raise ValueError(f"Unknown model variant: {variant}")
46
+
47
+ if loaded_models.get(variant, None) is None:
48
+ # save model to cache
49
+ loaded_models[variant] = create_session(model_repo, token=HF_TOKEN)
50
+
51
+ return loaded_models[variant]
52
+
53
+
54
+ def predict(
55
+ image: Image.Image,
56
+ variant: str,
57
+ general_threshold: float = 0.35,
58
+ character_threshold: float = 0.85,
59
+ ):
60
+ # Load model
61
+ model: rt.InferenceSession = load_model(variant)
62
+ # load labels
63
+ labels: LabelData = load_labels()
64
+
65
+ # get input size and name
66
+ _, h, w, _ = model.get_inputs()[0].shape
67
+ input_name = model.get_inputs()[0].name
68
+ output_name = model.get_outputs()[0].name
69
+
70
+ # preprocess image
71
+ image = preprocess_image(image, (h, w))
72
+
73
+ # turn into BGR24 numpy array of N,H,W,C since thats what these want
74
+ inputs = image.convert("RGB").convert("BGR;24")
75
+ inputs = np.array(inputs).astype(np.float32)
76
+ inputs = np.expand_dims(inputs, axis=0)
77
+
78
+ # Run the ONNX model
79
+ probs = model.run([output_name], {input_name: inputs})
80
+
81
+ # Convert indices+probs to labels
82
+ probs = list(zip(labels.names, probs[0][0].astype(float)))
83
+
84
+ # First 4 labels are actually ratings
85
+ rating_labels = dict([probs[i] for i in labels.rating])
86
+
87
+ # General labels, pick any where prediction confidence > threshold
88
+ gen_labels = [probs[i] for i in labels.general]
89
+ gen_labels = dict([x for x in gen_labels if x[1] > general_threshold])
90
+ gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
91
+
92
+ # Character labels, pick any where prediction confidence > threshold
93
+ char_labels = [probs[i] for i in labels.character]
94
+ char_labels = dict([x for x in char_labels if x[1] > character_threshold])
95
+ char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
96
+
97
+ # Combine general and character labels, sort by confidence
98
+ combined_names = [x for x in gen_labels]
99
+ combined_names.extend([x for x in char_labels])
100
+
101
+ # Convert to a string suitable for use as a training caption
102
+ caption = ", ".join(combined_names)
103
+ booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
104
+
105
+ return image, caption, booru, rating_labels, char_labels, gen_labels
106
+
107
+
108
+ with gr.Blocks(title="pi-chan's tagger") as demo:
109
+ with gr.Row(equal_height=False):
110
+ with gr.Column():
111
+ img_input = gr.Image(
112
+ label="Input",
113
+ type="pil",
114
+ image_mode="RGB",
115
+ sources=["upload", "clipboard"],
116
+ )
117
+ variant = gr.Radio(choices=list(MODEL_VARIANTS.keys()), label="Model Variant", value="MOAT")
118
+ gen_thresh = gr.Slider(0.0, 1.0, value=0.35, label="General Tag Threshold")
119
+ char_thresh = gr.Slider(0.0, 1.0, value=0.85, label="Character Tag Threshold")
120
+ show_processed = gr.Checkbox(label="Show Preprocessed", value=False)
121
+ with gr.Row():
122
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
123
+ clear = gr.ClearButton(
124
+ components=[],
125
+ variant="secondary",
126
+ size="lg",
127
+ )
128
+ with gr.Row():
129
+ examples = gr.Examples(
130
+ examples=[
131
+ [imgpath, var, 0.35, 0.85]
132
+ for imgpath in example_images
133
+ for var in ["MOAT", "ConvNeXTv2"]
134
+ ],
135
+ inputs=[img_input, variant, gen_thresh, char_thresh],
136
+ )
137
+ with gr.Column():
138
+ img_output = gr.Image(label="Preprocessed", type="pil", image_mode="RGB", scale=1, visible=False)
139
+ with gr.Group():
140
+ tags_string = gr.Textbox(
141
+ label="Caption", placeholder="Caption will appear here", show_copy_button=True
142
+ )
143
+ tags_booru = gr.Textbox(
144
+ label="Tags", placeholder="Tag string will appear here", show_copy_button=True
145
+ )
146
+ rating = gr.Label(label="Rating")
147
+ character = gr.Label(label="Character")
148
+ general = gr.Label(label="General")
149
+
150
+ # tell clear button which components to clear
151
+ clear.add([img_input, img_output, tags_string, rating, character, general])
152
+
153
+ # show/hide processed image
154
+ def on_select_show_processed(evt: gr.SelectData):
155
+ return gr.update(visible=evt.selected)
156
+
157
+ show_processed.select(on_select_show_processed, inputs=[], outputs=[img_output])
158
+
159
+ submit.click(
160
+ predict,
161
+ inputs=[img_input, variant, gen_thresh, char_thresh],
162
+ outputs=[img_output, tags_string, tags_booru, rating, character, general],
163
+ api_name="predict",
164
+ )
165
+
166
+ if __name__ == "__main__":
167
+ demo.queue(max_size=10)
168
+ demo.launch(server_name="0.0.0.0", server_port=7871)
data/selected_tags.csv ADDED
The diff for this file is too large to render. See raw diff
 
examples/img-01.png ADDED

Git LFS Details

  • SHA256: 37a2bec1c653272457c6b6e5fec6da8ac4676d973f7cd87c545a6e1ab6be288c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.53 MB
examples/img-02.png ADDED

Git LFS Details

  • SHA256: 90ee6035ce0caec46bbda3a9d48bdcd2cd7384487781615c4251301ab5422d45
  • Pointer size: 131 Bytes
  • Size of remote file: 434 kB
pyproject.toml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "pi-tagger-space"
3
+ version = "0.1.0"
4
+ authors = [
5
+ { name = "Andi Powers-Holmes", email = "[email protected]" },
6
+ ]
7
+ maintainers = [
8
+ { name = "Andi Powers-Holmes", email = "[email protected]" },
9
+ ]
10
+ description = "pi-tagger Gradio Space"
11
+ readme = "README.md"
12
+ requires-python = ">=3.9, <3.11"
13
+ keywords = [
14
+ "deep-learning",
15
+ "machine-learning",
16
+ "pytorch",
17
+ ]
18
+ license = { file = "LICENSE.md" }
19
+ classifiers = [
20
+ "Programming Language :: Python :: 3",
21
+ "License :: OSI Approved :: MIT License",
22
+ ]
23
+ dependencies = [
24
+ "gradio >=4.19.2, < 5.0.0",
25
+ "numpy >= 1.23.5",
26
+ "onnxruntime-gpu >= 1.14.1",
27
+ "pandas >= 2.0.0",
28
+ "Pillow >= 9.5.0",
29
+ "PyYAML",
30
+ "safetensors",
31
+ "simple-parsing >= 0.1.0",
32
+ ]
33
+
34
+ [project.urls]
35
+ Repository = "https://huggingface.co/spaces/neggles/pi-tagger"
36
+
37
+ [project.optional-dependencies]
38
+ dev = [
39
+ "ruff >=0.0.289",
40
+ "setuptools-scm >= 8.0.0",
41
+ "pre-commit >= 3.0.0", # remember to run `pre-commit install` after installing
42
+ "tabulate >= 0.8.9", # for inductor log prettyprinting
43
+ ]
44
+ all = [
45
+ "pi-tagger-space[dev]",
46
+ ]
47
+
48
+ [build-system]
49
+ build-backend = "setuptools.build_meta"
50
+ requires = ["setuptools>=64", "wheel"]
51
+
52
+ [tool.setuptools.packages.find]
53
+ namespaces = true
54
+ where = ["."]
55
+ include = ["pi-tagger"]
56
+
57
+
58
+ [tool.ruff]
59
+ line-length = 110
60
+ target-version = "py310"
61
+ extend-exclude = ["/usr/lib/*"]
62
+
63
+ [tool.ruff.lint]
64
+ ignore = [
65
+ "F841", # local variable assigned but never used
66
+ "F842", # local variable annotated but never used
67
+ "E501", # line too long - will be fixed in format
68
+ ]
69
+
70
+ [tool.ruff.format]
71
+ quote-style = "double"
72
+ indent-style = "space"
73
+ line-ending = "auto"
74
+ skip-magic-trailing-comma = false
75
+ docstring-code-format = true
76
+
77
+ [tool.ruff.lint.isort]
78
+ combine-as-imports = true
79
+ force-wrap-aliases = true
80
+ known-local-folder = ["pi-tagger"]
81
+ known-first-party = ["pi-tagger"]
82
+
83
+
84
+ [tool.pyright]
85
+ include = ["src/**"]
86
+ exclude = ["/usr/lib/**"]
87
+ stubPath = "./typings"
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio >=4.19.2, < 5.0.0
2
+ numpy >= 1.23.5
3
+ onnxruntime-gpu >= 1.14.1
4
+ pandas >= 2.0.0
5
+ Pillow >= 9.5.0
6
+ safetensors
7
+ simple-parsing >= 0.1.0
8
+ huggingface-hub >= 0.14.0
9
+ hf-transfer
tagger/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common import LabelData, load_labels, preprocess_image
2
+ from .model import create_session, download_onnx
3
+
4
+ __all__ = [
5
+ "create_session",
6
+ "download_onnx",
7
+ "LabelData",
8
+ "load_labels",
9
+ "preprocess_image",
10
+ ]
tagger/common.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import asdict, dataclass
3
+ from functools import lru_cache
4
+ from os import PathLike
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from PIL import Image
11
+
12
+
13
+ class DictJsonMixin:
14
+ def asdict(self, *args, **kwargs) -> dict[str, Any]:
15
+ return asdict(self, *args, **kwargs)
16
+
17
+ def asjson(self, *args, **kwargs):
18
+ return json.dumps(asdict(self, *args, **kwargs))
19
+
20
+
21
+ @dataclass
22
+ class LabelData(DictJsonMixin):
23
+ names: list[str]
24
+ rating: list[np.int64]
25
+ general: list[np.int64]
26
+ character: list[np.int64]
27
+
28
+
29
+ @dataclass
30
+ class ImageLabels(DictJsonMixin):
31
+ caption: str
32
+ booru: str
33
+ rating: dict[str, float]
34
+ general: dict[str, float]
35
+ character: dict[str, float]
36
+
37
+
38
+ @lru_cache(maxsize=5)
39
+ def load_labels(csv_path: PathLike = "data/selected_tags.csv") -> LabelData:
40
+ csv_path = Path(csv_path).resolve()
41
+ if not csv_path.is_file():
42
+ raise FileNotFoundError("No selected_tags.csv found")
43
+
44
+ df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
45
+ tag_data = LabelData(
46
+ names=df["name"].tolist(),
47
+ rating=list(np.where(df["category"] == 9)[0]),
48
+ general=list(np.where(df["category"] == 0)[0]),
49
+ character=list(np.where(df["category"] == 4)[0]),
50
+ )
51
+
52
+ return tag_data
53
+
54
+
55
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
56
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
57
+ if image.mode not in ["RGB", "RGBA"]:
58
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
59
+ # convert RGBA to RGB with white background
60
+ if image.mode == "RGBA":
61
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
62
+ canvas.alpha_composite(image)
63
+ image = canvas.convert("RGB")
64
+ return image
65
+
66
+
67
+ def pil_pad_square(
68
+ image: Image.Image,
69
+ fill: tuple[int, int, int] = (255, 255, 255),
70
+ ) -> Image.Image:
71
+ w, h = image.size
72
+ # get the largest dimension so we can pad to a square
73
+ px = max(image.size)
74
+ # pad to square with white background
75
+ canvas = Image.new("RGB", (px, px), fill)
76
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
77
+ return canvas
78
+
79
+
80
+ def preprocess_image(
81
+ image: Image.Image,
82
+ size_px: int | tuple[int, int],
83
+ upscale: bool = True,
84
+ ) -> Image.Image:
85
+ """
86
+ Preprocess an image to be square and centered on a white background.
87
+ """
88
+ if isinstance(size_px, int):
89
+ size_px = (size_px, size_px)
90
+
91
+ # ensure RGB and pad to square
92
+ image = pil_ensure_rgb(image)
93
+ image = pil_pad_square(image)
94
+
95
+ # resize to target size
96
+ if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
97
+ if upscale is False:
98
+ raise ValueError("Image is smaller than target size, and upscaling is disabled")
99
+ image = image.resize(size_px, Image.LANCZOS)
100
+ if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
101
+ image.thumbnail(size_px, Image.BICUBIC)
102
+
103
+ return image
tagger/model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import onnxruntime as rt
5
+ from huggingface_hub import hf_hub_download
6
+
7
+
8
+ def download_onnx(
9
+ repo_id: str,
10
+ filename: str = "model.onnx",
11
+ revision: Optional[str] = None,
12
+ token: Optional[str] = None,
13
+ ) -> Path:
14
+ if not filename.endswith(".onnx"):
15
+ filename += ".onnx"
16
+
17
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token)
18
+ return Path(model_path).resolve()
19
+
20
+
21
+ def create_session(
22
+ repo_id: str,
23
+ revision: Optional[str] = None,
24
+ token: Optional[str] = None,
25
+ ) -> rt.InferenceSession:
26
+ model_path = download_onnx(repo_id, revision=revision, token=token)
27
+ if not model_path.is_file():
28
+ model_path = model_path.joinpath("model.onnx")
29
+ if not model_path.is_file():
30
+ raise FileNotFoundError(f"Model not found: {model_path}")
31
+
32
+ model = rt.InferenceSession(
33
+ str(model_path),
34
+ providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"],
35
+ )
36
+ return model