Spaces:
Running
on
Zero
Running
on
Zero
# -------------------------------------------------------- | |
# Octree-based Sparse Convolutional Neural Networks | |
# Copyright (c) 2022 Peng-Shuai Wang <[email protected]> | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by Peng-Shuai Wang | |
# -------------------------------------------------------- | |
import torch | |
from typing import Optional, Union | |
class KeyLUT: | |
def __init__(self): | |
r256 = torch.arange(256, dtype=torch.int64) | |
r512 = torch.arange(512, dtype=torch.int64) | |
zero = torch.zeros(256, dtype=torch.int64) | |
device = torch.device("cpu") | |
self._encode = { | |
device: ( | |
self.xyz2key(r256, zero, zero, 8), | |
self.xyz2key(zero, r256, zero, 8), | |
self.xyz2key(zero, zero, r256, 8), | |
) | |
} | |
self._decode = {device: self.key2xyz(r512, 9)} | |
def encode_lut(self, device=torch.device("cpu")): | |
if device not in self._encode: | |
cpu = torch.device("cpu") | |
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) | |
return self._encode[device] | |
def decode_lut(self, device=torch.device("cpu")): | |
if device not in self._decode: | |
cpu = torch.device("cpu") | |
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) | |
return self._decode[device] | |
def xyz2key(self, x, y, z, depth): | |
key = torch.zeros_like(x) | |
for i in range(depth): | |
mask = 1 << i | |
key = ( | |
key | |
| ((x & mask) << (2 * i + 2)) | |
| ((y & mask) << (2 * i + 1)) | |
| ((z & mask) << (2 * i + 0)) | |
) | |
return key | |
def key2xyz(self, key, depth): | |
x = torch.zeros_like(key) | |
y = torch.zeros_like(key) | |
z = torch.zeros_like(key) | |
for i in range(depth): | |
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) | |
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) | |
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) | |
return x, y, z | |
_key_lut = KeyLUT() | |
def xyz2key( | |
x: torch.Tensor, | |
y: torch.Tensor, | |
z: torch.Tensor, | |
b: Optional[Union[torch.Tensor, int]] = None, | |
depth: int = 16, | |
): | |
r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys | |
based on pre-computed look up tables. The speed of this function is much | |
faster than the method based on for-loop. | |
Args: | |
x (torch.Tensor): The x coordinate. | |
y (torch.Tensor): The y coordinate. | |
z (torch.Tensor): The z coordinate. | |
b (torch.Tensor or int): The batch index of the coordinates, and should be | |
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of | |
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. | |
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). | |
""" | |
EX, EY, EZ = _key_lut.encode_lut(x.device) | |
x, y, z = x.long(), y.long(), z.long() | |
mask = 255 if depth > 8 else (1 << depth) - 1 | |
key = EX[x & mask] | EY[y & mask] | EZ[z & mask] | |
if depth > 8: | |
mask = (1 << (depth - 8)) - 1 | |
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] | |
key = key16 << 24 | key | |
if b is not None: | |
b = b.long() | |
key = b << 48 | key | |
return key | |
def key2xyz(key: torch.Tensor, depth: int = 16): | |
r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates | |
and the batch index based on pre-computed look up tables. | |
Args: | |
key (torch.Tensor): The shuffled key. | |
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). | |
""" | |
DX, DY, DZ = _key_lut.decode_lut(key.device) | |
x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) | |
b = key >> 48 | |
key = key & ((1 << 48) - 1) | |
n = (depth + 2) // 3 | |
for i in range(n): | |
k = key >> (i * 9) & 511 | |
x = x | (DX[k] << (i * 3)) | |
y = y | (DY[k] << (i * 3)) | |
z = z | (DZ[k] << (i * 3)) | |
return x, y, z, b | |