cavargas10 commited on
Commit
72cf5bb
·
verified ·
1 Parent(s): fe1ca13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -286
app.py CHANGED
@@ -4,9 +4,7 @@ import imageio
4
  import numpy as np
5
  import tqdm
6
  import torch
7
- import torch.nn as nn
8
  import torch.nn.functional as F
9
- import torchvision.transforms.functional as TF
10
  from safetensors.torch import load_file
11
  import rembg
12
  import gradio as gr
@@ -14,15 +12,13 @@ import gradio as gr
14
  import kiui
15
  from kiui.op import recenter
16
  from kiui.cam import orbit_camera
17
- from core.utils import get_rays, grid_distortion, orbit_camera_jitter
18
-
19
  from core.options import AllConfigs, Options
20
- from core.models import LTRFM_Mesh,LTRFM_NeRF
21
  from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
22
  from mvdream.pipeline_mvdream import MVDreamPipeline
23
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
24
  from huggingface_hub import hf_hub_download
25
-
26
  import spaces
27
 
28
  IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
@@ -32,54 +28,51 @@ GRADIO_OBJ_PATH = 'gradio_output_rgb.obj'
32
  GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
33
  GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
34
 
35
- #opt = tyro.cli(AllConfigs)
36
-
37
  ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM_6V_SDF.ckpt")
38
 
39
  opt = Options(
40
- input_size=512,
41
  down_channels=(32, 64, 128, 256, 512),
42
  down_attention=(False, False, False, False, True),
43
  up_channels=(512, 256, 128),
44
  up_attention=(True, False, False, False),
45
  volume_mode='TRF_NeRF',
46
  splat_size=64,
47
- output_size=62, #crop patch
48
  data_mode='s5',
49
  num_views=8,
50
- gradient_accumulation_steps=1, #2
51
  mixed_precision='bf16',
52
  resume=ckpt_path,
53
  )
54
 
55
-
56
- # model
57
  if opt.volume_mode == 'TRF_Mesh':
58
  model = LTRFM_Mesh(opt)
59
  elif opt.volume_mode == 'TRF_NeRF':
60
  model = LTRFM_NeRF(opt)
61
  else:
62
- model = LGM(opt)
63
 
64
- # resume pretrained checkpoint
65
- if opt.resume is not None:
66
  if opt.resume.endswith('safetensors'):
67
  ckpt = load_file(opt.resume, device='cpu')
68
- else: #ckpt
69
  ckpt_dict = torch.load(opt.resume, map_location='cpu')
70
- ckpt=ckpt_dict["model"]
71
 
72
  state_dict = model.state_dict()
73
  for k, v in ckpt.items():
74
- k=k.replace('module.', '')
75
- if k in state_dict:
76
  if state_dict[k].shape == v.shape:
77
  state_dict[k].copy_(v)
78
  else:
79
  print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
80
  else:
81
  print(f'[WARN] unexpected param {k}: {v.shape}')
82
- print(f'[INFO] load resume success!')
83
 
84
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
  model = model.half().to(device)
@@ -93,338 +86,176 @@ proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
93
  proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
94
  proj_matrix[2, 3] = 1
95
 
96
- # load dreams
97
  pipe_text = MVDreamPipeline.from_pretrained(
98
- 'ashawkey/mvdream-sd2.1-diffusers', # remote weights
99
  torch_dtype=torch.float16,
100
  trust_remote_code=True,
101
- # local_files_only=True,
102
  )
103
  pipe_text = pipe_text.to(device)
104
 
105
- # mvdream
106
  pipe_image = MVDreamPipeline.from_pretrained(
107
- "ashawkey/imagedream-ipmv-diffusers", # remote weights
108
  torch_dtype=torch.float16,
109
  trust_remote_code=True,
110
- # local_files_only=True,
111
  )
112
  pipe_image = pipe_image.to(device)
113
 
114
-
115
- print('Loading 123plus model ...')
116
  pipe_image_plus = DiffusionPipeline.from_pretrained(
117
- "sudo-ai/zero123plus-v1.2",
118
  custom_pipeline="zero123plus",
119
  torch_dtype=torch.float16,
120
  trust_remote_code=True,
121
- #local_files_only=True,
122
  )
123
  pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
124
  pipe_image_plus.scheduler.config, timestep_spacing='trailing'
125
  )
126
 
127
- unet_path='./pretrained/diffusion_pytorch_model.bin'
128
 
129
- print('Loading custom white-background unet ...')
130
  if os.path.exists(unet_path):
131
  unet_ckpt_path = unet_path
132
  else:
133
  unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
 
134
  state_dict = torch.load(unet_ckpt_path, map_location='cpu')
135
  pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
136
  pipe_image_plus = pipe_image_plus.to(device)
137
 
138
- # load rembg
139
  bg_remover = rembg.new_session()
140
 
141
-
142
  @spaces.GPU
143
  def generate_mv(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
144
- # seed
145
  kiui.seed_everything(input_seed)
146
-
147
  os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
148
-
149
- # text-conditioned
150
  if condition_input_image is None:
151
  mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
152
  mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
153
- # bg removal
154
  mv_image = []
155
  for i in range(4):
156
- image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
157
- # to white bg
158
  image = image.astype(np.float32) / 255
159
  image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
160
  image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
161
  mv_image.append(image)
162
-
163
- mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
164
  input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
165
-
166
- processed_image=None
167
- # image-conditioned (may also input text, but no text usually works too)
168
  else:
169
- condition_input_image = np.array(condition_input_image) # uint8
170
- # bg removal
171
- carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4]
172
  mask = carved_image[..., -1] > 0
173
  image = recenter(carved_image, mask, border_ratio=0.2)
174
  image = image.astype(np.float32) / 255.0
175
  processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
176
-
177
- if mv_moedl_option=='mvdream':
178
- mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
179
-
180
- mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
181
  input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
 
182
  else:
183
  from PIL import Image
184
- from einops import rearrange, repeat
185
-
186
- # input_image=input_image* 255
187
  processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
188
  mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
189
  mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
190
- mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
191
  mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
192
  mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
193
  input_image = mv_image
194
- return mv_image_grid, processed_image, input_image
 
195
 
196
  @spaces.GPU
197
  def generate_3d(input_image, condition_input_image, mv_moedl_option=None, input_seed=42):
198
  kiui.seed_everything(input_seed)
199
-
200
- output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH)
201
- output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH)
202
- output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH)
203
-
204
- output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH)
205
- # generate gaussians
206
- # [4, 256, 256, 3], float32
207
- input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
208
- input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
209
 
210
- images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False)
211
-
212
- data = {}
213
- input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W]
214
- images_input_vit=images_input_vit.unsqueeze(0)
215
- data['input_vit']=images_input_vit
216
-
217
- elevation = 0
218
- cam_poses =[]
219
- if mv_moedl_option=='mvdream' or condition_input_image is None:
220
- azimuth = np.arange(0, 360, 90, dtype=np.int32)
221
- for azi in tqdm.tqdm(azimuth):
222
- cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
223
- cam_poses.append(cam_pose)
224
- else:
225
- azimuth = np.arange(30, 360, 60, dtype=np.int32)
226
- cnt = 0
227
- for azi in tqdm.tqdm(azimuth):
228
- if (cnt+1) % 2!= 0:
229
- elevation=-20
230
- else:
231
- elevation=30
232
- cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
233
- cam_poses.append(cam_pose)
234
- cnt=cnt+1
235
-
236
- cam_poses = torch.cat(cam_poses,0)
237
- radius = torch.norm(cam_poses[0, :3, 3])
238
- cam_poses[:, :3, 3] *= opt.cam_radius / radius
239
- transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0])
240
- cam_poses = transform.unsqueeze(0) @ cam_poses
241
-
242
- cam_poses=cam_poses.unsqueeze(0)
243
- data['source_camera']=cam_poses
244
-
245
- with torch.no_grad():
246
- if opt.volume_mode == 'TRF_Mesh':
247
- with torch.autocast(device_type='cuda', dtype=torch.float32):
248
- svd_volume = model.forward_svd_volume(input_image,data)
249
- else:
250
- with torch.autocast(device_type='cuda', dtype=torch.float16):
251
- svd_volume = model.forward_svd_volume(input_image,data)
252
-
253
- #time-consuming
254
- export_texmap=False
255
-
256
- mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap)
257
-
258
- if export_texmap:
259
- vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
260
-
261
- for i in range(len(tex_map)):
262
- mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj')
263
- save_obj_with_mtl(
264
- vertices.data.cpu().numpy(),
265
- uvs.data.cpu().numpy(),
266
- faces.data.cpu().numpy(),
267
- mesh_tex_idx.data.cpu().numpy(),
268
- tex_map[i].permute(1, 2, 0).data.cpu().numpy(),
269
- mesh_path,
270
- )
271
- else:
272
- vertices, faces, vertex_colors = mesh_out
273
-
274
- save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path)
275
- save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path)
276
- save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path)
277
-
278
- # images=[]
279
- # azimuth = np.arange(0, 360, 6, dtype=np.int32)
280
- # for azi in tqdm.tqdm(azimuth):
281
-
282
- # cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True))
283
-
284
- # if opt.volume_mode == 'TRF_Mesh':
285
- # cam_view = torch.inverse(cam_pose)
286
- # cam_view=cam_view.unsqueeze(0).unsqueeze(0).to(device)
287
- # data['w2c'] = cam_view
288
- # with torch.autocast(device_type='cuda', dtype=torch.float32):
289
- # render_images=model.render_frame(data)
290
- # else:
291
- # rays_o, rays_d = get_rays(cam_pose, opt.infer_render_size, opt.infer_render_size, opt.fovy) # [h, w, 3]
292
- # rays_o=rays_o.unsqueeze(0).unsqueeze(0).to(device)# B,V,H,W,3
293
- # rays_d=rays_d.unsqueeze(0).unsqueeze(0).to(device)
294
- # data['all_rays_o']=rays_o
295
- # data['all_rays_d']=rays_d
296
- # with torch.autocast(device_type='cuda', dtype=torch.float16):
297
- # render_images=model.render_frame(data)
298
- # image=render_images['images_pred']
299
-
300
- # images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
301
 
302
- # images = np.concatenate(images, axis=0)
303
- # imageio.mimwrite(output_video_path, images, fps=30)
304
-
305
-
306
- return output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path #, output_video_path
307
-
308
-
309
- # gradio UI
310
-
311
- _TITLE = '''LDM: Large Tensorial SDF Model for Textured Mesh Generation'''
312
-
313
- _DESCRIPTION = '''
314
 
 
315
 
316
- * Input can be text prompt, image.
317
- * The currently supported multi-view diffusion models include the image-conditioned MVdream and Zero123plus, as well as the text-conditioned Imagedream.
318
- * If you find the output unsatisfying, try using different multi-view diffusion models or seeds!
319
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- block = gr.Blocks(title=_TITLE).queue()
322
- with block:
 
 
 
323
  with gr.Row():
324
- with gr.Column(scale=1):
325
- gr.Markdown('# ' + _TITLE)
326
- gr.Markdown(_DESCRIPTION)
327
-
328
- with gr.Row(variant='panel'):
329
- with gr.Column(scale=1):
330
- with gr.Tab("Image-to-3D"):
331
- # input image
332
- with gr.Row():
333
- condition_input_image = gr.Image(
334
- label="Input Image",
335
- image_mode="RGBA",
336
- type="pil"
337
- )
338
-
339
- processed_image = gr.Image(
340
- label="Processed Image",
341
- image_mode="RGBA",
342
- type="pil",
343
- interactive=False
344
- )
345
-
346
-
347
- with gr.Row():
348
- mv_moedl_option = gr.Radio([
349
- "zero123plus",
350
- "mvdream"
351
- ], value="zero123plus",
352
- label="Multi-view Diffusion")
353
-
354
- with gr.Row(variant="panel"):
355
- gr.Examples(
356
- examples=[
357
- os.path.join("example", img_name) for img_name in sorted(os.listdir("example"))
358
- ],
359
- inputs=[condition_input_image],
360
- fn=lambda x: process(condition_input_image=x, prompt=''),
361
- cache_examples=False,
362
- examples_per_page=20,
363
- label='Image-to-3D Examples'
364
- )
365
-
366
- with gr.Tab("Text-to-3D"):
367
- # input prompt
368
- with gr.Row():
369
- input_text = gr.Textbox(label="prompt")
370
- # negative prompt
371
- with gr.Row():
372
- input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
373
-
374
- with gr.Row(variant="panel"):
375
- gr.Examples(
376
- examples=[
377
- "a hamburger",
378
- "a furry red fox head",
379
- "a teddy bear",
380
- "a motorbike",
381
- ],
382
- inputs=[input_text],
383
- fn=lambda x: process(condition_input_image=None, prompt=x),
384
- cache_examples=False,
385
- label='Text-to-3D Examples'
386
- )
387
-
388
- # elevation
389
- input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
390
- # inference steps
391
- input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
392
- # random seed
393
- input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
394
- # gen button
395
- button_gen = gr.Button("Generate")
396
-
397
-
398
- with gr.Column(scale=1):
399
- with gr.Row():
400
- # multi-view results
401
- mv_image_grid = gr.Image(interactive=False, show_label=False)
402
- # with gr.Row():
403
- # output_video_path = gr.Video(label="video")
404
- with gr.Row():
405
- output_obj_rgb_path = gr.Model3D(
406
- label="RGB Model (OBJ Format)",
407
- interactive=False,
408
- )
409
- with gr.Row():
410
- output_obj_albedo_path = gr.Model3D(
411
- label="Albedo Model (OBJ Format)",
412
- interactive=False,
413
- )
414
- with gr.Row():
415
- output_obj_shading_path = gr.Model3D(
416
- label="Shading Model (OBJ Format)",
417
- interactive=False,
418
- )
419
-
420
-
421
- input_image = gr.State()
422
- button_gen.click(fn=generate_mv, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed, mv_moedl_option],
423
- outputs=[mv_image_grid, processed_image, input_image],).success(
424
- fn=generate_3d,
425
- inputs=[input_image, condition_input_image, mv_moedl_option, input_seed],
426
- outputs=[output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path] , #output_video_path
427
- )
428
-
429
-
430
- block.launch(server_name="0.0.0.0", share=False)
 
4
  import numpy as np
5
  import tqdm
6
  import torch
 
7
  import torch.nn.functional as F
 
8
  from safetensors.torch import load_file
9
  import rembg
10
  import gradio as gr
 
12
  import kiui
13
  from kiui.op import recenter
14
  from kiui.cam import orbit_camera
15
+ from core.utils import get_rays
 
16
  from core.options import AllConfigs, Options
17
+ from core.models import LTRFM_Mesh, LTRFM_NeRF
18
  from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
19
  from mvdream.pipeline_mvdream import MVDreamPipeline
20
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
21
  from huggingface_hub import hf_hub_download
 
22
  import spaces
23
 
24
  IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
 
28
  GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
29
  GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
30
 
 
 
31
  ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM_6V_SDF.ckpt")
32
 
33
  opt = Options(
34
+ input_size=512,
35
  down_channels=(32, 64, 128, 256, 512),
36
  down_attention=(False, False, False, False, True),
37
  up_channels=(512, 256, 128),
38
  up_attention=(True, False, False, False),
39
  volume_mode='TRF_NeRF',
40
  splat_size=64,
41
+ output_size=62, # crop patch
42
  data_mode='s5',
43
  num_views=8,
44
+ gradient_accumulation_steps=1,
45
  mixed_precision='bf16',
46
  resume=ckpt_path,
47
  )
48
 
49
+ # Model selection
 
50
  if opt.volume_mode == 'TRF_Mesh':
51
  model = LTRFM_Mesh(opt)
52
  elif opt.volume_mode == 'TRF_NeRF':
53
  model = LTRFM_NeRF(opt)
54
  else:
55
+ model = None
56
 
57
+ # Resume pretrained checkpoint
58
+ if opt.resume:
59
  if opt.resume.endswith('safetensors'):
60
  ckpt = load_file(opt.resume, device='cpu')
61
+ else:
62
  ckpt_dict = torch.load(opt.resume, map_location='cpu')
63
+ ckpt = ckpt_dict["model"]
64
 
65
  state_dict = model.state_dict()
66
  for k, v in ckpt.items():
67
+ k = k.replace('module.', '')
68
+ if k in state_dict:
69
  if state_dict[k].shape == v.shape:
70
  state_dict[k].copy_(v)
71
  else:
72
  print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
73
  else:
74
  print(f'[WARN] unexpected param {k}: {v.shape}')
75
+ print('[INFO] load resume success!')
76
 
77
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78
  model = model.half().to(device)
 
86
  proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
87
  proj_matrix[2, 3] = 1
88
 
89
+ # Load dreams
90
  pipe_text = MVDreamPipeline.from_pretrained(
91
+ 'ashawkey/mvdream-sd2.1-diffusers',
92
  torch_dtype=torch.float16,
93
  trust_remote_code=True,
 
94
  )
95
  pipe_text = pipe_text.to(device)
96
 
 
97
  pipe_image = MVDreamPipeline.from_pretrained(
98
+ "ashawkey/imagedream-ipmv-diffusers",
99
  torch_dtype=torch.float16,
100
  trust_remote_code=True,
 
101
  )
102
  pipe_image = pipe_image.to(device)
103
 
 
 
104
  pipe_image_plus = DiffusionPipeline.from_pretrained(
105
+ "sudo-ai/zero123plus-v1.2",
106
  custom_pipeline="zero123plus",
107
  torch_dtype=torch.float16,
108
  trust_remote_code=True,
 
109
  )
110
  pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
111
  pipe_image_plus.scheduler.config, timestep_spacing='trailing'
112
  )
113
 
114
+ unet_path = './pretrained/diffusion_pytorch_model.bin'
115
 
 
116
  if os.path.exists(unet_path):
117
  unet_ckpt_path = unet_path
118
  else:
119
  unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
120
+
121
  state_dict = torch.load(unet_ckpt_path, map_location='cpu')
122
  pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
123
  pipe_image_plus = pipe_image_plus.to(device)
124
 
125
+ # Load rembg
126
  bg_remover = rembg.new_session()
127
 
 
128
  @spaces.GPU
129
  def generate_mv(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
 
130
  kiui.seed_everything(input_seed)
 
131
  os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
132
+
 
133
  if condition_input_image is None:
134
  mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
135
  mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
136
+
137
  mv_image = []
138
  for i in range(4):
139
+ image = rembg.remove(mv_image_uint8[i], session=bg_remover)
 
140
  image = image.astype(np.float32) / 255
141
  image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
142
  image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
143
  mv_image.append(image)
144
+
145
+ mv_image_grid = np.concatenate([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=1)
146
  input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
147
+ processed_image = None
148
+
 
149
  else:
150
+ condition_input_image = np.array(condition_input_image)
151
+ carved_image = rembg.remove(condition_input_image, session=bg_remover)
 
152
  mask = carved_image[..., -1] > 0
153
  image = recenter(carved_image, mask, border_ratio=0.2)
154
  image = image.astype(np.float32) / 255.0
155
  processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
156
+
157
+ if mv_moedl_option == 'mvdream':
158
+ mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
159
+ mv_image_grid = np.concatenate([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=1)
 
160
  input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
161
+
162
  else:
163
  from PIL import Image
164
+ from einops import rearrange
165
+
 
166
  processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
167
  mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
168
  mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
169
+ mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float()
170
  mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
171
  mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
172
  input_image = mv_image
173
+
174
+ return mv_image_grid, processed_image, input_image
175
 
176
  @spaces.GPU
177
  def generate_3d(input_image, condition_input_image, mv_moedl_option=None, input_seed=42):
178
  kiui.seed_everything(input_seed)
 
 
 
 
 
 
 
 
 
 
179
 
180
+ output_obj_rgb_path = os.path.join(opt.workspace, "gradio", GRADIO_OBJ_PATH)
181
+ output_obj_albedo_path = os.path.join(opt.workspace, "gradio", GRADIO_OBJ_ALBEDO_PATH)
182
+ output_obj_shading_path = os.path.join(opt.workspace, "gradio", GRADIO_OBJ_SHADING_PATH)
183
+ output_video_path = os.path.join(opt.workspace, "gradio", GRADIO_VIDEO_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device)
186
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
 
 
 
 
 
 
 
 
 
 
187
 
188
+ input_rays_o, input_rays_d = get_rays(opt, proj_matrix, device, 'center')
189
 
190
+ with torch.no_grad():
191
+ preds = model(
192
+ cond_img=input_image,
193
+ rays=(input_rays_o, input_rays_d)
194
+ )
195
+
196
+ pred_rgb = preds[0].permute(0, 2, 3, 1).contiguous().cpu().numpy()
197
+ pred_albedo = preds[1].permute(0, 2, 3, 1).contiguous().cpu().numpy()
198
+ pred_shading = preds[2].permute(0, 2, 3, 1).contiguous().cpu().numpy()
199
+
200
+ save_obj(output_obj_rgb_path, pred_rgb)
201
+ save_obj_with_mtl(output_obj_albedo_path, pred_albedo, mode="albedo")
202
+ save_obj_with_mtl(output_obj_shading_path, pred_shading, mode="shading")
203
+
204
+ camera_positions = orbit_camera(type="spherical", radius=2.5, h=3, w=2)
205
+ output_frames = []
206
+ for pose in tqdm.tqdm(camera_positions, ncols=0):
207
+ with torch.no_grad():
208
+ preds = model(cond_img=input_image, rays=get_rays(opt, proj_matrix, device, pose))
209
+ pred_rgb = preds[0].permute(0, 2, 3, 1).contiguous().cpu().numpy()
210
+ output_frames.append(pred_rgb)
211
+ output_frames = np.stack(output_frames, axis=0)
212
+
213
+ imageio.mimwrite(output_video_path, output_frames, fps=24, quality=8)
214
+ return output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path, output_video_path
215
+
216
+ def update_mv_model(mv_moedl_option):
217
+ if mv_moedl_option == 'mvdream':
218
+ return gr.update(visible=False)
219
+ else:
220
+ return gr.update(visible=True)
221
 
222
+ # Gradio interface
223
+ with gr.Blocks() as demo:
224
+ gr.Markdown(
225
+ "## Generate 3D object from text or image prompt"
226
+ )
227
  with gr.Row():
228
+ with gr.Column():
229
+ input_prompt = gr.Textbox(label="Prompt", lines=3)
230
+ input_image = gr.Image(label="Input Image", type='numpy', optional=True)
231
+ input_seed = gr.Slider(minimum=0, maximum=65535, step=1, label="Random Seed", value=42)
232
+ input_elevation = gr.Slider(minimum=-10, maximum=10, step=1, label="Elevation", value=0)
233
+ input_num_steps = gr.Slider(minimum=1, maximum=150, step=1, label="Number of Inference Steps", value=30)
234
+ mv_moedl_option = gr.Radio(
235
+ ["mvdream", "zero123plus"],
236
+ label="Model Option",
237
+ value="mvdream",
238
+ interactive=True
239
+ )
240
+ generate_mv_button = gr.Button(value="Generate Multi-View Images")
241
+ generate_mv_button.click(fn=generate_mv,
242
+ inputs=[input_image, input_prompt, '', input_elevation, input_num_steps, input_seed, mv_moedl_option],
243
+ outputs=['multi_view_output', 'processed_image_output', 'input_image_output'])
244
+
245
+ with gr.Column():
246
+ gr.Markdown("### Multi-View Images")
247
+ multi_view_output = gr.Image()
248
+ processed_image_output = gr.Image()
249
+ gr.Markdown("### Input Image (Processed)")
250
+ input_image_output = gr.Image()
251
+ generate_3d_button = gr.Button(value="Generate 3D Model")
252
+ generate_3d_button.click(fn=generate_3d,
253
+ inputs=[input_image_output, processed_image_output, mv_moedl_option, input_seed],
254
+ outputs=['output_obj_rgb', 'output_obj_albedo', 'output_obj_shading', 'output_video'])
255
+
256
+ output_obj_rgb = gr.File(label="RGB 3D Model (.obj)")
257
+ output_obj_albedo = gr.File(label="Albedo 3D Model (.obj)")
258
+ output_obj_shading = gr.File(label="Shading 3D Model (.obj)")
259
+ output_video = gr.Video(label="360° View of the Generated 3D Model (.mp4)")
260
+
261
+ demo.launch()