from __future__ import annotations import math from typing import Any, Dict, Optional import numpy as np import torch import torch.nn.functional as F import trimesh from jaxtyping import Float, Integer from torch import Tensor from spar3d.models.utils import dot try: from uv_unwrapper import Unwrapper except ImportError: import logging logging.warning( "Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`" ) # Exit early to avoid further errors raise ImportError("uv_unwrapper not found") try: import gpytoolbox TRIANGLE_REMESH_AVAILABLE = True except ImportError: TRIANGLE_REMESH_AVAILABLE = False import logging logging.warning( "Could not import gpytoolbox. Triangle remeshing functionality will be disabled. " "Install via `pip install gpytoolbox`" ) try: import pynim QUAD_REMESH_AVAILABLE = True except ImportError: QUAD_REMESH_AVAILABLE = False import logging logging.warning( "Could not import pynim. Quad remeshing functionality will be disabled. " "Install via `pip install git+https://github.com/vork/PyNanoInstantMeshes.git@v0.0.3`" ) class Mesh: def __init__( self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs ) -> None: self.v_pos: Float[Tensor, "Nv 3"] = v_pos self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None self._edges: Optional[Integer[Tensor, "Ne 2"]] = None self.extras: Dict[str, Any] = {} for k, v in kwargs.items(): self.add_extra(k, v) self.unwrapper = Unwrapper() def add_extra(self, k, v) -> None: self.extras[k] = v @property def requires_grad(self): return self.v_pos.requires_grad @property def v_nrm(self): if self._v_nrm is None: self._v_nrm = self._compute_vertex_normal() return self._v_nrm @property def v_tng(self): if self._v_tng is None: self._v_tng = self._compute_vertex_tangent() return self._v_tng @property def v_tex(self): if self._v_tex is None: self.unwrap_uv() return self._v_tex @property def edges(self): if self._edges is None: self._edges = self._compute_edges() return self._edges def _compute_vertex_normal(self): i0 = self.t_pos_idx[:, 0] i1 = self.t_pos_idx[:, 1] i2 = self.t_pos_idx[:, 2] v0 = self.v_pos[i0, :] v1 = self.v_pos[i1, :] v2 = self.v_pos[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # Splat face normals to vertices v_nrm = torch.zeros_like(self.v_pos) v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) # Normalize, replace zero (degenerated) normals with some default value v_nrm = torch.where( dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) ) v_nrm = F.normalize(v_nrm, dim=1) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(v_nrm)) return v_nrm def _compute_vertex_tangent(self): vn_idx = [None] * 3 pos = [None] * 3 tex = [None] * 3 for i in range(0, 3): pos[i] = self.v_pos[self.t_pos_idx[:, i]] tex[i] = self.v_tex[self.t_pos_idx[:, i]] # t_nrm_idx is always the same as t_pos_idx vn_idx[i] = self.t_pos_idx[:, i] tangents = torch.zeros_like(self.v_nrm) tansum = torch.zeros_like(self.v_nrm) # Compute tangent space for each triangle duv1 = tex[1] - tex[0] duv2 = tex[2] - tex[0] dpos1 = pos[1] - pos[0] dpos2 = pos[2] - pos[0] tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2] denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1] # Avoid division by zero for degenerated texture coordinates denom_safe = denom.clip(1e-6) tang = tng_nom / denom_safe # Update all 3 vertices for i in range(0, 3): idx = vn_idx[i][:, None].repeat(1, 3) tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang tansum.scatter_add_( 0, idx, torch.ones_like(tang) ) # tansum[n_i] = tansum[n_i] + 1 # Also normalize it. Here we do not normalize the individual triangles first so larger area # triangles influence the tangent space more tangents = tangents / tansum # Normalize and make sure tangent is perpendicular to normal tangents = F.normalize(tangents, dim=1) tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(tangents)) return tangents def quad_remesh( self, quad_vertex_count: int = -1, quad_rosy: int = 4, quad_crease_angle: float = -1.0, quad_smooth_iter: int = 2, quad_align_to_boundaries: bool = False, ) -> Mesh: if not QUAD_REMESH_AVAILABLE: raise ImportError("Quad remeshing requires pynim to be installed") if quad_vertex_count < 0: quad_vertex_count = self.v_pos.shape[0] v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32) t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32) new_vert, new_faces = pynim.remesh( v_pos, t_pos_idx, quad_vertex_count // 4, rosy=quad_rosy, posy=4, creaseAngle=quad_crease_angle, align_to_boundaries=quad_align_to_boundaries, smooth_iter=quad_smooth_iter, deterministic=False, ) # Briefly load in trimesh mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32)) v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous() t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous() # Create new mesh return Mesh(v_pos, t_pos_idx) def triangle_remesh( self, triangle_average_edge_length_multiplier: Optional[float] = None, triangle_remesh_steps: int = 10, triangle_vertex_count=-1, ): if not TRIANGLE_REMESH_AVAILABLE: raise ImportError("Triangle remeshing requires gpytoolbox to be installed") if triangle_vertex_count > 0: reduction = triangle_vertex_count / self.v_pos.shape[0] print("Triangle reduction:", reduction) v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32) t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32) if reduction > 1.0: subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2))) print("Subdivide iters:", subdivide_iters) v_pos, t_pos_idx = gpytoolbox.subdivide( v_pos, t_pos_idx, iters=subdivide_iters, ) reduction = triangle_vertex_count / v_pos.shape[0] # Simplify points_out, faces_out, _, _ = gpytoolbox.decimate( v_pos, t_pos_idx, face_ratio=reduction, ) # Convert back to torch self.v_pos = torch.from_numpy(points_out).to(self.v_pos) self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx) self._edges = None triangle_average_edge_length_multiplier = None edges = self.edges if triangle_average_edge_length_multiplier is None: h = None else: h = float( torch.linalg.norm( self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1 ) .mean() .item() * triangle_average_edge_length_multiplier ) # Convert to numpy v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64) t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32) # Remesh v_remesh, f_remesh = gpytoolbox.remesh_botsch( v_pos, t_pos_idx, triangle_remesh_steps, h, ) # Convert back to torch v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous() t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous() # Create new mesh return Mesh(v_pos, t_pos_idx) @torch.no_grad() def unwrap_uv( self, island_padding: float = 0.02, ) -> Mesh: uv, indices = self.unwrapper( self.v_pos, self.v_nrm, self.t_pos_idx, island_padding ) # Do store per vertex UVs. # This means we need to duplicate some vertices at the seams individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3) individual_faces = torch.arange( individual_vertices.shape[0], device=individual_vertices.device, dtype=self.t_pos_idx.dtype, ).reshape(-1, 3) uv_flat = uv[indices].reshape((-1, 2)) # uv_flat[:, 1] = 1 - uv_flat[:, 1] self.v_pos = individual_vertices self.t_pos_idx = individual_faces self._v_tex = uv_flat self._v_nrm = self._compute_vertex_normal() self._v_tng = self._compute_vertex_tangent() def _compute_edges(self): # Compute edges edges = torch.cat( [ self.t_pos_idx[:, [0, 1]], self.t_pos_idx[:, [1, 2]], self.t_pos_idx[:, [2, 0]], ], dim=0, ) edges = edges.sort()[0] edges = torch.unique(edges, dim=0) return edges