Spaces:
Running
on
Zero
Running
on
Zero
# | |
# Copyright (C) 2023, Inria | |
# GRAPHDECO research group, https://team.inria.fr/graphdeco | |
# All rights reserved. | |
# | |
# This software is free for non-commercial, research and evaluation use | |
# under the terms of the LICENSE.md file. | |
# | |
# For inquiries contact [email protected] | |
# | |
from typing import NamedTuple | |
import torch.nn as nn | |
import torch | |
from . import _C | |
def cpu_deep_copy_tuple(input_tuple): | |
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] | |
return tuple(copied_tensors) | |
def rasterize_gaussians( | |
means3D, | |
means2D, | |
sh, | |
colors_precomp, | |
opacities, | |
scales, | |
rotations, | |
cov3Ds_precomp, | |
raster_settings, | |
): | |
return _RasterizeGaussians.apply( | |
means3D, | |
means2D, | |
sh, | |
colors_precomp, | |
opacities, | |
scales, | |
rotations, | |
cov3Ds_precomp, | |
raster_settings, | |
) | |
class _RasterizeGaussians(torch.autograd.Function): | |
def forward( | |
ctx, | |
means3D, | |
means2D, | |
sh, | |
colors_precomp, | |
opacities, | |
scales, | |
rotations, | |
cov3Ds_precomp, | |
raster_settings, | |
): | |
# Restructure arguments the way that the C++ lib expects them | |
args = ( | |
raster_settings.bg, | |
means3D, | |
colors_precomp, | |
opacities, | |
scales, | |
rotations, | |
raster_settings.scale_modifier, | |
cov3Ds_precomp, | |
raster_settings.viewmatrix, | |
raster_settings.projmatrix, | |
raster_settings.tanfovx, | |
raster_settings.tanfovy, | |
raster_settings.image_height, | |
raster_settings.image_width, | |
sh, | |
raster_settings.sh_degree, | |
raster_settings.campos, | |
raster_settings.prefiltered, | |
raster_settings.debug | |
) | |
# Invoke C++/CUDA rasterizer | |
if raster_settings.debug: | |
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted | |
try: | |
num_rendered, color, depth, alpha, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) | |
except Exception as ex: | |
torch.save(cpu_args, "snapshot_fw.dump") | |
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") | |
raise ex | |
else: | |
num_rendered, color, depth, alpha, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) | |
# Keep relevant tensors for backward | |
ctx.raster_settings = raster_settings | |
ctx.num_rendered = num_rendered | |
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, alpha) | |
return color, radii, depth, alpha | |
def backward(ctx, grad_color, grad_radii, grad_depth, grad_alpha): | |
# Restore necessary values from context | |
num_rendered = ctx.num_rendered | |
raster_settings = ctx.raster_settings | |
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, alpha = ctx.saved_tensors | |
# Restructure args as C++ method expects them | |
args = (raster_settings.bg, | |
means3D, | |
radii, | |
colors_precomp, | |
scales, | |
rotations, | |
raster_settings.scale_modifier, | |
cov3Ds_precomp, | |
raster_settings.viewmatrix, | |
raster_settings.projmatrix, | |
raster_settings.tanfovx, | |
raster_settings.tanfovy, | |
grad_color, | |
grad_depth, | |
grad_alpha, | |
sh, | |
raster_settings.sh_degree, | |
raster_settings.campos, | |
geomBuffer, | |
num_rendered, | |
binningBuffer, | |
imgBuffer, | |
alpha, | |
raster_settings.debug) | |
# Compute gradients for relevant tensors by invoking backward method | |
if raster_settings.debug: | |
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted | |
try: | |
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) | |
except Exception as ex: | |
torch.save(cpu_args, "snapshot_bw.dump") | |
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") | |
raise ex | |
else: | |
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) | |
grads = ( | |
grad_means3D, | |
grad_means2D, | |
grad_sh, | |
grad_colors_precomp, | |
grad_opacities, | |
grad_scales, | |
grad_rotations, | |
grad_cov3Ds_precomp, | |
None, | |
) | |
return grads | |
class GaussianRasterizationSettings(NamedTuple): | |
image_height: int | |
image_width: int | |
tanfovx : float | |
tanfovy : float | |
bg : torch.Tensor | |
scale_modifier : float | |
viewmatrix : torch.Tensor | |
projmatrix : torch.Tensor | |
sh_degree : int | |
campos : torch.Tensor | |
prefiltered : bool | |
debug : bool | |
class GaussianRasterizer(nn.Module): | |
def __init__(self, raster_settings): | |
super().__init__() | |
self.raster_settings = raster_settings | |
def markVisible(self, positions): | |
# Mark visible points (based on frustum culling for camera) with a boolean | |
with torch.no_grad(): | |
raster_settings = self.raster_settings | |
visible = _C.mark_visible( | |
positions, | |
raster_settings.viewmatrix, | |
raster_settings.projmatrix) | |
return visible | |
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): | |
raster_settings = self.raster_settings | |
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): | |
raise Exception('Please provide excatly one of either SHs or precomputed colors!') | |
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None): | |
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') | |
if shs is None: | |
shs = torch.Tensor([]) | |
if colors_precomp is None: | |
colors_precomp = torch.Tensor([]) | |
if scales is None: | |
scales = torch.Tensor([]) | |
if rotations is None: | |
rotations = torch.Tensor([]) | |
if cov3D_precomp is None: | |
cov3D_precomp = torch.Tensor([]) | |
# Invoke C++/CUDA rasterization routine | |
return rasterize_gaussians( | |
means3D, | |
means2D, | |
shs, | |
colors_precomp, | |
opacities, | |
scales, | |
rotations, | |
cov3D_precomp, | |
raster_settings, | |
) | |