SmilingWolf commited on
Commit
9ee88e4
·
1 Parent(s): 104e60a

Update: switch to new models with character support

Browse files
Files changed (1) hide show
  1. app.py +86 -30
app.py CHANGED
@@ -20,7 +20,12 @@ from Utils import dbimutils
20
 
21
  TITLE = "WaifuDiffusion v1.4 Tags"
22
  DESCRIPTION = """
23
- Demo for [SmilingWolf/wd-v1-4-vit-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) and [SmilingWolf/wd-v1-4-convnext-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger) with "ready to copy" prompt and a prompt analyzer.
 
 
 
 
 
24
 
25
  Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
26
  Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
@@ -31,8 +36,9 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
31
  """
32
 
33
  HF_TOKEN = os.environ["HF_TOKEN"]
34
- VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger"
35
- CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger"
 
36
  MODEL_FILENAME = "model.onnx"
37
  LABEL_FILENAME = "selected_tags.csv"
38
 
@@ -40,7 +46,8 @@ LABEL_FILENAME = "selected_tags.csv"
40
  def parse_args() -> argparse.Namespace:
41
  parser = argparse.ArgumentParser()
42
  parser.add_argument("--score-slider-step", type=float, default=0.05)
43
- parser.add_argument("--score-threshold", type=float, default=0.35)
 
44
  parser.add_argument("--share", action="store_true")
45
  return parser.parse_args()
46
 
@@ -53,12 +60,31 @@ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
53
  return model
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def load_labels() -> list[str]:
57
  path = huggingface_hub.hf_hub_download(
58
- VIT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
59
  )
60
- df = pd.read_csv(path)["name"].tolist()
61
- return df
 
 
 
 
 
62
 
63
 
64
  def plaintext_to_html(text):
@@ -70,14 +96,22 @@ def plaintext_to_html(text):
70
 
71
  def predict(
72
  image: PIL.Image.Image,
73
- selected_model: str,
74
- score_threshold: float,
75
- models: dict,
76
- labels: list[str],
 
 
 
77
  ):
 
 
78
  rawimage = image
79
 
80
- model = models[selected_model]
 
 
 
81
  _, height, width, _ = model.get_inputs()[0].shape
82
 
83
  # Alpha to white
@@ -99,18 +133,23 @@ def predict(
99
  label_name = model.get_outputs()[0].name
100
  probs = model.run([label_name], {input_name: image})[0]
101
 
102
- labels = list(zip(labels, probs[0].astype(float)))
103
 
104
  # First 4 labels are actually ratings: pick one with argmax
105
- ratings_names = labels[:4]
106
  rating = dict(ratings_names)
107
 
108
- # Everything else is tags: pick any where prediction confidence > threshold
109
- tags_names = labels[4:]
110
- res = [x for x in tags_names if x[1] > score_threshold]
111
- res = dict(res)
 
 
 
 
 
112
 
113
- b = dict(sorted(res.items(), key=lambda item: item[1], reverse=True))
114
  a = (
115
  ", ".join(list(b.keys()))
116
  .replace("_", " ")
@@ -167,40 +206,57 @@ def predict(
167
  message = "Nothing found in the image."
168
  info = f"<div><p>{message}<p></div>"
169
 
170
- return (a, c, rating, res, info)
171
 
172
 
173
  def main():
 
 
 
174
  args = parse_args()
175
- vit_model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
176
- conv_model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
177
- labels = load_labels()
178
 
179
- models = {"ViT": vit_model, "ConvNext": conv_model}
 
 
 
180
 
181
- func = functools.partial(predict, models=models, labels=labels)
 
 
 
 
 
 
182
 
183
  gr.Interface(
184
  fn=func,
185
  inputs=[
186
  gr.Image(type="pil", label="Input"),
187
- gr.Radio(["ViT", "ConvNext"], value="ViT", label="Model"),
 
 
 
 
 
 
 
188
  gr.Slider(
189
  0,
190
  1,
191
  step=args.score_slider_step,
192
- value=args.score_threshold,
193
- label="Score Threshold",
194
  ),
195
  ],
196
  outputs=[
197
  gr.Textbox(label="Output (string)"),
198
  gr.Textbox(label="Output (raw string)"),
199
  gr.Label(label="Rating"),
200
- gr.Label(label="Output (label)"),
 
201
  gr.HTML(),
202
  ],
203
- examples=[["power.jpg", "ViT", 0.5]],
204
  title=TITLE,
205
  description=DESCRIPTION,
206
  allow_flagging="never",
 
20
 
21
  TITLE = "WaifuDiffusion v1.4 Tags"
22
  DESCRIPTION = """
23
+ Demo for:
24
+ - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
25
+ - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
26
+ - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
27
+
28
+ Includes "ready to copy" prompt and a prompt analyzer.
29
 
30
  Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
31
  Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
 
36
  """
37
 
38
  HF_TOKEN = os.environ["HF_TOKEN"]
39
+ SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
40
+ CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
41
+ VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
42
  MODEL_FILENAME = "model.onnx"
43
  LABEL_FILENAME = "selected_tags.csv"
44
 
 
46
  def parse_args() -> argparse.Namespace:
47
  parser = argparse.ArgumentParser()
48
  parser.add_argument("--score-slider-step", type=float, default=0.05)
49
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
50
+ parser.add_argument("--score-character-threshold", type=float, default=0.85)
51
  parser.add_argument("--share", action="store_true")
52
  return parser.parse_args()
53
 
 
60
  return model
61
 
62
 
63
+ def change_model(model_name):
64
+ global loaded_models
65
+
66
+ if model_name == "SwinV2":
67
+ model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
68
+ elif model_name == "ConvNext":
69
+ model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
70
+ elif model_name == "ViT":
71
+ model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
72
+
73
+ loaded_models[model_name] = model
74
+ return loaded_models[model_name]
75
+
76
+
77
  def load_labels() -> list[str]:
78
  path = huggingface_hub.hf_hub_download(
79
+ SWIN_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
80
  )
81
+ df = pd.read_csv(path)
82
+
83
+ tag_names = df["name"].tolist()
84
+ rating_indexes = list(np.where(df["category"] == 9)[0])
85
+ general_indexes = list(np.where(df["category"] == 0)[0])
86
+ character_indexes = list(np.where(df["category"] == 4)[0])
87
+ return tag_names, rating_indexes, general_indexes, character_indexes
88
 
89
 
90
  def plaintext_to_html(text):
 
96
 
97
  def predict(
98
  image: PIL.Image.Image,
99
+ model_name: str,
100
+ general_threshold: float,
101
+ character_threshold: float,
102
+ tag_names: list[str],
103
+ rating_indexes: list[np.int64],
104
+ general_indexes: list[np.int64],
105
+ character_indexes: list[np.int64],
106
  ):
107
+ global loaded_models
108
+
109
  rawimage = image
110
 
111
+ model = loaded_models[model_name]
112
+ if model is None:
113
+ model = change_model(model_name)
114
+
115
  _, height, width, _ = model.get_inputs()[0].shape
116
 
117
  # Alpha to white
 
133
  label_name = model.get_outputs()[0].name
134
  probs = model.run([label_name], {input_name: image})[0]
135
 
136
+ labels = list(zip(tag_names, probs[0].astype(float)))
137
 
138
  # First 4 labels are actually ratings: pick one with argmax
139
+ ratings_names = [labels[i] for i in rating_indexes]
140
  rating = dict(ratings_names)
141
 
142
+ # Then we have general tags: pick any where prediction confidence > threshold
143
+ general_names = [labels[i] for i in general_indexes]
144
+ general_res = [x for x in general_names if x[1] > general_threshold]
145
+ general_res = dict(general_res)
146
+
147
+ # Everything else is characters: pick any where prediction confidence > threshold
148
+ character_names = [labels[i] for i in character_indexes]
149
+ character_res = [x for x in character_names if x[1] > character_threshold]
150
+ character_res = dict(character_res)
151
 
152
+ b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
153
  a = (
154
  ", ".join(list(b.keys()))
155
  .replace("_", " ")
 
206
  message = "Nothing found in the image."
207
  info = f"<div><p>{message}<p></div>"
208
 
209
+ return (a, c, rating, character_res, general_res, info)
210
 
211
 
212
  def main():
213
+ global loaded_models
214
+ loaded_models = {"SwinV2": None, "ConvNext": None, "ViT": None}
215
+
216
  args = parse_args()
 
 
 
217
 
218
+ swin_model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
219
+ loaded_models["SwinV2"] = swin_model
220
+
221
+ tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
222
 
223
+ func = functools.partial(
224
+ predict,
225
+ tag_names=tag_names,
226
+ rating_indexes=rating_indexes,
227
+ general_indexes=general_indexes,
228
+ character_indexes=character_indexes,
229
+ )
230
 
231
  gr.Interface(
232
  fn=func,
233
  inputs=[
234
  gr.Image(type="pil", label="Input"),
235
+ gr.Radio(["SwinV2", "ConvNext", "ViT"], value="SwinV2", label="Model"),
236
+ gr.Slider(
237
+ 0,
238
+ 1,
239
+ step=args.score_slider_step,
240
+ value=args.score_general_threshold,
241
+ label="General Tags Threshold",
242
+ ),
243
  gr.Slider(
244
  0,
245
  1,
246
  step=args.score_slider_step,
247
+ value=args.score_character_threshold,
248
+ label="Character Tags Threshold",
249
  ),
250
  ],
251
  outputs=[
252
  gr.Textbox(label="Output (string)"),
253
  gr.Textbox(label="Output (raw string)"),
254
  gr.Label(label="Rating"),
255
+ gr.Label(label="Output (characters)"),
256
+ gr.Label(label="Output (tags)"),
257
  gr.HTML(),
258
  ],
259
+ examples=[["power.jpg", "SwinV2", 0.5]],
260
  title=TITLE,
261
  description=DESCRIPTION,
262
  allow_flagging="never",