|
from __future__ import annotations |
|
|
|
import math |
|
from typing import Any, Dict, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import trimesh |
|
from jaxtyping import Float, Integer |
|
from torch import Tensor |
|
|
|
from spar3d.models.utils import dot |
|
|
|
try: |
|
from uv_unwrapper import Unwrapper |
|
except ImportError: |
|
import logging |
|
|
|
logging.warning( |
|
"Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`" |
|
) |
|
|
|
raise ImportError("uv_unwrapper not found") |
|
|
|
try: |
|
import gpytoolbox |
|
|
|
TRIANGLE_REMESH_AVAILABLE = True |
|
except ImportError: |
|
TRIANGLE_REMESH_AVAILABLE = False |
|
import logging |
|
|
|
logging.warning( |
|
"Could not import gpytoolbox. Triangle remeshing functionality will be disabled. " |
|
"Install via `pip install gpytoolbox`" |
|
) |
|
|
|
try: |
|
import pynim |
|
|
|
QUAD_REMESH_AVAILABLE = True |
|
except ImportError: |
|
QUAD_REMESH_AVAILABLE = False |
|
import logging |
|
|
|
logging.warning( |
|
"Could not import pynim. Quad remeshing functionality will be disabled. " |
|
"Install via `pip install git+https://github.com/vork/[email protected]`" |
|
) |
|
|
|
|
|
class Mesh: |
|
def __init__( |
|
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs |
|
) -> None: |
|
self.v_pos: Float[Tensor, "Nv 3"] = v_pos |
|
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx |
|
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None |
|
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None |
|
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None |
|
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None |
|
self.extras: Dict[str, Any] = {} |
|
for k, v in kwargs.items(): |
|
self.add_extra(k, v) |
|
|
|
self.unwrapper = Unwrapper() |
|
|
|
def add_extra(self, k, v) -> None: |
|
self.extras[k] = v |
|
|
|
@property |
|
def requires_grad(self): |
|
return self.v_pos.requires_grad |
|
|
|
@property |
|
def v_nrm(self): |
|
if self._v_nrm is None: |
|
self._v_nrm = self._compute_vertex_normal() |
|
return self._v_nrm |
|
|
|
@property |
|
def v_tng(self): |
|
if self._v_tng is None: |
|
self._v_tng = self._compute_vertex_tangent() |
|
return self._v_tng |
|
|
|
@property |
|
def v_tex(self): |
|
if self._v_tex is None: |
|
self.unwrap_uv() |
|
return self._v_tex |
|
|
|
@property |
|
def edges(self): |
|
if self._edges is None: |
|
self._edges = self._compute_edges() |
|
return self._edges |
|
|
|
def _compute_vertex_normal(self): |
|
i0 = self.t_pos_idx[:, 0] |
|
i1 = self.t_pos_idx[:, 1] |
|
i2 = self.t_pos_idx[:, 2] |
|
|
|
v0 = self.v_pos[i0, :] |
|
v1 = self.v_pos[i1, :] |
|
v2 = self.v_pos[i2, :] |
|
|
|
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) |
|
|
|
|
|
v_nrm = torch.zeros_like(self.v_pos) |
|
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) |
|
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) |
|
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) |
|
|
|
|
|
v_nrm = torch.where( |
|
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) |
|
) |
|
v_nrm = F.normalize(v_nrm, dim=1) |
|
|
|
if torch.is_anomaly_enabled(): |
|
assert torch.all(torch.isfinite(v_nrm)) |
|
|
|
return v_nrm |
|
|
|
def _compute_vertex_tangent(self): |
|
vn_idx = [None] * 3 |
|
pos = [None] * 3 |
|
tex = [None] * 3 |
|
for i in range(0, 3): |
|
pos[i] = self.v_pos[self.t_pos_idx[:, i]] |
|
tex[i] = self.v_tex[self.t_pos_idx[:, i]] |
|
|
|
vn_idx[i] = self.t_pos_idx[:, i] |
|
|
|
tangents = torch.zeros_like(self.v_nrm) |
|
tansum = torch.zeros_like(self.v_nrm) |
|
|
|
|
|
duv1 = tex[1] - tex[0] |
|
duv2 = tex[2] - tex[0] |
|
dpos1 = pos[1] - pos[0] |
|
dpos2 = pos[2] - pos[0] |
|
|
|
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2] |
|
|
|
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1] |
|
|
|
|
|
denom_safe = denom.clip(1e-6) |
|
tang = tng_nom / denom_safe |
|
|
|
|
|
for i in range(0, 3): |
|
idx = vn_idx[i][:, None].repeat(1, 3) |
|
tangents.scatter_add_(0, idx, tang) |
|
tansum.scatter_add_( |
|
0, idx, torch.ones_like(tang) |
|
) |
|
|
|
|
|
tangents = tangents / tansum |
|
|
|
|
|
tangents = F.normalize(tangents, dim=1) |
|
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) |
|
|
|
if torch.is_anomaly_enabled(): |
|
assert torch.all(torch.isfinite(tangents)) |
|
|
|
return tangents |
|
|
|
def quad_remesh( |
|
self, |
|
quad_vertex_count: int = -1, |
|
quad_rosy: int = 4, |
|
quad_crease_angle: float = -1.0, |
|
quad_smooth_iter: int = 2, |
|
quad_align_to_boundaries: bool = False, |
|
) -> Mesh: |
|
if not QUAD_REMESH_AVAILABLE: |
|
raise ImportError("Quad remeshing requires pynim to be installed") |
|
if quad_vertex_count < 0: |
|
quad_vertex_count = self.v_pos.shape[0] |
|
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32) |
|
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32) |
|
|
|
new_vert, new_faces = pynim.remesh( |
|
v_pos, |
|
t_pos_idx, |
|
quad_vertex_count // 4, |
|
rosy=quad_rosy, |
|
posy=4, |
|
creaseAngle=quad_crease_angle, |
|
align_to_boundaries=quad_align_to_boundaries, |
|
smooth_iter=quad_smooth_iter, |
|
deterministic=False, |
|
) |
|
|
|
|
|
mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32)) |
|
|
|
v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous() |
|
t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous() |
|
|
|
|
|
return Mesh(v_pos, t_pos_idx) |
|
|
|
def triangle_remesh( |
|
self, |
|
triangle_average_edge_length_multiplier: Optional[float] = None, |
|
triangle_remesh_steps: int = 10, |
|
triangle_vertex_count=-1, |
|
): |
|
if not TRIANGLE_REMESH_AVAILABLE: |
|
raise ImportError("Triangle remeshing requires gpytoolbox to be installed") |
|
if triangle_vertex_count > 0: |
|
reduction = triangle_vertex_count / self.v_pos.shape[0] |
|
print("Triangle reduction:", reduction) |
|
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32) |
|
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32) |
|
if reduction > 1.0: |
|
subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2))) |
|
print("Subdivide iters:", subdivide_iters) |
|
v_pos, t_pos_idx = gpytoolbox.subdivide( |
|
v_pos, |
|
t_pos_idx, |
|
iters=subdivide_iters, |
|
) |
|
reduction = triangle_vertex_count / v_pos.shape[0] |
|
|
|
|
|
points_out, faces_out, _, _ = gpytoolbox.decimate( |
|
v_pos, |
|
t_pos_idx, |
|
face_ratio=reduction, |
|
) |
|
|
|
|
|
self.v_pos = torch.from_numpy(points_out).to(self.v_pos) |
|
self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx) |
|
self._edges = None |
|
triangle_average_edge_length_multiplier = None |
|
|
|
edges = self.edges |
|
if triangle_average_edge_length_multiplier is None: |
|
h = None |
|
else: |
|
h = float( |
|
torch.linalg.norm( |
|
self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1 |
|
) |
|
.mean() |
|
.item() |
|
* triangle_average_edge_length_multiplier |
|
) |
|
|
|
|
|
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64) |
|
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32) |
|
|
|
|
|
v_remesh, f_remesh = gpytoolbox.remesh_botsch( |
|
v_pos, |
|
t_pos_idx, |
|
triangle_remesh_steps, |
|
h, |
|
) |
|
|
|
|
|
v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous() |
|
t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous() |
|
|
|
|
|
return Mesh(v_pos, t_pos_idx) |
|
|
|
@torch.no_grad() |
|
def unwrap_uv( |
|
self, |
|
island_padding: float = 0.02, |
|
) -> Mesh: |
|
uv, indices = self.unwrapper( |
|
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding |
|
) |
|
|
|
|
|
|
|
individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3) |
|
individual_faces = torch.arange( |
|
individual_vertices.shape[0], |
|
device=individual_vertices.device, |
|
dtype=self.t_pos_idx.dtype, |
|
).reshape(-1, 3) |
|
uv_flat = uv[indices].reshape((-1, 2)) |
|
|
|
|
|
self.v_pos = individual_vertices |
|
self.t_pos_idx = individual_faces |
|
self._v_tex = uv_flat |
|
self._v_nrm = self._compute_vertex_normal() |
|
self._v_tng = self._compute_vertex_tangent() |
|
|
|
def _compute_edges(self): |
|
|
|
edges = torch.cat( |
|
[ |
|
self.t_pos_idx[:, [0, 1]], |
|
self.t_pos_idx[:, [1, 2]], |
|
self.t_pos_idx[:, [2, 0]], |
|
], |
|
dim=0, |
|
) |
|
edges = edges.sort()[0] |
|
edges = torch.unique(edges, dim=0) |
|
return edges |
|
|