import logging import warnings import diffusers import numpy as np import torch from diffusers import MarigoldDepthPipeline warnings.simplefilter(action="ignore", category=FutureWarning) diffusers.utils.logging.disable_progress_bar() class MarigoldDepthCompletionPipeline(MarigoldDepthPipeline): def __call__( self, image, sparse_depth, num_inference_steps=50, processing_resolution=0, seed=2024, dry_run=False, ): # Resolving variables device = self._execution_device generator = torch.Generator(device=device).manual_seed(seed) if dry_run: logging.warning("Dry run mode") for i in range(num_inference_steps): yield np.array(image)[:, :, 0].astype(float), float(np.log(i + 1)) return # Check inputs. if num_inference_steps is None: raise ValueError("Invalid num_inference_steps") if type(sparse_depth) is not np.ndarray or sparse_depth.ndim != 2: raise ValueError( "Sparse depth should be a 2D numpy ndarray with zeros at missing positions" ) with torch.no_grad(): # Prepare empty text conditioning if self.empty_text_embedding is None: prompt = "" text_inputs = self.tokenizer( prompt, padding="do_not_pad", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) self.empty_text_embedding = self.text_encoder(text_input_ids)[ 0 ] # [1,2,1024] # Preprocess input images image, padding, original_resolution = self.image_processor.preprocess( image, processing_resolution=processing_resolution, device=device, dtype=self.dtype, ) # [N,3,PPH,PPW] if sparse_depth.shape != original_resolution: raise ValueError( f"Sparse depth dimensions ({sparse_depth.shape}) must match that of the image ({image.shape[-2:]})" ) with torch.no_grad(): # Encode input image into latent space image_latent, pred_latent = self.prepare_latents( image, None, generator, 1, 1 ) # [N*E,4,h,w], [N*E,4,h,w] del image # Preprocess sparse depth sparse_depth = torch.from_numpy(sparse_depth)[None, None].float() sparse_depth = sparse_depth.to(device) sparse_mask = sparse_depth > 0 # Set up optimization targets scale = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True) shift = torch.nn.Parameter(torch.ones(1, device=device), requires_grad=True) pred_latent = torch.nn.Parameter(pred_latent, requires_grad=True) sparse_range = ( sparse_depth[sparse_mask].max() - sparse_depth[sparse_mask].min() ).item() sparse_lower = (sparse_depth[sparse_mask].min()).item() def affine_to_metric(depth): return (scale**2) * sparse_range * depth + (shift**2) * sparse_lower def latent_to_metric(latent): affine_invariant_prediction = self.decode_prediction( latent ) # [E,1,PPH,PPW] prediction = affine_to_metric(affine_invariant_prediction) prediction = self.image_processor.unpad_image( prediction, padding ) # [E,1,PH,PW] prediction = self.image_processor.resize_antialias( prediction, original_resolution, "bilinear", is_aa=False ) # [1,1,H,W] return prediction def loss_l1l2(input, target): out_l1 = torch.nn.functional.l1_loss(input, target) out_l2 = torch.nn.functional.mse_loss(input, target) out = out_l1 + out_l2 return out, out_l2.sqrt() optimizer = torch.optim.Adam( [ {"params": [scale, shift], "lr": 0.005}, {"params": [pred_latent], "lr": 0.05}, ] ) # Process the denoising loop self.scheduler.set_timesteps(num_inference_steps, device=device) for iter, t in enumerate( self.progress_bar( self.scheduler.timesteps, desc=f"Marigold-DC steps ({str(device)})..." ) ): optimizer.zero_grad() batch_latent = torch.cat([image_latent, pred_latent], dim=1) # [1,8,h,w] noise = self.unet( batch_latent, t, encoder_hidden_states=self.empty_text_embedding, return_dict=False, )[ 0 ] # [1,4,h,w] # Compute pred_epsilon to later rescale the depth latent gradient with torch.no_grad(): alpha_prod_t = self.scheduler.alphas_cumprod[t] beta_prod_t = 1 - alpha_prod_t pred_epsilon = (alpha_prod_t**0.5) * noise + ( beta_prod_t**0.5 ) * pred_latent step_output = self.scheduler.step( noise, t, pred_latent, generator=generator ) # Preview the final output depth, compute loss with guidance, backprop pred_original_sample = step_output.pred_original_sample current_metric_estimate = latent_to_metric(pred_original_sample) loss, rmse = loss_l1l2( current_metric_estimate[sparse_mask], sparse_depth[sparse_mask] ) loss.backward() # Scale gradients up with torch.no_grad(): pred_epsilon_norm = torch.linalg.norm(pred_epsilon).item() depth_latent_grad_norm = torch.linalg.norm(pred_latent.grad).item() scaling_factor = pred_epsilon_norm / max(depth_latent_grad_norm, 1e-8) pred_latent.grad *= scaling_factor optimizer.step() with torch.no_grad(): pred_latent.data = self.scheduler.step( noise, t, pred_latent, generator=generator ).prev_sample yield current_metric_estimate, rmse.item() del ( pred_original_sample, current_metric_estimate, step_output, pred_epsilon, noise, ) torch.cuda.empty_cache() # Offload all models self.maybe_free_model_hooks()