jammmmm's picture
Add spar3d demo files
38dbec8
raw
history blame
4.79 kB
# --------------------------------------------------------
# Adapted from: https://github.com/openai/point-e
# Licensed under the MIT License
# Copyright (c) 2022 OpenAI
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# --------------------------------------------------------
from typing import Dict, Iterator
import torch
import torch.nn as nn
from .gaussian_diffusion import GaussianDiffusion
class PointCloudSampler:
"""
A wrapper around a model that produces conditional sample tensors.
"""
def __init__(
self,
model: nn.Module,
diffusion: GaussianDiffusion,
num_points: int,
point_dim: int = 3,
guidance_scale: float = 3.0,
clip_denoised: bool = True,
sigma_min: float = 1e-3,
sigma_max: float = 120,
s_churn: float = 3,
):
self.model = model
self.num_points = num_points
self.point_dim = point_dim
self.guidance_scale = guidance_scale
self.clip_denoised = clip_denoised
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.s_churn = s_churn
self.diffusion = diffusion
def sample_batch_progressive(
self,
batch_size: int,
condition: torch.Tensor,
noise=None,
device=None,
guidance_scale=None,
) -> Iterator[Dict[str, torch.Tensor]]:
"""
Generate samples progressively using classifier-free guidance.
Args:
batch_size: Number of samples to generate
condition: Conditioning tensor
noise: Optional initial noise tensor
device: Device to run on
guidance_scale: Optional override for guidance scale
Returns:
Iterator of dicts containing intermediate samples
"""
if guidance_scale is None:
guidance_scale = self.guidance_scale
sample_shape = (batch_size, self.point_dim, self.num_points)
# Double the batch for classifier-free guidance
if guidance_scale != 1 and guidance_scale != 0:
condition = torch.cat([condition, torch.zeros_like(condition)], dim=0)
if noise is not None:
noise = torch.cat([noise, noise], dim=0)
model_kwargs = {"condition": condition}
internal_batch_size = batch_size
if guidance_scale != 1 and guidance_scale != 0:
model = self._uncond_guide_model(self.model, guidance_scale)
internal_batch_size *= 2
else:
model = self.model
samples_it = self.diffusion.ddim_sample_loop_progressive(
model,
shape=(internal_batch_size, *sample_shape[1:]),
model_kwargs=model_kwargs,
device=device,
clip_denoised=self.clip_denoised,
noise=noise,
)
for x in samples_it:
samples = {
"xstart": x["pred_xstart"][:batch_size],
"xprev": x["sample"][:batch_size] if "sample" in x else x["x"],
}
yield samples
def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module:
"""
Wraps the model for classifier-free guidance.
"""
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :]
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
return model_fn