jammmmm's picture
Add spar3d demo files
38dbec8
raw
history blame
27 kB
import os
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import trimesh
from einops import rearrange
from huggingface_hub import hf_hub_download
from jaxtyping import Float
from omegaconf import OmegaConf
from PIL import Image
from safetensors.torch import load_model
from torch import Tensor
from spar3d.models.diffusion.gaussian_diffusion import (
SpacedDiffusion,
get_named_beta_schedule,
space_timesteps,
)
from spar3d.models.diffusion.sampler import PointCloudSampler
from spar3d.models.isosurface import MarchingTetrahedraHelper
from spar3d.models.mesh import Mesh
from spar3d.models.utils import (
BaseModule,
ImageProcessor,
convert_data,
dilate_fill,
find_class,
float32_to_uint8_np,
normalize,
scale_tensor,
)
from spar3d.utils import (
create_intrinsic_from_fov_rad,
default_cond_c2w,
get_device,
normalize_pc_bbox,
)
try:
from texture_baker import TextureBaker
except ImportError:
import logging
logging.warning(
"Could not import texture_baker. Please install it via `pip install texture-baker/`"
)
# Exit early to avoid further errors
raise ImportError("texture_baker not found")
class SPAR3D(BaseModule):
@dataclass
class Config(BaseModule.Config):
cond_image_size: int
isosurface_resolution: int
isosurface_threshold: float = 10.0
radius: float = 1.0
background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
default_fovy_rad: float = 0.591627
default_distance: float = 2.2
camera_embedder_cls: str = ""
camera_embedder: dict = field(default_factory=dict)
image_tokenizer_cls: str = ""
image_tokenizer: dict = field(default_factory=dict)
point_embedder_cls: str = ""
point_embedder: dict = field(default_factory=dict)
tokenizer_cls: str = ""
tokenizer: dict = field(default_factory=dict)
backbone_cls: str = ""
backbone: dict = field(default_factory=dict)
post_processor_cls: str = ""
post_processor: dict = field(default_factory=dict)
decoder_cls: str = ""
decoder: dict = field(default_factory=dict)
image_estimator_cls: str = ""
image_estimator: dict = field(default_factory=dict)
global_estimator_cls: str = ""
global_estimator: dict = field(default_factory=dict)
# Point diffusion modules
pdiff_camera_embedder_cls: str = ""
pdiff_camera_embedder: dict = field(default_factory=dict)
pdiff_image_tokenizer_cls: str = ""
pdiff_image_tokenizer: dict = field(default_factory=dict)
pdiff_backbone_cls: str = ""
pdiff_backbone: dict = field(default_factory=dict)
scale_factor_xyz: float = 1.0
scale_factor_rgb: float = 1.0
bias_xyz: float = 0.0
bias_rgb: float = 0.0
train_time_steps: int = 1024
inference_time_steps: int = 64
mean_type: str = "epsilon"
var_type: str = "fixed_small"
diffu_sched: str = "cosine"
diffu_sched_exp: float = 12.0
guidance_scale: float = 3.0
sigma_max: float = 120.0
s_churn: float = 3.0
cfg: Config
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
):
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if os.path.isdir(os.path.join(base_dir, pretrained_model_name_or_path)):
config_path = os.path.join(
base_dir, pretrained_model_name_or_path, config_name
)
weight_path = os.path.join(
base_dir, pretrained_model_name_or_path, weight_name
)
else:
config_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=config_name
)
weight_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=weight_name
)
cfg = OmegaConf.load(config_path)
OmegaConf.resolve(cfg)
model = cls(cfg)
load_model(model, weight_path, strict=False)
return model
@property
def device(self):
return next(self.parameters()).device
def configure(self):
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
self.cfg.image_tokenizer
)
self.point_embedder = find_class(self.cfg.point_embedder_cls)(
self.cfg.point_embedder
)
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
self.cfg.camera_embedder
)
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
self.post_processor = find_class(self.cfg.post_processor_cls)(
self.cfg.post_processor
)
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
self.image_estimator = find_class(self.cfg.image_estimator_cls)(
self.cfg.image_estimator
)
self.global_estimator = find_class(self.cfg.global_estimator_cls)(
self.cfg.global_estimator
)
# point diffusion modules
self.pdiff_image_tokenizer = find_class(self.cfg.pdiff_image_tokenizer_cls)(
self.cfg.pdiff_image_tokenizer
)
self.pdiff_camera_embedder = find_class(self.cfg.pdiff_camera_embedder_cls)(
self.cfg.pdiff_camera_embedder
)
self.pdiff_backbone = find_class(self.cfg.pdiff_backbone_cls)(
self.cfg.pdiff_backbone
)
self.bbox: Float[Tensor, "2 3"]
self.register_buffer(
"bbox",
torch.as_tensor(
[
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
],
dtype=torch.float32,
),
)
self.isosurface_helper = MarchingTetrahedraHelper(
self.cfg.isosurface_resolution,
os.path.join(
os.path.dirname(__file__),
"..",
"load",
"tets",
f"{self.cfg.isosurface_resolution}_tets.npz",
),
)
self.baker = TextureBaker()
self.image_processor = ImageProcessor()
channel_scales = [self.cfg.scale_factor_xyz] * 3
channel_scales += [self.cfg.scale_factor_rgb] * 3
channel_biases = [self.cfg.bias_xyz] * 3
channel_biases += [self.cfg.bias_rgb] * 3
channel_scales = np.array(channel_scales)
channel_biases = np.array(channel_biases)
betas = get_named_beta_schedule(
self.cfg.diffu_sched, self.cfg.train_time_steps, self.cfg.diffu_sched_exp
)
diffusion_kwargs = dict(
betas=betas,
model_mean_type=self.cfg.mean_type,
model_var_type=self.cfg.var_type,
channel_scales=channel_scales,
channel_biases=channel_biases,
)
self.diffusion_spaced = SpacedDiffusion(
use_timesteps=space_timesteps(
self.cfg.train_time_steps,
"ddim" + str(self.cfg.inference_time_steps),
),
**diffusion_kwargs,
)
self.sampler = PointCloudSampler(
model=self.pdiff_backbone,
diffusion=self.diffusion_spaced,
num_points=512,
point_dim=6,
guidance_scale=self.cfg.guidance_scale,
clip_denoised=True,
sigma_min=1e-3,
sigma_max=self.cfg.sigma_max,
s_churn=self.cfg.s_churn,
)
def triplane_to_meshes(
self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
) -> list[Mesh]:
meshes = []
for i in range(triplanes.shape[0]):
triplane = triplanes[i]
grid_vertices = scale_tensor(
self.isosurface_helper.grid_vertices.to(triplanes.device),
self.isosurface_helper.points_range,
self.bbox,
)
values = self.query_triplane(grid_vertices, triplane)
decoded = self.decoder(values, include=["vertex_offset", "density"])
sdf = decoded["density"] - self.cfg.isosurface_threshold
deform = decoded["vertex_offset"].squeeze(0)
mesh: Mesh = self.isosurface_helper(
sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
)
mesh.v_pos = scale_tensor(
mesh.v_pos, self.isosurface_helper.points_range, self.bbox
)
meshes.append(mesh)
return meshes
def query_triplane(
self,
positions: Float[Tensor, "*B N 3"],
triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
) -> Float[Tensor, "*B N F"]:
batched = positions.ndim == 3
if not batched:
# no batch dimension
triplanes = triplanes[None, ...]
positions = positions[None, ...]
assert triplanes.ndim == 5 and positions.ndim == 3
positions = scale_tensor(
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
)
indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
(positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
dim=-3,
).to(triplanes.dtype)
out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
align_corners=True,
mode="bilinear",
)
out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
return out
def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
# if batch[rgb_cond] is only one view, add a view dimension
if len(batch["rgb_cond"].shape) == 4:
batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
batch_size, n_input_views = batch["rgb_cond"].shape[:2]
camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
camera_embeds = self.camera_embedder(**batch)
pc_embeds = self.point_embedder(batch["pc_cond"])
input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
modulation_cond=camera_embeds,
)
input_image_tokens = rearrange(
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
)
tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
cross_tokens = input_image_tokens
cross_tokens = torch.cat([cross_tokens, pc_embeds], dim=1)
tokens = self.backbone(
tokens,
encoder_hidden_states=cross_tokens,
modulation_cond=None,
)
direct_codes = self.tokenizer.detokenize(tokens)
scene_codes = self.post_processor(direct_codes)
return scene_codes, direct_codes
def forward_pdiff_cond(self, batch: Dict[str, Any]) -> Dict[str, Any]:
if len(batch["rgb_cond"].shape) == 4:
batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
_batch_size, n_input_views = batch["rgb_cond"].shape[:2]
# Camera modulation
camera_embeds: Float[Tensor, "B Nv Cc"] = self.pdiff_camera_embedder(**batch)
input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.pdiff_image_tokenizer(
rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
modulation_cond=camera_embeds,
)
input_image_tokens = rearrange(
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
)
return input_image_tokens
def run_image(
self,
image: Union[Image.Image, List[Image.Image]],
bake_resolution: int,
pointcloud: Optional[Union[List[np.ndarray], np.ndarray, Tensor]] = None,
remesh: Literal["none", "triangle", "quad"] = "none",
vertex_count: int = -1,
estimate_illumination: bool = False,
return_points: bool = False,
) -> Tuple[Union[trimesh.Trimesh, List[trimesh.Trimesh]], dict[str, Any]]:
if isinstance(image, list):
rgb_cond = []
mask_cond = []
for img in image:
mask, rgb = self.prepare_image(img)
mask_cond.append(mask)
rgb_cond.append(rgb)
rgb_cond = torch.stack(rgb_cond, 0)
mask_cond = torch.stack(mask_cond, 0)
batch_size = rgb_cond.shape[0]
else:
mask_cond, rgb_cond = self.prepare_image(image)
batch_size = 1
c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_rad(
self.cfg.default_fovy_rad,
self.cfg.cond_image_size,
self.cfg.cond_image_size,
)
batch = {
"rgb_cond": rgb_cond,
"mask_cond": mask_cond,
"c2w_cond": c2w_cond.view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1),
"intrinsic_cond": intrinsic.to(self.device)
.view(1, 1, 3, 3)
.repeat(batch_size, 1, 1, 1),
"intrinsic_normed_cond": intrinsic_normed_cond.to(self.device)
.view(1, 1, 3, 3)
.repeat(batch_size, 1, 1, 1),
}
meshes, global_dict = self.generate_mesh(
batch,
bake_resolution,
pointcloud,
remesh,
vertex_count,
estimate_illumination,
)
if return_points:
point_clouds = []
for i in range(batch_size):
xyz = batch["pc_cond"][i, :, :3].cpu().numpy()
color_rgb = (
(batch["pc_cond"][i, :, 3:6] * 255).cpu().numpy().astype(np.uint8)
)
pc_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb)
point_clouds.append(pc_trimesh)
global_dict["point_clouds"] = point_clouds
if batch_size == 1:
return meshes[0], global_dict
else:
return meshes, global_dict
def prepare_image(self, image):
if image.mode != "RGBA":
raise ValueError("Image must be in RGBA mode")
img_cond = (
torch.from_numpy(
np.asarray(
image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
).astype(np.float32)
/ 255.0
)
.float()
.clip(0, 1)
.to(self.device)
)
mask_cond = img_cond[:, :, -1:]
rgb_cond = torch.lerp(
torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
img_cond[:, :, :3],
mask_cond,
)
return mask_cond, rgb_cond
def generate_mesh(
self,
batch,
bake_resolution: int,
pointcloud: Optional[Union[List[float], np.ndarray, Tensor]] = None,
remesh: Literal["none", "triangle", "quad"] = "none",
vertex_count: int = -1,
estimate_illumination: bool = False,
) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
batch["rgb_cond"] = self.image_processor(
batch["rgb_cond"], self.cfg.cond_image_size
)
batch["mask_cond"] = self.image_processor(
batch["mask_cond"], self.cfg.cond_image_size
)
batch_size = batch["rgb_cond"].shape[0]
if pointcloud is not None:
if isinstance(pointcloud, list):
cond_tensor = torch.tensor(pointcloud).float().cuda().view(-1, 6)
xyz = cond_tensor[:, :3]
color_rgb = cond_tensor[:, 3:]
# Check if point cloud is a numpy array
elif isinstance(pointcloud, np.ndarray):
xyz = torch.tensor(pointcloud[:, :3]).float().cuda()
color_rgb = torch.tensor(pointcloud[:, 3:]).float().cuda()
else:
raise ValueError("Invalid point cloud type")
pointcloud = torch.cat([xyz, color_rgb], dim=-1).unsqueeze(0)
batch["pc_cond"] = pointcloud
if "pc_cond" not in batch:
cond_tokens = self.forward_pdiff_cond(batch)
sample_iter = self.sampler.sample_batch_progressive(
batch_size, cond_tokens, device=self.device
)
for x in sample_iter:
samples = x["xstart"]
denoised_pc = samples.permute(0, 2, 1).float() # [B, C, N] -> [B, N, C]
denoised_pc = normalize_pc_bbox(denoised_pc)
# predict the full 3D conditioned on the denoised point cloud
batch["pc_cond"] = denoised_pc
scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
global_dict = {}
if self.image_estimator is not None:
global_dict.update(
self.image_estimator(
torch.cat([batch["rgb_cond"], batch["mask_cond"]], dim=-1)
)
)
if self.global_estimator is not None and estimate_illumination:
global_dict.update(self.global_estimator(non_postprocessed_codes))
global_dict["pointcloud"] = batch["pc_cond"]
device = get_device()
with torch.no_grad():
with (
torch.autocast(device_type=device, enabled=False)
if "cuda" in device
else nullcontext()
):
meshes = self.triplane_to_meshes(scene_codes)
rets = []
for i, mesh in enumerate(meshes):
# Check for empty mesh
if mesh.v_pos.shape[0] == 0:
rets.append(trimesh.Trimesh())
continue
if remesh == "triangle":
mesh = mesh.triangle_remesh(triangle_vertex_count=vertex_count)
elif remesh == "quad":
mesh = mesh.quad_remesh(quad_vertex_count=vertex_count)
else:
if vertex_count > 0:
print(
"Warning: vertex_count is ignored when remesh is none"
)
if remesh != "none":
print(
f"After {remesh} remesh the mesh has {mesh.v_pos.shape[0]} verts and {mesh.t_pos_idx.shape[0]} faces",
)
mesh.unwrap_uv()
# Build textures
rast = self.baker.rasterize(
mesh.v_tex, mesh.t_pos_idx, bake_resolution
)
bake_mask = self.baker.get_mask(rast)
pos_bake = self.baker.interpolate(
mesh.v_pos,
rast,
mesh.t_pos_idx,
)
gb_pos = pos_bake[bake_mask]
tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
decoded = self.decoder(
tri_query, exclude=["density", "vertex_offset"]
)
nrm = self.baker.interpolate(
mesh.v_nrm,
rast,
mesh.t_pos_idx,
)
gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
decoded["normal"] = gb_nrm
# Check if any keys in global_dict start with decoded_
for k, v in global_dict.items():
if k.startswith("decoder_"):
decoded[k.replace("decoder_", "")] = v[i]
mat_out = {
"albedo": decoded["features"],
"roughness": decoded["roughness"],
"metallic": decoded["metallic"],
"normal": normalize(decoded["perturb_normal"]),
"bump": None,
}
for k, v in mat_out.items():
if v is None:
continue
if v.shape[0] == 1:
# Skip and directly add a single value
mat_out[k] = v[0]
else:
f = torch.zeros(
bake_resolution,
bake_resolution,
v.shape[-1],
dtype=v.dtype,
device=v.device,
)
if v.shape == f.shape:
continue
if k == "normal":
# Use un-normalized tangents here so that larger smaller tris
# Don't effect the tangents that much
tng = self.baker.interpolate(
mesh.v_tng,
rast,
mesh.t_pos_idx,
)
gb_tng = tng[bake_mask]
gb_tng = F.normalize(gb_tng, dim=-1)
gb_btng = F.normalize(
torch.cross(gb_nrm, gb_tng, dim=-1), dim=-1
)
normal = F.normalize(mat_out["normal"], dim=-1)
# Create tangent space matrix and transform normal
tangent_matrix = torch.stack(
[gb_tng, gb_btng, gb_nrm], dim=-1
)
normal_tangent = torch.bmm(
tangent_matrix.transpose(1, 2), normal.unsqueeze(-1)
).squeeze(-1)
# Convert from [-1,1] to [0,1] range for storage
normal_tangent = (normal_tangent * 0.5 + 0.5).clamp(
0, 1
)
f[bake_mask] = normal_tangent.view(-1, 3)
mat_out["bump"] = f
else:
f[bake_mask] = v.view(-1, v.shape[-1])
mat_out[k] = f
def uv_padding(arr):
if arr.ndim == 1:
return arr
return (
dilate_fill(
arr.permute(2, 0, 1)[None, ...].contiguous(),
bake_mask.unsqueeze(0).unsqueeze(0),
iterations=bake_resolution // 150,
)
.squeeze(0)
.permute(1, 2, 0)
.contiguous()
)
verts_np = convert_data(mesh.v_pos)
faces = convert_data(mesh.t_pos_idx)
uvs = convert_data(mesh.v_tex)
basecolor_tex = Image.fromarray(
float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
).convert("RGB")
basecolor_tex.format = "JPEG"
metallic = mat_out["metallic"].squeeze().cpu().item()
roughness = mat_out["roughness"].squeeze().cpu().item()
if "bump" in mat_out and mat_out["bump"] is not None:
bump_np = convert_data(uv_padding(mat_out["bump"]))
bump_up = np.ones_like(bump_np)
bump_up[..., :2] = 0.5
bump_up[..., 2:] = 1
bump_tex = Image.fromarray(
float32_to_uint8_np(
bump_np,
dither=True,
# Do not dither if something is perfectly flat
dither_mask=np.all(
bump_np == bump_up, axis=-1, keepdims=True
).astype(np.float32),
)
).convert("RGB")
bump_tex.format = (
"JPEG" # PNG would be better but the assets are larger
)
else:
bump_tex = None
material = trimesh.visual.material.PBRMaterial(
baseColorTexture=basecolor_tex,
roughnessFactor=roughness,
metallicFactor=metallic,
normalTexture=bump_tex,
)
tmesh = trimesh.Trimesh(
vertices=verts_np,
faces=faces,
visual=trimesh.visual.texture.TextureVisuals(
uv=uvs, material=material
),
)
rot = trimesh.transformations.rotation_matrix(
np.radians(-90), [1, 0, 0]
)
tmesh.apply_transform(rot)
tmesh.apply_transform(
trimesh.transformations.rotation_matrix(
np.radians(90), [0, 1, 0]
)
)
tmesh.invert()
rets.append(tmesh)
return rets, global_dict