kxhit commited on
Commit
0054ddf
·
1 Parent(s): f7fc9cc
Files changed (2) hide show
  1. README.md +1 -1
  2. gradio_demo/gradio_demo.py +0 -782
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: EscherNet
3
- app_file: ./gradio_demo/gradio_demo.py
4
  sdk: gradio
5
  sdk_version: 4.19.2
6
  ---
 
1
  ---
2
  title: EscherNet
3
+ app_file: app.py
4
  sdk: gradio
5
  sdk_version: 4.19.2
6
  ---
gradio_demo/gradio_demo.py DELETED
@@ -1,782 +0,0 @@
1
- import spaces
2
- import torch
3
- print("cuda is available: ", torch.cuda.is_available())
4
-
5
- import gradio as gr
6
- import os
7
- import shutil
8
- import rembg
9
- import numpy as np
10
- import math
11
- import open3d as o3d
12
- from PIL import Image
13
- import torchvision
14
- import trimesh
15
- from skimage.io import imsave
16
- import imageio
17
- import cv2
18
- import matplotlib.pyplot as pl
19
- pl.ion()
20
-
21
- CaPE_TYPE = "6DoF"
22
- device = 'cuda' #if torch.cuda.is_available() else 'cpu'
23
- weight_dtype = torch.float16
24
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
-
26
- # EscherNet
27
- # create angles in archimedean spiral with N steps
28
- def get_archimedean_spiral(sphere_radius, num_steps=250):
29
- # x-z plane, around upper y
30
- '''
31
- https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
32
- '''
33
- a = 40
34
- r = sphere_radius
35
-
36
- translations = []
37
- angles = []
38
-
39
- # i = a / 2
40
- i = 0.01
41
- while i < a:
42
- theta = i / a * math.pi
43
- x = r * math.sin(theta) * math.cos(-i)
44
- z = r * math.sin(-theta + math.pi) * math.sin(-i)
45
- y = r * - math.cos(theta)
46
-
47
- # translations.append((x, y, z)) # origin
48
- translations.append((x, z, -y))
49
- angles.append([np.rad2deg(-i), np.rad2deg(theta)])
50
-
51
- # i += a / (2 * num_steps)
52
- i += a / (1 * num_steps)
53
-
54
- return np.array(translations), np.stack(angles)
55
-
56
- def look_at(origin, target, up):
57
- forward = (target - origin)
58
- forward = forward / np.linalg.norm(forward)
59
- right = np.cross(up, forward)
60
- right = right / np.linalg.norm(right)
61
- new_up = np.cross(forward, right)
62
- rotation_matrix = np.column_stack((right, new_up, -forward, target))
63
- matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
64
- return matrix
65
-
66
- import einops
67
- import sys
68
-
69
- sys.path.insert(0, "./6DoF/") # TODO change it when deploying
70
- # use the customized diffusers modules
71
- from diffusers import DDIMScheduler
72
- from dataset import get_pose
73
- from CN_encoder import CN_encoder
74
- from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
-
76
- pretrained_model_name_or_path = "kxic/EscherNet_demo"
77
- resolution = 256
78
- h,w = resolution,resolution
79
- guidance_scale = 3.0
80
- radius = 2.2
81
- bg_color = [1., 1., 1., 1.]
82
- image_transforms = torchvision.transforms.Compose(
83
- [
84
- torchvision.transforms.Resize((resolution, resolution)), # 256, 256
85
- torchvision.transforms.ToTensor(),
86
- torchvision.transforms.Normalize([0.5], [0.5])
87
- ]
88
- )
89
- xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
90
- # only half toop
91
- xyzs_spiral = xyzs_spiral[:100]
92
- angles_spiral = angles_spiral[:100]
93
-
94
- # Init pipeline
95
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
96
- image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
97
- pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
98
- pretrained_model_name_or_path,
99
- revision=None,
100
- scheduler=scheduler,
101
- image_encoder=None,
102
- safety_checker=None,
103
- feature_extractor=None,
104
- torch_dtype=weight_dtype,
105
- )
106
- pipeline.image_encoder = image_encoder.to(weight_dtype)
107
- pipeline.set_progress_bar_config(disable=False)
108
-
109
- # pipeline.enable_xformers_memory_efficient_attention()
110
- # enable vae slicing
111
- pipeline.enable_vae_slicing()
112
- # pipeline = pipeline.to(device)
113
-
114
-
115
-
116
- @spaces.GPU(duration=120)
117
- def run_eschernet(tmpdirname, eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
118
- # set the random seed
119
- generator = torch.Generator(device=device).manual_seed(sample_seed)
120
- T_out = nvs_num
121
- T_in = len(eschernet_input_dict['imgs'])
122
- ####### output pose
123
- # TODO choose T_out number of poses sequentially from the spiral
124
- xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
125
- angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
126
-
127
- ####### input's max radius for translation scaling
128
- radii = eschernet_input_dict['radii']
129
- max_t = np.max(radii)
130
- min_t = np.min(radii)
131
-
132
- ####### input pose
133
- pose_in = []
134
- for T_in_index in range(T_in):
135
- pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
136
- pose[1:3, :] *= -1 # coordinate system conversion
137
- pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
138
- pose_in.append(torch.from_numpy(pose))
139
-
140
- ####### input image
141
- img = eschernet_input_dict['imgs'] / 255.
142
- img[img[:, :, :, -1] == 0.] = bg_color
143
- # TODO batch image_transforms
144
- input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
145
-
146
- ####### nvs pose
147
- pose_out = []
148
- for T_out_index in range(T_out):
149
- azimuth, polar = angles_out[T_out_index]
150
- if CaPE_TYPE == "4DoF":
151
- pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
152
- elif CaPE_TYPE == "6DoF":
153
- pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
154
- pose = np.linalg.inv(pose)
155
- pose[2, :] *= -1
156
- pose_out.append(torch.from_numpy(get_pose(pose)))
157
-
158
-
159
-
160
- # [B, T, C, H, W]
161
- input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
162
- # [B, T, 4]
163
- pose_in = np.stack(pose_in)
164
- pose_out = np.stack(pose_out)
165
-
166
- if CaPE_TYPE == "6DoF":
167
- pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
168
- pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
169
- pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
170
- pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
171
-
172
- pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
173
- pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
174
-
175
- input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
176
- assert T_in == input_image.shape[0]
177
- assert T_in == pose_in.shape[1]
178
- assert T_out == pose_out.shape[1]
179
-
180
- # run inference
181
- pipeline.to(device)
182
- pipeline.enable_xformers_memory_efficient_attention()
183
- if CaPE_TYPE == "6DoF":
184
- with torch.autocast("cuda"):
185
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
186
- poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
187
- height=h, width=w, T_in=T_in, T_out=T_out,
188
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
189
- output_type="numpy").images
190
- elif CaPE_TYPE == "4DoF":
191
- with torch.autocast("cuda"):
192
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in],
193
- height=h, width=w, T_in=T_in, T_out=T_out,
194
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
195
- output_type="numpy").images
196
-
197
- # save output image
198
- output_dir = os.path.join(tmpdirname, "eschernet")
199
- if os.path.exists(output_dir):
200
- shutil.rmtree(output_dir)
201
- os.makedirs(output_dir, exist_ok=True)
202
- # save to N imgs
203
- for i in range(T_out):
204
- imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
205
- # make a gif
206
- frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
207
- frame_one = frames[0]
208
- frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
209
- save_all=True, duration=50, loop=1)
210
-
211
- # get a video
212
- video_path = os.path.join(output_dir, "output.mp4")
213
- imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
214
-
215
-
216
- return image, video_path
217
-
218
- # TODO mesh it
219
- @spaces.GPU(duration=120)
220
- def make3d():
221
- pass
222
-
223
-
224
-
225
- ############################ Dust3r as Pose Estimation ############################
226
- from scipy.spatial.transform import Rotation
227
- import copy
228
-
229
- from dust3r.inference import inference
230
- from dust3r.model import AsymmetricCroCo3DStereo
231
- from dust3r.image_pairs import make_pairs
232
- from dust3r.utils.image import load_images, rgb
233
- from dust3r.utils.device import to_numpy
234
- from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
235
- from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
236
-
237
- import functools
238
- import math
239
-
240
- @spaces.GPU
241
- def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
242
- cam_color=None, as_pointcloud=False,
243
- transparent_cams=False, silent=False, same_focals=False):
244
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
245
- if not same_focals:
246
- assert (len(cams2world) == len(focals))
247
- pts3d = to_numpy(pts3d)
248
- imgs = to_numpy(imgs)
249
- focals = to_numpy(focals)
250
- cams2world = to_numpy(cams2world)
251
-
252
- scene = trimesh.Scene()
253
-
254
- # add axes
255
- scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
256
-
257
- # full pointcloud
258
- if as_pointcloud:
259
- pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
260
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
261
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
262
- scene.add_geometry(pct)
263
- else:
264
- meshes = []
265
- for i in range(len(imgs)):
266
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
267
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
268
- scene.add_geometry(mesh)
269
-
270
- # add each camera
271
- for i, pose_c2w in enumerate(cams2world):
272
- if isinstance(cam_color, list):
273
- camera_edge_color = cam_color[i]
274
- else:
275
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
276
- if same_focals:
277
- focal = focals[0]
278
- else:
279
- focal = focals[i]
280
- add_scene_cam(scene, pose_c2w, camera_edge_color,
281
- None if transparent_cams else imgs[i], focal,
282
- imsize=imgs[i].shape[1::-1], screen_width=cam_size)
283
-
284
- rot = np.eye(4)
285
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
286
- scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
287
- outfile = os.path.join(outdir, 'scene.glb')
288
- if not silent:
289
- print('(exporting 3D scene to', outfile, ')')
290
- scene.export(file_obj=outfile)
291
- return outfile
292
-
293
- @spaces.GPU(duration=120)
294
- def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
295
- clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
296
- """
297
- extract 3D_model (glb file) from a reconstructed scene
298
- """
299
- if scene is None:
300
- return None
301
- # post processes
302
- if clean_depth:
303
- scene = scene.clean_pointcloud()
304
- if mask_sky:
305
- scene = scene.mask_sky()
306
-
307
- # get optimized values from scene
308
- rgbimg = to_numpy(scene.imgs)
309
- focals = to_numpy(scene.get_focals().cpu())
310
- # cams2world = to_numpy(scene.get_im_poses().cpu())
311
- # TODO use the vis_poses
312
- cams2world = scene.vis_poses
313
-
314
- # 3D pointcloud from depthmap, poses and intrinsics
315
- # pts3d = to_numpy(scene.get_pts3d())
316
- # TODO use the vis_poses
317
- pts3d = scene.vis_pts3d
318
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
319
- msk = to_numpy(scene.get_masks())
320
-
321
- return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
322
- transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
323
- same_focals=same_focals)
324
-
325
- @spaces.GPU(duration=120)
326
- def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr,
327
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
328
- scenegraph_type, winsize, refid, same_focals):
329
- """
330
- from a list of images, run dust3r inference, global aligner.
331
- then run get_3D_model_from_scene
332
- """
333
- # remove the directory if it already exists
334
- if os.path.exists(outdir):
335
- shutil.rmtree(outdir)
336
- os.makedirs(outdir, exist_ok=True)
337
- imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True)
338
- if len(imgs) == 1:
339
- imgs = [imgs[0], copy.deepcopy(imgs[0])]
340
- imgs[1]['idx'] = 1
341
- if scenegraph_type == "swin":
342
- scenegraph_type = scenegraph_type + "-" + str(winsize)
343
- elif scenegraph_type == "oneref":
344
- scenegraph_type = scenegraph_type + "-" + str(refid)
345
-
346
- pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
347
- output = inference(pairs, model, device, batch_size=1, verbose=not silent)
348
-
349
- mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
350
- scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
351
- lr = 0.01
352
-
353
- if mode == GlobalAlignerMode.PointCloudOptimizer:
354
- loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
355
-
356
- # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
357
- # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
358
-
359
- # also return rgb, depth and confidence imgs
360
- # depth is normalized with the max value for all images
361
- # we apply the jet colormap on the confidence maps
362
- rgbimg = scene.imgs
363
- # depths = to_numpy(scene.get_depthmaps())
364
- # confs = to_numpy([c for c in scene.im_conf])
365
- # cmap = pl.get_cmap('jet')
366
- # depths_max = max([d.max() for d in depths])
367
- # depths = [d / depths_max for d in depths]
368
- # confs_max = max([d.max() for d in confs])
369
- # confs = [cmap(d / confs_max) for d in confs]
370
-
371
- imgs = []
372
- rgbaimg = []
373
- for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
374
- imgs.append(rgbimg[i])
375
- # imgs.append(rgb(depths[i]))
376
- # imgs.append(rgb(confs[i]))
377
- # imgs.append(imgs_rgba[i])
378
- if len(imgs_rgba) == 1 and i == 1:
379
- imgs.append(imgs_rgba[0])
380
- rgbaimg.append(np.array(imgs_rgba[0]))
381
- else:
382
- imgs.append(imgs_rgba[i])
383
- rgbaimg.append(np.array(imgs_rgba[i]))
384
-
385
- rgbaimg = np.array(rgbaimg)
386
-
387
- # for eschernet
388
- # get optimized values from scene
389
- rgbimg = to_numpy(scene.imgs)
390
- focals = to_numpy(scene.get_focals().cpu())
391
- cams2world = to_numpy(scene.get_im_poses().cpu())
392
-
393
- # 3D pointcloud from depthmap, poses and intrinsics
394
- pts3d = to_numpy(scene.get_pts3d())
395
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
396
- msk = to_numpy(scene.get_masks())
397
- obj_mask = rgbaimg[..., 3] > 0
398
-
399
- # TODO set global coordinate system at the center of the scene, z-axis is up
400
- pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
401
- pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
402
- centroid = np.mean(pts_obj, axis=0) # obj center
403
- obj2world = np.eye(4)
404
- obj2world[:3, 3] = -centroid # T_wc
405
-
406
- # get z_up vector
407
- # TODO fit a plane and get the normal vector
408
- pcd = o3d.geometry.PointCloud()
409
- pcd.points = o3d.utility.Vector3dVector(pts)
410
- plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
411
- # get the normalised normal vector dim = 3
412
- normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
413
- # the normal direction should be pointing up
414
- if normal[1] < 0:
415
- normal = -normal
416
- # print("normal", normal)
417
-
418
- # # TODO z-up 180
419
- # z_up = np.array([[1,0,0,0],
420
- # [0,-1,0,0],
421
- # [0,0,-1,0],
422
- # [0,0,0,1]])
423
- # obj2world = z_up @ obj2world
424
-
425
- # # avg the y
426
- # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
427
- # # import pdb; pdb.set_trace()
428
- # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
429
- # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
430
- # rot = Rotation.from_rotvec(rot_angle * rot_axis)
431
- # z_up = np.eye(4)
432
- # z_up[:3, :3] = rot.as_matrix()
433
-
434
- # get the rotation matrix from normal to z-axis
435
- z_axis = np.array([0, 0, 1])
436
- rot_axis = np.cross(normal, z_axis)
437
- rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
438
- rot = Rotation.from_rotvec(rot_angle * rot_axis)
439
- z_up = np.eye(4)
440
- z_up[:3, :3] = rot.as_matrix()
441
- obj2world = z_up @ obj2world
442
- # flip 180
443
- flip_rot = np.array([[1, 0, 0, 0],
444
- [0, -1, 0, 0],
445
- [0, 0, -1, 0],
446
- [0, 0, 0, 1]])
447
- obj2world = flip_rot @ obj2world
448
-
449
- # get new cams2obj
450
- cams2obj = []
451
- for i, cam2world in enumerate(cams2world):
452
- cams2obj.append(obj2world @ cam2world)
453
- # TODO transform pts3d to the new coordinate system
454
- for i, pts in enumerate(pts3d):
455
- pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
456
- -1)) \
457
- .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
458
- cams2world = np.array(cams2obj)
459
- # TODO rewrite hack
460
- scene.vis_poses = cams2world.copy()
461
- scene.vis_pts3d = pts3d.copy()
462
-
463
- # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
464
- for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
465
- np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
466
- pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
467
- pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
468
- # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
469
- # save the min/max radius of camera
470
- radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
471
- np.save(os.path.join(outdir, "radii.npy"), radii)
472
-
473
- eschernet_input = {"poses": cams2world,
474
- "radii": radii,
475
- "imgs": rgbaimg}
476
-
477
- outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
478
- clean_depth, transparent_cams, cam_size, same_focals=same_focals)
479
-
480
- return scene, outfile, imgs, eschernet_input
481
-
482
-
483
- def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
484
- num_files = len(inputfiles) if inputfiles is not None else 1
485
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
486
- if scenegraph_type == "swin":
487
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
488
- minimum=1, maximum=max_winsize, step=1, visible=True)
489
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
490
- maximum=num_files - 1, step=1, visible=False)
491
- elif scenegraph_type == "oneref":
492
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
493
- minimum=1, maximum=max_winsize, step=1, visible=False)
494
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
495
- maximum=num_files - 1, step=1, visible=True)
496
- else:
497
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
498
- minimum=1, maximum=max_winsize, step=1, visible=False)
499
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
500
- maximum=num_files - 1, step=1, visible=False)
501
- return winsize, refid
502
-
503
-
504
- def get_examples(path):
505
- objs = []
506
- for obj_name in sorted(os.listdir(path)):
507
- img_files = []
508
- for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
509
- img_files.append(os.path.join(path, obj_name, img_file))
510
- objs.append([img_files])
511
- print("objs = ", objs)
512
- return objs
513
-
514
- def preview_input(inputfiles):
515
- if inputfiles is None:
516
- return None
517
- imgs = []
518
- for img_file in inputfiles:
519
- img = pl.imread(img_file)
520
- imgs.append(img)
521
- return imgs
522
-
523
- def main():
524
- # dustr init
525
- silent = False
526
- image_size = 224
527
- weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
528
- model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
529
- # dust3r will write the 3D model inside tmpdirname
530
- # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
531
- tmpdirname = os.path.join('logs/user_object')
532
- # remove the directory if it already exists
533
- if os.path.exists(tmpdirname):
534
- shutil.rmtree(tmpdirname)
535
- os.makedirs(tmpdirname, exist_ok=True)
536
- if not silent:
537
- print('Outputing stuff in', tmpdirname)
538
-
539
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
540
- model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
541
-
542
- generate_mvs = functools.partial(run_eschernet, tmpdirname)
543
-
544
- _HEADER_ = '''
545
- <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
546
- <b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
547
-
548
- Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
549
-
550
- <a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
551
- <a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
552
- <a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
553
-
554
- <h4><b>Tips:</b></h4>
555
-
556
- - Our model can take <b>any number input images</b>. The more images you provide, the better the results.
557
-
558
- - Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
559
-
560
- - The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
561
-
562
- - The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
563
-
564
- '''
565
-
566
- _CITE_ = r"""
567
- 📝 <b>Citation</b>:
568
- ```bibtex
569
- @article{kong2024eschernet,
570
- title={EscherNet: A Generative Model for Scalable View Synthesis},
571
- author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
572
- journal={arXiv preprint arXiv:2402.03908},
573
- year={2024}
574
- }
575
- ```
576
- """
577
-
578
- with gr.Blocks() as demo:
579
- gr.Markdown(_HEADER_)
580
- mv_images = gr.State()
581
- scene = gr.State(None)
582
- eschernet_input = gr.State(None)
583
- with gr.Row(variant="panel"):
584
- # left column
585
- with gr.Column():
586
- with gr.Row():
587
- input_image = gr.File(file_count="multiple")
588
- # with gr.Row():
589
- # # set the size of the window
590
- # preview_image = gr.Gallery(label='Input Views', rows=1,
591
- with gr.Row():
592
- run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
593
- with gr.Row():
594
- processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
595
- with gr.Row(variant="panel"):
596
- # input examples under "examples" folder
597
- gr.Examples(
598
- examples=get_examples('examples'),
599
- # examples=[
600
- # [['examples/controller/frame000077.jpg', 'examples/controller/frame000032.jpg', 'examples/controller/frame000172.jpg']],
601
- # [['examples/hairdryer/frame000081.jpg', 'examples/hairdryer/frame000162.jpg', 'examples/hairdryer/frame000003.jpg']],
602
- # ],
603
- inputs=[input_image],
604
- label="Examples (click one set of images to start!)",
605
- examples_per_page=20
606
- )
607
-
608
-
609
-
610
-
611
-
612
- # right column
613
- with gr.Column():
614
-
615
- with gr.Row():
616
- outmodel = gr.Model3D()
617
-
618
- with gr.Row():
619
- gr.Markdown('''
620
- <h4><b>Check if the pose and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
621
- ''')
622
-
623
- with gr.Row():
624
- with gr.Group():
625
- do_remove_background = gr.Checkbox(
626
- label="Remove Background", value=True
627
- )
628
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
629
-
630
- sample_steps = gr.Slider(
631
- label="Sample Steps",
632
- minimum=30,
633
- maximum=75,
634
- value=50,
635
- step=5,
636
- visible=False
637
- )
638
-
639
- nvs_num = gr.Slider(
640
- label="Number of Novel Views",
641
- minimum=5,
642
- maximum=100,
643
- value=30,
644
- step=1
645
- )
646
-
647
- nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
648
- value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
649
-
650
- with gr.Row():
651
- gr.Markdown('''
652
- <h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
653
- ''')
654
-
655
- with gr.Row():
656
- submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
657
-
658
- with gr.Row():
659
- # mv_show_images = gr.Image(
660
- # label="Generated Multi-views",
661
- # type="pil",
662
- # width=379,
663
- # interactive=False
664
- # )
665
- with gr.Column():
666
- output_video = gr.Video(
667
- label="video", format="mp4",
668
- width=379,
669
- autoplay=True,
670
- interactive=False
671
- )
672
-
673
- # with gr.Row():
674
- # with gr.Tab("OBJ"):
675
- # output_model_obj = gr.Model3D(
676
- # label="Output Model (OBJ Format)",
677
- # #width=768,
678
- # interactive=False,
679
- # )
680
- # gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
681
- # with gr.Tab("GLB"):
682
- # output_model_glb = gr.Model3D(
683
- # label="Output Model (GLB Format)",
684
- # #width=768,
685
- # interactive=False,
686
- # )
687
- # gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
688
-
689
- with gr.Row():
690
- gr.Markdown('''The novel views are generated on an archimedean spiral. You can download the video''')
691
-
692
- gr.Markdown(_CITE_)
693
-
694
- # set dust3r parameter invisible to be clean
695
- with gr.Column():
696
- with gr.Row():
697
- schedule = gr.Dropdown(["linear", "cosine"],
698
- value='linear', label="schedule", info="For global alignment!", visible=False)
699
- niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
700
- label="num_iterations", info="For global alignment!", visible=False)
701
- scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
702
- value='complete', label="Scenegraph",
703
- info="Define how to make pairs",
704
- interactive=True, visible=False)
705
- same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
706
- winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
707
- minimum=1, maximum=1, step=1, visible=False)
708
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
709
-
710
- with gr.Row():
711
- # adjust the confidence threshold
712
- min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
713
- # adjust the camera size in the output pointcloud
714
- cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
715
- with gr.Row():
716
- as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
717
- # two post process implemented
718
- mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
719
- clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
720
- transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
721
-
722
- # events
723
- # scenegraph_type.change(set_scenegraph_options,
724
- # inputs=[input_image, winsize, refid, scenegraph_type],
725
- # outputs=[winsize, refid])
726
- input_image.change(set_scenegraph_options,
727
- inputs=[input_image, winsize, refid, scenegraph_type],
728
- outputs=[winsize, refid])
729
- # min_conf_thr.release(fn=model_from_scene_fun,
730
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
731
- # clean_depth, transparent_cams, cam_size, same_focals],
732
- # outputs=outmodel)
733
- # cam_size.change(fn=model_from_scene_fun,
734
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
735
- # clean_depth, transparent_cams, cam_size, same_focals],
736
- # outputs=outmodel)
737
- # as_pointcloud.change(fn=model_from_scene_fun,
738
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
739
- # clean_depth, transparent_cams, cam_size, same_focals],
740
- # outputs=outmodel)
741
- # mask_sky.change(fn=model_from_scene_fun,
742
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
743
- # clean_depth, transparent_cams, cam_size, same_focals],
744
- # outputs=outmodel)
745
- # clean_depth.change(fn=model_from_scene_fun,
746
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
747
- # clean_depth, transparent_cams, cam_size, same_focals],
748
- # outputs=outmodel)
749
- # transparent_cams.change(model_from_scene_fun,
750
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
751
- # clean_depth, transparent_cams, cam_size, same_focals],
752
- # outputs=outmodel)
753
- run_dust3r.click(fn=recon_fun,
754
- inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
755
- mask_sky, clean_depth, transparent_cams, cam_size,
756
- scenegraph_type, winsize, refid, same_focals],
757
- outputs=[scene, outmodel, processed_image, eschernet_input])
758
-
759
-
760
- # events
761
- # preview images on input change
762
- input_image.change(fn=preview_input,
763
- inputs=[input_image],
764
- outputs=[processed_image])
765
-
766
- submit.click(fn=generate_mvs,
767
- inputs=[eschernet_input, sample_steps, sample_seed,
768
- nvs_num, nvs_mode],
769
- outputs=[mv_images, output_video],
770
- )#.success(
771
- # # fn=make3d,
772
- # # inputs=[mv_images],
773
- # # outputs=[output_video, output_model_obj, output_model_glb]
774
- # # )
775
-
776
-
777
-
778
- demo.queue(max_size=10)
779
- demo.launch(share=True, server_name="0.0.0.0", server_port=None)
780
-
781
- if __name__ == '__main__':
782
- main()