Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
import os | |
from PIL import Image | |
class DynamicsVisualizer: | |
def __init__(self): | |
device = torch.device("cpu") | |
self.device = device | |
self.width = 640 | |
self.height = 480 | |
self.vis_cam_id = 1 | |
self.bg_id = 0 # 0: black, 1: white | |
self.imgs = None | |
self.gs_orig = None | |
self.gs_pred = None | |
self.actions = None | |
self.videos = None | |
self.example_name = None | |
self.action_name = None | |
self.form_image_is_set = False | |
self.form_video_is_set = False | |
self.form_3dgs_orig_is_set = False | |
self.form_3dgs_pred_is_set = False | |
def load_example(self): | |
example_path = os.path.join('data', self.example_name) | |
self.imgs = [Image.open(os.path.join(example_path, f'img_{i}.png')) for i in range(4)] | |
self.gs_orig = os.path.join(example_path, 'gs_orig.splat') | |
def load_action(self): | |
action_path = os.path.join('data', self.action_name) | |
self.imgs = [Image.open(os.path.join(action_path, f'img_{i}.png')) for i in range(4)] | |
self.videos = [os.path.join(action_path, f'video_{i}.mp4') for i in range(4)] | |
self.gs_pred = os.path.join(action_path, 'gs_pred.splat') | |
def reset(self): | |
self.imgs = None | |
self.gs_orig = None | |
self.gs_pred = None | |
self.actions = None | |
self.videos = None | |
self.vis_cam_id = 1 | |
self.bg_id = 0 # 0: black, 1: white | |
self.example_name = None | |
self.action_name = None | |
form_image = gr.Image(label='Initial state and actions', value=None, width=self.width, height=self.height) | |
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height) | |
form_3dgs_orig = gr.Model3D(label='Original Gaussian Splats', value=None) | |
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=None) | |
self.form_image_is_set = False | |
self.form_video_is_set = False | |
self.form_3dgs_orig_is_set = False | |
self.form_3dgs_pred_is_set = False | |
return form_image, form_video, form_3dgs_orig, form_3dgs_pred | |
def on_click_set_example(self, state): | |
self.example_name = f"{int(state['example_id'])}" | |
self.load_example() | |
init_image = self.imgs[self.vis_cam_id] | |
form_image = gr.Image(label='Initial state and actions', value=init_image, width=self.width, height=self.height) | |
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height) | |
form_3dgs_orig = gr.Model3D(label='Original Gaussian Splats', value=self.gs_orig, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0]) | |
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=None) | |
self.form_image_is_set = True | |
self.form_video_is_set = False | |
self.form_3dgs_orig_is_set = True | |
self.form_3dgs_pred_is_set = False | |
return form_image, form_video, form_3dgs_orig, form_3dgs_pred | |
def on_click_set_action(self, state): | |
self.action_name = f"{self.example_name}/action-{int(state['action_id'])}" | |
self.load_action() | |
action_image = self.imgs[self.vis_cam_id] | |
form_image = gr.Image(label='Initial state and actions', value=action_image, width=self.width, height=self.height) | |
self.form_image_is_set = True | |
return form_image | |
def on_click_run(self): | |
form_video = gr.Video(label='Predicted video', value=self.videos[self.vis_cam_id], width=self.width, height=self.height) | |
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=self.gs_pred, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0]) | |
self.form_video_is_set = True | |
self.form_3dgs_pred_is_set = True | |
return form_video, form_3dgs_pred | |
def on_click_change_view(self, state): | |
self.vis_cam_id = int(state['view_id']) | |
form_image = gr.Image(label='Initial state and actions', value=self.imgs[self.vis_cam_id], width=self.width, height=self.height) | |
if self.form_video_is_set: | |
form_video = gr.Video(label='Predicted video', value=self.videos[self.vis_cam_id], width=self.width, height=self.height) | |
else: | |
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height) | |
return form_image, form_video | |
# def on_click_change_bg(self): | |
# if self.bg_id == 0: | |
# self.bg_id = 1 | |
# else: | |
# self.bg_id = 0 | |
# if self.form_3dgs_orig_is_set: | |
# form_3dgs_orig = gr.Model3D(value=self.gs_orig, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0]) | |
# else: | |
# form_3dgs_orig = gr.Model3D(value=None) | |
# if self.form_3dgs_pred_is_set: | |
# form_3dgs_pred = gr.Model3D(value=self.gs_pred, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0]) | |
# else: | |
# form_3dgs_pred = gr.Model3D(value=None) | |
# return form_3dgs_orig, form_3dgs_pred | |
def launch(self, share=False): | |
with gr.Blocks() as app: | |
# with gr.Row(): | |
# with gr.Column(scale=2): | |
# run_reset = gr.Button('Clear All') | |
# with gr.Column(scale=2): | |
# run_set_example = gr.Button('Set Example') | |
# with gr.Column(scale=2): | |
# run_set_action = gr.Button('Set Action') | |
# with gr.Column(scale=2): | |
# run_run = gr.Button('Run') | |
# with gr.Row(): | |
# with gr.Column(scale=1, min_width=20): | |
# with gr.Row(): | |
# run_view_0 = gr.Button('View 0') | |
# with gr.Row(): | |
# run_view_1 = gr.Button('View 1') | |
# with gr.Row(): | |
# run_view_2 = gr.Button('View 2') | |
# with gr.Row(): | |
# run_view_3 = gr.Button('View 3') | |
with gr.Row(): | |
gr.Markdown("# Dynamic 3D Gaussian Tracking for Graph-Based Neural Dynamics Modeling") | |
with gr.Row(): | |
gr.Markdown('Project page: [https://gs-dynamics.github.io/](https://gs-dynamics.github.io/)') | |
with gr.Row(): | |
gr.Markdown() | |
with gr.Row(): | |
gr.Markdown() | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("**Step 0**: click **Clear All** to clear all window and reset the visualizer.") | |
with gr.Column(scale=1): | |
run_reset = gr.Button('Clear All') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("**Step 1**: select the object.") | |
with gr.Column(scale=1): | |
run_set_example_0 = gr.Button('Rope') | |
with gr.Column(scale=1): | |
run_set_example_1 = gr.Button('Rope - Long') | |
with gr.Column(scale=1): | |
run_set_example_2 = gr.Button('Toy Animal') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("**Step 2**: select the action.") | |
with gr.Column(scale=1): | |
run_set_action_0 = gr.Button('Action 1') | |
with gr.Column(scale=1): | |
run_set_action_1 = gr.Button('Action 2') | |
with gr.Column(scale=1): | |
run_set_action_2 = gr.Button('Action 3') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("**Step 3**: click **Run** to visualize the predicted video and Splats.") | |
with gr.Column(scale=1): | |
run_run = gr.Button('Run') | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=20): | |
with gr.Row(): | |
gr.Markdown() | |
with gr.Row(): | |
gr.Markdown() | |
with gr.Row(): | |
gr.Markdown() | |
with gr.Row(): | |
gr.Markdown() | |
# with gr.Row(): | |
# gr.Markdown() | |
# with gr.Row(): | |
# gr.Markdown() | |
with gr.Row(): | |
gr.Markdown("Our model uses only 4 cameras for reconstructing the Gaussian Splats. Click the buttons below to change the view.") | |
with gr.Row(): | |
run_view_0 = gr.Button('View 0') | |
with gr.Row(): | |
run_view_1 = gr.Button('View 1') | |
with gr.Row(): | |
run_view_2 = gr.Button('View 2') | |
with gr.Row(): | |
run_view_3 = gr.Button('View 3') | |
with gr.Column(scale=4): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
form_image = gr.Image( | |
label='Initial state and actions', | |
value=None, | |
width=self.width, | |
height=self.height, | |
) | |
with gr.Column(scale=2): | |
form_video = gr.Video( | |
label='Predicted video', | |
value=None, | |
width=self.width, | |
height=self.height, | |
) | |
with gr.Row(): | |
# with gr.Column(scale=1, min_width=20): | |
# pass | |
# with gr.Row(): | |
# change_bg = gr.Button('Black/White Background') | |
with gr.Column(scale=2): | |
form_3dgs_orig = gr.Model3D( | |
label='Original Gaussian Splats', | |
value=None, | |
) | |
with gr.Column(scale=2): | |
form_3dgs_pred = gr.Model3D( | |
label='Predicted Gaussian Splats', | |
value=None, | |
) | |
with gr.Row(): | |
gr.Markdown("## Notes:") | |
with gr.Row(): | |
gr.Markdown("- Due to the computation constraints of Hugging Face Space, all results are precomputed. ") | |
with gr.Row(): | |
gr.Markdown("- Training a GS for an object takes around 30 seconds. Prediction typically takes only 1-2 seconds for each push!") | |
with gr.Row(): | |
gr.Markdown("- More examples may be added in the future. Stay tuned!") | |
# with gr.Row(): | |
# with gr.Column(scale=1): | |
# gr.Markdown("You can change the view to any of the 4 cameras.") | |
# with gr.Column(scale=1): | |
# run_view_0 = gr.Button('View 1') | |
# with gr.Column(scale=1): | |
# run_view_1 = gr.Button('View 2') | |
# with gr.Column(scale=1): | |
# run_view_2 = gr.Button('View 3') | |
# with gr.Column(scale=1): | |
# run_view_3 = gr.Button('View 4') | |
# Set up callbacks | |
run_reset.click(self.reset, | |
inputs=[], | |
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred]) | |
run_set_example_0.click(self.on_click_set_example, | |
inputs=[gr.State({'example_id': 0})], | |
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred]) | |
run_set_example_1.click(self.on_click_set_example, | |
inputs=[gr.State({'example_id': 1})], | |
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred]) | |
run_set_example_2.click(self.on_click_set_example, | |
inputs=[gr.State({'example_id': 2})], | |
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred]) | |
run_set_action_0.click(self.on_click_set_action, | |
inputs=[gr.State({'action_id': 0})], | |
outputs=[form_image]) | |
run_set_action_1.click(self.on_click_set_action, | |
inputs=[gr.State({'action_id': 1})], | |
outputs=[form_image]) | |
run_set_action_2.click(self.on_click_set_action, | |
inputs=[gr.State({'action_id': 2})], | |
outputs=[form_image]) | |
run_run.click(self.on_click_run, | |
inputs=[], | |
outputs=[form_video, form_3dgs_pred]) | |
run_view_0.click(self.on_click_change_view, | |
inputs=[gr.State({'view_id': 1})], | |
outputs=[form_image, form_video]) | |
run_view_1.click(self.on_click_change_view, | |
inputs=[gr.State({'view_id': 2})], | |
outputs=[form_image, form_video]) | |
run_view_2.click(self.on_click_change_view, | |
inputs=[gr.State({'view_id': 3})], | |
outputs=[form_image, form_video]) | |
run_view_3.click(self.on_click_change_view, | |
inputs=[gr.State({'view_id': 0})], | |
outputs=[form_image, form_video]) | |
# change_bg.click(self.on_click_change_bg, | |
# inputs=[], | |
# outputs=[form_3dgs_orig, form_3dgs_pred]) | |
app.launch(share=share) | |
if __name__ == '__main__': | |
visualizer = DynamicsVisualizer() | |
visualizer.launch(share=True) | |