|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Iterator |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .gaussian_diffusion import GaussianDiffusion |
|
|
|
|
|
class PointCloudSampler: |
|
""" |
|
A wrapper around a model that produces conditional sample tensors. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: nn.Module, |
|
diffusion: GaussianDiffusion, |
|
num_points: int, |
|
point_dim: int = 3, |
|
guidance_scale: float = 3.0, |
|
clip_denoised: bool = True, |
|
sigma_min: float = 1e-3, |
|
sigma_max: float = 120, |
|
s_churn: float = 3, |
|
): |
|
self.model = model |
|
self.num_points = num_points |
|
self.point_dim = point_dim |
|
self.guidance_scale = guidance_scale |
|
self.clip_denoised = clip_denoised |
|
self.sigma_min = sigma_min |
|
self.sigma_max = sigma_max |
|
self.s_churn = s_churn |
|
|
|
self.diffusion = diffusion |
|
|
|
def sample_batch_progressive( |
|
self, |
|
batch_size: int, |
|
condition: torch.Tensor, |
|
noise=None, |
|
device=None, |
|
guidance_scale=None, |
|
) -> Iterator[Dict[str, torch.Tensor]]: |
|
""" |
|
Generate samples progressively using classifier-free guidance. |
|
|
|
Args: |
|
batch_size: Number of samples to generate |
|
condition: Conditioning tensor |
|
noise: Optional initial noise tensor |
|
device: Device to run on |
|
guidance_scale: Optional override for guidance scale |
|
|
|
Returns: |
|
Iterator of dicts containing intermediate samples |
|
""" |
|
if guidance_scale is None: |
|
guidance_scale = self.guidance_scale |
|
|
|
sample_shape = (batch_size, self.point_dim, self.num_points) |
|
|
|
|
|
if guidance_scale != 1 and guidance_scale != 0: |
|
condition = torch.cat([condition, torch.zeros_like(condition)], dim=0) |
|
if noise is not None: |
|
noise = torch.cat([noise, noise], dim=0) |
|
model_kwargs = {"condition": condition} |
|
|
|
internal_batch_size = batch_size |
|
if guidance_scale != 1 and guidance_scale != 0: |
|
model = self._uncond_guide_model(self.model, guidance_scale) |
|
internal_batch_size *= 2 |
|
else: |
|
model = self.model |
|
|
|
samples_it = self.diffusion.ddim_sample_loop_progressive( |
|
model, |
|
shape=(internal_batch_size, *sample_shape[1:]), |
|
model_kwargs=model_kwargs, |
|
device=device, |
|
clip_denoised=self.clip_denoised, |
|
noise=noise, |
|
) |
|
|
|
for x in samples_it: |
|
samples = { |
|
"xstart": x["pred_xstart"][:batch_size], |
|
"xprev": x["sample"][:batch_size] if "sample" in x else x["x"], |
|
} |
|
yield samples |
|
|
|
def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module: |
|
""" |
|
Wraps the model for classifier-free guidance. |
|
""" |
|
|
|
def model_fn(x_t, ts, **kwargs): |
|
half = x_t[: len(x_t) // 2] |
|
combined = torch.cat([half, half], dim=0) |
|
model_out = model(combined, ts, **kwargs) |
|
|
|
eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :] |
|
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) |
|
half_eps = uncond_eps + scale * (cond_eps - uncond_eps) |
|
eps = torch.cat([half_eps, half_eps], dim=0) |
|
return torch.cat([eps, rest], dim=1) |
|
|
|
return model_fn |
|
|