# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import functools import os import spaces import gradio as gr import numpy as np import plotly.graph_objects as go import torch as torch from PIL import Image from scipy.ndimage import maximum_filter from marigold_dc import MarigoldDepthCompletionPipeline from gradio_imageslider import ImageSlider from huggingface_hub import login DRY_RUN = False def dilate_rgb_image(image, kernel_size): r_channel, g_channel, b_channel = image[..., 0], image[..., 1], image[..., 2] r_dilated = maximum_filter(r_channel, size=kernel_size) g_dilated = maximum_filter(g_channel, size=kernel_size) b_dilated = maximum_filter(b_channel, size=kernel_size) dilated_image = np.stack([r_dilated, g_dilated, b_dilated], axis=-1) return dilated_image def generate_rmse_plot(steps, metrics, denoise_steps): y_min = min(metrics) y_max = max(metrics) fig = go.Figure() fig.add_trace( go.Scatter( x=steps, y=metrics, mode="lines+markers", line=dict(color="#af2928"), name="RMSE", ) ) if denoise_steps < 20: x_dtick = 1 else: x_dtick = 5 fig.update_layout( autosize=False, height=300, xaxis_title="Steps", xaxis_range=[0, denoise_steps + 1], xaxis=dict( scaleanchor="y", scaleratio=1.5, dtick=x_dtick, ), yaxis_title="RMSE", yaxis_range=[np.log10(max(y_min - 0.1, 0.1)), np.log10(y_max + 1)], yaxis=dict( type="log", ), hovermode="x unified", template="plotly_white", ) return fig def process( pipe, path_image, path_sparse, denoise_steps, ): image = Image.open(path_image) sparse_depth = np.load(path_sparse) sparse_depth_valid = sparse_depth[sparse_depth > 0] sparse_depth_min = np.min(sparse_depth_valid) sparse_depth_max = np.max(sparse_depth_valid) width, height = image.size max_dim = max(width, height) processing_resolution = 0 if max_dim > 768: processing_resolution = 768 metrics = [] steps = [] for step, (pred, rmse) in enumerate( pipe( image=Image.open(path_image), sparse_depth=sparse_depth, num_inference_steps=denoise_steps + 1, processing_resolution=processing_resolution, dry_run=DRY_RUN, ) ): min_both = min(sparse_depth_min, pred.min().item()) max_both = min(sparse_depth_max, pred.max().item()) metrics.append(rmse) steps.append(step) vis_pred = pipe.image_processor.visualize_depth( pred, val_min=min_both, val_max=max_both )[0] vis_sparse = pipe.image_processor.visualize_depth( sparse_depth, val_min=min_both, val_max=max_both )[0] vis_sparse = np.array(vis_sparse) vis_sparse[sparse_depth <= 0] = (0, 0, 0) vis_sparse = dilate_rgb_image(vis_sparse, kernel_size=5) vis_sparse = Image.fromarray(vis_sparse) plot = generate_rmse_plot(steps, metrics, denoise_steps) yield ( [vis_sparse, vis_pred], plot, ) def run_demo_server(pipe): process_pipe = spaces.GPU(functools.partial(process, pipe)) os.environ["GRADIO_ALLOW_FLAGGING"] = "never" with gr.Blocks( analytics_enabled=False, title="Marigold Depth Completion", css=""" #short { height: 130px; } .slider .inner { width: 4px; background: #FFF; } .slider .icon-wrap svg { fill: #FFF; stroke: #FFF; stroke-width: 3px; } .viewport { aspect-ratio: 4/3; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } """, ) as demo: gr.HTML( """

⇆ Marigold-DC: Zero-Shot Monocular Depth Completion with Guided Diffusion

Website Badge arXiv Badge badge-github-stars social
Start exploring the interactive examples at the bottom of the page!

""" ) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="filepath", ) input_sparse = gr.File( label="Input sparse depth (numpy file)", elem_id="short", ) with gr.Accordion("Advanced options", open=False): denoise_steps = gr.Slider( label="Number of denoising steps", minimum=10, maximum=50, step=1, value=50, ) with gr.Row(): submit_btn = gr.Button(value="Compute Depth", variant="primary") clear_btn = gr.Button(value="Clear") with gr.Column(): output_slider = ImageSlider( label="Completed depth (red-near, blue-far)", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, ) plot = gr.Plot( label="RMSE between input and result", elem_id="viewport", ) inputs = [ input_image, input_sparse, denoise_steps, ] outputs = [ output_slider, plot, ] def submit_depth_fn(path_image, path_sparse, denoise_steps): for outputs in process_pipe(path_image, path_sparse, denoise_steps): yield outputs submit_btn.click( fn=submit_depth_fn, inputs=inputs, outputs=outputs, ) gr.Examples( fn=submit_depth_fn, examples=[ [ "files/kitti_1.png", "files/kitti_1.npy", 10, # denoise_steps ], [ "files/kitti_2.png", "files/kitti_2.npy", 10, # denoise_steps ], [ "files/teaser.png", "files/teaser_1000.npy", 10, # denoise_steps ], [ "files/teaser.png", "files/teaser_100.npy", 10, # denoise_steps ], [ "files/teaser.png", "files/teaser_10.npy", 10, # denoise_steps ], ], inputs=inputs, outputs=outputs, cache_examples="lazy", ) def clear_fn(): return [ gr.Image(value=None, interactive=True), gr.File(None, interactive=True), None, ] clear_btn.click( fn=clear_fn, inputs=[], outputs=[ input_image, input_sparse, output_slider, ], ) demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, ) def main(): CHECKPOINT = "prs-eth/marigold-depth-v1-0" os.system("pip freeze") if "HF_TOKEN_LOGIN" in os.environ: login(token=os.environ["HF_TOKEN_LOGIN"]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipe = MarigoldDepthCompletionPipeline.from_pretrained(CHECKPOINT) try: import xformers pipe.enable_xformers_memory_efficient_attention() except: pass # run without xformers pipe = pipe.to(device) run_demo_server(pipe) if __name__ == "__main__": main()