# -------------------------------------------------------- # Adapted from: https://github.com/openai/point-e # Licensed under the MIT License # Copyright (c) 2022 OpenAI # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # -------------------------------------------------------- from typing import Dict, Iterator import torch import torch.nn as nn from .gaussian_diffusion import GaussianDiffusion class PointCloudSampler: """ A wrapper around a model that produces conditional sample tensors. """ def __init__( self, model: nn.Module, diffusion: GaussianDiffusion, num_points: int, point_dim: int = 3, guidance_scale: float = 3.0, clip_denoised: bool = True, sigma_min: float = 1e-3, sigma_max: float = 120, s_churn: float = 3, ): self.model = model self.num_points = num_points self.point_dim = point_dim self.guidance_scale = guidance_scale self.clip_denoised = clip_denoised self.sigma_min = sigma_min self.sigma_max = sigma_max self.s_churn = s_churn self.diffusion = diffusion def sample_batch_progressive( self, batch_size: int, condition: torch.Tensor, noise=None, device=None, guidance_scale=None, ) -> Iterator[Dict[str, torch.Tensor]]: """ Generate samples progressively using classifier-free guidance. Args: batch_size: Number of samples to generate condition: Conditioning tensor noise: Optional initial noise tensor device: Device to run on guidance_scale: Optional override for guidance scale Returns: Iterator of dicts containing intermediate samples """ if guidance_scale is None: guidance_scale = self.guidance_scale sample_shape = (batch_size, self.point_dim, self.num_points) # Double the batch for classifier-free guidance if guidance_scale != 1 and guidance_scale != 0: condition = torch.cat([condition, torch.zeros_like(condition)], dim=0) if noise is not None: noise = torch.cat([noise, noise], dim=0) model_kwargs = {"condition": condition} internal_batch_size = batch_size if guidance_scale != 1 and guidance_scale != 0: model = self._uncond_guide_model(self.model, guidance_scale) internal_batch_size *= 2 else: model = self.model samples_it = self.diffusion.ddim_sample_loop_progressive( model, shape=(internal_batch_size, *sample_shape[1:]), model_kwargs=model_kwargs, device=device, clip_denoised=self.clip_denoised, noise=noise, ) for x in samples_it: samples = { "xstart": x["pred_xstart"][:batch_size], "xprev": x["sample"][:batch_size] if "sample" in x else x["x"], } yield samples def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module: """ Wraps the model for classifier-free guidance. """ def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // 2] combined = torch.cat([half, half], dim=0) model_out = model(combined, ts, **kwargs) eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :] cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) half_eps = uncond_eps + scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) return model_fn