Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -22,6 +22,8 @@ Demo for the WaifuDiffusion tagger models
|
|
22 |
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
23 |
"""
|
24 |
|
|
|
|
|
25 |
|
26 |
# Dataset v3 series of models:
|
27 |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
@@ -37,6 +39,12 @@ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
|
37 |
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
38 |
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Files to download from the repos
|
41 |
MODEL_FILENAME = "model.onnx"
|
42 |
LABEL_FILENAME = "selected_tags.csv"
|
@@ -110,10 +118,12 @@ class Predictor:
|
|
110 |
csv_path = huggingface_hub.hf_hub_download(
|
111 |
model_repo,
|
112 |
LABEL_FILENAME,
|
|
|
113 |
)
|
114 |
model_path = huggingface_hub.hf_hub_download(
|
115 |
model_repo,
|
116 |
MODEL_FILENAME,
|
|
|
117 |
)
|
118 |
return csv_path, model_path
|
119 |
|
@@ -223,7 +233,7 @@ class Predictor:
|
|
223 |
" ".join(sorted_general_strings)
|
224 |
)
|
225 |
sorted_general_strings = (
|
226 |
-
", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
|
227 |
)
|
228 |
# sorted_general_strings = sorted_general_strings.map(
|
229 |
# lambda x: x.replace("_", " ") if x not in kaomojis else x
|
@@ -243,11 +253,15 @@ def main():
|
|
243 |
VIT_MODEL_DSV3_REPO,
|
244 |
VIT_LARGE_MODEL_DSV3_REPO,
|
245 |
EVA02_LARGE_MODEL_DSV3_REPO,
|
|
|
246 |
MOAT_MODEL_DSV2_REPO,
|
247 |
SWIN_MODEL_DSV2_REPO,
|
248 |
CONV_MODEL_DSV2_REPO,
|
249 |
CONV2_MODEL_DSV2_REPO,
|
250 |
VIT_MODEL_DSV2_REPO,
|
|
|
|
|
|
|
251 |
]
|
252 |
|
253 |
with gr.Blocks(title=TITLE) as demo:
|
|
|
22 |
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
23 |
"""
|
24 |
|
25 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
26 |
+
|
27 |
|
28 |
# Dataset v3 series of models:
|
29 |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
|
|
39 |
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
40 |
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
41 |
|
42 |
+
# IdolSankaku series of models:
|
43 |
+
EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
|
44 |
+
SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
# Files to download from the repos
|
49 |
MODEL_FILENAME = "model.onnx"
|
50 |
LABEL_FILENAME = "selected_tags.csv"
|
|
|
118 |
csv_path = huggingface_hub.hf_hub_download(
|
119 |
model_repo,
|
120 |
LABEL_FILENAME,
|
121 |
+
use_auth_token=HF_TOKEN,
|
122 |
)
|
123 |
model_path = huggingface_hub.hf_hub_download(
|
124 |
model_repo,
|
125 |
MODEL_FILENAME,
|
126 |
+
use_auth_token=HF_TOKEN,
|
127 |
)
|
128 |
return csv_path, model_path
|
129 |
|
|
|
233 |
" ".join(sorted_general_strings)
|
234 |
)
|
235 |
sorted_general_strings = (
|
236 |
+
", ".join(sorted_general_strings).replace("(", r"\(").replace(")", r"\)")
|
237 |
)
|
238 |
# sorted_general_strings = sorted_general_strings.map(
|
239 |
# lambda x: x.replace("_", " ") if x not in kaomojis else x
|
|
|
253 |
VIT_MODEL_DSV3_REPO,
|
254 |
VIT_LARGE_MODEL_DSV3_REPO,
|
255 |
EVA02_LARGE_MODEL_DSV3_REPO,
|
256 |
+
# ---
|
257 |
MOAT_MODEL_DSV2_REPO,
|
258 |
SWIN_MODEL_DSV2_REPO,
|
259 |
CONV_MODEL_DSV2_REPO,
|
260 |
CONV2_MODEL_DSV2_REPO,
|
261 |
VIT_MODEL_DSV2_REPO,
|
262 |
+
# ---
|
263 |
+
SWINV2_MODEL_IS_DSV1_REPO,
|
264 |
+
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
265 |
]
|
266 |
|
267 |
with gr.Blocks(title=TITLE) as demo:
|