Spaces:
Runtime error
Runtime error
salexashenko
commited on
Commit
Β·
dad47db
1
Parent(s):
20c7c47
Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import base64
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
MAX_COLORS = 12
|
9 |
+
|
10 |
+
|
11 |
+
def create_binary_matrix(img_arr, target_color):
|
12 |
+
mask = np.all(img_arr == target_color, axis=-1)
|
13 |
+
binary_matrix = mask.astype(int)
|
14 |
+
return binary_matrix
|
15 |
+
|
16 |
+
def preprocess_mask(mask_, h, w, device):
|
17 |
+
mask = np.array(mask_)
|
18 |
+
mask = mask.astype(np.float32)
|
19 |
+
mask = mask[None, None]
|
20 |
+
mask[mask < 0.5] = 0
|
21 |
+
mask[mask >= 0.5] = 1
|
22 |
+
mask = torch.from_numpy(mask).to(device)
|
23 |
+
mask = torch.nn.functional.interpolate(mask, size=(h, w), mode='nearest')
|
24 |
+
return mask
|
25 |
+
|
26 |
+
def process_sketch(canvas_data):
|
27 |
+
binary_matrixes = []
|
28 |
+
base64_img = canvas_data['image']
|
29 |
+
image_data = base64.b64decode(base64_img.split(',')[1])
|
30 |
+
image = Image.open(BytesIO(image_data)).convert("RGB")
|
31 |
+
im2arr = np.array(image)
|
32 |
+
colors = [tuple(map(int, rgb[4:-1].split(','))) for rgb in canvas_data['colors']]
|
33 |
+
colors_fixed = []
|
34 |
+
|
35 |
+
r, g, b = 255, 255, 255
|
36 |
+
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
|
37 |
+
binary_matrixes.append(binary_matrix)
|
38 |
+
binary_matrix_ = np.repeat(np.expand_dims(binary_matrix, axis=(-1)), 3, axis=(-1))
|
39 |
+
colored_map = binary_matrix_*(r,g,b) + (1-binary_matrix_)*(50,50,50)
|
40 |
+
colors_fixed.append(gr.update(value=colored_map.astype(np.uint8)))
|
41 |
+
|
42 |
+
for color in colors:
|
43 |
+
r, g, b = color
|
44 |
+
if any(c != 255 for c in (r, g, b)):
|
45 |
+
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
|
46 |
+
binary_matrixes.append(binary_matrix)
|
47 |
+
binary_matrix_ = np.repeat(np.expand_dims(binary_matrix, axis=(-1)), 3, axis=(-1))
|
48 |
+
colored_map = binary_matrix_*(r,g,b) + (1-binary_matrix_)*(50,50,50)
|
49 |
+
colors_fixed.append(gr.update(value=colored_map.astype(np.uint8)))
|
50 |
+
|
51 |
+
visibilities = []
|
52 |
+
colors = []
|
53 |
+
for n in range(MAX_COLORS):
|
54 |
+
visibilities.append(gr.update(visible=False))
|
55 |
+
colors.append(gr.update())
|
56 |
+
for n in range(len(colors_fixed)):
|
57 |
+
visibilities[n] = gr.update(visible=True)
|
58 |
+
colors[n] = colors_fixed[n]
|
59 |
+
|
60 |
+
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
|
61 |
+
|
62 |
+
def process_prompts(binary_matrixes, *seg_prompts):
|
63 |
+
return [gr.update(visible=True), gr.update(value=' , '.join(seg_prompts[:len(binary_matrixes)]))]
|