jiuface commited on
Commit
9fee3bb
·
1 Parent(s): 5273050
Files changed (1) hide show
  1. app.py +27 -18
app.py CHANGED
@@ -18,7 +18,8 @@ import time
18
  import boto3
19
  from io import BytesIO
20
  from datetime import datetime
21
- from diffusers.utils import load_image
 
22
  import json
23
  from preprocessor import Preprocessor
24
  from diffusers.pipelines.flux.pipeline_flux_controlnet_inpaint import FluxControlNetInpaintPipeline
@@ -32,17 +33,14 @@ MAX_SEED = np.iinfo(np.int32).max
32
  IMAGE_SIZE = 1024
33
 
34
  # init
35
- dtype = torch.bfloat16
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
  base_model = "black-forest-labs/FLUX.1-dev"
38
 
39
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
40
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
41
 
42
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
43
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
44
- pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=dtype, vae=taef1).to(device)
45
 
 
46
 
47
 
48
  control_mode_ids = {
@@ -176,7 +174,7 @@ def run_flux(
176
  generator = torch.Generator().manual_seed(seed_slicer)
177
 
178
  with calculateDuration("run pipe"):
179
- genearte_image = pipe(
180
  prompt=prompt,
181
  image=image,
182
  mask_image=mask,
@@ -191,7 +189,7 @@ def run_flux(
191
  joint_attention_kwargs={"scale": lora_scale}
192
  ).images[0]
193
 
194
- return genearte_image
195
 
196
  @spaces.GPU(duration=120)
197
  def process(
@@ -306,19 +304,30 @@ def process(
306
  resolution_wh=(width, height),
307
  progress=progress
308
  )
309
- except:
 
310
  result["message"] = "generate image failed"
311
- return None, json.dumps(result)
 
312
 
313
  print("run flux finish")
314
- if upload_to_r2:
315
- with calculateDuration("upload image"):
316
- url = upload_image_to_r2(generated_image, account_id, access_key, secret_key, bucket)
317
- result = {"status": "success", "message": "upload image success", "url": url}
318
- else:
319
- result = {"status": "success", "message": "Image generated but not uploaded"}
 
 
320
 
321
- return generated_image, json.dumps(result)
 
 
 
 
 
 
 
322
 
323
 
324
  with gr.Blocks() as demo:
@@ -449,7 +458,7 @@ with gr.Blocks() as demo:
449
  secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
450
 
451
  with gr.Column():
452
- generated_image = gr.Image(label="Result", show_label=False)
453
  output_json_component = gr.Code(label="JSON Result", language="json")
454
 
455
  submit_button_component.click(
@@ -475,7 +484,7 @@ with gr.Blocks() as demo:
475
  bucket
476
  ],
477
  outputs=[
478
- generated_image,
479
  output_json_component
480
  ]
481
  )
 
18
  import boto3
19
  from io import BytesIO
20
  from datetime import datetime
21
+ from diffusers.utils import load_image, make_image_grid
22
+
23
  import json
24
  from preprocessor import Preprocessor
25
  from diffusers.pipelines.flux.pipeline_flux_controlnet_inpaint import FluxControlNetInpaintPipeline
 
33
  IMAGE_SIZE = 1024
34
 
35
  # init
 
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
  base_model = "black-forest-labs/FLUX.1-dev"
38
 
39
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
40
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
41
 
 
 
 
42
 
43
+ pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
44
 
45
 
46
  control_mode_ids = {
 
174
  generator = torch.Generator().manual_seed(seed_slicer)
175
 
176
  with calculateDuration("run pipe"):
177
+ generated_image = pipe(
178
  prompt=prompt,
179
  image=image,
180
  mask_image=mask,
 
189
  joint_attention_kwargs={"scale": lora_scale}
190
  ).images[0]
191
 
192
+ return generated_image
193
 
194
  @spaces.GPU(duration=120)
195
  def process(
 
304
  resolution_wh=(width, height),
305
  progress=progress
306
  )
307
+ except Exception as e:
308
+ result["status"] = "faield"
309
  result["message"] = "generate image failed"
310
+ print(e)
311
+ generated_image = None
312
 
313
  print("run flux finish")
314
+ if generated_image:
315
+ if upload_to_r2:
316
+ with calculateDuration("upload image"):
317
+ url = upload_image_to_r2(generated_image, account_id, access_key, secret_key, bucket)
318
+ result = {"status": "success", "message": "upload image success", "url": url}
319
+ else:
320
+ result = {"status": "success", "message": "Image generated but not uploaded"}
321
+
322
 
323
+ final_images = []
324
+ final_images.append(image)
325
+ final_images.append(mask)
326
+ final_images.append(control_image)
327
+ if generated_image:
328
+ final_images.append(generated_image)
329
+
330
+ return final_images, json.dumps(result)
331
 
332
 
333
  with gr.Blocks() as demo:
 
458
  secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
459
 
460
  with gr.Column():
461
+ generated_images = gr.Gallery(label="Result", show_label=True)
462
  output_json_component = gr.Code(label="JSON Result", language="json")
463
 
464
  submit_button_component.click(
 
484
  bucket
485
  ],
486
  outputs=[
487
+ generated_images,
488
  output_json_component
489
  ]
490
  )