cavargas10 commited on
Commit
9b13c64
·
verified ·
1 Parent(s): c1c0440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +381 -429
app.py CHANGED
@@ -1,430 +1,382 @@
1
- import os
2
- import tyro
3
- import imageio
4
- import numpy as np
5
- import tqdm
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- import torchvision.transforms.functional as TF
10
- from safetensors.torch import load_file
11
- import rembg
12
- import gradio as gr
13
-
14
- import kiui
15
- from kiui.op import recenter
16
- from kiui.cam import orbit_camera
17
- from core.utils import get_rays, grid_distortion, orbit_camera_jitter
18
-
19
- from core.options import AllConfigs, Options
20
- from core.models import LTRFM_Mesh,LTRFM_NeRF
21
- from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
22
- from mvdream.pipeline_mvdream import MVDreamPipeline
23
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
24
- from huggingface_hub import hf_hub_download
25
-
26
- import spaces
27
-
28
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
- IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
- GRADIO_VIDEO_PATH = 'gradio_output.mp4'
31
- GRADIO_OBJ_PATH = 'gradio_output_rgb.obj'
32
- GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
33
- GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
34
-
35
- #opt = tyro.cli(AllConfigs)
36
-
37
- ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt")
38
-
39
- opt = Options(
40
- input_size=512,
41
- down_channels=(32, 64, 128, 256, 512),
42
- down_attention=(False, False, False, False, True),
43
- up_channels=(512, 256, 128),
44
- up_attention=(True, False, False, False),
45
- volume_mode='TRF_NeRF',
46
- splat_size=64,
47
- output_size=62, #crop patch
48
- data_mode='s5',
49
- num_views=8,
50
- gradient_accumulation_steps=1, #2
51
- mixed_precision='bf16',
52
- resume=ckpt_path,
53
- )
54
-
55
-
56
- # model
57
- if opt.volume_mode == 'TRF_Mesh':
58
- model = LTRFM_Mesh(opt)
59
- elif opt.volume_mode == 'TRF_NeRF':
60
- model = LTRFM_NeRF(opt)
61
- else:
62
- model = LGM(opt)
63
-
64
- # resume pretrained checkpoint
65
- if opt.resume is not None:
66
- if opt.resume.endswith('safetensors'):
67
- ckpt = load_file(opt.resume, device='cpu')
68
- else: #ckpt
69
- ckpt_dict = torch.load(opt.resume, map_location='cpu')
70
- ckpt=ckpt_dict["model"]
71
-
72
- state_dict = model.state_dict()
73
- for k, v in ckpt.items():
74
- k=k.replace('module.', '')
75
- if k in state_dict:
76
- if state_dict[k].shape == v.shape:
77
- state_dict[k].copy_(v)
78
- else:
79
- print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
80
- else:
81
- print(f'[WARN] unexpected param {k}: {v.shape}')
82
- print(f'[INFO] load resume success!')
83
-
84
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
- model = model.half().to(device)
86
- model.eval()
87
-
88
- tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
89
- proj_matrix = torch.zeros(4, 4, dtype=torch.float32).to(device)
90
- proj_matrix[0, 0] = 1 / tan_half_fov
91
- proj_matrix[1, 1] = 1 / tan_half_fov
92
- proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
93
- proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
94
- proj_matrix[2, 3] = 1
95
-
96
- # load dreams
97
- pipe_text = MVDreamPipeline.from_pretrained(
98
- 'ashawkey/mvdream-sd2.1-diffusers', # remote weights
99
- torch_dtype=torch.float16,
100
- trust_remote_code=True,
101
- # local_files_only=True,
102
- )
103
- pipe_text = pipe_text.to(device)
104
-
105
- # mvdream
106
- pipe_image = MVDreamPipeline.from_pretrained(
107
- "ashawkey/imagedream-ipmv-diffusers", # remote weights
108
- torch_dtype=torch.float16,
109
- trust_remote_code=True,
110
- # local_files_only=True,
111
- )
112
- pipe_image = pipe_image.to(device)
113
-
114
-
115
- print('Loading 123plus model ...')
116
- pipe_image_plus = DiffusionPipeline.from_pretrained(
117
- "sudo-ai/zero123plus-v1.2",
118
- custom_pipeline="zero123plus",
119
- torch_dtype=torch.float16,
120
- trust_remote_code=True,
121
- #local_files_only=True,
122
- )
123
- pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
124
- pipe_image_plus.scheduler.config, timestep_spacing='trailing'
125
- )
126
-
127
- unet_path='./pretrained/diffusion_pytorch_model.bin'
128
-
129
- print('Loading custom white-background unet ...')
130
- if os.path.exists(unet_path):
131
- unet_ckpt_path = unet_path
132
- else:
133
- unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
134
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
135
- pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
136
- pipe_image_plus = pipe_image_plus.to(device)
137
-
138
- # load rembg
139
- bg_remover = rembg.new_session()
140
-
141
-
142
- @spaces.GPU
143
- def generate_mv(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
144
- # seed
145
- kiui.seed_everything(input_seed)
146
-
147
- os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
148
-
149
- # text-conditioned
150
- if condition_input_image is None:
151
- mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
152
- mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
153
- # bg removal
154
- mv_image = []
155
- for i in range(4):
156
- image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
157
- # to white bg
158
- image = image.astype(np.float32) / 255
159
- image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
160
- image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
161
- mv_image.append(image)
162
-
163
- mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
164
- input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
165
-
166
- processed_image=None
167
- # image-conditioned (may also input text, but no text usually works too)
168
- else:
169
- condition_input_image = np.array(condition_input_image) # uint8
170
- # bg removal
171
- carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4]
172
- mask = carved_image[..., -1] > 0
173
- image = recenter(carved_image, mask, border_ratio=0.2)
174
- image = image.astype(np.float32) / 255.0
175
- processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
176
-
177
- if mv_moedl_option=='mvdream':
178
- mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
179
-
180
- mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
181
- input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
182
- else:
183
- from PIL import Image
184
- from einops import rearrange, repeat
185
-
186
- # input_image=input_image* 255
187
- processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
188
- mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
189
- mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
190
- mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
191
- mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
192
- mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
193
- input_image = mv_image
194
- return mv_image_grid, processed_image, input_image
195
-
196
- @spaces.GPU
197
- def generate_3d(input_image, condition_input_image, mv_moedl_option=None, input_seed=42):
198
- kiui.seed_everything(input_seed)
199
-
200
- output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH)
201
- output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH)
202
- output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH)
203
-
204
- output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH)
205
- # generate gaussians
206
- # [4, 256, 256, 3], float32
207
- input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
208
- input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
209
-
210
- images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False)
211
-
212
- data = {}
213
- input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W]
214
- images_input_vit=images_input_vit.unsqueeze(0)
215
- data['input_vit']=images_input_vit
216
-
217
- elevation = 0
218
- cam_poses =[]
219
- if mv_moedl_option=='mvdream' or condition_input_image is None:
220
- azimuth = np.arange(0, 360, 90, dtype=np.int32)
221
- for azi in tqdm.tqdm(azimuth):
222
- cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
223
- cam_poses.append(cam_pose)
224
- else:
225
- azimuth = np.arange(30, 360, 60, dtype=np.int32)
226
- cnt = 0
227
- for azi in tqdm.tqdm(azimuth):
228
- if (cnt+1) % 2!= 0:
229
- elevation=-20
230
- else:
231
- elevation=30
232
- cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
233
- cam_poses.append(cam_pose)
234
- cnt=cnt+1
235
-
236
- cam_poses = torch.cat(cam_poses,0)
237
- radius = torch.norm(cam_poses[0, :3, 3])
238
- cam_poses[:, :3, 3] *= opt.cam_radius / radius
239
- transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0])
240
- cam_poses = transform.unsqueeze(0) @ cam_poses
241
-
242
- cam_poses=cam_poses.unsqueeze(0)
243
- data['source_camera']=cam_poses
244
-
245
- with torch.no_grad():
246
- if opt.volume_mode == 'TRF_Mesh':
247
- with torch.autocast(device_type='cuda', dtype=torch.float32):
248
- svd_volume = model.forward_svd_volume(input_image,data)
249
- else:
250
- with torch.autocast(device_type='cuda', dtype=torch.float16):
251
- svd_volume = model.forward_svd_volume(input_image,data)
252
-
253
- #time-consuming
254
- export_texmap=False
255
-
256
- mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap)
257
-
258
- if export_texmap:
259
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
260
-
261
- for i in range(len(tex_map)):
262
- mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj')
263
- save_obj_with_mtl(
264
- vertices.data.cpu().numpy(),
265
- uvs.data.cpu().numpy(),
266
- faces.data.cpu().numpy(),
267
- mesh_tex_idx.data.cpu().numpy(),
268
- tex_map[i].permute(1, 2, 0).data.cpu().numpy(),
269
- mesh_path,
270
- )
271
- else:
272
- vertices, faces, vertex_colors = mesh_out
273
-
274
- save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path)
275
- save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path)
276
- save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path)
277
-
278
- # images=[]
279
- # azimuth = np.arange(0, 360, 6, dtype=np.int32)
280
- # for azi in tqdm.tqdm(azimuth):
281
-
282
- # cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True))
283
-
284
- # if opt.volume_mode == 'TRF_Mesh':
285
- # cam_view = torch.inverse(cam_pose)
286
- # cam_view=cam_view.unsqueeze(0).unsqueeze(0).to(device)
287
- # data['w2c'] = cam_view
288
- # with torch.autocast(device_type='cuda', dtype=torch.float32):
289
- # render_images=model.render_frame(data)
290
- # else:
291
- # rays_o, rays_d = get_rays(cam_pose, opt.infer_render_size, opt.infer_render_size, opt.fovy) # [h, w, 3]
292
- # rays_o=rays_o.unsqueeze(0).unsqueeze(0).to(device)# B,V,H,W,3
293
- # rays_d=rays_d.unsqueeze(0).unsqueeze(0).to(device)
294
- # data['all_rays_o']=rays_o
295
- # data['all_rays_d']=rays_d
296
- # with torch.autocast(device_type='cuda', dtype=torch.float16):
297
- # render_images=model.render_frame(data)
298
- # image=render_images['images_pred']
299
-
300
- # images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
301
-
302
- # images = np.concatenate(images, axis=0)
303
- # imageio.mimwrite(output_video_path, images, fps=30)
304
-
305
-
306
- return output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path #, output_video_path
307
-
308
-
309
- # gradio UI
310
-
311
- _TITLE = '''LDM: Large Tensorial SDF Model for Textured Mesh Generation'''
312
-
313
- _DESCRIPTION = '''
314
-
315
-
316
- * Input can be text prompt, image.
317
- * The currently supported multi-view diffusion models include the image-conditioned MVdream and Zero123plus, as well as the text-conditioned Imagedream.
318
- * If you find the output unsatisfying, try using different multi-view diffusion models or seeds!
319
- '''
320
-
321
- block = gr.Blocks(title=_TITLE).queue()
322
- with block:
323
- with gr.Row():
324
- with gr.Column(scale=1):
325
- gr.Markdown('# ' + _TITLE)
326
- gr.Markdown(_DESCRIPTION)
327
-
328
- with gr.Row(variant='panel'):
329
- with gr.Column(scale=1):
330
- with gr.Tab("Image-to-3D"):
331
- # input image
332
- with gr.Row():
333
- condition_input_image = gr.Image(
334
- label="Input Image",
335
- image_mode="RGBA",
336
- type="pil"
337
- )
338
-
339
- processed_image = gr.Image(
340
- label="Processed Image",
341
- image_mode="RGBA",
342
- type="pil",
343
- interactive=False
344
- )
345
-
346
-
347
- with gr.Row():
348
- mv_moedl_option = gr.Radio([
349
- "zero123plus",
350
- "mvdream"
351
- ], value="zero123plus",
352
- label="Multi-view Diffusion")
353
-
354
- with gr.Row(variant="panel"):
355
- gr.Examples(
356
- examples=[
357
- os.path.join("example", img_name) for img_name in sorted(os.listdir("example"))
358
- ],
359
- inputs=[condition_input_image],
360
- fn=lambda x: process(condition_input_image=x, prompt=''),
361
- cache_examples=False,
362
- examples_per_page=20,
363
- label='Image-to-3D Examples'
364
- )
365
-
366
- with gr.Tab("Text-to-3D"):
367
- # input prompt
368
- with gr.Row():
369
- input_text = gr.Textbox(label="prompt")
370
- # negative prompt
371
- with gr.Row():
372
- input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
373
-
374
- with gr.Row(variant="panel"):
375
- gr.Examples(
376
- examples=[
377
- "a hamburger",
378
- "a furry red fox head",
379
- "a teddy bear",
380
- "a motorbike",
381
- ],
382
- inputs=[input_text],
383
- fn=lambda x: process(condition_input_image=None, prompt=x),
384
- cache_examples=False,
385
- label='Text-to-3D Examples'
386
- )
387
-
388
- # elevation
389
- input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
390
- # inference steps
391
- input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
392
- # random seed
393
- input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
394
- # gen button
395
- button_gen = gr.Button("Generate")
396
-
397
-
398
- with gr.Column(scale=1):
399
- with gr.Row():
400
- # multi-view results
401
- mv_image_grid = gr.Image(interactive=False, show_label=False)
402
- # with gr.Row():
403
- # output_video_path = gr.Video(label="video")
404
- with gr.Row():
405
- output_obj_rgb_path = gr.Model3D(
406
- label="RGB Model (OBJ Format)",
407
- interactive=False,
408
- )
409
- with gr.Row():
410
- output_obj_albedo_path = gr.Model3D(
411
- label="Albedo Model (OBJ Format)",
412
- interactive=False,
413
- )
414
- with gr.Row():
415
- output_obj_shading_path = gr.Model3D(
416
- label="Shading Model (OBJ Format)",
417
- interactive=False,
418
- )
419
-
420
-
421
- input_image = gr.State()
422
- button_gen.click(fn=generate_mv, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed, mv_moedl_option],
423
- outputs=[mv_image_grid, processed_image, input_image],).success(
424
- fn=generate_3d,
425
- inputs=[input_image, condition_input_image, mv_moedl_option, input_seed],
426
- outputs=[output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path] , #output_video_path
427
- )
428
-
429
-
430
  block.launch(server_name="0.0.0.0", share=False)
 
1
+ import os
2
+ import tyro
3
+ import imageio
4
+ import numpy as np
5
+ import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from safetensors.torch import load_file
11
+ import rembg
12
+ import gradio as gr
13
+
14
+ import kiui
15
+ from kiui.op import recenter
16
+ from kiui.cam import orbit_camera
17
+ from core.utils import get_rays, grid_distortion, orbit_camera_jitter
18
+
19
+ from core.options import AllConfigs, Options
20
+ from core.models import LTRFM_Mesh,LTRFM_NeRF
21
+ from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
22
+ from mvdream.pipeline_mvdream import MVDreamPipeline
23
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ import spaces
27
+
28
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
+ GRADIO_VIDEO_PATH = 'gradio_output.mp4'
31
+ GRADIO_OBJ_PATH = 'gradio_output_rgb.obj'
32
+ GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
33
+ GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
34
+
35
+ #opt = tyro.cli(AllConfigs)
36
+
37
+ ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt")
38
+
39
+ opt = Options(
40
+ input_size=512,
41
+ down_channels=(32, 64, 128, 256, 512),
42
+ down_attention=(False, False, False, False, True),
43
+ up_channels=(512, 256, 128),
44
+ up_attention=(True, False, False, False),
45
+ volume_mode='TRF_NeRF',
46
+ splat_size=64,
47
+ output_size=62, #crop patch
48
+ data_mode='s5',
49
+ num_views=8,
50
+ gradient_accumulation_steps=1, #2
51
+ mixed_precision='bf16',
52
+ resume=ckpt_path,
53
+ )
54
+
55
+
56
+ # model
57
+ if opt.volume_mode == 'TRF_Mesh':
58
+ model = LTRFM_Mesh(opt)
59
+ elif opt.volume_mode == 'TRF_NeRF':
60
+ model = LTRFM_NeRF(opt)
61
+ else:
62
+ model = LGM(opt)
63
+
64
+ # resume pretrained checkpoint
65
+ if opt.resume is not None:
66
+ if opt.resume.endswith('safetensors'):
67
+ ckpt = load_file(opt.resume, device='cpu')
68
+ else: #ckpt
69
+ ckpt_dict = torch.load(opt.resume, map_location='cpu')
70
+ ckpt=ckpt_dict["model"]
71
+
72
+ state_dict = model.state_dict()
73
+ for k, v in ckpt.items():
74
+ k=k.replace('module.', '')
75
+ if k in state_dict:
76
+ if state_dict[k].shape == v.shape:
77
+ state_dict[k].copy_(v)
78
+ else:
79
+ print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
80
+ else:
81
+ print(f'[WARN] unexpected param {k}: {v.shape}')
82
+ print(f'[INFO] load resume success!')
83
+
84
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
+ model = model.half().to(device)
86
+ model.eval()
87
+
88
+ tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
89
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32).to(device)
90
+ proj_matrix[0, 0] = 1 / tan_half_fov
91
+ proj_matrix[1, 1] = 1 / tan_half_fov
92
+ proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
93
+ proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
94
+ proj_matrix[2, 3] = 1
95
+
96
+ # load dreams
97
+ pipe_text = MVDreamPipeline.from_pretrained(
98
+ 'ashawkey/mvdream-sd2.1-diffusers', # remote weights
99
+ torch_dtype=torch.float16,
100
+ trust_remote_code=True,
101
+ # local_files_only=True,
102
+ )
103
+ pipe_text = pipe_text.to(device)
104
+
105
+ # mvdream
106
+ pipe_image = MVDreamPipeline.from_pretrained(
107
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
108
+ torch_dtype=torch.float16,
109
+ trust_remote_code=True,
110
+ # local_files_only=True,
111
+ )
112
+ pipe_image = pipe_image.to(device)
113
+
114
+
115
+ print('Loading 123plus model ...')
116
+ pipe_image_plus = DiffusionPipeline.from_pretrained(
117
+ "sudo-ai/zero123plus-v1.2",
118
+ custom_pipeline="zero123plus",
119
+ torch_dtype=torch.float16,
120
+ trust_remote_code=True,
121
+ #local_files_only=True,
122
+ )
123
+ pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
124
+ pipe_image_plus.scheduler.config, timestep_spacing='trailing'
125
+ )
126
+
127
+ unet_path='./pretrained/diffusion_pytorch_model.bin'
128
+
129
+ print('Loading custom white-background unet ...')
130
+ if os.path.exists(unet_path):
131
+ unet_ckpt_path = unet_path
132
+ else:
133
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
134
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
135
+ pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
136
+ pipe_image_plus = pipe_image_plus.to(device)
137
+
138
+ # load rembg
139
+ bg_remover = rembg.new_session()
140
+
141
+
142
+ @spaces.GPU
143
+ def generate_mv(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
144
+ # seed
145
+ kiui.seed_everything(input_seed)
146
+
147
+ os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
148
+
149
+ # text-conditioned
150
+ if condition_input_image is None:
151
+ mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
152
+ mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
153
+ # bg removal
154
+ mv_image = []
155
+ for i in range(4):
156
+ image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
157
+ # to white bg
158
+ image = image.astype(np.float32) / 255
159
+ image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
160
+ image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
161
+ mv_image.append(image)
162
+
163
+ mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
164
+ input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
165
+
166
+ processed_image=None
167
+ # image-conditioned (may also input text, but no text usually works too)
168
+ else:
169
+ condition_input_image = np.array(condition_input_image) # uint8
170
+ # bg removal
171
+ carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4]
172
+ mask = carved_image[..., -1] > 0
173
+ image = recenter(carved_image, mask, border_ratio=0.2)
174
+ image = image.astype(np.float32) / 255.0
175
+ processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
176
+
177
+ if mv_moedl_option=='mvdream':
178
+ mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
179
+
180
+ mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
181
+ input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
182
+ else:
183
+ from PIL import Image
184
+ from einops import rearrange, repeat
185
+
186
+ # input_image=input_image* 255
187
+ processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
188
+ mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
189
+ mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
190
+ mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
191
+ mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
192
+ mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
193
+ input_image = mv_image
194
+ return mv_image_grid, processed_image, input_image
195
+
196
+ @spaces.GPU
197
+ def generate_3d(input_image, condition_input_image, mv_moedl_option=None, input_seed=42):
198
+ kiui.seed_everything(input_seed)
199
+
200
+ output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH)
201
+ output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH)
202
+ output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH)
203
+
204
+ output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH)
205
+ # generate gaussians
206
+ # [4, 256, 256, 3], float32
207
+ input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
208
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
209
+
210
+ images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False)
211
+
212
+ data = {}
213
+ input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W]
214
+ images_input_vit=images_input_vit.unsqueeze(0)
215
+ data['input_vit']=images_input_vit
216
+
217
+ elevation = 0
218
+ cam_poses =[]
219
+ if mv_moedl_option=='mvdream' or condition_input_image is None:
220
+ azimuth = np.arange(0, 360, 90, dtype=np.int32)
221
+ for azi in tqdm.tqdm(azimuth):
222
+ cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
223
+ cam_poses.append(cam_pose)
224
+ else:
225
+ azimuth = np.arange(30, 360, 60, dtype=np.int32)
226
+ cnt = 0
227
+ for azi in tqdm.tqdm(azimuth):
228
+ if (cnt+1) % 2!= 0:
229
+ elevation=-20
230
+ else:
231
+ elevation=30
232
+ cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
233
+ cam_poses.append(cam_pose)
234
+ cnt=cnt+1
235
+
236
+ cam_poses = torch.cat(cam_poses,0)
237
+ radius = torch.norm(cam_poses[0, :3, 3])
238
+ cam_poses[:, :3, 3] *= opt.cam_radius / radius
239
+ transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0])
240
+ cam_poses = transform.unsqueeze(0) @ cam_poses
241
+
242
+ cam_poses=cam_poses.unsqueeze(0)
243
+ data['source_camera']=cam_poses
244
+
245
+ with torch.no_grad():
246
+ if opt.volume_mode == 'TRF_Mesh':
247
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
248
+ svd_volume = model.forward_svd_volume(input_image,data)
249
+ else:
250
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
251
+ svd_volume = model.forward_svd_volume(input_image,data)
252
+
253
+ #time-consuming
254
+ export_texmap=False
255
+
256
+ mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap)
257
+
258
+ if export_texmap:
259
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
260
+
261
+ for i in range(len(tex_map)):
262
+ mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj')
263
+ save_obj_with_mtl(
264
+ vertices.data.cpu().numpy(),
265
+ uvs.data.cpu().numpy(),
266
+ faces.data.cpu().numpy(),
267
+ mesh_tex_idx.data.cpu().numpy(),
268
+ tex_map[i].permute(1, 2, 0).data.cpu().numpy(),
269
+ mesh_path,
270
+ )
271
+ else:
272
+ vertices, faces, vertex_colors = mesh_out
273
+
274
+ save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path)
275
+ save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path)
276
+ save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path)
277
+
278
+ # images=[]
279
+ # azimuth = np.arange(0, 360, 6, dtype=np.int32)
280
+ # for azi in tqdm.tqdm(azimuth):
281
+
282
+ # cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True))
283
+
284
+ # if opt.volume_mode == 'TRF_Mesh':
285
+ # cam_view = torch.inverse(cam_pose)
286
+ # cam_view=cam_view.unsqueeze(0).unsqueeze(0).to(device)
287
+ # data['w2c'] = cam_view
288
+ # with torch.autocast(device_type='cuda', dtype=torch.float32):
289
+ # render_images=model.render_frame(data)
290
+ # else:
291
+ # rays_o, rays_d = get_rays(cam_pose, opt.infer_render_size, opt.infer_render_size, opt.fovy) # [h, w, 3]
292
+ # rays_o=rays_o.unsqueeze(0).unsqueeze(0).to(device)# B,V,H,W,3
293
+ # rays_d=rays_d.unsqueeze(0).unsqueeze(0).to(device)
294
+ # data['all_rays_o']=rays_o
295
+ # data['all_rays_d']=rays_d
296
+ # with torch.autocast(device_type='cuda', dtype=torch.float16):
297
+ # render_images=model.render_frame(data)
298
+ # image=render_images['images_pred']
299
+
300
+ # images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
301
+
302
+ # images = np.concatenate(images, axis=0)
303
+ # imageio.mimwrite(output_video_path, images, fps=30)
304
+
305
+
306
+ return output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path #, output_video_path
307
+
308
+
309
+ # gradio UI
310
+
311
+ with block:
312
+ with gr.Row():
313
+ with gr.Column(scale=1):
314
+
315
+ with gr.Row(variant='panel'):
316
+ with gr.Column(scale=1):
317
+
318
+ with gr.Tab("Text-to-3D"):
319
+ # input prompt
320
+ with gr.Row():
321
+ input_text = gr.Textbox(label="prompt")
322
+ # negative prompt
323
+ with gr.Row():
324
+ input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
325
+
326
+ with gr.Row(variant="panel"):
327
+ gr.Examples(
328
+ examples=[
329
+ "a hamburger",
330
+ "a furry red fox head",
331
+ "a teddy bear",
332
+ "a motorbike",
333
+ ],
334
+ inputs=[input_text],
335
+ fn=lambda x: process(condition_input_image=None, prompt=x),
336
+ cache_examples=False,
337
+ label='Text-to-3D Examples'
338
+ )
339
+
340
+ # elevation
341
+ input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
342
+ # inference steps
343
+ input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
344
+ # random seed
345
+ input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
346
+ # gen button
347
+ button_gen = gr.Button("Generate")
348
+
349
+
350
+ with gr.Column(scale=1):
351
+ with gr.Row():
352
+ # multi-view results
353
+ mv_image_grid = gr.Image(interactive=False, show_label=False)
354
+ # with gr.Row():
355
+ # output_video_path = gr.Video(label="video")
356
+ with gr.Row():
357
+ output_obj_rgb_path = gr.Model3D(
358
+ label="RGB Model (OBJ Format)",
359
+ interactive=False,
360
+ )
361
+ with gr.Row():
362
+ output_obj_albedo_path = gr.Model3D(
363
+ label="Albedo Model (OBJ Format)",
364
+ interactive=False,
365
+ )
366
+ with gr.Row():
367
+ output_obj_shading_path = gr.Model3D(
368
+ label="Shading Model (OBJ Format)",
369
+ interactive=False,
370
+ )
371
+
372
+
373
+ input_image = gr.State()
374
+ button_gen.click(fn=generate_mv, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed, mv_moedl_option],
375
+ outputs=[mv_image_grid, processed_image, input_image],).success(
376
+ fn=generate_3d,
377
+ inputs=[input_image, condition_input_image, mv_moedl_option, input_seed],
378
+ outputs=[output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path] , #output_video_path
379
+ )
380
+
381
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  block.launch(server_name="0.0.0.0", share=False)