Jose Benitez commited on
Commit
5bccfc0
·
1 Parent(s): 7f2e027

add video support

Browse files
Files changed (1) hide show
  1. app.py +84 -40
app.py CHANGED
@@ -26,7 +26,8 @@ css = """
26
  height: 62px;
27
  }
28
  """
29
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
30
  model_configs = {
31
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
32
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
@@ -55,48 +56,91 @@ Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](ht
55
  def predict_depth(image):
56
  return model.infer_image(image)
57
 
58
- with gr.Blocks(css=css) as demo:
59
- gr.Markdown(title)
60
- gr.Markdown(description)
61
- gr.Markdown("### Depth Prediction demo")
62
-
63
- with gr.Row():
64
- input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
65
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
66
- submit = gr.Button(value="Compute Depth")
67
- gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
68
- raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
69
-
70
- cmap = matplotlib.colormaps.get_cmap('Spectral_r')
71
-
72
- def on_submit(image):
73
- original_image = image.copy()
74
-
75
- h, w = image.shape[:2]
76
-
77
- depth = predict_depth(image[:, :, ::-1])
78
-
79
- raw_depth = Image.fromarray(depth.astype('uint16'))
80
- tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
81
- raw_depth.save(tmp_raw_depth.name)
82
-
83
  depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
84
  depth = depth.astype(np.uint8)
85
- colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
86
-
87
- gray_depth = Image.fromarray(depth)
88
- tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
89
- gray_depth.save(tmp_gray_depth.name)
90
-
91
- return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
92
-
93
- submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
94
-
95
- example_files = os.listdir('assets/examples')
96
- example_files.sort()
97
- example_files = [os.path.join('assets/examples', filename) for filename in example_files]
98
- examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  if __name__ == '__main__':
102
  demo.queue().launch(share=True)
 
26
  height: 62px;
27
  }
28
  """
29
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
30
+
31
  model_configs = {
32
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
33
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
 
56
  def predict_depth(image):
57
  return model.infer_image(image)
58
 
59
+ def process_video(video_path):
60
+ input_size = 518
61
+ temp_output_path = tempfile.mktemp(suffix='.mp4')
62
+
63
+ raw_video = cv2.VideoCapture(video_path)
64
+ frame_width = int(raw_video.get(cv2.CAP_PROP_FRAME_WIDTH))
65
+ frame_height = int(raw_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
66
+ frame_rate = int(raw_video.get(cv2.CAP_PROP_FPS))
67
+
68
+ out = cv2.VideoWriter(temp_output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (frame_width, frame_height))
69
+
70
+ while raw_video.isOpened():
71
+ ret, raw_frame = raw_video.read()
72
+ if not ret:
73
+ break
74
+
75
+ depth = model.infer_image(raw_frame, input_size)
76
+
 
 
 
 
 
 
 
77
  depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
78
  depth = depth.astype(np.uint8)
79
+ colored_depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
80
+
81
+ out.write(colored_depth)
82
+
83
+ raw_video.release()
84
+ out.release()
85
+
86
+ return temp_output_path
 
 
 
 
 
 
87
 
88
+ with gr.Blocks(css=css) as demo:
89
+ gr.Markdown(title)
90
+ gr.Markdown(description)
91
+
92
+ with gr.Tabs():
93
+ with gr.TabItem("Image"):
94
+ gr.Markdown("### Depth Prediction demo")
95
+ with gr.Row():
96
+ input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
97
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
98
+ submit = gr.Button(value="Compute Depth")
99
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
100
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
101
+
102
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
103
+
104
+ def on_submit(image):
105
+ original_image = image.copy()
106
+
107
+ h, w = image.shape[:2]
108
+
109
+ depth = predict_depth(image[:, :, ::-1])
110
+
111
+ raw_depth = Image.fromarray(depth.astype('uint16'))
112
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
113
+ raw_depth.save(tmp_raw_depth.name)
114
+
115
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
116
+ depth = depth.astype(np.uint8)
117
+ colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
118
+
119
+ gray_depth = Image.fromarray(depth)
120
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
121
+ gray_depth.save(tmp_gray_depth.name)
122
+
123
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
124
+
125
+ submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
126
+
127
+ example_files = os.listdir('assets/examples')
128
+ example_files.sort()
129
+ example_files = [os.path.join('assets/examples', filename) for filename in example_files]
130
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
131
+
132
+ with gr.TabItem("Video"):
133
+ gr.Markdown("### Video Depth Prediction demo")
134
+ input_video = gr.Video(label="Input Video")
135
+ output_video = gr.Video(label="Output Video")
136
+ process_video_btn = gr.Button(value="Process Video")
137
+
138
+ process_video_btn.click(process_video, inputs=[input_video], outputs=[output_video])
139
+
140
+ example_files = os.listdir('assets/examples_video')
141
+ example_files.sort()
142
+ example_files = [os.path.join('assets/examples_video', filename) for filename in example_files]
143
+ examples = gr.Examples(examples=example_files, inputs=[input_video], outputs=[output_video], fn=process_video)
144
 
145
  if __name__ == '__main__':
146
  demo.queue().launch(share=True)