ECON / lib /common /seg3d_lossless.py
Yuliang's picture
Support TEXTure
487ee6d
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.ops.marching_cubes import marching_cubes
from .seg3d_utils import SmoothConv3D, create_grid3D, plot_mask3D
logging.getLogger("lightning").setLevel(logging.ERROR)
class Seg3dLossless(nn.Module):
def __init__(
self,
query_func,
b_min,
b_max,
resolutions,
channels=1,
balance_value=0.5,
align_corners=False,
visualize=False,
debug=False,
use_cuda_impl=False,
faster=False,
use_shadow=False,
**kwargs,
):
"""
align_corners: same with how you process gt. (grid_sample / interpolate)
"""
super().__init__()
self.query_func = query_func
self.register_buffer("b_min", torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
self.register_buffer("b_max", torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
# ti.init(arch=ti.cuda)
# self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
if type(resolutions[0]) is int:
resolutions = torch.tensor([(res, res, res) for res in resolutions])
else:
resolutions = torch.tensor(resolutions)
self.register_buffer("resolutions", resolutions)
self.batchsize = self.b_min.size(0)
assert self.batchsize == 1
self.balance_value = balance_value
self.channels = channels
assert self.channels == 1
self.align_corners = align_corners
self.visualize = visualize
self.debug = debug
self.use_cuda_impl = use_cuda_impl
self.faster = faster
self.use_shadow = use_shadow
for resolution in resolutions:
assert (
resolution[0] % 2 == 1 and resolution[1] % 2 == 1
), f"resolution {resolution} need to be odd becuase of align_corner."
# init first resolution
init_coords = create_grid3D(0, resolutions[-1] - 1, steps=resolutions[0]) # [N, 3]
init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1, 1) # [bz, N, 3]
self.register_buffer("init_coords", init_coords)
# some useful tensors
calculated = torch.zeros(
(self.resolutions[-1][2], self.resolutions[-1][1], self.resolutions[-1][0]),
dtype=torch.bool,
)
self.register_buffer("calculated", calculated)
gird8_offsets = (
torch.stack(
torch.meshgrid(
[
torch.tensor([-1, 0, 1]),
torch.tensor([-1, 0, 1]),
torch.tensor([-1, 0, 1]),
],
indexing="ij",
)
).int().view(3, -1).t()
) # [27, 3]
self.register_buffer("gird8_offsets", gird8_offsets)
# smooth convs
self.smooth_conv3x3 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=3)
self.smooth_conv5x5 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=5)
self.smooth_conv7x7 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=7)
self.smooth_conv9x9 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=9)
@torch.no_grad()
def batch_eval(self, coords, **kwargs):
"""
coords: in the coordinates of last resolution
**kwargs: for query_func
"""
coords = coords.detach()
# normalize coords to fit in [b_min, b_max]
if self.align_corners:
coords2D = coords.float() / (self.resolutions[-1] - 1)
else:
step = 1.0 / self.resolutions[-1].float()
coords2D = coords.float() / self.resolutions[-1] + step / 2
coords2D = coords2D * (self.b_max - self.b_min) + self.b_min
# query function
occupancys = self.query_func(**kwargs, points=coords2D)
if type(occupancys) is list:
occupancys = torch.stack(occupancys) # [bz, C, N]
assert (
len(occupancys.size()) == 3
), "query_func should return a occupancy with shape of [bz, C, N]"
return occupancys
@torch.no_grad()
def forward(self, **kwargs):
if self.faster:
return self._forward_faster(**kwargs)
else:
return self._forward(**kwargs)
@torch.no_grad()
def _forward_faster(self, **kwargs):
"""
In faster mode, we make following changes to exchange accuracy for speed:
1. no conflict checking: 4.88 fps -> 6.56 fps
2. smooth_conv9x9 ~ smooth_conv3x3 for different resolution
3. last step no examine
"""
final_W = self.resolutions[-1][0]
final_H = self.resolutions[-1][1]
final_D = self.resolutions[-1][2]
for resolution in self.resolutions:
W, H, D = resolution
stride = (self.resolutions[-1] - 1) / (resolution - 1)
# first step
if torch.equal(resolution, self.resolutions[0]):
coords = self.init_coords.clone() # torch.long
occupancys = self.batch_eval(coords, **kwargs)
occupancys = occupancys.view(self.batchsize, self.channels, D, H, W)
if (occupancys > 0.5).sum() == 0:
# return F.interpolate(
# occupancys, size=(final_D, final_H, final_W),
# mode="linear", align_corners=True)
return None
if self.visualize:
self.plot(occupancys, coords, final_D, final_H, final_W)
with torch.no_grad():
coords_accum = coords / stride
# last step
elif torch.equal(resolution, self.resolutions[-1]):
with torch.no_grad():
# here true is correct!
valid = F.interpolate(
(occupancys > self.balance_value).float(),
size=(D, H, W),
mode="trilinear",
align_corners=True,
)
# here true is correct!
occupancys = F.interpolate(
occupancys.float(),
size=(D, H, W),
mode="trilinear",
align_corners=True,
)
# is_boundary = (valid > 0.0) & (valid < 1.0)
is_boundary = valid == 0.5
# next steps
else:
coords_accum *= 2
with torch.no_grad():
# here true is correct!
valid = F.interpolate(
(occupancys > self.balance_value).float(),
size=(D, H, W),
mode="trilinear",
align_corners=True,
)
# here true is correct!
occupancys = F.interpolate(
occupancys.float(),
size=(D, H, W),
mode="trilinear",
align_corners=True,
)
is_boundary = (valid > 0.0) & (valid < 1.0)
with torch.no_grad():
if torch.equal(resolution, self.resolutions[1]):
is_boundary = (self.smooth_conv9x9(is_boundary.float()) > 0)[0, 0]
elif torch.equal(resolution, self.resolutions[2]):
is_boundary = (self.smooth_conv7x7(is_boundary.float()) > 0)[0, 0]
else:
is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0]
coords_accum = coords_accum.long()
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
coords_accum[0, :, 0], ] = False
point_coords = (
is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
)
point_indices = (
point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
point_coords[:, :, 0]
)
R, C, D, H, W = occupancys.shape
# inferred value
coords = point_coords * stride
if coords.size(1) == 0:
continue
occupancys_topk = self.batch_eval(coords, **kwargs)
# put mask point predictions to the right places on the upsampled grid.
R, C, D, H, W = occupancys.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
occupancys = (
occupancys.reshape(R, C,
D * H * W).scatter_(2, point_indices,
occupancys_topk).view(R, C, D, H, W)
)
with torch.no_grad():
voxels = coords / stride
coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
return occupancys[0, 0]
@torch.no_grad()
def _forward(self, **kwargs):
"""
output occupancy field would be:
(bz, C, res, res)
"""
final_W = self.resolutions[-1][0]
final_H = self.resolutions[-1][1]
final_D = self.resolutions[-1][2]
calculated = self.calculated.clone()
for resolution in self.resolutions:
W, H, D = resolution
stride = (self.resolutions[-1] - 1) / (resolution - 1)
if self.visualize:
this_stage_coords = []
# first step
if torch.equal(resolution, self.resolutions[0]):
coords = self.init_coords.clone() # torch.long
occupancys = self.batch_eval(coords, **kwargs)
occupancys = occupancys.view(self.batchsize, self.channels, D, H, W)
if self.visualize:
self.plot(occupancys, coords, final_D, final_H, final_W)
with torch.no_grad():
coords_accum = coords / stride
calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
# next steps
else:
coords_accum *= 2
with torch.no_grad():
# here true is correct!
valid = F.interpolate(
(occupancys > self.balance_value).float(),
size=(D, H, W),
mode="trilinear",
align_corners=True,
)
# here true is correct!
occupancys = F.interpolate(
occupancys.float(),
size=(D, H, W),
mode="trilinear",
align_corners=True,
)
is_boundary = (valid > 0.0) & (valid < 1.0)
with torch.no_grad():
# TODO
if self.use_shadow and torch.equal(resolution, self.resolutions[-1]):
# larger z means smaller depth here
depth_res = resolution[2].item()
depth_index = torch.linspace(0, depth_res - 1,
steps=depth_res).type_as(occupancys.device)
depth_index_max = (
torch.max(
(occupancys > self.balance_value) * (depth_index + 1),
dim=-1,
keepdim=True,
)[0] - 1
)
shadow = depth_index < depth_index_max
is_boundary[shadow] = False
is_boundary = is_boundary[0, 0]
else:
is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0]
# is_boundary = is_boundary[0, 0]
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
coords_accum[0, :, 0], ] = False
point_coords = (
is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
)
point_indices = (
point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
point_coords[:, :, 0]
)
R, C, D, H, W = occupancys.shape
# interpolated value
occupancys_interp = torch.gather(
occupancys.reshape(R, C, D * H * W),
2,
point_indices.unsqueeze(1),
)
# inferred value
coords = point_coords * stride
if coords.size(1) == 0:
continue
occupancys_topk = self.batch_eval(coords, **kwargs)
if self.visualize:
this_stage_coords.append(coords)
# put mask point predictions to the right places on the upsampled grid.
R, C, D, H, W = occupancys.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
occupancys = (
occupancys.reshape(R, C,
D * H * W).scatter_(2, point_indices,
occupancys_topk).view(R, C, D, H, W)
)
with torch.no_grad():
# conflicts
conflicts = ((occupancys_interp - self.balance_value) *
(occupancys_topk - self.balance_value) < 0)[0, 0]
if self.visualize:
self.plot(occupancys, coords, final_D, final_H, final_W)
voxels = coords / stride
coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
while conflicts.sum() > 0:
if self.use_shadow and torch.equal(resolution, self.resolutions[-1]):
break
with torch.no_grad():
conflicts_coords = coords[0, conflicts, :]
if self.debug:
self.plot(
occupancys,
conflicts_coords.unsqueeze(0),
final_D,
final_H,
final_W,
title="conflicts",
)
conflicts_boundary = ((
conflicts_coords.int() + self.gird8_offsets.unsqueeze(1) * stride.int()
).reshape(-1, 3).long().unique(dim=0))
conflicts_boundary[:, 0] = conflicts_boundary[:, 0].clamp(
0,
calculated.size(2) - 1
)
conflicts_boundary[:, 1] = conflicts_boundary[:, 1].clamp(
0,
calculated.size(1) - 1
)
conflicts_boundary[:, 2] = conflicts_boundary[:, 2].clamp(
0,
calculated.size(0) - 1
)
coords = conflicts_boundary[calculated[conflicts_boundary[:, 2],
conflicts_boundary[:, 1],
conflicts_boundary[:, 0], ] == False]
if self.debug:
self.plot(
occupancys,
coords.unsqueeze(0),
final_D,
final_H,
final_W,
title="coords",
)
coords = coords.unsqueeze(0)
point_coords = coords / stride
point_indices = (
point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
point_coords[:, :, 0]
)
R, C, D, H, W = occupancys.shape
# interpolated value
occupancys_interp = torch.gather(
occupancys.reshape(R, C, D * H * W),
2,
point_indices.unsqueeze(1),
)
# inferred value
coords = point_coords * stride
if coords.size(1) == 0:
break
occupancys_topk = self.batch_eval(coords, **kwargs)
if self.visualize:
this_stage_coords.append(coords)
with torch.no_grad():
# conflicts
conflicts = ((occupancys_interp - self.balance_value) *
(occupancys_topk - self.balance_value) < 0)[0, 0]
# put mask point predictions to the right places on the upsampled grid.
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
occupancys = (
occupancys.reshape(R, C,
D * H * W).scatter_(2, point_indices,
occupancys_topk).view(R, C, D, H, W)
)
with torch.no_grad():
voxels = coords / stride
coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
if self.visualize:
this_stage_coords = torch.cat(this_stage_coords, dim=1)
self.plot(occupancys, this_stage_coords, final_D, final_H, final_W)
return occupancys[0, 0]
def plot(self, occupancys, coords, final_D, final_H, final_W, title="", **kwargs):
final = F.interpolate(
occupancys.float(),
size=(final_D, final_H, final_W),
mode="trilinear",
align_corners=True,
) # here true is correct!
x = coords[0, :, 0].to("cpu")
y = coords[0, :, 1].to("cpu")
z = coords[0, :, 2].to("cpu")
plot_mask3D(final[0, 0].to("cpu"), title, (x, y, z), **kwargs)
@torch.no_grad()
def find_vertices(self, sdf, direction="front"):
"""
- direction: "front" | "back" | "left" | "right"
"""
resolution = sdf.size(2)
if direction == "front":
pass
elif direction == "left":
sdf = sdf.permute(2, 1, 0)
elif direction == "back":
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
sdf = sdf[inv_idx, :, :]
elif direction == "right":
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
sdf = sdf[:, :, inv_idx]
sdf = sdf.permute(2, 1, 0)
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
sdf = sdf[inv_idx, :, :]
sdf_all = sdf.permute(2, 1, 0)
# shadow
grad_v = (sdf_all > 0.5) * torch.linspace(resolution, 1, steps=resolution).to(sdf.device)
grad_c = torch.ones_like(sdf_all) * torch.linspace(0, resolution - 1,
steps=resolution).to(sdf.device)
max_v, max_c = grad_v.max(dim=2)
shadow = grad_c > max_c.view(resolution, resolution, 1)
keep = (sdf_all > 0.5) & (~shadow)
p1 = keep.nonzero(as_tuple=False).t() # [3, N]
p2 = p1.clone() # z
p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
p3 = p1.clone() # y
p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
p4 = p1.clone() # x
p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
v2 = sdf_all[p2[0, :], p2[1, :], p2[2, :]]
v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
X = p1[0, :].long() # [N,]
Y = p1[1, :].long() # [N,]
Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * (v2 - 0.5
) / (v2 - v1) # [N,]
Z = Z.clamp(0, resolution)
# normal
norm_z = v2 - v1
norm_y = v3 - v1
norm_x = v4 - v1
# print (v2.min(dim=0)[0], v2.max(dim=0)[0], v3.min(dim=0)[0], v3.max(dim=0)[0])
norm = torch.stack([norm_x, norm_y, norm_z], dim=1)
norm = norm / torch.norm(norm, p=2, dim=1, keepdim=True)
return X, Y, Z, norm
@torch.no_grad()
def render_normal(self, resolution, X, Y, Z, norm):
image = torch.ones((1, 3, resolution, resolution), dtype=torch.float32).to(norm.device)
color = (norm + 1) / 2.0
color = color.clamp(0, 1)
image[0, :, Y, X] = color.t()
return image
@torch.no_grad()
def display(self, sdf):
# render
X, Y, Z, norm = self.find_vertices(sdf, direction="front")
image1 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
X, Y, Z, norm = self.find_vertices(sdf, direction="left")
image2 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
X, Y, Z, norm = self.find_vertices(sdf, direction="right")
image3 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
X, Y, Z, norm = self.find_vertices(sdf, direction="back")
image4 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
image = torch.cat([image1, image2, image3, image4], axis=3)
image = image.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.0
return np.uint8(image)
@torch.no_grad()
def export_mesh(self, occupancys):
final = occupancys[1:, 1:, 1:].contiguous()
verts, faces = marching_cubes(final.unsqueeze(0), isolevel=0.5)
verts = verts[0].cpu().float()
faces = faces[0].cpu().long()[:, [0, 2, 1]]
return verts, faces