SkalskiP commited on
Commit
95ab1e6
·
1 Parent(s): cfdaacd

inflation and mask blur sliders fixed

Browse files
Files changed (1) hide show
  1. app.py +58 -28
app.py CHANGED
@@ -1,11 +1,12 @@
 
1
  import random
2
- from typing import Tuple
3
 
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
7
  import torch
8
- from PIL import Image, ImageFilter, ImageOps
9
  from diffusers import FluxInpaintPipeline
10
  from gradio_client import Client, handle_file
11
 
@@ -20,23 +21,24 @@ for taking it to the next level by enabling inpainting with the FLUX.
20
  MAX_SEED = np.iinfo(np.int32).max
21
  IMAGE_SIZE = 1024
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
23
 
24
- client = Client("SkalskiP/florence-sam-masking")
25
-
26
-
27
- def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
28
- image = image.convert("RGBA")
29
- data = image.getdata()
30
- new_data = []
31
- for item in data:
32
- avg = sum(item[:3]) / 3
33
- if avg < threshold:
34
- new_data.append((0, 0, 0, 0))
35
- else:
36
- new_data.append(item)
37
 
38
- image.putdata(new_data)
39
- return image
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  # EXAMPLES = [
@@ -68,11 +70,8 @@ def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
68
  # ]
69
  # ]
70
 
71
- pipe = FluxInpaintPipeline.from_pretrained(
72
- "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
73
-
74
 
75
- def resize_image_dimensions(
76
  original_resolution_wh: Tuple[int, int],
77
  maximum_dimension: int = IMAGE_SIZE
78
  ) -> Tuple[int, int]:
@@ -92,12 +91,40 @@ def resize_image_dimensions(
92
  return new_width, new_height
93
 
94
 
95
- def is_image_empty(image: Image.Image) -> bool:
96
  gray_img = image.convert("L")
97
  pixels = list(gray_img.getdata())
98
  return all(pixel == 0 for pixel in pixels)
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  @spaces.GPU(duration=100)
102
  def process(
103
  input_image_editor: dict,
@@ -125,26 +152,29 @@ def process(
125
  gr.Info("Please upload an image.")
126
  return None, None
127
 
128
- if is_image_empty(mask) and not masking_prompt_text:
129
  gr.Info("Please draw a mask or enter a masking prompt.")
130
  return None, None
131
 
132
- if not is_image_empty(mask) and masking_prompt_text:
133
  gr.Info("Both mask and masking prompt are provided. Please provide only one.")
134
  return None, None
135
 
136
- if is_image_empty(mask):
137
- mask = client.predict(
138
  image_input=handle_file(image_path),
139
  text_input=masking_prompt_text,
140
  api_name="/process_image")
141
  mask = Image.open(mask)
142
 
143
- width, height = resize_image_dimensions(original_resolution_wh=image.size)
144
  image = image.resize((width, height), Image.LANCZOS)
145
  mask = mask.resize((width, height), Image.LANCZOS)
146
  if mask_inflation_slider:
147
- mask = ImageOps.expand(mask, border=mask_inflation_slider, fill=255)
 
 
 
148
  if mask_blur_slider:
149
  mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur_slider))
150
 
 
1
+ import cv2
2
  import random
3
+ from typing import Tuple, Optional
4
 
5
  import gradio as gr
6
  import numpy as np
7
  import spaces
8
  import torch
9
+ from PIL import Image, ImageFilter
10
  from diffusers import FluxInpaintPipeline
11
  from gradio_client import Client, handle_file
12
 
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  IMAGE_SIZE = 1024
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ PIPE = FluxInpaintPipeline.from_pretrained(
25
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
26
+ CLIENT = Client("SkalskiP/florence-sam-masking")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
30
+ # image = image.convert("RGBA")
31
+ # data = image.getdata()
32
+ # new_data = []
33
+ # for item in data:
34
+ # avg = sum(item[:3]) / 3
35
+ # if avg < threshold:
36
+ # new_data.append((0, 0, 0, 0))
37
+ # else:
38
+ # new_data.append(item)
39
+ #
40
+ # image.putdata(new_data)
41
+ # return image
42
 
43
 
44
  # EXAMPLES = [
 
70
  # ]
71
  # ]
72
 
 
 
 
73
 
74
+ def calculate_image_dimensions_for_flux(
75
  original_resolution_wh: Tuple[int, int],
76
  maximum_dimension: int = IMAGE_SIZE
77
  ) -> Tuple[int, int]:
 
91
  return new_width, new_height
92
 
93
 
94
+ def is_mask_empty(image: Image.Image) -> bool:
95
  gray_img = image.convert("L")
96
  pixels = list(gray_img.getdata())
97
  return all(pixel == 0 for pixel in pixels)
98
 
99
 
100
+ def process_mask(
101
+ mask: Image.Image,
102
+ mask_inflation: Optional[int] = None,
103
+ mask_blur: Optional[int] = None
104
+ ) -> Image.Image:
105
+ """
106
+ Inflates and blurs the white regions of a mask.
107
+
108
+ Args:
109
+ mask (Image.Image): The input mask image.
110
+ mask_inflation (Optional[int]): The number of pixels to inflate the mask by.
111
+ mask_blur (Optional[int]): The radius of the Gaussian blur to apply.
112
+
113
+ Returns:
114
+ Image.Image: The processed mask with inflated and/or blurred regions.
115
+ """
116
+ if mask_inflation and mask_inflation > 0:
117
+ mask_array = np.array(mask)
118
+ kernel = np.ones((mask_inflation, mask_inflation), np.uint8)
119
+ mask_array = cv2.dilate(mask_array, kernel, iterations=1)
120
+ mask = Image.fromarray(mask_array)
121
+
122
+ if mask_blur and mask_blur > 0:
123
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
124
+
125
+ return mask
126
+
127
+
128
  @spaces.GPU(duration=100)
129
  def process(
130
  input_image_editor: dict,
 
152
  gr.Info("Please upload an image.")
153
  return None, None
154
 
155
+ if is_mask_empty(mask) and not masking_prompt_text:
156
  gr.Info("Please draw a mask or enter a masking prompt.")
157
  return None, None
158
 
159
+ if not is_mask_empty(mask) and masking_prompt_text:
160
  gr.Info("Both mask and masking prompt are provided. Please provide only one.")
161
  return None, None
162
 
163
+ if is_mask_empty(mask):
164
+ mask = CLIENT.predict(
165
  image_input=handle_file(image_path),
166
  text_input=masking_prompt_text,
167
  api_name="/process_image")
168
  mask = Image.open(mask)
169
 
170
+ width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
171
  image = image.resize((width, height), Image.LANCZOS)
172
  mask = mask.resize((width, height), Image.LANCZOS)
173
  if mask_inflation_slider:
174
+ mask_array = np.array(mask)
175
+ kernel = np.ones((mask_inflation_slider, mask_inflation_slider), np.uint8)
176
+ mask_array = cv2.dilate(mask_array, kernel, iterations=1)
177
+ mask = Image.fromarray(mask_array)
178
  if mask_blur_slider:
179
  mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur_slider))
180