jiuface commited on
Commit
daf6c0f
·
1 Parent(s): a4f92f5

remove mask generataion

Browse files
Files changed (1) hide show
  1. app.py +14 -75
app.py CHANGED
@@ -39,9 +39,12 @@ dtype = torch.bfloat16
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
  base_model = "black-forest-labs/FLUX.1-dev"
41
 
 
 
 
42
 
43
- FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device)
44
- SAM_IMAGE_MODEL = load_sam_image_model(device=device)
45
 
46
 
47
  class calculateDuration:
@@ -147,9 +150,7 @@ def run_flux(
147
  ) -> Image.Image:
148
  print("Running FLUX...")
149
 
150
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
151
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
152
- pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
153
 
154
  with calculateDuration("load lora"):
155
  print("start to load lora", lora_path, lora_weights)
@@ -178,62 +179,10 @@ def run_flux(
178
 
179
  return genearte_image
180
 
181
- @spaces.GPU(duration=10)
182
- def genearte_mask(image_input: Image.Image, masking_prompt_text: str) -> Image.Image:
183
- # generate mask by florence & sam
184
- print("Generating mask...")
185
- task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
186
-
187
- with calculateDuration("FLORENCE"):
188
- print(task_prompt, masking_prompt_text)
189
- _, result = run_florence_inference(
190
- model=FLORENCE_MODEL,
191
- processor=FLORENCE_PROCESSOR,
192
- device=device,
193
- image=image_input,
194
- task=task_prompt,
195
- text=masking_prompt_text
196
- )
197
-
198
- with calculateDuration("sv.Detections"):
199
- # start to dectect
200
- detections = sv.Detections.from_lmm(
201
- lmm=sv.LMM.FLORENCE_2,
202
- result=result,
203
- resolution_wh=image_input.size
204
- )
205
-
206
- images = []
207
-
208
- with calculateDuration("generate segmenet mask"):
209
- # using sam generate segments images
210
- detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
211
- if len(detections) == 0:
212
- gr.Info("No objects detected.")
213
- return None
214
- print("mask generated:", len(detections.mask))
215
- kernel_size = dilate
216
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
217
-
218
- for i in range(len(detections.mask)):
219
- mask = detections.mask[i].astype(np.uint8) * 255
220
- images.append(mask)
221
-
222
- # merge mark into on image
223
- merged_mask = np.zeros_like(images[0], dtype=np.uint8)
224
- for mask in images:
225
- merged_mask = cv2.bitwise_or(merged_mask, mask)
226
-
227
- images = [merged_mask]
228
-
229
- return images[0]
230
-
231
-
232
-
233
  def process(
234
  image_url: str,
 
235
  inpainting_prompt_text: str,
236
- masking_prompt_text: str,
237
  mask_inflation_slider: int,
238
  mask_blur_slider: int,
239
  seed_slicer: int,
@@ -260,26 +209,16 @@ def process(
260
  result["message"] = "invalid inpainting prompt"
261
  return json.dumps(result)
262
 
263
- if not masking_prompt_text:
264
- gr.Info("Please enter masking_prompt_text.")
265
- result["message"] = "invalid masking prompt"
266
- return json.dumps(result)
267
 
268
  with calculateDuration("load image"):
269
  image = load_image(image_url)
 
270
 
271
- mask = genearte_mask(image, masking_prompt_text)
272
-
273
- if not image:
274
- gr.Info("Please upload an image.")
275
  result["message"] = "can not load image"
276
  return json.dumps(result)
277
 
278
- if is_mask_empty(mask):
279
- gr.Info("Please draw a mask or enter a masking prompt.")
280
- result["message"] = "can not generate mask"
281
- return json.dumps(result)
282
-
283
  # generate
284
  width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
285
  image = image.resize((width, height), Image.LANCZOS)
@@ -321,11 +260,11 @@ with gr.Blocks() as demo:
321
  container=False,
322
  )
323
 
324
- masking_prompt_text_component = gr.Text(
325
- label="Masking prompt",
326
  show_label=False,
327
  max_lines=1,
328
- placeholder="Enter text to generate masking",
329
  container=False,
330
  )
331
 
@@ -439,8 +378,8 @@ with gr.Blocks() as demo:
439
  fn=process,
440
  inputs=[
441
  image_url,
 
442
  inpainting_prompt_text_component,
443
- masking_prompt_text_component,
444
  mask_inflation_slider_component,
445
  mask_blur_slider_component,
446
  seed_slicer_component,
 
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
  base_model = "black-forest-labs/FLUX.1-dev"
41
 
42
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
43
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
44
+ pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
45
 
46
+ # FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device)
47
+ # SAM_IMAGE_MODEL = load_sam_image_model(device=device)
48
 
49
 
50
  class calculateDuration:
 
150
  ) -> Image.Image:
151
  print("Running FLUX...")
152
 
153
+
 
 
154
 
155
  with calculateDuration("load lora"):
156
  print("start to load lora", lora_path, lora_weights)
 
179
 
180
  return genearte_image
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  def process(
183
  image_url: str,
184
+ mask_url: str,
185
  inpainting_prompt_text: str,
 
186
  mask_inflation_slider: int,
187
  mask_blur_slider: int,
188
  seed_slicer: int,
 
209
  result["message"] = "invalid inpainting prompt"
210
  return json.dumps(result)
211
 
 
 
 
 
212
 
213
  with calculateDuration("load image"):
214
  image = load_image(image_url)
215
+ mask = load_image(mask_url)
216
 
217
+ if not image or not mask:
218
+ gr.Info("Please upload an image & mask by url.")
 
 
219
  result["message"] = "can not load image"
220
  return json.dumps(result)
221
 
 
 
 
 
 
222
  # generate
223
  width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
224
  image = image.resize((width, height), Image.LANCZOS)
 
260
  container=False,
261
  )
262
 
263
+ mask_url = gr.Text(
264
+ label="image url of masking",
265
  show_label=False,
266
  max_lines=1,
267
+ placeholder="Enter url of masking",
268
  container=False,
269
  )
270
 
 
378
  fn=process,
379
  inputs=[
380
  image_url,
381
+ mask_url,
382
  inpainting_prompt_text_component,
 
383
  mask_inflation_slider_component,
384
  mask_blur_slider_component,
385
  seed_slicer_component,