Spaces:
Running
on
Zero
Running
on
Zero
inflation and mask blur sliders fixed
Browse files
app.py
CHANGED
@@ -1,11 +1,12 @@
|
|
|
|
1 |
import random
|
2 |
-
from typing import Tuple
|
3 |
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import spaces
|
7 |
import torch
|
8 |
-
from PIL import Image, ImageFilter
|
9 |
from diffusers import FluxInpaintPipeline
|
10 |
from gradio_client import Client, handle_file
|
11 |
|
@@ -20,23 +21,24 @@ for taking it to the next level by enabling inpainting with the FLUX.
|
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
IMAGE_SIZE = 1024
|
22 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
23 |
|
24 |
-
client = Client("SkalskiP/florence-sam-masking")
|
25 |
-
|
26 |
-
|
27 |
-
def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
|
28 |
-
image = image.convert("RGBA")
|
29 |
-
data = image.getdata()
|
30 |
-
new_data = []
|
31 |
-
for item in data:
|
32 |
-
avg = sum(item[:3]) / 3
|
33 |
-
if avg < threshold:
|
34 |
-
new_data.append((0, 0, 0, 0))
|
35 |
-
else:
|
36 |
-
new_data.append(item)
|
37 |
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
# EXAMPLES = [
|
@@ -68,11 +70,8 @@ def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
|
|
68 |
# ]
|
69 |
# ]
|
70 |
|
71 |
-
pipe = FluxInpaintPipeline.from_pretrained(
|
72 |
-
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
73 |
-
|
74 |
|
75 |
-
def
|
76 |
original_resolution_wh: Tuple[int, int],
|
77 |
maximum_dimension: int = IMAGE_SIZE
|
78 |
) -> Tuple[int, int]:
|
@@ -92,12 +91,40 @@ def resize_image_dimensions(
|
|
92 |
return new_width, new_height
|
93 |
|
94 |
|
95 |
-
def
|
96 |
gray_img = image.convert("L")
|
97 |
pixels = list(gray_img.getdata())
|
98 |
return all(pixel == 0 for pixel in pixels)
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
@spaces.GPU(duration=100)
|
102 |
def process(
|
103 |
input_image_editor: dict,
|
@@ -125,26 +152,29 @@ def process(
|
|
125 |
gr.Info("Please upload an image.")
|
126 |
return None, None
|
127 |
|
128 |
-
if
|
129 |
gr.Info("Please draw a mask or enter a masking prompt.")
|
130 |
return None, None
|
131 |
|
132 |
-
if not
|
133 |
gr.Info("Both mask and masking prompt are provided. Please provide only one.")
|
134 |
return None, None
|
135 |
|
136 |
-
if
|
137 |
-
mask =
|
138 |
image_input=handle_file(image_path),
|
139 |
text_input=masking_prompt_text,
|
140 |
api_name="/process_image")
|
141 |
mask = Image.open(mask)
|
142 |
|
143 |
-
width, height =
|
144 |
image = image.resize((width, height), Image.LANCZOS)
|
145 |
mask = mask.resize((width, height), Image.LANCZOS)
|
146 |
if mask_inflation_slider:
|
147 |
-
|
|
|
|
|
|
|
148 |
if mask_blur_slider:
|
149 |
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur_slider))
|
150 |
|
|
|
1 |
+
import cv2
|
2 |
import random
|
3 |
+
from typing import Tuple, Optional
|
4 |
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
import spaces
|
8 |
import torch
|
9 |
+
from PIL import Image, ImageFilter
|
10 |
from diffusers import FluxInpaintPipeline
|
11 |
from gradio_client import Client, handle_file
|
12 |
|
|
|
21 |
MAX_SEED = np.iinfo(np.int32).max
|
22 |
IMAGE_SIZE = 1024
|
23 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
PIPE = FluxInpaintPipeline.from_pretrained(
|
25 |
+
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
26 |
+
CLIENT = Client("SkalskiP/florence-sam-masking")
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
|
30 |
+
# image = image.convert("RGBA")
|
31 |
+
# data = image.getdata()
|
32 |
+
# new_data = []
|
33 |
+
# for item in data:
|
34 |
+
# avg = sum(item[:3]) / 3
|
35 |
+
# if avg < threshold:
|
36 |
+
# new_data.append((0, 0, 0, 0))
|
37 |
+
# else:
|
38 |
+
# new_data.append(item)
|
39 |
+
#
|
40 |
+
# image.putdata(new_data)
|
41 |
+
# return image
|
42 |
|
43 |
|
44 |
# EXAMPLES = [
|
|
|
70 |
# ]
|
71 |
# ]
|
72 |
|
|
|
|
|
|
|
73 |
|
74 |
+
def calculate_image_dimensions_for_flux(
|
75 |
original_resolution_wh: Tuple[int, int],
|
76 |
maximum_dimension: int = IMAGE_SIZE
|
77 |
) -> Tuple[int, int]:
|
|
|
91 |
return new_width, new_height
|
92 |
|
93 |
|
94 |
+
def is_mask_empty(image: Image.Image) -> bool:
|
95 |
gray_img = image.convert("L")
|
96 |
pixels = list(gray_img.getdata())
|
97 |
return all(pixel == 0 for pixel in pixels)
|
98 |
|
99 |
|
100 |
+
def process_mask(
|
101 |
+
mask: Image.Image,
|
102 |
+
mask_inflation: Optional[int] = None,
|
103 |
+
mask_blur: Optional[int] = None
|
104 |
+
) -> Image.Image:
|
105 |
+
"""
|
106 |
+
Inflates and blurs the white regions of a mask.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
mask (Image.Image): The input mask image.
|
110 |
+
mask_inflation (Optional[int]): The number of pixels to inflate the mask by.
|
111 |
+
mask_blur (Optional[int]): The radius of the Gaussian blur to apply.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Image.Image: The processed mask with inflated and/or blurred regions.
|
115 |
+
"""
|
116 |
+
if mask_inflation and mask_inflation > 0:
|
117 |
+
mask_array = np.array(mask)
|
118 |
+
kernel = np.ones((mask_inflation, mask_inflation), np.uint8)
|
119 |
+
mask_array = cv2.dilate(mask_array, kernel, iterations=1)
|
120 |
+
mask = Image.fromarray(mask_array)
|
121 |
+
|
122 |
+
if mask_blur and mask_blur > 0:
|
123 |
+
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
124 |
+
|
125 |
+
return mask
|
126 |
+
|
127 |
+
|
128 |
@spaces.GPU(duration=100)
|
129 |
def process(
|
130 |
input_image_editor: dict,
|
|
|
152 |
gr.Info("Please upload an image.")
|
153 |
return None, None
|
154 |
|
155 |
+
if is_mask_empty(mask) and not masking_prompt_text:
|
156 |
gr.Info("Please draw a mask or enter a masking prompt.")
|
157 |
return None, None
|
158 |
|
159 |
+
if not is_mask_empty(mask) and masking_prompt_text:
|
160 |
gr.Info("Both mask and masking prompt are provided. Please provide only one.")
|
161 |
return None, None
|
162 |
|
163 |
+
if is_mask_empty(mask):
|
164 |
+
mask = CLIENT.predict(
|
165 |
image_input=handle_file(image_path),
|
166 |
text_input=masking_prompt_text,
|
167 |
api_name="/process_image")
|
168 |
mask = Image.open(mask)
|
169 |
|
170 |
+
width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
|
171 |
image = image.resize((width, height), Image.LANCZOS)
|
172 |
mask = mask.resize((width, height), Image.LANCZOS)
|
173 |
if mask_inflation_slider:
|
174 |
+
mask_array = np.array(mask)
|
175 |
+
kernel = np.ones((mask_inflation_slider, mask_inflation_slider), np.uint8)
|
176 |
+
mask_array = cv2.dilate(mask_array, kernel, iterations=1)
|
177 |
+
mask = Image.fromarray(mask_array)
|
178 |
if mask_blur_slider:
|
179 |
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur_slider))
|
180 |
|