gs-dynamics / app.py
kaifz's picture
update
b05869c
import gradio as gr
import torch
import numpy as np
import os
from PIL import Image
class DynamicsVisualizer:
def __init__(self):
device = torch.device("cpu")
self.device = device
self.width = 640
self.height = 480
self.vis_cam_id = 1
self.bg_id = 0 # 0: black, 1: white
self.imgs = None
self.gs_orig = None
self.gs_pred = None
self.actions = None
self.videos = None
self.example_name = None
self.action_name = None
self.form_image_is_set = False
self.form_video_is_set = False
self.form_3dgs_orig_is_set = False
self.form_3dgs_pred_is_set = False
def load_example(self):
example_path = os.path.join('data', self.example_name)
self.imgs = [Image.open(os.path.join(example_path, f'img_{i}.png')) for i in range(4)]
self.gs_orig = os.path.join(example_path, 'gs_orig.splat')
def load_action(self):
action_path = os.path.join('data', self.action_name)
self.imgs = [Image.open(os.path.join(action_path, f'img_{i}.png')) for i in range(4)]
self.videos = [os.path.join(action_path, f'video_{i}.mp4') for i in range(4)]
self.gs_pred = os.path.join(action_path, 'gs_pred.splat')
def reset(self):
self.imgs = None
self.gs_orig = None
self.gs_pred = None
self.actions = None
self.videos = None
self.vis_cam_id = 1
self.bg_id = 0 # 0: black, 1: white
self.example_name = None
self.action_name = None
form_image = gr.Image(label='Initial state and actions', value=None, width=self.width, height=self.height)
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height)
form_3dgs_orig = gr.Model3D(label='Original Gaussian Splats', value=None)
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=None)
self.form_image_is_set = False
self.form_video_is_set = False
self.form_3dgs_orig_is_set = False
self.form_3dgs_pred_is_set = False
return form_image, form_video, form_3dgs_orig, form_3dgs_pred
def on_click_set_example(self, state):
self.example_name = f"{int(state['example_id'])}"
self.load_example()
init_image = self.imgs[self.vis_cam_id]
form_image = gr.Image(label='Initial state and actions', value=init_image, width=self.width, height=self.height)
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height)
form_3dgs_orig = gr.Model3D(label='Original Gaussian Splats', value=self.gs_orig, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0])
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=None)
self.form_image_is_set = True
self.form_video_is_set = False
self.form_3dgs_orig_is_set = True
self.form_3dgs_pred_is_set = False
return form_image, form_video, form_3dgs_orig, form_3dgs_pred
def on_click_set_action(self, state):
self.action_name = f"{self.example_name}/action-{int(state['action_id'])}"
self.load_action()
action_image = self.imgs[self.vis_cam_id]
form_image = gr.Image(label='Initial state and actions', value=action_image, width=self.width, height=self.height)
self.form_image_is_set = True
return form_image
def on_click_run(self):
form_video = gr.Video(label='Predicted video', value=self.videos[self.vis_cam_id], width=self.width, height=self.height)
form_3dgs_pred = gr.Model3D(label='Predicted Gaussian Splats', value=self.gs_pred, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0])
self.form_video_is_set = True
self.form_3dgs_pred_is_set = True
return form_video, form_3dgs_pred
def on_click_change_view(self, state):
self.vis_cam_id = int(state['view_id'])
form_image = gr.Image(label='Initial state and actions', value=self.imgs[self.vis_cam_id], width=self.width, height=self.height)
if self.form_video_is_set:
form_video = gr.Video(label='Predicted video', value=self.videos[self.vis_cam_id], width=self.width, height=self.height)
else:
form_video = gr.Video(label='Predicted video', value=None, width=self.width, height=self.height)
return form_image, form_video
# def on_click_change_bg(self):
# if self.bg_id == 0:
# self.bg_id = 1
# else:
# self.bg_id = 0
# if self.form_3dgs_orig_is_set:
# form_3dgs_orig = gr.Model3D(value=self.gs_orig, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0])
# else:
# form_3dgs_orig = gr.Model3D(value=None)
# if self.form_3dgs_pred_is_set:
# form_3dgs_pred = gr.Model3D(value=self.gs_pred, clear_color=[self.bg_id, self.bg_id, self.bg_id, 0])
# else:
# form_3dgs_pred = gr.Model3D(value=None)
# return form_3dgs_orig, form_3dgs_pred
def launch(self, share=False):
with gr.Blocks() as app:
# with gr.Row():
# with gr.Column(scale=2):
# run_reset = gr.Button('Clear All')
# with gr.Column(scale=2):
# run_set_example = gr.Button('Set Example')
# with gr.Column(scale=2):
# run_set_action = gr.Button('Set Action')
# with gr.Column(scale=2):
# run_run = gr.Button('Run')
# with gr.Row():
# with gr.Column(scale=1, min_width=20):
# with gr.Row():
# run_view_0 = gr.Button('View 0')
# with gr.Row():
# run_view_1 = gr.Button('View 1')
# with gr.Row():
# run_view_2 = gr.Button('View 2')
# with gr.Row():
# run_view_3 = gr.Button('View 3')
with gr.Row():
gr.Markdown("# Dynamic 3D Gaussian Tracking for Graph-Based Neural Dynamics Modeling")
with gr.Row():
gr.Markdown('Project page: [https://gs-dynamics.github.io/](https://gs-dynamics.github.io/)')
with gr.Row():
gr.Markdown()
with gr.Row():
gr.Markdown()
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 0**: click **Clear All** to clear all window and reset the visualizer.")
with gr.Column(scale=1):
run_reset = gr.Button('Clear All')
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 1**: select the object.")
with gr.Column(scale=1):
run_set_example_0 = gr.Button('Rope')
with gr.Column(scale=1):
run_set_example_1 = gr.Button('Rope - Long')
with gr.Column(scale=1):
run_set_example_2 = gr.Button('Toy Animal')
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 2**: select the action.")
with gr.Column(scale=1):
run_set_action_0 = gr.Button('Action 1')
with gr.Column(scale=1):
run_set_action_1 = gr.Button('Action 2')
with gr.Column(scale=1):
run_set_action_2 = gr.Button('Action 3')
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Step 3**: click **Run** to visualize the predicted video and Splats.")
with gr.Column(scale=1):
run_run = gr.Button('Run')
with gr.Row():
with gr.Column(scale=1, min_width=20):
with gr.Row():
gr.Markdown()
with gr.Row():
gr.Markdown()
with gr.Row():
gr.Markdown()
with gr.Row():
gr.Markdown()
# with gr.Row():
# gr.Markdown()
# with gr.Row():
# gr.Markdown()
with gr.Row():
gr.Markdown("Our model uses only 4 cameras for reconstructing the Gaussian Splats. Click the buttons below to change the view.")
with gr.Row():
run_view_0 = gr.Button('View 0')
with gr.Row():
run_view_1 = gr.Button('View 1')
with gr.Row():
run_view_2 = gr.Button('View 2')
with gr.Row():
run_view_3 = gr.Button('View 3')
with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=2):
form_image = gr.Image(
label='Initial state and actions',
value=None,
width=self.width,
height=self.height,
)
with gr.Column(scale=2):
form_video = gr.Video(
label='Predicted video',
value=None,
width=self.width,
height=self.height,
)
with gr.Row():
# with gr.Column(scale=1, min_width=20):
# pass
# with gr.Row():
# change_bg = gr.Button('Black/White Background')
with gr.Column(scale=2):
form_3dgs_orig = gr.Model3D(
label='Original Gaussian Splats',
value=None,
)
with gr.Column(scale=2):
form_3dgs_pred = gr.Model3D(
label='Predicted Gaussian Splats',
value=None,
)
with gr.Row():
gr.Markdown("## Notes:")
with gr.Row():
gr.Markdown("- Due to the computation constraints of Hugging Face Space, all results are precomputed. ")
with gr.Row():
gr.Markdown("- Training a GS for an object takes around 30 seconds. Prediction typically takes only 1-2 seconds for each push!")
with gr.Row():
gr.Markdown("- More examples may be added in the future. Stay tuned!")
# with gr.Row():
# with gr.Column(scale=1):
# gr.Markdown("You can change the view to any of the 4 cameras.")
# with gr.Column(scale=1):
# run_view_0 = gr.Button('View 1')
# with gr.Column(scale=1):
# run_view_1 = gr.Button('View 2')
# with gr.Column(scale=1):
# run_view_2 = gr.Button('View 3')
# with gr.Column(scale=1):
# run_view_3 = gr.Button('View 4')
# Set up callbacks
run_reset.click(self.reset,
inputs=[],
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred])
run_set_example_0.click(self.on_click_set_example,
inputs=[gr.State({'example_id': 0})],
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred])
run_set_example_1.click(self.on_click_set_example,
inputs=[gr.State({'example_id': 1})],
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred])
run_set_example_2.click(self.on_click_set_example,
inputs=[gr.State({'example_id': 2})],
outputs=[form_image, form_video, form_3dgs_orig, form_3dgs_pred])
run_set_action_0.click(self.on_click_set_action,
inputs=[gr.State({'action_id': 0})],
outputs=[form_image])
run_set_action_1.click(self.on_click_set_action,
inputs=[gr.State({'action_id': 1})],
outputs=[form_image])
run_set_action_2.click(self.on_click_set_action,
inputs=[gr.State({'action_id': 2})],
outputs=[form_image])
run_run.click(self.on_click_run,
inputs=[],
outputs=[form_video, form_3dgs_pred])
run_view_0.click(self.on_click_change_view,
inputs=[gr.State({'view_id': 1})],
outputs=[form_image, form_video])
run_view_1.click(self.on_click_change_view,
inputs=[gr.State({'view_id': 2})],
outputs=[form_image, form_video])
run_view_2.click(self.on_click_change_view,
inputs=[gr.State({'view_id': 3})],
outputs=[form_image, form_video])
run_view_3.click(self.on_click_change_view,
inputs=[gr.State({'view_id': 0})],
outputs=[form_image, form_video])
# change_bg.click(self.on_click_change_bg,
# inputs=[],
# outputs=[form_3dgs_orig, form_3dgs_pred])
app.launch(share=share)
if __name__ == '__main__':
visualizer = DynamicsVisualizer()
visualizer.launch(share=True)