# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import functools import os import spaces import gradio as gr import numpy as np import plotly.graph_objects as go import torch as torch from PIL import Image from scipy.ndimage import maximum_filter from marigold_dc import MarigoldDepthCompletionPipeline from gradio_imageslider import ImageSlider from huggingface_hub import login DRY_RUN = False def dilate_rgb_image(image, kernel_size): r_channel, g_channel, b_channel = image[..., 0], image[..., 1], image[..., 2] r_dilated = maximum_filter(r_channel, size=kernel_size) g_dilated = maximum_filter(g_channel, size=kernel_size) b_dilated = maximum_filter(b_channel, size=kernel_size) dilated_image = np.stack([r_dilated, g_dilated, b_dilated], axis=-1) return dilated_image def generate_rmse_plot(steps, metrics, denoise_steps): y_min = min(metrics) y_max = max(metrics) fig = go.Figure() fig.add_trace( go.Scatter( x=steps, y=metrics, mode="lines+markers", line=dict(color="#af2928"), name="RMSE", ) ) if denoise_steps < 20: x_dtick = 1 else: x_dtick = 5 fig.update_layout( autosize=False, height=300, xaxis_title="Steps", xaxis_range=[0, denoise_steps + 1], xaxis=dict( scaleanchor="y", scaleratio=1.5, dtick=x_dtick, ), yaxis_title="RMSE", yaxis_range=[np.log10(max(y_min - 0.1, 0.1)), np.log10(y_max + 1)], yaxis=dict( type="log", ), hovermode="x unified", template="plotly_white", ) return fig def process( pipe, path_image, path_sparse, denoise_steps, ): image = Image.open(path_image) sparse_depth = np.load(path_sparse) sparse_depth_valid = sparse_depth[sparse_depth > 0] sparse_depth_min = np.min(sparse_depth_valid) sparse_depth_max = np.max(sparse_depth_valid) width, height = image.size max_dim = max(width, height) processing_resolution = 0 if max_dim > 768: processing_resolution = 768 metrics = [] steps = [] for step, (pred, rmse) in enumerate( pipe( image=Image.open(path_image), sparse_depth=sparse_depth, num_inference_steps=denoise_steps + 1, processing_resolution=processing_resolution, dry_run=DRY_RUN, ) ): min_both = min(sparse_depth_min, pred.min().item()) max_both = min(sparse_depth_max, pred.max().item()) metrics.append(rmse) steps.append(step) vis_pred = pipe.image_processor.visualize_depth( pred, val_min=min_both, val_max=max_both )[0] vis_sparse = pipe.image_processor.visualize_depth( sparse_depth, val_min=min_both, val_max=max_both )[0] vis_sparse = np.array(vis_sparse) vis_sparse[sparse_depth <= 0] = (0, 0, 0) vis_sparse = dilate_rgb_image(vis_sparse, kernel_size=5) vis_sparse = Image.fromarray(vis_sparse) plot = generate_rmse_plot(steps, metrics, denoise_steps) yield ( [vis_sparse, vis_pred], plot, ) def run_demo_server(pipe): process_pipe = spaces.GPU(functools.partial(process, pipe)) os.environ["GRADIO_ALLOW_FLAGGING"] = "never" with gr.Blocks( analytics_enabled=False, title="Marigold Depth Completion", css=""" #short { height: 130px; } .slider .inner { width: 4px; background: #FFF; } .slider .icon-wrap svg { fill: #FFF; stroke: #FFF; stroke-width: 3px; } .viewport { aspect-ratio: 4/3; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } """, ) as demo: gr.HTML( """
Start exploring the interactive examples at the bottom of the page!