linoyts HF staff commited on
Commit
7cb2038
·
verified ·
1 Parent(s): b7b00e2

add scribble controlnet

Browse files
Files changed (1) hide show
  1. app.py +59 -3
app.py CHANGED
@@ -16,10 +16,41 @@ from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -31,7 +62,11 @@ def end_session(req: gr.Request):
31
  shutil.rmtree(user_dir)
32
 
33
 
34
- def preprocess_image(image: Image.Image) -> Image.Image:
 
 
 
 
35
  """
36
  Preprocess the input image.
37
 
@@ -41,6 +76,21 @@ def preprocess_image(image: Image.Image) -> Image.Image:
41
  Returns:
42
  Image.Image: The preprocessed image.
43
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  processed_image = pipeline.preprocess_image(image)
45
  return processed_image
46
 
@@ -268,7 +318,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
268
  with gr.Column():
269
  with gr.Tabs() as input_tabs:
270
  with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
271
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
 
 
272
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
273
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
274
  gr.Markdown("""
@@ -352,7 +404,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
352
 
353
  image_prompt.upload(
354
  preprocess_image,
355
- inputs=[image_prompt],
356
  outputs=[image_prompt],
357
  )
358
  multiimage_prompt.upload(
@@ -365,6 +417,10 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
365
  get_seed,
366
  inputs=[randomize_seed, seed],
367
  outputs=[seed],
 
 
 
 
368
  ).then(
369
  image_to_3d,
370
  inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
 
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
 
19
+
20
+ import os
21
+ import random
22
+ import torch
23
+ import torchvision.transforms.functional as TF
24
+
25
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
26
+ from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
27
+ from controlnet_aux import PidiNetDetector, HEDdetector
28
+ from diffusers.utils import load_image
29
+ from huggingface_hub import HfApi
30
+ from pathlib import Path
31
+ from PIL import Image, ImageOps
32
+ import torch
33
+ import numpy as np
34
+ import cv2
35
+ import os
36
+ import random
37
+
38
  MAX_SEED = np.iinfo(np.int32).max
39
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
40
  os.makedirs(TMP_DIR, exist_ok=True)
41
 
42
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
43
+
44
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
45
+ "sd-community/sdxl-flash",
46
+ controlnet=controlnet,
47
+ vae=vae,
48
+ torch_dtype=torch.float16,
49
+ # scheduler=eulera_scheduler,
50
+ )
51
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
52
+
53
+ pipe.to(device)
54
 
55
  def start_session(req: gr.Request):
56
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
62
  shutil.rmtree(user_dir)
63
 
64
 
65
+ def preprocess_image(image: Image.Image,
66
+ prompt: str,
67
+ num_steps: int = 25,
68
+ guidance_scale: float = 5,
69
+ controlnet_conditioning_scale: float = 1.0,) -> Image.Image:
70
  """
71
  Preprocess the input image.
72
 
 
76
  Returns:
77
  Image.Image: The preprocessed image.
78
  """
79
+ width, height = image['composite'].size
80
+ ratio = np.sqrt(1024. * 1024. / (width * height))
81
+ new_width, new_height = int(width * ratio), int(height * ratio)
82
+ image = image['composite'].resize((new_width, new_height))
83
+
84
+ image = pipe(
85
+ prompt=prompt,
86
+ negative_prompt=negative_prompt,
87
+ image=image,
88
+ num_inference_steps=num_steps,
89
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
90
+ guidance_scale=guidance_scale,
91
+ width=new_width,
92
+ height=new_height,).images[0]
93
+
94
  processed_image = pipeline.preprocess_image(image)
95
  return processed_image
96
 
 
318
  with gr.Column():
319
  with gr.Tabs() as input_tabs:
320
  with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
321
+ #image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
322
+ image_prompt = image = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512))
323
+ prompt = gr.Textbox(label="Prompt")
324
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
325
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
326
  gr.Markdown("""
 
404
 
405
  image_prompt.upload(
406
  preprocess_image,
407
+ inputs=[image_prompt, prompt],
408
  outputs=[image_prompt],
409
  )
410
  multiimage_prompt.upload(
 
417
  get_seed,
418
  inputs=[randomize_seed, seed],
419
  outputs=[seed],
420
+ ).then(
421
+ preprocess_image,
422
+ inputs=[image_prompt, prompt],
423
+ outputs=[image_prompt],
424
  ).then(
425
  image_to_3d,
426
  inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],