Update demo with latest changes
Browse filesCo-authored-by: Aaryaman Vasishta <[email protected]>
- gradio_app.py +98 -13
- requirements.txt +1 -0
- run.py +2 -2
- spar3d/models/global_estimator/reni_estimator.py +7 -3
- spar3d/system.py +14 -10
gradio_app.py
CHANGED
@@ -2,10 +2,12 @@ import os
|
|
2 |
import random
|
3 |
import tempfile
|
4 |
import time
|
|
|
5 |
from contextlib import nullcontext
|
6 |
from functools import lru_cache
|
7 |
from typing import Any
|
8 |
|
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
import torch
|
@@ -62,6 +64,23 @@ example_files = [
|
|
62 |
]
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def forward_model(
|
66 |
batch,
|
67 |
system,
|
@@ -105,11 +124,16 @@ def forward_model(
|
|
105 |
|
106 |
# forward for the final mesh
|
107 |
trimesh_mesh, _glob_dict = model.generate_mesh(
|
108 |
-
batch,
|
|
|
|
|
|
|
|
|
109 |
)
|
110 |
trimesh_mesh = trimesh_mesh[0]
|
|
|
111 |
|
112 |
-
return trimesh_mesh, pc_rgb_trimesh
|
113 |
|
114 |
|
115 |
def run_model(
|
@@ -169,7 +193,7 @@ def run_model(
|
|
169 |
dim=1,
|
170 |
)
|
171 |
|
172 |
-
trimesh_mesh, trimesh_pc = forward_model(
|
173 |
model_batch,
|
174 |
model,
|
175 |
guidance_scale=guidance_scale,
|
@@ -191,9 +215,13 @@ def run_model(
|
|
191 |
trimesh_pc.export(tmp_file_pc)
|
192 |
generated_files.append(tmp_file_pc)
|
193 |
|
|
|
|
|
|
|
|
|
194 |
print("Generation took:", time.time() - start, "s")
|
195 |
|
196 |
-
return tmp_file, tmp_file_pc, trimesh_pc
|
197 |
|
198 |
|
199 |
def create_batch(input_image: Image) -> dict[str, Any]:
|
@@ -272,7 +300,7 @@ def process_model_run(
|
|
272 |
f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
|
273 |
)
|
274 |
|
275 |
-
glb_file, pc_file, pc_plot = run_model(
|
276 |
background_state,
|
277 |
guidance_scale,
|
278 |
random_seed,
|
@@ -295,7 +323,7 @@ def process_model_run(
|
|
295 |
]
|
296 |
)
|
297 |
|
298 |
-
return glb_file, pc_file, point_list
|
299 |
|
300 |
|
301 |
def regenerate_run(
|
@@ -308,7 +336,7 @@ def regenerate_run(
|
|
308 |
vertex_count,
|
309 |
texture_resolution,
|
310 |
):
|
311 |
-
glb_file, pc_file, point_list = process_model_run(
|
312 |
background_state,
|
313 |
guidance_scale,
|
314 |
random_seed,
|
@@ -318,6 +346,8 @@ def regenerate_run(
|
|
318 |
vertex_count,
|
319 |
texture_resolution,
|
320 |
)
|
|
|
|
|
321 |
return (
|
322 |
gr.update(), # run_btn
|
323 |
gr.update(), # img_proc_state
|
@@ -325,10 +355,12 @@ def regenerate_run(
|
|
325 |
gr.update(), # preview_removal
|
326 |
gr.update(value=glb_file, visible=True), # output_3d
|
327 |
gr.update(visible=True), # hdr_row
|
|
|
328 |
gr.update(visible=True), # point_cloud_row
|
329 |
gr.update(value=point_list), # point_cloud_editor
|
330 |
gr.update(value=pc_file), # pc_download
|
331 |
gr.update(visible=False), # regenerate_btn
|
|
|
332 |
)
|
333 |
|
334 |
|
@@ -362,7 +394,7 @@ def run_button(
|
|
362 |
else:
|
363 |
pc_cond = None
|
364 |
|
365 |
-
glb_file, pc_file, pc_list = process_model_run(
|
366 |
background_state,
|
367 |
guidance_scale,
|
368 |
random_seed,
|
@@ -373,6 +405,8 @@ def run_button(
|
|
373 |
texture_resolution,
|
374 |
)
|
375 |
|
|
|
|
|
376 |
if torch.cuda.is_available():
|
377 |
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
|
378 |
elif torch.backends.mps.is_available():
|
@@ -387,10 +421,12 @@ def run_button(
|
|
387 |
gr.update(), # preview_removal
|
388 |
gr.update(value=glb_file, visible=True), # output_3d
|
389 |
gr.update(visible=True), # hdr_row
|
|
|
390 |
gr.update(visible=True), # point_cloud_row
|
391 |
gr.update(value=pc_list), # point_cloud_editor
|
392 |
gr.update(value=pc_file), # pc_download
|
393 |
gr.update(visible=False), # regenerate_btn
|
|
|
394 |
)
|
395 |
|
396 |
elif run_btn == "Remove Background":
|
@@ -410,10 +446,12 @@ def run_button(
|
|
410 |
gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
|
411 |
gr.update(value=None, visible=False), # output_3d
|
412 |
gr.update(visible=False), # hdr_row
|
|
|
413 |
gr.update(visible=False), # point_cloud_row
|
414 |
gr.update(value=None), # point_cloud_editor
|
415 |
gr.update(value=None), # pc_download
|
416 |
gr.update(visible=False), # regenerate_btn
|
|
|
417 |
)
|
418 |
|
419 |
|
@@ -425,11 +463,13 @@ def requires_bg_remove(image, fr, no_crop):
|
|
425 |
None, # background_remove_state
|
426 |
gr.update(value=None, visible=False), # preview_removal
|
427 |
gr.update(value=None, visible=False), # output_3d
|
428 |
-
gr.update(visible=False), # hdr_row
|
|
|
429 |
gr.update(visible=False), # point_cloud_row
|
430 |
gr.update(value=None), # point_cloud_editor
|
431 |
gr.update(value=None), # pc_download
|
432 |
gr.update(visible=False), # regenerate_btn
|
|
|
433 |
)
|
434 |
alpha_channel = np.array(image.getchannel("A"))
|
435 |
min_alpha = alpha_channel.min()
|
@@ -446,10 +486,12 @@ def requires_bg_remove(image, fr, no_crop):
|
|
446 |
gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
|
447 |
gr.update(value=None, visible=False), # output_3d
|
448 |
gr.update(visible=False), # hdr_row
|
|
|
449 |
gr.update(visible=False), # point_cloud_row
|
450 |
gr.update(value=None), # point_cloud_editor
|
451 |
gr.update(value=None), # pc_download
|
452 |
gr.update(visible=False), # regenerate_btn
|
|
|
453 |
)
|
454 |
return (
|
455 |
gr.update(value="Remove Background", visible=True), # run_Btn
|
@@ -458,10 +500,12 @@ def requires_bg_remove(image, fr, no_crop):
|
|
458 |
gr.update(value=None, visible=False), # preview_removal
|
459 |
gr.update(value=None, visible=False), # output_3d
|
460 |
gr.update(visible=False), # hdr_row
|
|
|
461 |
gr.update(visible=False), # point_cloud_row
|
462 |
gr.update(value=None), # point_cloud_editor
|
463 |
gr.update(value=None), # pc_download
|
464 |
gr.update(visible=False), # regenerate_btn
|
|
|
465 |
)
|
466 |
|
467 |
|
@@ -487,6 +531,7 @@ def update_resolution_controls(remesh_choice, vertex_count_type):
|
|
487 |
with gr.Blocks() as demo:
|
488 |
img_proc_state = gr.State()
|
489 |
background_remove_state = gr.State()
|
|
|
490 |
gr.Markdown(
|
491 |
"""
|
492 |
# SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
|
@@ -699,12 +744,46 @@ with gr.Blocks() as demo:
|
|
699 |
inputs=hdr_illumination_file,
|
700 |
)
|
701 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
702 |
hdr_illumination_file.change(
|
703 |
-
|
704 |
-
inputs=hdr_illumination_file,
|
705 |
-
outputs=[output_3d],
|
706 |
)
|
707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
examples = gr.Examples(
|
709 |
examples=example_files, inputs=input_img, examples_per_page=11
|
710 |
)
|
@@ -719,10 +798,12 @@ with gr.Blocks() as demo:
|
|
719 |
preview_removal,
|
720 |
output_3d,
|
721 |
hdr_row,
|
|
|
722 |
point_cloud_row,
|
723 |
point_cloud_editor,
|
724 |
pc_download,
|
725 |
regenerate_btn,
|
|
|
726 |
],
|
727 |
)
|
728 |
|
@@ -751,10 +832,12 @@ with gr.Blocks() as demo:
|
|
751 |
preview_removal,
|
752 |
output_3d,
|
753 |
hdr_row,
|
|
|
754 |
point_cloud_row,
|
755 |
point_cloud_editor,
|
756 |
pc_download,
|
757 |
regenerate_btn,
|
|
|
758 |
],
|
759 |
)
|
760 |
|
@@ -782,11 +865,13 @@ with gr.Blocks() as demo:
|
|
782 |
preview_removal,
|
783 |
output_3d,
|
784 |
hdr_row,
|
|
|
785 |
point_cloud_row,
|
786 |
point_cloud_editor,
|
787 |
pc_download,
|
788 |
regenerate_btn,
|
|
|
789 |
],
|
790 |
)
|
791 |
|
792 |
-
demo.queue().launch()
|
|
|
2 |
import random
|
3 |
import tempfile
|
4 |
import time
|
5 |
+
import zipfile
|
6 |
from contextlib import nullcontext
|
7 |
from functools import lru_cache
|
8 |
from typing import Any
|
9 |
|
10 |
+
import cv2
|
11 |
import gradio as gr
|
12 |
import numpy as np
|
13 |
import torch
|
|
|
64 |
]
|
65 |
|
66 |
|
67 |
+
def create_zip_file(glb_file, pc_file, illumination_file):
|
68 |
+
if not all([glb_file, pc_file, illumination_file]):
|
69 |
+
return None
|
70 |
+
|
71 |
+
# Create a temporary zip file
|
72 |
+
temp_dir = tempfile.mkdtemp()
|
73 |
+
zip_path = os.path.join(temp_dir, "spar3d_output.zip")
|
74 |
+
|
75 |
+
with zipfile.ZipFile(zip_path, "w") as zipf:
|
76 |
+
zipf.write(glb_file, "mesh.glb")
|
77 |
+
zipf.write(pc_file, "points.ply")
|
78 |
+
zipf.write(illumination_file, "illumination.hdr")
|
79 |
+
|
80 |
+
generated_files.append(zip_path)
|
81 |
+
return zip_path
|
82 |
+
|
83 |
+
|
84 |
def forward_model(
|
85 |
batch,
|
86 |
system,
|
|
|
124 |
|
125 |
# forward for the final mesh
|
126 |
trimesh_mesh, _glob_dict = model.generate_mesh(
|
127 |
+
batch,
|
128 |
+
texture_resolution,
|
129 |
+
remesh=remesh_option,
|
130 |
+
vertex_count=vertex_count,
|
131 |
+
estimate_illumination=True,
|
132 |
)
|
133 |
trimesh_mesh = trimesh_mesh[0]
|
134 |
+
illumination = _glob_dict["illumination"]
|
135 |
|
136 |
+
return trimesh_mesh, pc_rgb_trimesh, illumination.cpu().detach().numpy()[0]
|
137 |
|
138 |
|
139 |
def run_model(
|
|
|
193 |
dim=1,
|
194 |
)
|
195 |
|
196 |
+
trimesh_mesh, trimesh_pc, illumination_map = forward_model(
|
197 |
model_batch,
|
198 |
model,
|
199 |
guidance_scale=guidance_scale,
|
|
|
215 |
trimesh_pc.export(tmp_file_pc)
|
216 |
generated_files.append(tmp_file_pc)
|
217 |
|
218 |
+
tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr")
|
219 |
+
cv2.imwrite(tmp_file_illumination, illumination_map)
|
220 |
+
generated_files.append(tmp_file_illumination)
|
221 |
+
|
222 |
print("Generation took:", time.time() - start, "s")
|
223 |
|
224 |
+
return tmp_file, tmp_file_pc, tmp_file_illumination, trimesh_pc
|
225 |
|
226 |
|
227 |
def create_batch(input_image: Image) -> dict[str, Any]:
|
|
|
300 |
f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
|
301 |
)
|
302 |
|
303 |
+
glb_file, pc_file, illumination_file, pc_plot = run_model(
|
304 |
background_state,
|
305 |
guidance_scale,
|
306 |
random_seed,
|
|
|
323 |
]
|
324 |
)
|
325 |
|
326 |
+
return glb_file, pc_file, illumination_file, point_list
|
327 |
|
328 |
|
329 |
def regenerate_run(
|
|
|
336 |
vertex_count,
|
337 |
texture_resolution,
|
338 |
):
|
339 |
+
glb_file, pc_file, illumination_file, point_list = process_model_run(
|
340 |
background_state,
|
341 |
guidance_scale,
|
342 |
random_seed,
|
|
|
346 |
vertex_count,
|
347 |
texture_resolution,
|
348 |
)
|
349 |
+
zip_file = create_zip_file(glb_file, pc_file, illumination_file)
|
350 |
+
|
351 |
return (
|
352 |
gr.update(), # run_btn
|
353 |
gr.update(), # img_proc_state
|
|
|
355 |
gr.update(), # preview_removal
|
356 |
gr.update(value=glb_file, visible=True), # output_3d
|
357 |
gr.update(visible=True), # hdr_row
|
358 |
+
illumination_file, # hdr_file
|
359 |
gr.update(visible=True), # point_cloud_row
|
360 |
gr.update(value=point_list), # point_cloud_editor
|
361 |
gr.update(value=pc_file), # pc_download
|
362 |
gr.update(visible=False), # regenerate_btn
|
363 |
+
gr.update(value=zip_file, visible=True), # download_all_btn
|
364 |
)
|
365 |
|
366 |
|
|
|
394 |
else:
|
395 |
pc_cond = None
|
396 |
|
397 |
+
glb_file, pc_file, illumination_file, pc_list = process_model_run(
|
398 |
background_state,
|
399 |
guidance_scale,
|
400 |
random_seed,
|
|
|
405 |
texture_resolution,
|
406 |
)
|
407 |
|
408 |
+
zip_file = create_zip_file(glb_file, pc_file, illumination_file)
|
409 |
+
|
410 |
if torch.cuda.is_available():
|
411 |
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
|
412 |
elif torch.backends.mps.is_available():
|
|
|
421 |
gr.update(), # preview_removal
|
422 |
gr.update(value=glb_file, visible=True), # output_3d
|
423 |
gr.update(visible=True), # hdr_row
|
424 |
+
illumination_file, # hdr_file
|
425 |
gr.update(visible=True), # point_cloud_row
|
426 |
gr.update(value=pc_list), # point_cloud_editor
|
427 |
gr.update(value=pc_file), # pc_download
|
428 |
gr.update(visible=False), # regenerate_btn
|
429 |
+
gr.update(value=zip_file, visible=True), # download_all_btn
|
430 |
)
|
431 |
|
432 |
elif run_btn == "Remove Background":
|
|
|
446 |
gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
|
447 |
gr.update(value=None, visible=False), # output_3d
|
448 |
gr.update(visible=False), # hdr_row
|
449 |
+
None, # hdr_file
|
450 |
gr.update(visible=False), # point_cloud_row
|
451 |
gr.update(value=None), # point_cloud_editor
|
452 |
gr.update(value=None), # pc_download
|
453 |
gr.update(visible=False), # regenerate_btn
|
454 |
+
gr.update(value=None, visible=False), # download_all_btn
|
455 |
)
|
456 |
|
457 |
|
|
|
463 |
None, # background_remove_state
|
464 |
gr.update(value=None, visible=False), # preview_removal
|
465 |
gr.update(value=None, visible=False), # output_3d
|
466 |
+
gr.update(value=None, visible=False), # hdr_row
|
467 |
+
None, # hdr_file
|
468 |
gr.update(visible=False), # point_cloud_row
|
469 |
gr.update(value=None), # point_cloud_editor
|
470 |
gr.update(value=None), # pc_download
|
471 |
gr.update(visible=False), # regenerate_btn
|
472 |
+
gr.update(value=None, visible=False), # download_all_btn
|
473 |
)
|
474 |
alpha_channel = np.array(image.getchannel("A"))
|
475 |
min_alpha = alpha_channel.min()
|
|
|
486 |
gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
|
487 |
gr.update(value=None, visible=False), # output_3d
|
488 |
gr.update(visible=False), # hdr_row
|
489 |
+
None, # hdr_file
|
490 |
gr.update(visible=False), # point_cloud_row
|
491 |
gr.update(value=None), # point_cloud_editor
|
492 |
gr.update(value=None), # pc_download
|
493 |
gr.update(visible=False), # regenerate_btn
|
494 |
+
gr.update(value=None, visible=False), # download_all_btn
|
495 |
)
|
496 |
return (
|
497 |
gr.update(value="Remove Background", visible=True), # run_Btn
|
|
|
500 |
gr.update(value=None, visible=False), # preview_removal
|
501 |
gr.update(value=None, visible=False), # output_3d
|
502 |
gr.update(visible=False), # hdr_row
|
503 |
+
None, # hdr_file
|
504 |
gr.update(visible=False), # point_cloud_row
|
505 |
gr.update(value=None), # point_cloud_editor
|
506 |
gr.update(value=None), # pc_download
|
507 |
gr.update(visible=False), # regenerate_btn
|
508 |
+
gr.update(value=None, visible=False), # download_all_btn
|
509 |
)
|
510 |
|
511 |
|
|
|
531 |
with gr.Blocks() as demo:
|
532 |
img_proc_state = gr.State()
|
533 |
background_remove_state = gr.State()
|
534 |
+
hdr_illumination_file_state = gr.State()
|
535 |
gr.Markdown(
|
536 |
"""
|
537 |
# SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
|
|
|
744 |
inputs=hdr_illumination_file,
|
745 |
)
|
746 |
|
747 |
+
def update_hdr_illumination_file(state, cur_update):
|
748 |
+
# If the current value of hdr_illumination_file is the same as cur_update, then we don't need to update
|
749 |
+
if (
|
750 |
+
hdr_illumination_file.value is not None
|
751 |
+
and hdr_illumination_file.value == cur_update
|
752 |
+
):
|
753 |
+
return (
|
754 |
+
gr.update(),
|
755 |
+
gr.update(),
|
756 |
+
)
|
757 |
+
update_value = cur_update if cur_update is not None else state
|
758 |
+
if update_value is not None:
|
759 |
+
return (
|
760 |
+
gr.update(value=update_value),
|
761 |
+
gr.update(
|
762 |
+
env_map=(
|
763 |
+
update_value.name
|
764 |
+
if isinstance(update_value, gr.File)
|
765 |
+
else update_value
|
766 |
+
)
|
767 |
+
),
|
768 |
+
)
|
769 |
+
return (gr.update(value=None), gr.update(env_map=None))
|
770 |
+
|
771 |
hdr_illumination_file.change(
|
772 |
+
update_hdr_illumination_file,
|
773 |
+
inputs=[hdr_illumination_file_state, hdr_illumination_file],
|
774 |
+
outputs=[hdr_illumination_file, output_3d],
|
775 |
)
|
776 |
|
777 |
+
download_all_btn = gr.File(
|
778 |
+
label="Download All Files (ZIP)", file_count="single", visible=False
|
779 |
+
)
|
780 |
+
|
781 |
+
hdr_illumination_file_state.change(
|
782 |
+
fn=lambda x: gr.update(value=x),
|
783 |
+
inputs=hdr_illumination_file_state,
|
784 |
+
outputs=hdr_illumination_file,
|
785 |
+
)
|
786 |
+
|
787 |
examples = gr.Examples(
|
788 |
examples=example_files, inputs=input_img, examples_per_page=11
|
789 |
)
|
|
|
798 |
preview_removal,
|
799 |
output_3d,
|
800 |
hdr_row,
|
801 |
+
hdr_illumination_file_state,
|
802 |
point_cloud_row,
|
803 |
point_cloud_editor,
|
804 |
pc_download,
|
805 |
regenerate_btn,
|
806 |
+
download_all_btn,
|
807 |
],
|
808 |
)
|
809 |
|
|
|
832 |
preview_removal,
|
833 |
output_3d,
|
834 |
hdr_row,
|
835 |
+
hdr_illumination_file_state,
|
836 |
point_cloud_row,
|
837 |
point_cloud_editor,
|
838 |
pc_download,
|
839 |
regenerate_btn,
|
840 |
+
download_all_btn,
|
841 |
],
|
842 |
)
|
843 |
|
|
|
865 |
preview_removal,
|
866 |
output_3d,
|
867 |
hdr_row,
|
868 |
+
hdr_illumination_file_state,
|
869 |
point_cloud_row,
|
870 |
point_cloud_editor,
|
871 |
pc_download,
|
872 |
regenerate_btn,
|
873 |
+
download_all_btn,
|
874 |
],
|
875 |
)
|
876 |
|
877 |
+
demo.queue().launch(share=False)
|
requirements.txt
CHANGED
@@ -16,6 +16,7 @@ transparent-background==1.3.3
|
|
16 |
gradio==4.43.0
|
17 |
gradio-litmodel3d==0.0.1
|
18 |
gradio-pointcloudeditor==0.0.9
|
|
|
19 |
gpytoolbox==0.2.0
|
20 |
# ./texture_baker/
|
21 |
# ./uv_unwrapper/
|
|
|
16 |
gradio==4.43.0
|
17 |
gradio-litmodel3d==0.0.1
|
18 |
gradio-pointcloudeditor==0.0.9
|
19 |
+
opencv-python==4.10.0.84
|
20 |
gpytoolbox==0.2.0
|
21 |
# ./texture_baker/
|
22 |
# ./uv_unwrapper/
|
run.py
CHANGED
@@ -32,9 +32,9 @@ if __name__ == "__main__":
|
|
32 |
)
|
33 |
parser.add_argument(
|
34 |
"--pretrained-model",
|
35 |
-
default="stabilityai/
|
36 |
type=str,
|
37 |
-
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/
|
38 |
)
|
39 |
parser.add_argument(
|
40 |
"--foreground-ratio",
|
|
|
32 |
)
|
33 |
parser.add_argument(
|
34 |
"--pretrained-model",
|
35 |
+
default="stabilityai/stable-point-aware-3d",
|
36 |
type=str,
|
37 |
+
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-point-aware-3d'",
|
38 |
)
|
39 |
parser.add_argument(
|
40 |
"--foreground-ratio",
|
spar3d/models/global_estimator/reni_estimator.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from dataclasses import dataclass, field
|
2 |
-
from typing import Any
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
@@ -95,6 +95,7 @@ class ReniLatentCodeEstimator(BaseModule):
|
|
95 |
def forward(
|
96 |
self,
|
97 |
triplane: Float[Tensor, "B 3 F Ht Wt"],
|
|
|
98 |
) -> dict[str, Any]:
|
99 |
x = self.layers(
|
100 |
triplane.reshape(
|
@@ -104,9 +105,12 @@ class ReniLatentCodeEstimator(BaseModule):
|
|
104 |
x = x.mean(dim=[-2, -1])
|
105 |
|
106 |
latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
|
107 |
-
rotations = self.fc_rotations(x)
|
108 |
scale = self.fc_scale(x)
|
109 |
|
110 |
-
|
|
|
|
|
|
|
111 |
|
112 |
return {"illumination": env_map["rgb"]}
|
|
|
1 |
from dataclasses import dataclass, field
|
2 |
+
from typing import Any, Optional
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
|
|
95 |
def forward(
|
96 |
self,
|
97 |
triplane: Float[Tensor, "B 3 F Ht Wt"],
|
98 |
+
rotation: Optional[Float[Tensor, "B 3 3"]] = None,
|
99 |
) -> dict[str, Any]:
|
100 |
x = self.layers(
|
101 |
triplane.reshape(
|
|
|
105 |
x = x.mean(dim=[-2, -1])
|
106 |
|
107 |
latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
|
108 |
+
rotations = rotation_6d_to_matrix(self.fc_rotations(x))
|
109 |
scale = self.fc_scale(x)
|
110 |
|
111 |
+
if rotation is not None:
|
112 |
+
rotations = rotations @ rotation.to(dtype=rotations.dtype)
|
113 |
+
|
114 |
+
env_map = self.reni_env_map(latents, rotations, scale)
|
115 |
|
116 |
return {"illumination": env_map["rgb"]}
|
spar3d/system.py
CHANGED
@@ -506,6 +506,11 @@ class SPAR3D(BaseModule):
|
|
506 |
|
507 |
scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
|
508 |
|
|
|
|
|
|
|
|
|
|
|
509 |
global_dict = {}
|
510 |
if self.image_estimator is not None:
|
511 |
global_dict.update(
|
@@ -514,7 +519,14 @@ class SPAR3D(BaseModule):
|
|
514 |
)
|
515 |
)
|
516 |
if self.global_estimator is not None and estimate_illumination:
|
517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
|
519 |
global_dict["pointcloud"] = batch["pc_cond"]
|
520 |
|
@@ -700,15 +712,7 @@ class SPAR3D(BaseModule):
|
|
700 |
uv=uvs, material=material
|
701 |
),
|
702 |
)
|
703 |
-
|
704 |
-
np.radians(-90), [1, 0, 0]
|
705 |
-
)
|
706 |
-
tmesh.apply_transform(rot)
|
707 |
-
tmesh.apply_transform(
|
708 |
-
trimesh.transformations.rotation_matrix(
|
709 |
-
np.radians(90), [0, 1, 0]
|
710 |
-
)
|
711 |
-
)
|
712 |
|
713 |
tmesh.invert()
|
714 |
|
|
|
506 |
|
507 |
scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
|
508 |
|
509 |
+
# Create a rotation matrix for the final output domain
|
510 |
+
rotation = trimesh.transformations.rotation_matrix(np.radians(-90), [1, 0, 0])
|
511 |
+
rotation2 = trimesh.transformations.rotation_matrix(np.radians(90), [0, 1, 0])
|
512 |
+
output_rotation = rotation2 @ rotation
|
513 |
+
|
514 |
global_dict = {}
|
515 |
if self.image_estimator is not None:
|
516 |
global_dict.update(
|
|
|
519 |
)
|
520 |
)
|
521 |
if self.global_estimator is not None and estimate_illumination:
|
522 |
+
rotation_torch = (
|
523 |
+
torch.tensor(output_rotation)
|
524 |
+
.to(self.device, dtype=torch.float32)[:3, :3]
|
525 |
+
.unsqueeze(0)
|
526 |
+
)
|
527 |
+
global_dict.update(
|
528 |
+
self.global_estimator(non_postprocessed_codes, rotation=rotation_torch)
|
529 |
+
)
|
530 |
|
531 |
global_dict["pointcloud"] = batch["pc_cond"]
|
532 |
|
|
|
712 |
uv=uvs, material=material
|
713 |
),
|
714 |
)
|
715 |
+
tmesh.apply_transform(output_rotation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
716 |
|
717 |
tmesh.invert()
|
718 |
|