from dataclasses import dataclass from typing import Callable, Optional import cv2 import numpy as np import torch import torch.nn.functional as F from diffusers import DiffusionPipeline from diffusers.utils import BaseOutput def pad_camera_extrinsics_4x4(extrinsics): if extrinsics.shape[-2] == 4: return extrinsics padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics) if extrinsics.ndim == 3: padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1) extrinsics = torch.cat([extrinsics, padding], dim=-2) return extrinsics def center_looking_at_camera_pose( camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None, ): if look_at is None: look_at = torch.tensor([0, 0, 0], dtype=torch.float32) if up_world is None: up_world = torch.tensor([0, 0, 1], dtype=torch.float32) if camera_position.ndim == 2: look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) z_axis = camera_position - look_at z_axis = F.normalize(z_axis, dim=-1).float() x_axis = torch.linalg.cross(up_world, z_axis, dim=-1) x_axis = F.normalize(x_axis, dim=-1).float() y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1) y_axis = F.normalize(y_axis, dim=-1).float() extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) extrinsics = pad_camera_extrinsics_4x4(extrinsics) return extrinsics def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): azimuths = np.deg2rad(azimuths) elevations = np.deg2rad(elevations) xs = radius * np.cos(elevations) * np.cos(azimuths) ys = radius * np.cos(elevations) * np.sin(azimuths) zs = radius * np.sin(elevations) cam_locations = np.stack([xs, ys, zs], axis=-1) cam_locations = torch.from_numpy(cam_locations).float() c2ws = center_looking_at_camera_pose(cam_locations) return c2ws def FOV_to_intrinsics(fov, device="cpu"): focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) intrinsics = torch.tensor( [[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device ) return intrinsics def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) c2ws = spherical_camera_pose(azimuths, elevations, radius) c2ws = c2ws.float().flatten(-2) Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) extrinsics = c2ws[:, :12] intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) cameras = torch.cat([extrinsics, intrinsics], dim=-1) return cameras.unsqueeze(0).repeat(batch_size, 1, 1) @dataclass class InstantMeshPipelineOutput(BaseOutput): vertices: np.ndarray faces: np.ndarray uvs: np.ndarray texture: np.ndarray class InstantMeshPipeline(DiffusionPipeline): def __init__(self, lrm): super().__init__() self.lrm = lrm self.register_modules(lrm=self.lrm) @torch.no_grad() def __call__( self, images: torch.Tensor, generate_texture: bool = False, progress_callback: Optional[Callable[[float], None]] = None, ): self.lrm.init_flexicubes_geometry(self._execution_device, fovy=30.0) cameras = get_zero123plus_input_cameras().to(self._execution_device) planes = self.lrm.forward_planes(images, cameras) if generate_texture: mesh_out = self.lrm.extract_mesh( planes, use_texture_map=True, texture_resolution=1024, progress_callback=progress_callback, ) vertices, vertex_indices, uvs, uv_indices, texture = mesh_out vertices = vertices.cpu().numpy() vertex_indices = vertex_indices.cpu().numpy() uvs = uvs.cpu().numpy() uv_indices = uv_indices.cpu().numpy() texture = texture.permute(1, 2, 0).cpu().numpy() vertex_indices_flat = vertex_indices.reshape(-1) uv_indices_flat = uv_indices.reshape(-1) vertex_uv_pairs = np.stack([vertex_indices_flat, uv_indices_flat], axis=1) unique_pairs, unique_indices = np.unique( vertex_uv_pairs, axis=0, return_inverse=True ) vertices = vertices[unique_pairs[:, 0]] uvs = uvs[unique_pairs[:, 1]] faces = unique_indices.reshape(-1, 3) lo, hi = 0, 1 img = np.asarray(texture, dtype=np.float32) img = (img - lo) * (255 / (hi - lo)) img = img.clip(0, 255) mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) mask = (mask <= 3.0).astype(np.float32) kernel = np.ones((3, 3), "uint8") dilate_img = cv2.dilate(img, kernel, iterations=1) img = img * (1 - mask) + dilate_img * mask img = img.clip(0, 255).astype(np.uint8) texture = np.ascontiguousarray(img[::-1, :, :]) return InstantMeshPipelineOutput( vertices=vertices, faces=faces, uvs=uvs, texture=texture, ) else: mesh_out = self.lrm.extract_mesh( planes, use_texture_map=False, progress_callback=progress_callback, ) vertices, faces, _ = mesh_out vertices = vertices.cpu().numpy() faces = faces.cpu().numpy() return InstantMeshPipelineOutput( vertices=vertices, faces=faces, uvs=None, texture=None, )