CaramelTEQUILA commited on
Commit
184eac4
·
verified ·
1 Parent(s): bacad8f

Upload 3 files

Browse files
Files changed (3) hide show
  1. gradio_app.py +187 -0
  2. requirements.txt +10 -0
  3. run.py +162 -0
gradio_app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import rembg
9
+ import torch
10
+ from PIL import Image
11
+ from functools import partial
12
+
13
+ from tsr.system import TSR
14
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
+
16
+ import argparse
17
+
18
+
19
+ if torch.cuda.is_available():
20
+ device = "cuda:0"
21
+ else:
22
+ device = "cpu"
23
+
24
+ model = TSR.from_pretrained(
25
+ "stabilityai/TripoSR",
26
+ config_name="config.yaml",
27
+ weight_name="model.ckpt",
28
+ )
29
+
30
+ # adjust the chunk size to balance between speed and memory usage
31
+ model.renderer.set_chunk_size(8192)
32
+ model.to(device)
33
+
34
+ rembg_session = rembg.new_session()
35
+
36
+
37
+ def check_input_image(input_image):
38
+ if input_image is None:
39
+ raise gr.Error("No image uploaded!")
40
+
41
+
42
+ def preprocess(input_image, do_remove_background, foreground_ratio):
43
+ def fill_background(image):
44
+ image = np.array(image).astype(np.float32) / 255.0
45
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
46
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
47
+ return image
48
+
49
+ if do_remove_background:
50
+ image = input_image.convert("RGB")
51
+ image = remove_background(image, rembg_session)
52
+ image = resize_foreground(image, foreground_ratio)
53
+ image = fill_background(image)
54
+ else:
55
+ image = input_image
56
+ if image.mode == "RGBA":
57
+ image = fill_background(image)
58
+ return image
59
+
60
+
61
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
62
+ scene_codes = model(image, device=device)
63
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
64
+ mesh = to_gradio_3d_orientation(mesh)
65
+ rv = []
66
+ for format in formats:
67
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
68
+ mesh.export(mesh_path.name)
69
+ rv.append(mesh_path.name)
70
+ return rv
71
+
72
+
73
+ def run_example(image_pil):
74
+ preprocessed = preprocess(image_pil, False, 0.9)
75
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
76
+ return preprocessed, mesh_name_obj, mesh_name_glb
77
+
78
+
79
+ with gr.Blocks(title="TripoSR") as interface:
80
+ gr.Markdown(
81
+ """
82
+ # TripoSR Demo
83
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
84
+
85
+ **Tips:**
86
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
87
+ 2. It's better to disable "Remove Background" for the provided examples (except fot the last one) since they have been already preprocessed.
88
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
89
+ """
90
+ )
91
+ with gr.Row(variant="panel"):
92
+ with gr.Column():
93
+ with gr.Row():
94
+ input_image = gr.Image(
95
+ label="Input Image",
96
+ image_mode="RGBA",
97
+ sources="upload",
98
+ type="pil",
99
+ elem_id="content_image",
100
+ )
101
+ processed_image = gr.Image(label="Processed Image", interactive=False)
102
+ with gr.Row():
103
+ with gr.Group():
104
+ do_remove_background = gr.Checkbox(
105
+ label="Remove Background", value=True
106
+ )
107
+ foreground_ratio = gr.Slider(
108
+ label="Foreground Ratio",
109
+ minimum=0.5,
110
+ maximum=1.0,
111
+ value=0.85,
112
+ step=0.05,
113
+ )
114
+ mc_resolution = gr.Slider(
115
+ label="Marching Cubes Resolution",
116
+ minimum=32,
117
+ maximum=320,
118
+ value=256,
119
+ step=32
120
+ )
121
+ with gr.Row():
122
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
123
+ with gr.Column():
124
+ with gr.Tab("OBJ"):
125
+ output_model_obj = gr.Model3D(
126
+ label="Output Model (OBJ Format)",
127
+ interactive=False,
128
+ )
129
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
130
+ with gr.Tab("GLB"):
131
+ output_model_glb = gr.Model3D(
132
+ label="Output Model (GLB Format)",
133
+ interactive=False,
134
+ )
135
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
136
+ with gr.Row(variant="panel"):
137
+ gr.Examples(
138
+ examples=[
139
+ "examples/hamburger.png",
140
+ "examples/poly_fox.png",
141
+ "examples/robot.png",
142
+ "examples/teapot.png",
143
+ "examples/tiger_girl.png",
144
+ "examples/horse.png",
145
+ "examples/flamingo.png",
146
+ "examples/unicorn.png",
147
+ "examples/chair.png",
148
+ "examples/iso_house.png",
149
+ "examples/marble.png",
150
+ "examples/police_woman.png",
151
+ "examples/captured.jpeg",
152
+ ],
153
+ inputs=[input_image],
154
+ outputs=[processed_image, output_model_obj, output_model_glb],
155
+ cache_examples=False,
156
+ fn=partial(run_example),
157
+ label="Examples",
158
+ examples_per_page=20,
159
+ )
160
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
161
+ fn=preprocess,
162
+ inputs=[input_image, do_remove_background, foreground_ratio],
163
+ outputs=[processed_image],
164
+ ).success(
165
+ fn=generate,
166
+ inputs=[processed_image, mc_resolution],
167
+ outputs=[output_model_obj, output_model_glb],
168
+ )
169
+
170
+
171
+
172
+ if __name__ == '__main__':
173
+ parser = argparse.ArgumentParser()
174
+ parser.add_argument('--username', type=str, default=None, help='Username for authentication')
175
+ parser.add_argument('--password', type=str, default=None, help='Password for authentication')
176
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
177
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
178
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
179
+ parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
180
+ args = parser.parse_args()
181
+ interface.queue(max_size=args.queuesize)
182
+ interface.launch(
183
+ auth=(args.username, args.password) if (args.username and args.password) else None,
184
+ share=args.share,
185
+ server_name="0.0.0.0" if args.listen else None,
186
+ server_port=args.port
187
+ )
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ Pillow==10.1.0
3
+ einops==0.7.0
4
+ git+https://github.com/tatsy/torchmcubes.git
5
+ transformers==4.35.0
6
+ trimesh==4.0.5
7
+ rembg
8
+ huggingface-hub
9
+ imageio[ffmpeg]
10
+ gradio
run.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import rembg
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from tsr.system import TSR
12
+ from tsr.utils import remove_background, resize_foreground, save_video
13
+
14
+
15
+ class Timer:
16
+ def __init__(self):
17
+ self.items = {}
18
+ self.time_scale = 1000.0 # ms
19
+ self.time_unit = "ms"
20
+
21
+ def start(self, name: str) -> None:
22
+ if torch.cuda.is_available():
23
+ torch.cuda.synchronize()
24
+ self.items[name] = time.time()
25
+ logging.info(f"{name} ...")
26
+
27
+ def end(self, name: str) -> float:
28
+ if name not in self.items:
29
+ return
30
+ if torch.cuda.is_available():
31
+ torch.cuda.synchronize()
32
+ start_time = self.items.pop(name)
33
+ delta = time.time() - start_time
34
+ t = delta * self.time_scale
35
+ logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
36
+
37
+
38
+ timer = Timer()
39
+
40
+
41
+ logging.basicConfig(
42
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
43
+ )
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
46
+ parser.add_argument(
47
+ "--device",
48
+ default="cuda:0",
49
+ type=str,
50
+ help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
51
+ )
52
+ parser.add_argument(
53
+ "--pretrained-model-name-or-path",
54
+ default="stabilityai/TripoSR",
55
+ type=str,
56
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
57
+ )
58
+ parser.add_argument(
59
+ "--chunk-size",
60
+ default=8192,
61
+ type=int,
62
+ help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
63
+ )
64
+ parser.add_argument(
65
+ "--mc-resolution",
66
+ default=256,
67
+ type=int,
68
+ help="Marching cubes grid resolution. Default: 256"
69
+ )
70
+ parser.add_argument(
71
+ "--no-remove-bg",
72
+ action="store_true",
73
+ help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
74
+ )
75
+ parser.add_argument(
76
+ "--foreground-ratio",
77
+ default=0.85,
78
+ type=float,
79
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
80
+ )
81
+ parser.add_argument(
82
+ "--output-dir",
83
+ default="output/",
84
+ type=str,
85
+ help="Output directory to save the results. Default: 'output/'",
86
+ )
87
+ parser.add_argument(
88
+ "--model-save-format",
89
+ default="obj",
90
+ type=str,
91
+ choices=["obj", "glb"],
92
+ help="Format to save the extracted mesh. Default: 'obj'",
93
+ )
94
+ parser.add_argument(
95
+ "--render",
96
+ action="store_true",
97
+ help="If specified, save a NeRF-rendered video. Default: false",
98
+ )
99
+ args = parser.parse_args()
100
+
101
+ output_dir = args.output_dir
102
+ os.makedirs(output_dir, exist_ok=True)
103
+
104
+ device = args.device
105
+ if not torch.cuda.is_available():
106
+ device = "cpu"
107
+
108
+ timer.start("Initializing model")
109
+ model = TSR.from_pretrained(
110
+ args.pretrained_model_name_or_path,
111
+ config_name="config.yaml",
112
+ weight_name="model.ckpt",
113
+ )
114
+ model.renderer.set_chunk_size(args.chunk_size)
115
+ model.to(device)
116
+ timer.end("Initializing model")
117
+
118
+ timer.start("Processing images")
119
+ images = []
120
+
121
+ if args.no_remove_bg:
122
+ rembg_session = None
123
+ else:
124
+ rembg_session = rembg.new_session()
125
+
126
+ for i, image_path in enumerate(args.image):
127
+ if args.no_remove_bg:
128
+ image = np.array(Image.open(image_path).convert("RGB"))
129
+ else:
130
+ image = remove_background(Image.open(image_path), rembg_session)
131
+ image = resize_foreground(image, args.foreground_ratio)
132
+ image = np.array(image).astype(np.float32) / 255.0
133
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
134
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
135
+ if not os.path.exists(os.path.join(output_dir, str(i))):
136
+ os.makedirs(os.path.join(output_dir, str(i)))
137
+ image.save(os.path.join(output_dir, str(i), f"input.png"))
138
+ images.append(image)
139
+ timer.end("Processing images")
140
+
141
+ for i, image in enumerate(images):
142
+ logging.info(f"Running image {i + 1}/{len(images)} ...")
143
+
144
+ timer.start("Running model")
145
+ with torch.no_grad():
146
+ scene_codes = model([image], device=device)
147
+ timer.end("Running model")
148
+
149
+ if args.render:
150
+ timer.start("Rendering")
151
+ render_images = model.render(scene_codes, n_views=30, return_type="pil")
152
+ for ri, render_image in enumerate(render_images[0]):
153
+ render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
154
+ save_video(
155
+ render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
156
+ )
157
+ timer.end("Rendering")
158
+
159
+ timer.start("Exporting mesh")
160
+ meshes = model.extract_mesh(scene_codes, resolution=args.mc_resolution)
161
+ meshes[0].export(os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}"))
162
+ timer.end("Exporting mesh")