jadechoghari
commited on
Create math_utils.py
Browse fileswe are flattening the directory, since HF only supports flat imports
- math_utils.py +123 -0
math_utils.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# MIT License
|
7 |
+
|
8 |
+
# Copyright (c) 2022 Petr Kellnhofer
|
9 |
+
|
10 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
11 |
+
# of this software and associated documentation files (the "Software"), to deal
|
12 |
+
# in the Software without restriction, including without limitation the rights
|
13 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
14 |
+
# copies of the Software, and to permit persons to whom the Software is
|
15 |
+
# furnished to do so, subject to the following conditions:
|
16 |
+
|
17 |
+
# The above copyright notice and this permission notice shall be included in all
|
18 |
+
# copies or substantial portions of the Software.
|
19 |
+
|
20 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
21 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
22 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
23 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
24 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
25 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
26 |
+
# SOFTWARE.
|
27 |
+
|
28 |
+
import torch
|
29 |
+
|
30 |
+
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
|
31 |
+
"""
|
32 |
+
Left-multiplies MxM @ NxM. Returns NxM.
|
33 |
+
"""
|
34 |
+
res = torch.matmul(vectors4, matrix.T)
|
35 |
+
return res
|
36 |
+
|
37 |
+
|
38 |
+
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
|
39 |
+
"""
|
40 |
+
Normalize vector lengths.
|
41 |
+
"""
|
42 |
+
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
|
43 |
+
|
44 |
+
def torch_dot(x: torch.Tensor, y: torch.Tensor):
|
45 |
+
"""
|
46 |
+
Dot product of two tensors.
|
47 |
+
"""
|
48 |
+
return (x * y).sum(-1)
|
49 |
+
|
50 |
+
|
51 |
+
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
|
52 |
+
"""
|
53 |
+
Author: Petr Kellnhofer
|
54 |
+
Intersects rays with the [-1, 1] NDC volume.
|
55 |
+
Returns min and max distance of entry.
|
56 |
+
Returns -1 for no intersection.
|
57 |
+
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
|
58 |
+
"""
|
59 |
+
o_shape = rays_o.shape
|
60 |
+
rays_o = rays_o.detach().reshape(-1, 3)
|
61 |
+
rays_d = rays_d.detach().reshape(-1, 3)
|
62 |
+
|
63 |
+
|
64 |
+
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
|
65 |
+
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
|
66 |
+
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
|
67 |
+
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
|
68 |
+
|
69 |
+
# Precompute inverse for stability.
|
70 |
+
invdir = 1 / rays_d
|
71 |
+
sign = (invdir < 0).long()
|
72 |
+
|
73 |
+
# Intersect with YZ plane.
|
74 |
+
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
75 |
+
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
76 |
+
|
77 |
+
# Intersect with XZ plane.
|
78 |
+
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
79 |
+
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
80 |
+
|
81 |
+
# Resolve parallel rays.
|
82 |
+
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
|
83 |
+
|
84 |
+
# Use the shortest intersection.
|
85 |
+
tmin = torch.max(tmin, tymin)
|
86 |
+
tmax = torch.min(tmax, tymax)
|
87 |
+
|
88 |
+
# Intersect with XY plane.
|
89 |
+
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
90 |
+
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
91 |
+
|
92 |
+
# Resolve parallel rays.
|
93 |
+
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
|
94 |
+
|
95 |
+
# Use the shortest intersection.
|
96 |
+
tmin = torch.max(tmin, tzmin)
|
97 |
+
tmax = torch.min(tmax, tzmax)
|
98 |
+
|
99 |
+
# Mark invalid.
|
100 |
+
tmin[torch.logical_not(is_valid)] = -1
|
101 |
+
tmax[torch.logical_not(is_valid)] = -2
|
102 |
+
|
103 |
+
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
|
104 |
+
|
105 |
+
|
106 |
+
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
|
107 |
+
"""
|
108 |
+
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
|
109 |
+
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
|
110 |
+
"""
|
111 |
+
# create a tensor of 'num' steps from 0 to 1
|
112 |
+
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
|
113 |
+
|
114 |
+
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
|
115 |
+
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
|
116 |
+
# "cannot statically infer the expected size of a list in this contex", hence the code below
|
117 |
+
for i in range(start.ndim):
|
118 |
+
steps = steps.unsqueeze(-1)
|
119 |
+
|
120 |
+
# the output starts at 'start' and increments until 'stop' in each dimension
|
121 |
+
out = start[None] + steps * (stop - start)[None]
|
122 |
+
|
123 |
+
return out
|