Spaces:
Runtime error
Runtime error
bram-w
commited on
Commit
·
93d448d
1
Parent(s):
47a05f1
safety check
Browse files- edict_functions.py +29 -1
edict_functions.py
CHANGED
@@ -17,6 +17,8 @@ import os
|
|
17 |
from torchvision import datasets
|
18 |
import pickle
|
19 |
|
|
|
|
|
20 |
# StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
|
21 |
use_half_prec = True
|
22 |
if use_half_prec:
|
@@ -66,7 +68,30 @@ else:
|
|
66 |
clip.double().to(device)
|
67 |
print("Loaded all models")
|
68 |
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
def EDICT_editing(im_path,
|
@@ -597,6 +622,9 @@ def baseline_stablediffusion(prompt="",
|
|
597 |
|
598 |
image = (image / 2 + 0.5).clamp(0, 1)
|
599 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
|
|
|
|
|
|
600 |
image = (image[0] * 255).round().astype("uint8")
|
601 |
return Image.fromarray(image)
|
602 |
####################################
|
|
|
17 |
from torchvision import datasets
|
18 |
import pickle
|
19 |
|
20 |
+
|
21 |
+
|
22 |
# StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
|
23 |
use_half_prec = True
|
24 |
if use_half_prec:
|
|
|
68 |
clip.double().to(device)
|
69 |
print("Loaded all models")
|
70 |
|
71 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
72 |
+
from transformers import AutoFeatureExtractor
|
73 |
+
# load safety model
|
74 |
+
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
75 |
+
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
76 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
77 |
+
def load_replacement(x):
|
78 |
+
try:
|
79 |
+
hwc = x.shape
|
80 |
+
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
81 |
+
y = (np.array(y)/255.0).astype(x.dtype)
|
82 |
+
assert y.shape == x.shape
|
83 |
+
return y
|
84 |
+
except Exception:
|
85 |
+
return x
|
86 |
+
def check_safety(x_image):
|
87 |
+
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
88 |
+
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
89 |
+
assert x_checked_image.shape[0] == len(has_nsfw_concept)
|
90 |
+
for i in range(len(has_nsfw_concept)):
|
91 |
+
if has_nsfw_concept[i]:
|
92 |
+
# x_checked_image[i] = load_replacement(x_checked_image[i])
|
93 |
+
x_checked_image[i] *= 0 # load_replacement(x_checked_image[i])
|
94 |
+
return x_checked_image, has_nsfw_concept
|
95 |
|
96 |
|
97 |
def EDICT_editing(im_path,
|
|
|
622 |
|
623 |
image = (image / 2 + 0.5).clamp(0, 1)
|
624 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
625 |
+
|
626 |
+
image, _ = check_safety(image)
|
627 |
+
|
628 |
image = (image[0] * 255).round().astype("uint8")
|
629 |
return Image.fromarray(image)
|
630 |
####################################
|