SmilingWolf commited on
Commit
b1ae84d
·
1 Parent(s): 1b58573
Files changed (7) hide show
  1. README.md +2 -2
  2. Utils/dbimutils.py +54 -0
  3. app.py +134 -115
  4. miku.jpg +0 -0
  5. miku2.jpg +0 -0
  6. power.jpg +0 -0
  7. requirements.txt +3 -3
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: DeepDanbooru String
3
  emoji: 💬
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  duplicated_from: NoCrypt/DeepDanbooru_string
 
1
  ---
2
+ title: WaifuDiffusion v1.4 Tags
3
  emoji: 💬
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.6
8
  app_file: app.py
9
  pinned: false
10
  duplicated_from: NoCrypt/DeepDanbooru_string
Utils/dbimutils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DanBooru IMage Utility functions
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
9
+ if img.endswith(".gif"):
10
+ img = Image.open(img)
11
+ img = img.convert("RGB")
12
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
13
+ else:
14
+ img = cv2.imread(img, flag)
15
+ return img
16
+
17
+
18
+ def smart_24bit(img):
19
+ if img.dtype is np.dtype(np.uint16):
20
+ img = (img / 257).astype(np.uint8)
21
+
22
+ if len(img.shape) == 2:
23
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
24
+ elif img.shape[2] == 4:
25
+ trans_mask = img[:, :, 3] == 0
26
+ img[trans_mask] = [255, 255, 255, 255]
27
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
28
+ return img
29
+
30
+
31
+ def make_square(img, target_size):
32
+ old_size = img.shape[:2]
33
+ desired_size = max(old_size)
34
+ desired_size = max(desired_size, target_size)
35
+
36
+ delta_w = desired_size - old_size[1]
37
+ delta_h = desired_size - old_size[0]
38
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
39
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
40
+
41
+ color = [255, 255, 255]
42
+ new_im = cv2.copyMakeBorder(
43
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
44
+ )
45
+ return new_im
46
+
47
+
48
+ def smart_resize(img, size):
49
+ # Assumes the image has already gone through make_square
50
+ if img.shape[0] > size:
51
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
52
+ elif img.shape[0] < size:
53
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
54
+ return img
app.py CHANGED
@@ -4,135 +4,164 @@ from __future__ import annotations
4
 
5
  import argparse
6
  import functools
7
- import os
8
  import html
9
- import pathlib
10
- import tarfile
11
 
12
- import deepdanbooru as dd
13
  import gradio as gr
14
  import huggingface_hub
15
  import numpy as np
16
- import PIL.Image
17
- import tensorflow as tf
18
  import piexif
19
  import piexif.helper
 
 
 
20
 
21
- TITLE = 'DeepDanbooru String'
 
 
 
 
 
 
 
 
22
 
23
- TOKEN = os.environ['TOKEN']
24
- MODEL_REPO = 'NoCrypt/DeepDanbooru_string'
25
- MODEL_FILENAME = 'model-resnet_custom_v3.h5'
26
- LABEL_FILENAME = 'tags.txt'
27
 
28
 
29
  def parse_args() -> argparse.Namespace:
30
  parser = argparse.ArgumentParser()
31
- parser.add_argument('--score-slider-step', type=float, default=0.05)
32
- parser.add_argument('--score-threshold', type=float, default=0.5)
33
- parser.add_argument('--theme', type=str, default='dark-grass')
34
- parser.add_argument('--live', action='store_true')
35
- parser.add_argument('--share', action='store_true')
36
- parser.add_argument('--port', type=int)
37
- parser.add_argument('--disable-queue',
38
- dest='enable_queue',
39
- action='store_false')
40
- parser.add_argument('--allow-flagging', type=str, default='never')
41
  return parser.parse_args()
42
 
43
 
44
- def load_sample_image_paths() -> list[pathlib.Path]:
45
- image_dir = pathlib.Path('images')
46
- if not image_dir.exists():
47
- dataset_repo = 'hysts/sample-images-TADNE'
48
- path = huggingface_hub.hf_hub_download(dataset_repo,
49
- 'images.tar.gz',
50
- repo_type='dataset',
51
- use_auth_token=TOKEN)
52
- with tarfile.open(path) as f:
53
- f.extractall()
54
- return sorted(image_dir.glob('*'))
55
-
56
-
57
- def load_model() -> tf.keras.Model:
58
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
59
- MODEL_FILENAME,
60
- use_auth_token=TOKEN)
61
- model = tf.keras.models.load_model(path)
62
  return model
63
 
64
 
65
  def load_labels() -> list[str]:
66
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
67
- LABEL_FILENAME,
68
- use_auth_token=TOKEN)
69
- with open(path) as f:
70
- labels = [line.strip() for line in f.readlines()]
71
- return labels
72
 
73
  def plaintext_to_html(text):
74
- text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
 
 
75
  return text
76
 
77
- def predict(image: PIL.Image.Image, score_threshold: float,
78
- model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
 
 
 
 
 
79
  rawimage = image
80
- _, height, width, _ = model.input_shape
 
 
 
 
 
 
81
  image = np.asarray(image)
82
- image = tf.image.resize(image,
83
- size=(height, width),
84
- method=tf.image.ResizeMethod.AREA,
85
- preserve_aspect_ratio=True)
86
- image = image.numpy()
87
- image = dd.image.transform_and_pad_image(image, width, height)
88
- image = image / 255.
89
- probs = model.predict(image[None, ...])[0]
90
- probs = probs.astype(float)
91
- res = dict()
92
- for prob, label in zip(probs.tolist(), labels):
93
- if prob < score_threshold:
94
- continue
95
- res[label] = prob
96
- b = dict(sorted(res.items(),key=lambda item:item[1], reverse=True))
97
- a = ', '.join(list(b.keys())).replace('_',' ').replace('(','\(').replace(')','\)')
98
- c = ', '.join(list(b.keys()))
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  items = rawimage.info
101
- geninfo = ''
102
-
103
  if "exif" in rawimage.info:
104
  exif = piexif.load(rawimage.info["exif"])
105
- exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
106
  try:
107
  exif_comment = piexif.helper.UserComment.load(exif_comment)
108
  except ValueError:
109
- exif_comment = exif_comment.decode('utf8', errors="ignore")
110
-
111
- items['exif comment'] = exif_comment
112
  geninfo = exif_comment
113
-
114
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
115
- 'loop', 'background', 'timestamp', 'duration']:
 
 
 
 
 
 
 
 
 
 
116
  items.pop(field, None)
117
-
118
- geninfo = items.get('parameters', geninfo)
119
-
120
  info = f"""
121
  <p><h4>PNG Info</h4></p>
122
  """
123
  for key, text in items.items():
124
- info += f"""
 
125
  <div>
126
  <p><b>{plaintext_to_html(str(key))}</b></p>
127
  <p>{plaintext_to_html(str(text))}</p>
128
  </div>
129
- """.strip()+"\n"
130
-
 
 
131
  if len(info) == 0:
132
  message = "Nothing found in the image."
133
  info = f"<div><p>{message}<p></div>"
134
-
135
- return (a,c,res,info)
136
 
137
 
138
  def main():
@@ -141,45 +170,35 @@ def main():
141
  labels = load_labels()
142
 
143
  func = functools.partial(predict, model=model, labels=labels)
144
- func = functools.update_wrapper(func, predict)
145
 
146
  gr.Interface(
147
- func,
148
- [
149
- gr.inputs.Image(type='pil', label='Input'),
150
- gr.inputs.Slider(0,
151
- 1,
152
- step=args.score_slider_step,
153
- default=args.score_threshold,
154
- label='Score Threshold'),
155
- ],
156
- [
157
- gr.outputs.Textbox(label='Output (string)'),
158
- gr.outputs.Textbox(label='Output (raw string)'),
159
- gr.outputs.Label(label='Output (label)'),
160
- gr.outputs.HTML()
161
  ],
162
- examples=[
163
- ['miku.jpg',0.5],
164
- ['miku2.jpg',0.5]
 
 
 
165
  ],
 
166
  title=TITLE,
167
- description='''
168
- Demo for [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) with "ready to copy" prompt and a prompt analyzer.
169
-
170
- Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
171
-
172
- PNG Info code forked from [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
173
- ''',
174
- theme=args.theme,
175
- allow_flagging=args.allow_flagging,
176
- live=args.live,
177
  ).launch(
178
- enable_queue=args.enable_queue,
179
- server_port=args.port,
180
  share=args.share,
181
  )
182
 
183
 
184
- if __name__ == '__main__':
185
  main()
 
4
 
5
  import argparse
6
  import functools
 
7
  import html
8
+ import os
 
9
 
 
10
  import gradio as gr
11
  import huggingface_hub
12
  import numpy as np
13
+ import onnxruntime as rt
14
+ import pandas as pd
15
  import piexif
16
  import piexif.helper
17
+ import PIL.Image
18
+
19
+ from Utils import dbimutils
20
 
21
+ TITLE = "WaifuDiffusion v1.4 Tags"
22
+ DESCRIPTION = """
23
+ Demo for [SmilingWolf/wd-v1-4-vit-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) with "ready to copy" prompt and a prompt analyzer.
24
+
25
+ Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
26
+ Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
27
+
28
+ PNG Info code forked from [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
29
+ """
30
 
31
+ HF_TOKEN = os.environ["HF_TOKEN"]
32
+ MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger"
33
+ MODEL_FILENAME = "ViTB16_11_07_2022_18h19m14s.onnx"
34
+ LABEL_FILENAME = "selected_tags.csv"
35
 
36
 
37
  def parse_args() -> argparse.Namespace:
38
  parser = argparse.ArgumentParser()
39
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
40
+ parser.add_argument("--score-threshold", type=float, default=0.35)
41
+ parser.add_argument("--share", action="store_true")
 
 
 
 
 
 
 
42
  return parser.parse_args()
43
 
44
 
45
+ def load_model() -> rt.InferenceSession:
46
+ path = huggingface_hub.hf_hub_download(
47
+ MODEL_REPO, MODEL_FILENAME, use_auth_token=HF_TOKEN
48
+ )
49
+ model = rt.InferenceSession(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  return model
51
 
52
 
53
  def load_labels() -> list[str]:
54
+ path = huggingface_hub.hf_hub_download(
55
+ MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
56
+ )
57
+ df = pd.read_csv(path)["name"].tolist()
58
+ return df
59
+
60
 
61
  def plaintext_to_html(text):
62
+ text = (
63
+ "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split("\n")]) + "</p>"
64
+ )
65
  return text
66
 
67
+
68
+ def predict(
69
+ image: PIL.Image.Image,
70
+ score_threshold: float,
71
+ model: rt.InferenceSession,
72
+ labels: list[str],
73
+ ):
74
  rawimage = image
75
+ _, height, width, _ = model.get_inputs()[0].shape
76
+
77
+ # Alpha to white
78
+ image = image.convert("RGBA")
79
+ new_image = PIL.Image.new("RGBA", image.size, "WHITE")
80
+ new_image.paste(image, mask=image)
81
+ image = new_image.convert("RGB")
82
  image = np.asarray(image)
83
+
84
+ # PIL RGB to OpenCV BGR
85
+ image = image[:, :, ::-1]
86
+
87
+ image = dbimutils.make_square(image, height)
88
+ image = dbimutils.smart_resize(image, height)
89
+ image = image.astype(np.float32)
90
+ image = np.expand_dims(image, 0)
91
+
92
+ input_name = model.get_inputs()[0].name
93
+ label_name = model.get_outputs()[0].name
94
+ probs = model.run([label_name], {input_name: image})[0]
95
+
96
+ labels = list(zip(labels, probs[0].astype(float)))
97
+
98
+ # First 4 labels are actually ratings: pick one with argmax
99
+ ratings_names = labels[:4]
100
+ rating = dict(ratings_names)
101
+
102
+ # Everything else is tags: pick any where prediction confidence > threshold
103
+ tags_names = labels[4:]
104
+ res = [x for x in tags_names if x[1] > score_threshold]
105
+ res = dict(res)
106
+
107
+ b = dict(sorted(res.items(), key=lambda item: item[1], reverse=True))
108
+ a = (
109
+ ", ".join(list(b.keys()))
110
+ .replace("_", " ")
111
+ .replace("(", "\(")
112
+ .replace(")", "\)")
113
+ )
114
+ c = ", ".join(list(b.keys()))
115
+
116
  items = rawimage.info
117
+ geninfo = ""
118
+
119
  if "exif" in rawimage.info:
120
  exif = piexif.load(rawimage.info["exif"])
121
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
122
  try:
123
  exif_comment = piexif.helper.UserComment.load(exif_comment)
124
  except ValueError:
125
+ exif_comment = exif_comment.decode("utf8", errors="ignore")
126
+
127
+ items["exif comment"] = exif_comment
128
  geninfo = exif_comment
129
+
130
+ for field in [
131
+ "jfif",
132
+ "jfif_version",
133
+ "jfif_unit",
134
+ "jfif_density",
135
+ "dpi",
136
+ "exif",
137
+ "loop",
138
+ "background",
139
+ "timestamp",
140
+ "duration",
141
+ ]:
142
  items.pop(field, None)
143
+
144
+ geninfo = items.get("parameters", geninfo)
145
+
146
  info = f"""
147
  <p><h4>PNG Info</h4></p>
148
  """
149
  for key, text in items.items():
150
+ info += (
151
+ f"""
152
  <div>
153
  <p><b>{plaintext_to_html(str(key))}</b></p>
154
  <p>{plaintext_to_html(str(text))}</p>
155
  </div>
156
+ """.strip()
157
+ + "\n"
158
+ )
159
+
160
  if len(info) == 0:
161
  message = "Nothing found in the image."
162
  info = f"<div><p>{message}<p></div>"
163
+
164
+ return (a, c, rating, res, info)
165
 
166
 
167
  def main():
 
170
  labels = load_labels()
171
 
172
  func = functools.partial(predict, model=model, labels=labels)
 
173
 
174
  gr.Interface(
175
+ fn=func,
176
+ inputs=[
177
+ gr.Image(type="pil", label="Input"),
178
+ gr.Slider(
179
+ 0,
180
+ 1,
181
+ step=args.score_slider_step,
182
+ value=args.score_threshold,
183
+ label="Score Threshold",
184
+ ),
 
 
 
 
185
  ],
186
+ outputs=[
187
+ gr.Textbox(label="Output (string)"),
188
+ gr.Textbox(label="Output (raw string)"),
189
+ gr.Label(label="Rating"),
190
+ gr.Label(label="Output (label)"),
191
+ gr.HTML(),
192
  ],
193
+ examples=[["power.jpg", 0.5]],
194
  title=TITLE,
195
+ description=DESCRIPTION,
196
+ allow_flagging="never",
 
 
 
 
 
 
 
 
197
  ).launch(
198
+ enable_queue=True,
 
199
  share=args.share,
200
  )
201
 
202
 
203
+ if __name__ == "__main__":
204
  main()
miku.jpg DELETED
Binary file (125 kB)
 
miku2.jpg DELETED
Binary file (220 kB)
 
power.jpg ADDED
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  pillow>=9.0.0
2
- tensorflow>=2.7.0
3
- git+https://github.com/KichangKim/DeepDanbooru@v3-20200915-sgd-e30#egg=deepdanbooru
4
- piexif>=1.1.3
 
1
  pillow>=9.0.0
2
+ piexif>=1.1.3
3
+ onnxruntime>=1.12.0
4
+ opencv-python