Spaces:
Sleeping
Sleeping
cavargas10
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,430 +1,382 @@
|
|
1 |
-
import os
|
2 |
-
import tyro
|
3 |
-
import imageio
|
4 |
-
import numpy as np
|
5 |
-
import tqdm
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
import torchvision.transforms.functional as TF
|
10 |
-
from safetensors.torch import load_file
|
11 |
-
import rembg
|
12 |
-
import gradio as gr
|
13 |
-
|
14 |
-
import kiui
|
15 |
-
from kiui.op import recenter
|
16 |
-
from kiui.cam import orbit_camera
|
17 |
-
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
|
18 |
-
|
19 |
-
from core.options import AllConfigs, Options
|
20 |
-
from core.models import LTRFM_Mesh,LTRFM_NeRF
|
21 |
-
from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
|
22 |
-
from mvdream.pipeline_mvdream import MVDreamPipeline
|
23 |
-
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
24 |
-
from huggingface_hub import hf_hub_download
|
25 |
-
|
26 |
-
import spaces
|
27 |
-
|
28 |
-
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
29 |
-
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
30 |
-
GRADIO_VIDEO_PATH = 'gradio_output.mp4'
|
31 |
-
GRADIO_OBJ_PATH = 'gradio_output_rgb.obj'
|
32 |
-
GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
|
33 |
-
GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
|
34 |
-
|
35 |
-
#opt = tyro.cli(AllConfigs)
|
36 |
-
|
37 |
-
ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt")
|
38 |
-
|
39 |
-
opt = Options(
|
40 |
-
input_size=512,
|
41 |
-
down_channels=(32, 64, 128, 256, 512),
|
42 |
-
down_attention=(False, False, False, False, True),
|
43 |
-
up_channels=(512, 256, 128),
|
44 |
-
up_attention=(True, False, False, False),
|
45 |
-
volume_mode='TRF_NeRF',
|
46 |
-
splat_size=64,
|
47 |
-
output_size=62, #crop patch
|
48 |
-
data_mode='s5',
|
49 |
-
num_views=8,
|
50 |
-
gradient_accumulation_steps=1, #2
|
51 |
-
mixed_precision='bf16',
|
52 |
-
resume=ckpt_path,
|
53 |
-
)
|
54 |
-
|
55 |
-
|
56 |
-
# model
|
57 |
-
if opt.volume_mode == 'TRF_Mesh':
|
58 |
-
model = LTRFM_Mesh(opt)
|
59 |
-
elif opt.volume_mode == 'TRF_NeRF':
|
60 |
-
model = LTRFM_NeRF(opt)
|
61 |
-
else:
|
62 |
-
model = LGM(opt)
|
63 |
-
|
64 |
-
# resume pretrained checkpoint
|
65 |
-
if opt.resume is not None:
|
66 |
-
if opt.resume.endswith('safetensors'):
|
67 |
-
ckpt = load_file(opt.resume, device='cpu')
|
68 |
-
else: #ckpt
|
69 |
-
ckpt_dict = torch.load(opt.resume, map_location='cpu')
|
70 |
-
ckpt=ckpt_dict["model"]
|
71 |
-
|
72 |
-
state_dict = model.state_dict()
|
73 |
-
for k, v in ckpt.items():
|
74 |
-
k=k.replace('module.', '')
|
75 |
-
if k in state_dict:
|
76 |
-
if state_dict[k].shape == v.shape:
|
77 |
-
state_dict[k].copy_(v)
|
78 |
-
else:
|
79 |
-
print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
|
80 |
-
else:
|
81 |
-
print(f'[WARN] unexpected param {k}: {v.shape}')
|
82 |
-
print(f'[INFO] load resume success!')
|
83 |
-
|
84 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
85 |
-
model = model.half().to(device)
|
86 |
-
model.eval()
|
87 |
-
|
88 |
-
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
89 |
-
proj_matrix = torch.zeros(4, 4, dtype=torch.float32).to(device)
|
90 |
-
proj_matrix[0, 0] = 1 / tan_half_fov
|
91 |
-
proj_matrix[1, 1] = 1 / tan_half_fov
|
92 |
-
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
93 |
-
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
94 |
-
proj_matrix[2, 3] = 1
|
95 |
-
|
96 |
-
# load dreams
|
97 |
-
pipe_text = MVDreamPipeline.from_pretrained(
|
98 |
-
'ashawkey/mvdream-sd2.1-diffusers', # remote weights
|
99 |
-
torch_dtype=torch.float16,
|
100 |
-
trust_remote_code=True,
|
101 |
-
# local_files_only=True,
|
102 |
-
)
|
103 |
-
pipe_text = pipe_text.to(device)
|
104 |
-
|
105 |
-
# mvdream
|
106 |
-
pipe_image = MVDreamPipeline.from_pretrained(
|
107 |
-
"ashawkey/imagedream-ipmv-diffusers", # remote weights
|
108 |
-
torch_dtype=torch.float16,
|
109 |
-
trust_remote_code=True,
|
110 |
-
# local_files_only=True,
|
111 |
-
)
|
112 |
-
pipe_image = pipe_image.to(device)
|
113 |
-
|
114 |
-
|
115 |
-
print('Loading 123plus model ...')
|
116 |
-
pipe_image_plus = DiffusionPipeline.from_pretrained(
|
117 |
-
"sudo-ai/zero123plus-v1.2",
|
118 |
-
custom_pipeline="zero123plus",
|
119 |
-
torch_dtype=torch.float16,
|
120 |
-
trust_remote_code=True,
|
121 |
-
#local_files_only=True,
|
122 |
-
)
|
123 |
-
pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
124 |
-
pipe_image_plus.scheduler.config, timestep_spacing='trailing'
|
125 |
-
)
|
126 |
-
|
127 |
-
unet_path='./pretrained/diffusion_pytorch_model.bin'
|
128 |
-
|
129 |
-
print('Loading custom white-background unet ...')
|
130 |
-
if os.path.exists(unet_path):
|
131 |
-
unet_ckpt_path = unet_path
|
132 |
-
else:
|
133 |
-
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
|
134 |
-
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
135 |
-
pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
|
136 |
-
pipe_image_plus = pipe_image_plus.to(device)
|
137 |
-
|
138 |
-
# load rembg
|
139 |
-
bg_remover = rembg.new_session()
|
140 |
-
|
141 |
-
|
142 |
-
@spaces.GPU
|
143 |
-
def generate_mv(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
|
144 |
-
# seed
|
145 |
-
kiui.seed_everything(input_seed)
|
146 |
-
|
147 |
-
os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
|
148 |
-
|
149 |
-
# text-conditioned
|
150 |
-
if condition_input_image is None:
|
151 |
-
mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
|
152 |
-
mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
|
153 |
-
# bg removal
|
154 |
-
mv_image = []
|
155 |
-
for i in range(4):
|
156 |
-
image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
|
157 |
-
# to white bg
|
158 |
-
image = image.astype(np.float32) / 255
|
159 |
-
image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
|
160 |
-
image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
|
161 |
-
mv_image.append(image)
|
162 |
-
|
163 |
-
mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
|
164 |
-
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
|
165 |
-
|
166 |
-
processed_image=None
|
167 |
-
# image-conditioned (may also input text, but no text usually works too)
|
168 |
-
else:
|
169 |
-
condition_input_image = np.array(condition_input_image) # uint8
|
170 |
-
# bg removal
|
171 |
-
carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4]
|
172 |
-
mask = carved_image[..., -1] > 0
|
173 |
-
image = recenter(carved_image, mask, border_ratio=0.2)
|
174 |
-
image = image.astype(np.float32) / 255.0
|
175 |
-
processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
|
176 |
-
|
177 |
-
if mv_moedl_option=='mvdream':
|
178 |
-
mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
|
179 |
-
|
180 |
-
mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
|
181 |
-
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
|
182 |
-
else:
|
183 |
-
from PIL import Image
|
184 |
-
from einops import rearrange, repeat
|
185 |
-
|
186 |
-
# input_image=input_image* 255
|
187 |
-
processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
|
188 |
-
mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
|
189 |
-
mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
|
190 |
-
mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
191 |
-
mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
|
192 |
-
mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
|
193 |
-
input_image = mv_image
|
194 |
-
return mv_image_grid, processed_image, input_image
|
195 |
-
|
196 |
-
@spaces.GPU
|
197 |
-
def generate_3d(input_image, condition_input_image, mv_moedl_option=None, input_seed=42):
|
198 |
-
kiui.seed_everything(input_seed)
|
199 |
-
|
200 |
-
output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH)
|
201 |
-
output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH)
|
202 |
-
output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH)
|
203 |
-
|
204 |
-
output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH)
|
205 |
-
# generate gaussians
|
206 |
-
# [4, 256, 256, 3], float32
|
207 |
-
input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
208 |
-
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
209 |
-
|
210 |
-
images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False)
|
211 |
-
|
212 |
-
data = {}
|
213 |
-
input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W]
|
214 |
-
images_input_vit=images_input_vit.unsqueeze(0)
|
215 |
-
data['input_vit']=images_input_vit
|
216 |
-
|
217 |
-
elevation = 0
|
218 |
-
cam_poses =[]
|
219 |
-
if mv_moedl_option=='mvdream' or condition_input_image is None:
|
220 |
-
azimuth = np.arange(0, 360, 90, dtype=np.int32)
|
221 |
-
for azi in tqdm.tqdm(azimuth):
|
222 |
-
cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
223 |
-
cam_poses.append(cam_pose)
|
224 |
-
else:
|
225 |
-
azimuth = np.arange(30, 360, 60, dtype=np.int32)
|
226 |
-
cnt = 0
|
227 |
-
for azi in tqdm.tqdm(azimuth):
|
228 |
-
if (cnt+1) % 2!= 0:
|
229 |
-
elevation=-20
|
230 |
-
else:
|
231 |
-
elevation=30
|
232 |
-
cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
233 |
-
cam_poses.append(cam_pose)
|
234 |
-
cnt=cnt+1
|
235 |
-
|
236 |
-
cam_poses = torch.cat(cam_poses,0)
|
237 |
-
radius = torch.norm(cam_poses[0, :3, 3])
|
238 |
-
cam_poses[:, :3, 3] *= opt.cam_radius / radius
|
239 |
-
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0])
|
240 |
-
cam_poses = transform.unsqueeze(0) @ cam_poses
|
241 |
-
|
242 |
-
cam_poses=cam_poses.unsqueeze(0)
|
243 |
-
data['source_camera']=cam_poses
|
244 |
-
|
245 |
-
with torch.no_grad():
|
246 |
-
if opt.volume_mode == 'TRF_Mesh':
|
247 |
-
with torch.autocast(device_type='cuda', dtype=torch.float32):
|
248 |
-
svd_volume = model.forward_svd_volume(input_image,data)
|
249 |
-
else:
|
250 |
-
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
251 |
-
svd_volume = model.forward_svd_volume(input_image,data)
|
252 |
-
|
253 |
-
#time-consuming
|
254 |
-
export_texmap=False
|
255 |
-
|
256 |
-
mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap)
|
257 |
-
|
258 |
-
if export_texmap:
|
259 |
-
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
260 |
-
|
261 |
-
for i in range(len(tex_map)):
|
262 |
-
mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj')
|
263 |
-
save_obj_with_mtl(
|
264 |
-
vertices.data.cpu().numpy(),
|
265 |
-
uvs.data.cpu().numpy(),
|
266 |
-
faces.data.cpu().numpy(),
|
267 |
-
mesh_tex_idx.data.cpu().numpy(),
|
268 |
-
tex_map[i].permute(1, 2, 0).data.cpu().numpy(),
|
269 |
-
mesh_path,
|
270 |
-
)
|
271 |
-
else:
|
272 |
-
vertices, faces, vertex_colors = mesh_out
|
273 |
-
|
274 |
-
save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path)
|
275 |
-
save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path)
|
276 |
-
save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path)
|
277 |
-
|
278 |
-
# images=[]
|
279 |
-
# azimuth = np.arange(0, 360, 6, dtype=np.int32)
|
280 |
-
# for azi in tqdm.tqdm(azimuth):
|
281 |
-
|
282 |
-
# cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True))
|
283 |
-
|
284 |
-
# if opt.volume_mode == 'TRF_Mesh':
|
285 |
-
# cam_view = torch.inverse(cam_pose)
|
286 |
-
# cam_view=cam_view.unsqueeze(0).unsqueeze(0).to(device)
|
287 |
-
# data['w2c'] = cam_view
|
288 |
-
# with torch.autocast(device_type='cuda', dtype=torch.float32):
|
289 |
-
# render_images=model.render_frame(data)
|
290 |
-
# else:
|
291 |
-
# rays_o, rays_d = get_rays(cam_pose, opt.infer_render_size, opt.infer_render_size, opt.fovy) # [h, w, 3]
|
292 |
-
# rays_o=rays_o.unsqueeze(0).unsqueeze(0).to(device)# B,V,H,W,3
|
293 |
-
# rays_d=rays_d.unsqueeze(0).unsqueeze(0).to(device)
|
294 |
-
# data['all_rays_o']=rays_o
|
295 |
-
# data['all_rays_d']=rays_d
|
296 |
-
# with torch.autocast(device_type='cuda', dtype=torch.float16):
|
297 |
-
# render_images=model.render_frame(data)
|
298 |
-
# image=render_images['images_pred']
|
299 |
-
|
300 |
-
# images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
301 |
-
|
302 |
-
# images = np.concatenate(images, axis=0)
|
303 |
-
# imageio.mimwrite(output_video_path, images, fps=30)
|
304 |
-
|
305 |
-
|
306 |
-
return output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path #, output_video_path
|
307 |
-
|
308 |
-
|
309 |
-
# gradio UI
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
with gr.
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
inputs=[input_text],
|
383 |
-
fn=lambda x: process(condition_input_image=None, prompt=x),
|
384 |
-
cache_examples=False,
|
385 |
-
label='Text-to-3D Examples'
|
386 |
-
)
|
387 |
-
|
388 |
-
# elevation
|
389 |
-
input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
|
390 |
-
# inference steps
|
391 |
-
input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
|
392 |
-
# random seed
|
393 |
-
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
|
394 |
-
# gen button
|
395 |
-
button_gen = gr.Button("Generate")
|
396 |
-
|
397 |
-
|
398 |
-
with gr.Column(scale=1):
|
399 |
-
with gr.Row():
|
400 |
-
# multi-view results
|
401 |
-
mv_image_grid = gr.Image(interactive=False, show_label=False)
|
402 |
-
# with gr.Row():
|
403 |
-
# output_video_path = gr.Video(label="video")
|
404 |
-
with gr.Row():
|
405 |
-
output_obj_rgb_path = gr.Model3D(
|
406 |
-
label="RGB Model (OBJ Format)",
|
407 |
-
interactive=False,
|
408 |
-
)
|
409 |
-
with gr.Row():
|
410 |
-
output_obj_albedo_path = gr.Model3D(
|
411 |
-
label="Albedo Model (OBJ Format)",
|
412 |
-
interactive=False,
|
413 |
-
)
|
414 |
-
with gr.Row():
|
415 |
-
output_obj_shading_path = gr.Model3D(
|
416 |
-
label="Shading Model (OBJ Format)",
|
417 |
-
interactive=False,
|
418 |
-
)
|
419 |
-
|
420 |
-
|
421 |
-
input_image = gr.State()
|
422 |
-
button_gen.click(fn=generate_mv, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed, mv_moedl_option],
|
423 |
-
outputs=[mv_image_grid, processed_image, input_image],).success(
|
424 |
-
fn=generate_3d,
|
425 |
-
inputs=[input_image, condition_input_image, mv_moedl_option, input_seed],
|
426 |
-
outputs=[output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path] , #output_video_path
|
427 |
-
)
|
428 |
-
|
429 |
-
|
430 |
block.launch(server_name="0.0.0.0", share=False)
|
|
|
1 |
+
import os
|
2 |
+
import tyro
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
import tqdm
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.transforms.functional as TF
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
import rembg
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
import kiui
|
15 |
+
from kiui.op import recenter
|
16 |
+
from kiui.cam import orbit_camera
|
17 |
+
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
|
18 |
+
|
19 |
+
from core.options import AllConfigs, Options
|
20 |
+
from core.models import LTRFM_Mesh,LTRFM_NeRF
|
21 |
+
from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
|
22 |
+
from mvdream.pipeline_mvdream import MVDreamPipeline
|
23 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
|
26 |
+
import spaces
|
27 |
+
|
28 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
29 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
30 |
+
GRADIO_VIDEO_PATH = 'gradio_output.mp4'
|
31 |
+
GRADIO_OBJ_PATH = 'gradio_output_rgb.obj'
|
32 |
+
GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
|
33 |
+
GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
|
34 |
+
|
35 |
+
#opt = tyro.cli(AllConfigs)
|
36 |
+
|
37 |
+
ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt")
|
38 |
+
|
39 |
+
opt = Options(
|
40 |
+
input_size=512,
|
41 |
+
down_channels=(32, 64, 128, 256, 512),
|
42 |
+
down_attention=(False, False, False, False, True),
|
43 |
+
up_channels=(512, 256, 128),
|
44 |
+
up_attention=(True, False, False, False),
|
45 |
+
volume_mode='TRF_NeRF',
|
46 |
+
splat_size=64,
|
47 |
+
output_size=62, #crop patch
|
48 |
+
data_mode='s5',
|
49 |
+
num_views=8,
|
50 |
+
gradient_accumulation_steps=1, #2
|
51 |
+
mixed_precision='bf16',
|
52 |
+
resume=ckpt_path,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
# model
|
57 |
+
if opt.volume_mode == 'TRF_Mesh':
|
58 |
+
model = LTRFM_Mesh(opt)
|
59 |
+
elif opt.volume_mode == 'TRF_NeRF':
|
60 |
+
model = LTRFM_NeRF(opt)
|
61 |
+
else:
|
62 |
+
model = LGM(opt)
|
63 |
+
|
64 |
+
# resume pretrained checkpoint
|
65 |
+
if opt.resume is not None:
|
66 |
+
if opt.resume.endswith('safetensors'):
|
67 |
+
ckpt = load_file(opt.resume, device='cpu')
|
68 |
+
else: #ckpt
|
69 |
+
ckpt_dict = torch.load(opt.resume, map_location='cpu')
|
70 |
+
ckpt=ckpt_dict["model"]
|
71 |
+
|
72 |
+
state_dict = model.state_dict()
|
73 |
+
for k, v in ckpt.items():
|
74 |
+
k=k.replace('module.', '')
|
75 |
+
if k in state_dict:
|
76 |
+
if state_dict[k].shape == v.shape:
|
77 |
+
state_dict[k].copy_(v)
|
78 |
+
else:
|
79 |
+
print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
|
80 |
+
else:
|
81 |
+
print(f'[WARN] unexpected param {k}: {v.shape}')
|
82 |
+
print(f'[INFO] load resume success!')
|
83 |
+
|
84 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
85 |
+
model = model.half().to(device)
|
86 |
+
model.eval()
|
87 |
+
|
88 |
+
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
89 |
+
proj_matrix = torch.zeros(4, 4, dtype=torch.float32).to(device)
|
90 |
+
proj_matrix[0, 0] = 1 / tan_half_fov
|
91 |
+
proj_matrix[1, 1] = 1 / tan_half_fov
|
92 |
+
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
93 |
+
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
94 |
+
proj_matrix[2, 3] = 1
|
95 |
+
|
96 |
+
# load dreams
|
97 |
+
pipe_text = MVDreamPipeline.from_pretrained(
|
98 |
+
'ashawkey/mvdream-sd2.1-diffusers', # remote weights
|
99 |
+
torch_dtype=torch.float16,
|
100 |
+
trust_remote_code=True,
|
101 |
+
# local_files_only=True,
|
102 |
+
)
|
103 |
+
pipe_text = pipe_text.to(device)
|
104 |
+
|
105 |
+
# mvdream
|
106 |
+
pipe_image = MVDreamPipeline.from_pretrained(
|
107 |
+
"ashawkey/imagedream-ipmv-diffusers", # remote weights
|
108 |
+
torch_dtype=torch.float16,
|
109 |
+
trust_remote_code=True,
|
110 |
+
# local_files_only=True,
|
111 |
+
)
|
112 |
+
pipe_image = pipe_image.to(device)
|
113 |
+
|
114 |
+
|
115 |
+
print('Loading 123plus model ...')
|
116 |
+
pipe_image_plus = DiffusionPipeline.from_pretrained(
|
117 |
+
"sudo-ai/zero123plus-v1.2",
|
118 |
+
custom_pipeline="zero123plus",
|
119 |
+
torch_dtype=torch.float16,
|
120 |
+
trust_remote_code=True,
|
121 |
+
#local_files_only=True,
|
122 |
+
)
|
123 |
+
pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
124 |
+
pipe_image_plus.scheduler.config, timestep_spacing='trailing'
|
125 |
+
)
|
126 |
+
|
127 |
+
unet_path='./pretrained/diffusion_pytorch_model.bin'
|
128 |
+
|
129 |
+
print('Loading custom white-background unet ...')
|
130 |
+
if os.path.exists(unet_path):
|
131 |
+
unet_ckpt_path = unet_path
|
132 |
+
else:
|
133 |
+
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
|
134 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
135 |
+
pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
|
136 |
+
pipe_image_plus = pipe_image_plus.to(device)
|
137 |
+
|
138 |
+
# load rembg
|
139 |
+
bg_remover = rembg.new_session()
|
140 |
+
|
141 |
+
|
142 |
+
@spaces.GPU
|
143 |
+
def generate_mv(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
|
144 |
+
# seed
|
145 |
+
kiui.seed_everything(input_seed)
|
146 |
+
|
147 |
+
os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
|
148 |
+
|
149 |
+
# text-conditioned
|
150 |
+
if condition_input_image is None:
|
151 |
+
mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
|
152 |
+
mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
|
153 |
+
# bg removal
|
154 |
+
mv_image = []
|
155 |
+
for i in range(4):
|
156 |
+
image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
|
157 |
+
# to white bg
|
158 |
+
image = image.astype(np.float32) / 255
|
159 |
+
image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
|
160 |
+
image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
|
161 |
+
mv_image.append(image)
|
162 |
+
|
163 |
+
mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
|
164 |
+
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
|
165 |
+
|
166 |
+
processed_image=None
|
167 |
+
# image-conditioned (may also input text, but no text usually works too)
|
168 |
+
else:
|
169 |
+
condition_input_image = np.array(condition_input_image) # uint8
|
170 |
+
# bg removal
|
171 |
+
carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4]
|
172 |
+
mask = carved_image[..., -1] > 0
|
173 |
+
image = recenter(carved_image, mask, border_ratio=0.2)
|
174 |
+
image = image.astype(np.float32) / 255.0
|
175 |
+
processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
|
176 |
+
|
177 |
+
if mv_moedl_option=='mvdream':
|
178 |
+
mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
|
179 |
+
|
180 |
+
mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
|
181 |
+
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
|
182 |
+
else:
|
183 |
+
from PIL import Image
|
184 |
+
from einops import rearrange, repeat
|
185 |
+
|
186 |
+
# input_image=input_image* 255
|
187 |
+
processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
|
188 |
+
mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
|
189 |
+
mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
|
190 |
+
mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
191 |
+
mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
|
192 |
+
mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
|
193 |
+
input_image = mv_image
|
194 |
+
return mv_image_grid, processed_image, input_image
|
195 |
+
|
196 |
+
@spaces.GPU
|
197 |
+
def generate_3d(input_image, condition_input_image, mv_moedl_option=None, input_seed=42):
|
198 |
+
kiui.seed_everything(input_seed)
|
199 |
+
|
200 |
+
output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH)
|
201 |
+
output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH)
|
202 |
+
output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH)
|
203 |
+
|
204 |
+
output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH)
|
205 |
+
# generate gaussians
|
206 |
+
# [4, 256, 256, 3], float32
|
207 |
+
input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
208 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
209 |
+
|
210 |
+
images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False)
|
211 |
+
|
212 |
+
data = {}
|
213 |
+
input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W]
|
214 |
+
images_input_vit=images_input_vit.unsqueeze(0)
|
215 |
+
data['input_vit']=images_input_vit
|
216 |
+
|
217 |
+
elevation = 0
|
218 |
+
cam_poses =[]
|
219 |
+
if mv_moedl_option=='mvdream' or condition_input_image is None:
|
220 |
+
azimuth = np.arange(0, 360, 90, dtype=np.int32)
|
221 |
+
for azi in tqdm.tqdm(azimuth):
|
222 |
+
cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
223 |
+
cam_poses.append(cam_pose)
|
224 |
+
else:
|
225 |
+
azimuth = np.arange(30, 360, 60, dtype=np.int32)
|
226 |
+
cnt = 0
|
227 |
+
for azi in tqdm.tqdm(azimuth):
|
228 |
+
if (cnt+1) % 2!= 0:
|
229 |
+
elevation=-20
|
230 |
+
else:
|
231 |
+
elevation=30
|
232 |
+
cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
233 |
+
cam_poses.append(cam_pose)
|
234 |
+
cnt=cnt+1
|
235 |
+
|
236 |
+
cam_poses = torch.cat(cam_poses,0)
|
237 |
+
radius = torch.norm(cam_poses[0, :3, 3])
|
238 |
+
cam_poses[:, :3, 3] *= opt.cam_radius / radius
|
239 |
+
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0])
|
240 |
+
cam_poses = transform.unsqueeze(0) @ cam_poses
|
241 |
+
|
242 |
+
cam_poses=cam_poses.unsqueeze(0)
|
243 |
+
data['source_camera']=cam_poses
|
244 |
+
|
245 |
+
with torch.no_grad():
|
246 |
+
if opt.volume_mode == 'TRF_Mesh':
|
247 |
+
with torch.autocast(device_type='cuda', dtype=torch.float32):
|
248 |
+
svd_volume = model.forward_svd_volume(input_image,data)
|
249 |
+
else:
|
250 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
251 |
+
svd_volume = model.forward_svd_volume(input_image,data)
|
252 |
+
|
253 |
+
#time-consuming
|
254 |
+
export_texmap=False
|
255 |
+
|
256 |
+
mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap)
|
257 |
+
|
258 |
+
if export_texmap:
|
259 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
260 |
+
|
261 |
+
for i in range(len(tex_map)):
|
262 |
+
mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj')
|
263 |
+
save_obj_with_mtl(
|
264 |
+
vertices.data.cpu().numpy(),
|
265 |
+
uvs.data.cpu().numpy(),
|
266 |
+
faces.data.cpu().numpy(),
|
267 |
+
mesh_tex_idx.data.cpu().numpy(),
|
268 |
+
tex_map[i].permute(1, 2, 0).data.cpu().numpy(),
|
269 |
+
mesh_path,
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
vertices, faces, vertex_colors = mesh_out
|
273 |
+
|
274 |
+
save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path)
|
275 |
+
save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path)
|
276 |
+
save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path)
|
277 |
+
|
278 |
+
# images=[]
|
279 |
+
# azimuth = np.arange(0, 360, 6, dtype=np.int32)
|
280 |
+
# for azi in tqdm.tqdm(azimuth):
|
281 |
+
|
282 |
+
# cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True))
|
283 |
+
|
284 |
+
# if opt.volume_mode == 'TRF_Mesh':
|
285 |
+
# cam_view = torch.inverse(cam_pose)
|
286 |
+
# cam_view=cam_view.unsqueeze(0).unsqueeze(0).to(device)
|
287 |
+
# data['w2c'] = cam_view
|
288 |
+
# with torch.autocast(device_type='cuda', dtype=torch.float32):
|
289 |
+
# render_images=model.render_frame(data)
|
290 |
+
# else:
|
291 |
+
# rays_o, rays_d = get_rays(cam_pose, opt.infer_render_size, opt.infer_render_size, opt.fovy) # [h, w, 3]
|
292 |
+
# rays_o=rays_o.unsqueeze(0).unsqueeze(0).to(device)# B,V,H,W,3
|
293 |
+
# rays_d=rays_d.unsqueeze(0).unsqueeze(0).to(device)
|
294 |
+
# data['all_rays_o']=rays_o
|
295 |
+
# data['all_rays_d']=rays_d
|
296 |
+
# with torch.autocast(device_type='cuda', dtype=torch.float16):
|
297 |
+
# render_images=model.render_frame(data)
|
298 |
+
# image=render_images['images_pred']
|
299 |
+
|
300 |
+
# images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
301 |
+
|
302 |
+
# images = np.concatenate(images, axis=0)
|
303 |
+
# imageio.mimwrite(output_video_path, images, fps=30)
|
304 |
+
|
305 |
+
|
306 |
+
return output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path #, output_video_path
|
307 |
+
|
308 |
+
|
309 |
+
# gradio UI
|
310 |
+
|
311 |
+
with block:
|
312 |
+
with gr.Row():
|
313 |
+
with gr.Column(scale=1):
|
314 |
+
|
315 |
+
with gr.Row(variant='panel'):
|
316 |
+
with gr.Column(scale=1):
|
317 |
+
|
318 |
+
with gr.Tab("Text-to-3D"):
|
319 |
+
# input prompt
|
320 |
+
with gr.Row():
|
321 |
+
input_text = gr.Textbox(label="prompt")
|
322 |
+
# negative prompt
|
323 |
+
with gr.Row():
|
324 |
+
input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
|
325 |
+
|
326 |
+
with gr.Row(variant="panel"):
|
327 |
+
gr.Examples(
|
328 |
+
examples=[
|
329 |
+
"a hamburger",
|
330 |
+
"a furry red fox head",
|
331 |
+
"a teddy bear",
|
332 |
+
"a motorbike",
|
333 |
+
],
|
334 |
+
inputs=[input_text],
|
335 |
+
fn=lambda x: process(condition_input_image=None, prompt=x),
|
336 |
+
cache_examples=False,
|
337 |
+
label='Text-to-3D Examples'
|
338 |
+
)
|
339 |
+
|
340 |
+
# elevation
|
341 |
+
input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
|
342 |
+
# inference steps
|
343 |
+
input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
|
344 |
+
# random seed
|
345 |
+
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
|
346 |
+
# gen button
|
347 |
+
button_gen = gr.Button("Generate")
|
348 |
+
|
349 |
+
|
350 |
+
with gr.Column(scale=1):
|
351 |
+
with gr.Row():
|
352 |
+
# multi-view results
|
353 |
+
mv_image_grid = gr.Image(interactive=False, show_label=False)
|
354 |
+
# with gr.Row():
|
355 |
+
# output_video_path = gr.Video(label="video")
|
356 |
+
with gr.Row():
|
357 |
+
output_obj_rgb_path = gr.Model3D(
|
358 |
+
label="RGB Model (OBJ Format)",
|
359 |
+
interactive=False,
|
360 |
+
)
|
361 |
+
with gr.Row():
|
362 |
+
output_obj_albedo_path = gr.Model3D(
|
363 |
+
label="Albedo Model (OBJ Format)",
|
364 |
+
interactive=False,
|
365 |
+
)
|
366 |
+
with gr.Row():
|
367 |
+
output_obj_shading_path = gr.Model3D(
|
368 |
+
label="Shading Model (OBJ Format)",
|
369 |
+
interactive=False,
|
370 |
+
)
|
371 |
+
|
372 |
+
|
373 |
+
input_image = gr.State()
|
374 |
+
button_gen.click(fn=generate_mv, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed, mv_moedl_option],
|
375 |
+
outputs=[mv_image_grid, processed_image, input_image],).success(
|
376 |
+
fn=generate_3d,
|
377 |
+
inputs=[input_image, condition_input_image, mv_moedl_option, input_seed],
|
378 |
+
outputs=[output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path] , #output_video_path
|
379 |
+
)
|
380 |
+
|
381 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
block.launch(server_name="0.0.0.0", share=False)
|