linoyts HF staff commited on
Commit
06ee083
·
verified ·
1 Parent(s): 4285cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -38
app.py CHANGED
@@ -129,7 +129,8 @@ def preprocess_image(image: Image.Image,
129
  style_name: str = "",
130
  num_steps: int = 25,
131
  guidance_scale: float = 5,
132
- controlnet_conditioning_scale: float = 1.0,) -> Image.Image:
 
133
  """
134
  Preprocess the input image.
135
 
@@ -140,32 +141,35 @@ def preprocess_image(image: Image.Image,
140
  Image.Image: The preprocessed image.
141
  """
142
 
143
- width, height = image['composite'].size
144
- ratio = np.sqrt(1024. * 1024. / (width * height))
145
- new_width, new_height = int(width * ratio), int(height * ratio)
146
- image = image['composite'].resize((new_width, new_height))
147
-
148
- print("image:",type(image))
149
-
150
- prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
151
 
152
- print("params:", prompt, negative_prompt, style_name, num_steps, guidance_scale, controlnet_conditioning_scale)
153
- image = pipe_control(
154
- prompt=prompt,
155
- negative_prompt=negative_prompt,
156
- image=image,
157
- num_inference_steps=num_steps,
158
- controlnet_conditioning_scale=controlnet_conditioning_scale,
159
- guidance_scale=guidance_scale,
160
- width=new_width,
161
- height=new_height).images[0]
162
 
 
163
 
164
- processed_image = pipeline.preprocess_image(image)
165
- return processed_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
 
168
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
169
  """
170
  Preprocess a list of input images.
171
 
@@ -177,7 +181,7 @@ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image
177
  """
178
  images = [image[0] for image in images]
179
  processed_images = [pipeline.preprocess_image(image) for image in images]
180
- return processed_images
181
 
182
 
183
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
@@ -327,6 +331,9 @@ def extract_glb(
327
  return glb_path, glb_path
328
 
329
 
 
 
 
330
  @spaces.GPU
331
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
332
  """
@@ -377,11 +384,8 @@ def split_image(image: Image.Image) -> List[Image.Image]:
377
 
378
  with gr.Blocks(delete_cache=(600, 600), js=js_func) as demo:
379
  gr.Markdown("""
380
- ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
381
- * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
382
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
383
-
384
- ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
385
  """)
386
 
387
  with gr.Row():
@@ -438,6 +442,7 @@ with gr.Blocks(delete_cache=(600, 600), js=js_func) as demo:
438
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
439
 
440
  is_multiimage = gr.State(False)
 
441
  output_buf = gr.State()
442
 
443
  #Example images at the bottom of the page
@@ -476,15 +481,19 @@ with gr.Blocks(delete_cache=(600, 600), js=js_func) as demo:
476
  outputs=[is_multiimage, single_image_example, multiimage_example]
477
  )
478
 
479
- image_prompt.upload(
480
- preprocess_image,
481
- inputs=[image_prompt, prompt],
482
- outputs=[image_prompt],
 
 
 
 
483
  )
484
  multiimage_prompt.upload(
485
  preprocess_images,
486
  inputs=[multiimage_prompt],
487
- outputs=[multiimage_prompt],
488
  )
489
 
490
  generate_btn.click(
@@ -493,12 +502,12 @@ with gr.Blocks(delete_cache=(600, 600), js=js_func) as demo:
493
  outputs=[seed],
494
  ).then(
495
  preprocess_image,
496
- inputs=[image_prompt, prompt, negative_prompt, style],
497
- outputs=[image_prompt],
498
  ).then(
499
  image_to_3d,
500
- inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
501
- outputs=[output_buf, video_output],
502
  ).then(
503
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
504
  outputs=[extract_glb_btn, extract_gs_btn],
@@ -552,7 +561,6 @@ if __name__ == "__main__":
552
  controlnet=controlnet,
553
  vae=vae,
554
  torch_dtype=torch.float16,
555
- # scheduler=eulera_scheduler,
556
  )
557
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
558
  pipe_control.to(device)
 
129
  style_name: str = "",
130
  num_steps: int = 25,
131
  guidance_scale: float = 5,
132
+ controlnet_conditioning_scale: float = 1.0,
133
+ do_preprocess: bool = True) -> Image.Image:
134
  """
135
  Preprocess the input image.
136
 
 
141
  Image.Image: The preprocessed image.
142
  """
143
 
144
+ if do_preprocess:
145
+ width, height = image['composite'].size
146
+ ratio = np.sqrt(1024. * 1024. / (width * height))
147
+ new_width, new_height = int(width * ratio), int(height * ratio)
148
+ image = image['composite'].resize((new_width, new_height))
 
 
 
149
 
150
+ print("image:",type(image))
 
 
 
 
 
 
 
 
 
151
 
152
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
153
 
154
+ print("params:", prompt, negative_prompt, style_name, num_steps, guidance_scale, controlnet_conditioning_scale)
155
+ image = pipe_control(
156
+ prompt=prompt,
157
+ negative_prompt=negative_prompt,
158
+ image=image,
159
+ num_inference_steps=num_steps,
160
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
161
+ guidance_scale=guidance_scale,
162
+ width=new_width,
163
+ height=new_height).images[0]
164
+
165
+
166
+ processed_image = pipeline.preprocess_image(image)
167
+ return processed_image, False
168
+ else:
169
+ return image, False
170
 
171
 
172
+ def preprocess_images(images: List[Tuple[Image.Image, str]], do_preprocess = True) -> List[Image.Image]:
173
  """
174
  Preprocess a list of input images.
175
 
 
181
  """
182
  images = [image[0] for image in images]
183
  processed_images = [pipeline.preprocess_image(image) for image in images]
184
+ return processed_images, False
185
 
186
 
187
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
331
  return glb_path, glb_path
332
 
333
 
334
+ def reset_do_preprocess():
335
+ return True
336
+
337
  @spaces.GPU
338
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
339
  """
 
384
 
385
  with gr.Blocks(delete_cache=(600, 600), js=js_func) as demo:
386
  gr.Markdown("""
387
+ ## Sketch to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
388
+ * draw or upload a sketch and click "Generate" to create a 3D asset.
 
 
 
389
  """)
390
 
391
  with gr.Row():
 
442
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
443
 
444
  is_multiimage = gr.State(False)
445
+ do_preprocess = gr.State(True)
446
  output_buf = gr.State()
447
 
448
  #Example images at the bottom of the page
 
481
  outputs=[is_multiimage, single_image_example, multiimage_example]
482
  )
483
 
484
+ # image_prompt.upload(
485
+ # preprocess_image,
486
+ # inputs=[image_prompt, prompt, negative_prompt, style, do_preprocess],
487
+ # outputs=[image_prompt, do_preprocess],
488
+ # )
489
+ image_prompt.change(
490
+ reset_do_preprocess,
491
+ outputs=[do_preprocess]
492
  )
493
  multiimage_prompt.upload(
494
  preprocess_images,
495
  inputs=[multiimage_prompt],
496
+ outputs=[multiimage_prompt, do_preprocess],
497
  )
498
 
499
  generate_btn.click(
 
502
  outputs=[seed],
503
  ).then(
504
  preprocess_image,
505
+ inputs=[image_prompt, prompt, negative_prompt, style, do_preprocess],
506
+ outputs=[image_prompt, do_preprocess],
507
  ).then(
508
  image_to_3d,
509
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, do_preprocess],
510
+ outputs=[output_buf, video_output, do_preprocess],
511
  ).then(
512
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
513
  outputs=[extract_glb_btn, extract_gs_btn],
 
561
  controlnet=controlnet,
562
  vae=vae,
563
  torch_dtype=torch.float16,
 
564
  )
565
  pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
566
  pipe_control.to(device)