# Copyright 2024 Anton Obukhov and Kevin Qu, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldcomputervision.github.io # -------------------------------------------------------------------------- from __future__ import annotations import functools import os import tempfile import warnings import spaces import gradio as gr import numpy as np import torch as torch from PIL import Image from gradio_imageslider import ImageSlider from huggingface_hub import login from gradio_patches.examples import Examples from gradio_patches.flagging import HuggingFaceDatasetSaver, FlagMethod from marigold_iid_appearance import MarigoldIIDAppearancePipeline from marigold_iid_lighting import MarigoldIIDLightingPipeline warnings.filterwarnings( "ignore", message=".*LoginButton created outside of a Blocks context.*" ) default_seed = 2024 default_image_denoise_steps = 4 default_image_ensemble_size = 1 default_image_processing_res = 768 default_image_reproducuble = True default_model_type = "appearance" default_share_always_show_hf_logout_btn = True default_share_always_show_accordion = False loaded_pipelines = {} # Cache to store loaded pipelines def process_with_loaded_pipeline( image_path, model_type=default_model_type, denoise_steps=default_image_denoise_steps, ensemble_size=default_image_ensemble_size, processing_res=default_image_processing_res, ): # Load and cache the pipeline based on the model type. if model_type not in loaded_pipelines.keys(): auth_token = os.environ.get("KEV_TOKEN") if model_type == "appearance": if "lighting" in loaded_pipelines.keys(): del loaded_pipelines[ "lighting" ] # to save GPU memory. Can be removed if enough memory is available for faster switching between models torch.cuda.empty_cache() loaded_pipelines[model_type] = ( MarigoldIIDAppearancePipeline.from_pretrained( "prs-eth/marigold-iid-appearance-v1-1", token=auth_token ) ) elif model_type == "lighting": if "appearance" in loaded_pipelines.keys(): del loaded_pipelines[ "appearance" ] # to save GPU memory. Can be removed if enough memory is available for faster switching between models torch.cuda.empty_cache() loaded_pipelines[model_type] = MarigoldIIDLightingPipeline.from_pretrained( "prs-eth/marigold-iid-lighting-v1-1", token=auth_token ) # Move the pipeline to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device) try: loaded_pipelines[model_type].enable_xformers_memory_efficient_attention() except: pass # run without xformers pipe = loaded_pipelines[model_type] # Process the image using the preloaded pipeline. return process_image( pipe=pipe, path_input=image_path, denoise_steps=denoise_steps, ensemble_size=ensemble_size, processing_res=processing_res, model_type=model_type, ) def process_image_check(path_input): if path_input is None: raise gr.Error( "Missing image in the first pane: upload a file or use one from the gallery below." ) def process_image( pipe, path_input, denoise_steps=default_image_denoise_steps, ensemble_size=default_image_ensemble_size, processing_res=default_image_processing_res, model_type=default_model_type, ): name_base, name_ext = os.path.splitext(os.path.basename(path_input)) print(f"Processing image {name_base}{name_ext}") path_output_dir = tempfile.mkdtemp() input_image = Image.open(path_input) pipe_out = pipe( input_image, denoising_steps=denoise_steps, ensemble_size=ensemble_size, processing_res=processing_res, batch_size=1 if processing_res == 0 else 0, # TODO: do we abuse "batch size" notation here? seed=default_seed, show_progress_bar=True, ) path_output_dir = os.path.splitext(path_input)[0] + "_output" os.makedirs(path_output_dir, exist_ok=True) if model_type == "appearance": path_albedo_out = os.path.join( path_output_dir, f"{name_base}_albedo_app_fp32.npy" ) path_albedo_out_vis = os.path.join( path_output_dir, f"{name_base}_albedo_app.png" ) path_material_out = os.path.join( path_output_dir, f"{name_base}_material_fp32.npy" ) path_material_out_vis = os.path.join( path_output_dir, f"{name_base}_material.png" ) albedo = pipe_out.albedo albedo_colored = pipe_out.albedo_colored material = pipe_out.material material_colored = pipe_out.material_colored np.save(path_albedo_out, albedo) albedo_colored.save(path_albedo_out_vis) np.save(path_material_out, material) material_colored.save(path_material_out_vis) return ( [path_input, path_albedo_out_vis], [path_input, path_material_out_vis], [path_input, path_material_out_vis], # placeholder which is not displayed [ path_albedo_out_vis, path_material_out_vis, path_albedo_out, path_material_out, ], ) elif model_type == "lighting": path_albedo_out = os.path.join( path_output_dir, f"{name_base}_albedo_res_fp32.npy" ) path_albedo_out_vis = os.path.join( path_output_dir, f"{name_base}_albedo_res.png" ) path_shading_out = os.path.join( path_output_dir, f"{name_base}_shading_fp32.npy" ) path_shading_out_vis = os.path.join(path_output_dir, f"{name_base}_shading.png") path_residual_out = os.path.join( path_output_dir, f"{name_base}_residual_fp32.npy" ) path_residual_out_vis = os.path.join( path_output_dir, f"{name_base}_residual.png" ) albedo = pipe_out.albedo albedo_colored = pipe_out.albedo_colored shading = pipe_out.shading shading_colored = pipe_out.shading_colored residual = pipe_out.residual residual_colored = pipe_out.residual_colored np.save(path_albedo_out, albedo) albedo_colored.save(path_albedo_out_vis) np.save(path_shading_out, shading) shading_colored.save(path_shading_out_vis) np.save(path_residual_out, residual) residual_colored.save(path_residual_out_vis) return ( [path_input, path_albedo_out_vis], [path_input, path_shading_out_vis], [path_input, path_residual_out_vis], [ path_albedo_out_vis, path_shading_out_vis, path_residual_out_vis, path_albedo_out, path_shading_out, path_residual_out, ], ) def run_demo_server(hf_writer=None): process_pipe_image = spaces.GPU( functools.partial(process_with_loaded_pipeline), duration=120 ) gradio_theme = gr.themes.Default() with gr.Blocks( theme=gradio_theme, title="Marigold Intrinsic Image Decomposition (Marigold-IID)", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """, head=""" """, ) as demo: if hf_writer is not None: print("Creating login button") share_login_btn = gr.LoginButton(size="sm", scale=1, render=False) print("Created login button") share_login_btn.activate() print("Activated login button") gr.Markdown( """ # Marigold Intrinsic Image Decomposition (IID)

badge-github-stars social

""" ) def get_share_instructions(is_full): out = ( "### Help us improve Marigold! If the output is not what you expected, " "you can help us by sharing it with us privately.\n" ) if is_full: out += ( "1. Sign into your Hugging Face account using the button below.\n" "1. Signing in may reset the demo and results; in that case, process the image again.\n" ) out += "1. Review and agree to the terms of usage and enter an optional message to us.\n" out += "1. Click the 'Share' button to submit the image to us privately.\n" return out def get_share_conditioned_on_login(profile: gr.OAuthProfile | None): state_logged_out = profile is None return get_share_instructions(is_full=state_logged_out), gr.Button( visible=(state_logged_out or default_share_always_show_hf_logout_btn) ) with gr.Row(): with gr.Column(): image_input = gr.Image( label="Input Image", type="filepath", ) model_type = gr.Radio( [ ("Appearance (albedo & material)", "appearance"), ("Lighting (albedo, shading & residual)", "lighting"), ], label="Model type: Marigold-IID-Appearance or Marigold IID-Lighting", value=default_model_type, ) with gr.Accordion("Advanced options", open=True): image_ensemble_size = gr.Slider( label="Ensemble size", minimum=1, maximum=5, step=1, value=default_image_ensemble_size, ) image_denoise_steps = gr.Slider( label="Number of denoising steps", minimum=1, maximum=10, step=1, value=default_image_denoise_steps, ) image_processing_res = gr.Radio( [ ("Native", 0), ("Recommended", 768), ], label="Processing resolution", value=default_image_processing_res, ) with gr.Row(): image_submit_btn = gr.Button(value="Compute IID", variant="primary") image_reset_btn = gr.Button(value="Reset") with gr.Column(): image_output_slider1 = ImageSlider( label="Predicted Albedo", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, visible=True, ) image_output_slider2 = ImageSlider( label="Predicted Material", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, visible=True, ) image_output_slider3 = ImageSlider( label="Predicted Residual", type="filepath", show_download_button=True, show_share_button=True, interactive=False, elem_classes="slider", position=0.25, visible=False, ) image_output_files = gr.Files( label="Output files", elem_id="download", interactive=False, ) if hf_writer is not None: with gr.Accordion( "Feedback", open=False, visible=default_share_always_show_accordion, ) as share_box: share_instructions = gr.Markdown( get_share_instructions(is_full=True), elem_classes="md_feedback", ) share_transfer_of_rights = gr.Checkbox( label="(Optional) I own or hold necessary rights to the submitted image. By " "checking this box, I grant an irrevocable, non-exclusive, transferable, " "royalty-free, worldwide license to use the uploaded image, including for " "publishing, reproducing, and model training. [transfer_of_rights]", scale=1, ) share_content_is_legal = gr.Checkbox( label="By checking this box, I acknowledge that my uploaded content is legal and " "safe, and that I am solely responsible for ensuring it complies with all " "applicable laws and regulations. Additionally, I am aware that my Hugging Face " "username is collected. [content_is_legal]", scale=1, ) share_reason = gr.Textbox( label="(Optional) Reason for feedback", max_lines=1, interactive=True, ) with gr.Row(): share_login_btn.render() share_share_btn = gr.Button( "Share", variant="stop", scale=1 ) # Function to toggle visibility and set dynamic labels def toggle_sliders_and_labels(model_type): if model_type == "appearance": return ( gr.update(visible=True, label="Predicted Albedo"), gr.update(visible=True, label="Predicted Material"), gr.update(visible=False), # Hide third slider ) elif model_type == "lighting": return ( gr.update(visible=True, label="Predicted Albedo"), gr.update(visible=True, label="Predicted Shading"), gr.update(visible=True, label="Predicted Residual"), ) # Attach the change event to update sliders model_type.change( fn=toggle_sliders_and_labels, inputs=[model_type], outputs=[image_output_slider1, image_output_slider2, image_output_slider3], show_progress=False, ) Examples( fn=process_pipe_image, examples=[ [os.path.join("files", "image", name), _model_type] for name in [ "livingroom.jpg", "books.jpg", "food_counter.png", "cat2.png", "costumes.png", "icecream.jpg", "juices.jpeg", "cat.jpg", "food.jpeg", "puzzle.jpeg", "screw.png", ] for _model_type in ["appearance", "lighting"] ], inputs=[image_input, model_type], outputs=[ image_output_slider1, image_output_slider2, image_output_slider3, image_output_files, ], cache_examples=True, # TODO: toggle later directory_name="examples_images", ) ### Image tab if hf_writer is not None: image_submit_btn.click( fn=process_image_check, inputs=image_input, outputs=None, preprocess=False, queue=False, ).success( get_share_conditioned_on_login, None, [share_instructions, share_login_btn], queue=False, ).then( lambda: ( gr.Button(value="Share", interactive=True), gr.Accordion(visible=True), False, False, "", ), None, [ share_share_btn, share_box, share_transfer_of_rights, share_content_is_legal, share_reason, ], queue=False, ).then( fn=process_pipe_image, inputs=[ image_input, model_type, image_denoise_steps, image_ensemble_size, image_processing_res, ], outputs=[ image_output_slider1, image_output_slider2, image_output_slider3, image_output_files, ], concurrency_limit=1, ) else: image_submit_btn.click( fn=process_image_check, inputs=image_input, outputs=None, preprocess=False, queue=False, ).success( fn=process_pipe_image, inputs=[ image_input, model_type, image_denoise_steps, image_ensemble_size, image_processing_res, ], outputs=[ image_output_slider1, image_output_slider2, image_output_slider3, image_output_files, ], concurrency_limit=1, ) image_reset_btn.click( fn=lambda: ( None, None, None, None, None, default_model_type, default_image_ensemble_size, default_image_denoise_steps, default_image_processing_res, ), inputs=[], outputs=[ image_input, image_output_slider1, image_output_slider2, image_output_slider3, image_output_files, model_type, image_ensemble_size, image_denoise_steps, image_processing_res, ], queue=False, ) if hf_writer is not None: image_reset_btn.click( fn=lambda: ( gr.Button(value="Share", interactive=True), gr.Accordion(visible=default_share_always_show_accordion), ), inputs=[], outputs=[ share_share_btn, share_box, ], queue=False, ) ### Share functionality if hf_writer is not None: share_components = [ image_input, image_denoise_steps, image_ensemble_size, image_processing_res, image_output_slider1, image_output_slider2, image_output_slider3, share_content_is_legal, share_transfer_of_rights, share_reason, ] hf_writer.setup(share_components, "shared_data") share_callback = FlagMethod(hf_writer, "Share", "", visual_feedback=True) def share_precheck( hf_content_is_legal, image_output_slider, profile: gr.OAuthProfile | None, ): if profile is None: raise gr.Error( "Log into the Space with your Hugging Face account first." ) if image_output_slider is None or image_output_slider[0] is None: raise gr.Error("No output detected; process the image first.") if not hf_content_is_legal: raise gr.Error( "You must consent that the uploaded content is legal." ) return gr.Button(value="Sharing in progress", interactive=False) share_share_btn.click( share_precheck, [share_content_is_legal, image_output_slider1], share_share_btn, preprocess=False, queue=False, ).success( share_callback, inputs=share_components, outputs=share_share_btn, preprocess=False, queue=False, ) demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, ) def main(): CROWD_DATA = "crowddata-marigold-iid-appearance-v1-1-space-v1-1" os.system("pip freeze") if "HF_TOKEN_LOGIN" in os.environ: login(token=os.environ["HF_TOKEN_LOGIN"]) hf_writer = None if "HF_TOKEN_LOGIN_WRITE_CROWD" in os.environ: hf_writer = HuggingFaceDatasetSaver( os.getenv("HF_TOKEN_LOGIN_WRITE_CROWD"), CROWD_DATA, private=True, info_filename="dataset_info.json", separate_dirs=True, ) run_demo_server(hf_writer) if __name__ == "__main__": main()