File size: 5,303 Bytes
6d70ed4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for converting from regular grids on a sphere, to triangular meshes."""
from graphcast import icosahedral_mesh
import numpy as np
import scipy
import trimesh
def _grid_lat_lon_to_coordinates(
grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
"""Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
# Convert to spherical coordinates phi and theta defined in the grid.
# Each [num_latitude_points, num_longitude_points]
phi_grid, theta_grid = np.meshgrid(
np.deg2rad(grid_longitude),
np.deg2rad(90 - grid_latitude))
# [num_latitude_points, num_longitude_points, 3]
# Note this assumes unit radius, since for now we model the earth as a
# sphere of unit radius, and keep any vertical dimension as a regular grid.
return np.stack(
[np.cos(phi_grid)*np.sin(theta_grid),
np.sin(phi_grid)*np.sin(theta_grid),
np.cos(theta_grid)], axis=-1)
def radius_query_indices(
*,
grid_latitude: np.ndarray,
grid_longitude: np.ndarray,
mesh: icosahedral_mesh.TriangularMesh,
radius: float) -> tuple[np.ndarray, np.ndarray]:
"""Returns mesh-grid edge indices for radius query.
Args:
grid_latitude: Latitude values for the grid [num_lat_points]
grid_longitude: Longitude values for the grid [num_lon_points]
mesh: Mesh object.
radius: Radius of connectivity in R3. for a sphere of unit radius.
Returns:
tuple with `grid_indices` and `mesh_indices` indicating edges between the
grid and the mesh such that the distances in a straight line (not geodesic)
are smaller than or equal to `radius`.
* grid_indices: Indices of shape [num_edges], that index into a
[num_lat_points, num_lon_points] grid, after flattening the leading axes.
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
"""
# [num_grid_points=num_lat_points * num_lon_points, 3]
grid_positions = _grid_lat_lon_to_coordinates(
grid_latitude, grid_longitude).reshape([-1, 3])
# [num_mesh_points, 3]
mesh_positions = mesh.vertices
kd_tree = scipy.spatial.cKDTree(mesh_positions)
# [num_grid_points, num_mesh_points_per_grid_point]
# Note `num_mesh_points_per_grid_point` is not constant, so this is a list
# of arrays, rather than a 2d array.
query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
grid_edge_indices = []
mesh_edge_indices = []
for grid_index, mesh_neighbors in enumerate(query_indices):
grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
mesh_edge_indices.append(mesh_neighbors)
# [num_edges]
grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
return grid_edge_indices, mesh_edge_indices
def in_mesh_triangle_indices(
*,
grid_latitude: np.ndarray,
grid_longitude: np.ndarray,
mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
"""Returns mesh-grid edge indices for grid points contained in mesh triangles.
Args:
grid_latitude: Latitude values for the grid [num_lat_points]
grid_longitude: Longitude values for the grid [num_lon_points]
mesh: Mesh object.
Returns:
tuple with `grid_indices` and `mesh_indices` indicating edges between the
grid and the mesh vertices of the triangle that contain each grid point.
The number of edges is always num_lat_points * num_lon_points * 3
* grid_indices: Indices of shape [num_edges], that index into a
[num_lat_points, num_lon_points] grid, after flattening the leading axes.
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
"""
# [num_grid_points=num_lat_points * num_lon_points, 3]
grid_positions = _grid_lat_lon_to_coordinates(
grid_latitude, grid_longitude).reshape([-1, 3])
mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
# [num_grid_points] with mesh face indices for each grid point.
_, _, query_face_indices = trimesh.proximity.closest_point(
mesh_trimesh, grid_positions)
# [num_grid_points, 3] with mesh node indices for each grid point.
mesh_edge_indices = mesh.faces[query_face_indices]
# [num_grid_points, 3] with grid node indices, where every row simply contains
# the row (grid_point) index.
grid_indices = np.arange(grid_positions.shape[0])
grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
# Flatten to get a regular list.
# [num_edges=num_grid_points*3]
mesh_edge_indices = mesh_edge_indices.reshape([-1])
grid_edge_indices = grid_edge_indices.reshape([-1])
return grid_edge_indices, mesh_edge_indices
|