InstantMesh / pipeline.py
dylanebert
optional progress callback
c8a48ed
from dataclasses import dataclass
from typing import Callable, Optional
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DiffusionPipeline
from diffusers.utils import BaseOutput
def pad_camera_extrinsics_4x4(extrinsics):
if extrinsics.shape[-2] == 4:
return extrinsics
padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
if extrinsics.ndim == 3:
padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
extrinsics = torch.cat([extrinsics, padding], dim=-2)
return extrinsics
def center_looking_at_camera_pose(
camera_position: torch.Tensor,
look_at: torch.Tensor = None,
up_world: torch.Tensor = None,
):
if look_at is None:
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
if up_world is None:
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
if camera_position.ndim == 2:
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
z_axis = camera_position - look_at
z_axis = F.normalize(z_axis, dim=-1).float()
x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
x_axis = F.normalize(x_axis, dim=-1).float()
y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
y_axis = F.normalize(y_axis, dim=-1).float()
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
extrinsics = pad_camera_extrinsics_4x4(extrinsics)
return extrinsics
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations)
xs = radius * np.cos(elevations) * np.cos(azimuths)
ys = radius * np.cos(elevations) * np.sin(azimuths)
zs = radius * np.sin(elevations)
cam_locations = np.stack([xs, ys, zs], axis=-1)
cam_locations = torch.from_numpy(cam_locations).float()
c2ws = center_looking_at_camera_pose(cam_locations)
return c2ws
def FOV_to_intrinsics(fov, device="cpu"):
focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
intrinsics = torch.tensor(
[[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device
)
return intrinsics
def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
c2ws = spherical_camera_pose(azimuths, elevations, radius)
c2ws = c2ws.float().flatten(-2)
Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
extrinsics = c2ws[:, :12]
intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
@dataclass
class InstantMeshPipelineOutput(BaseOutput):
vertices: np.ndarray
faces: np.ndarray
uvs: np.ndarray
texture: np.ndarray
class InstantMeshPipeline(DiffusionPipeline):
def __init__(self, lrm):
super().__init__()
self.lrm = lrm
self.register_modules(lrm=self.lrm)
@torch.no_grad()
def __call__(
self,
images: torch.Tensor,
generate_texture: bool = False,
progress_callback: Optional[Callable[[float], None]] = None,
):
self.lrm.init_flexicubes_geometry(self._execution_device, fovy=30.0)
cameras = get_zero123plus_input_cameras().to(self._execution_device)
planes = self.lrm.forward_planes(images, cameras)
if generate_texture:
mesh_out = self.lrm.extract_mesh(
planes,
use_texture_map=True,
texture_resolution=1024,
progress_callback=progress_callback,
)
vertices, vertex_indices, uvs, uv_indices, texture = mesh_out
vertices = vertices.cpu().numpy()
vertex_indices = vertex_indices.cpu().numpy()
uvs = uvs.cpu().numpy()
uv_indices = uv_indices.cpu().numpy()
texture = texture.permute(1, 2, 0).cpu().numpy()
vertex_indices_flat = vertex_indices.reshape(-1)
uv_indices_flat = uv_indices.reshape(-1)
vertex_uv_pairs = np.stack([vertex_indices_flat, uv_indices_flat], axis=1)
unique_pairs, unique_indices = np.unique(
vertex_uv_pairs, axis=0, return_inverse=True
)
vertices = vertices[unique_pairs[:, 0]]
uvs = uvs[unique_pairs[:, 1]]
faces = unique_indices.reshape(-1, 3)
lo, hi = 0, 1
img = np.asarray(texture, dtype=np.float32)
img = (img - lo) * (255 / (hi - lo))
img = img.clip(0, 255)
mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
mask = (mask <= 3.0).astype(np.float32)
kernel = np.ones((3, 3), "uint8")
dilate_img = cv2.dilate(img, kernel, iterations=1)
img = img * (1 - mask) + dilate_img * mask
img = img.clip(0, 255).astype(np.uint8)
texture = np.ascontiguousarray(img[::-1, :, :])
return InstantMeshPipelineOutput(
vertices=vertices,
faces=faces,
uvs=uvs,
texture=texture,
)
else:
mesh_out = self.lrm.extract_mesh(
planes,
use_texture_map=False,
progress_callback=progress_callback,
)
vertices, faces, _ = mesh_out
vertices = vertices.cpu().numpy()
faces = faces.cpu().numpy()
return InstantMeshPipelineOutput(
vertices=vertices,
faces=faces,
uvs=None,
texture=None,
)