|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tools for converting from regular grids on a sphere, to triangular meshes.""" |
|
|
|
from graphcast import icosahedral_mesh |
|
import numpy as np |
|
import scipy |
|
import trimesh |
|
|
|
|
|
def _grid_lat_lon_to_coordinates( |
|
grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray: |
|
"""Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3].""" |
|
|
|
|
|
phi_grid, theta_grid = np.meshgrid( |
|
np.deg2rad(grid_longitude), |
|
np.deg2rad(90 - grid_latitude)) |
|
|
|
|
|
|
|
|
|
return np.stack( |
|
[np.cos(phi_grid)*np.sin(theta_grid), |
|
np.sin(phi_grid)*np.sin(theta_grid), |
|
np.cos(theta_grid)], axis=-1) |
|
|
|
|
|
def radius_query_indices( |
|
*, |
|
grid_latitude: np.ndarray, |
|
grid_longitude: np.ndarray, |
|
mesh: icosahedral_mesh.TriangularMesh, |
|
radius: float) -> tuple[np.ndarray, np.ndarray]: |
|
"""Returns mesh-grid edge indices for radius query. |
|
|
|
Args: |
|
grid_latitude: Latitude values for the grid [num_lat_points] |
|
grid_longitude: Longitude values for the grid [num_lon_points] |
|
mesh: Mesh object. |
|
radius: Radius of connectivity in R3. for a sphere of unit radius. |
|
|
|
Returns: |
|
tuple with `grid_indices` and `mesh_indices` indicating edges between the |
|
grid and the mesh such that the distances in a straight line (not geodesic) |
|
are smaller than or equal to `radius`. |
|
* grid_indices: Indices of shape [num_edges], that index into a |
|
[num_lat_points, num_lon_points] grid, after flattening the leading axes. |
|
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. |
|
""" |
|
|
|
|
|
grid_positions = _grid_lat_lon_to_coordinates( |
|
grid_latitude, grid_longitude).reshape([-1, 3]) |
|
|
|
|
|
mesh_positions = mesh.vertices |
|
kd_tree = scipy.spatial.cKDTree(mesh_positions) |
|
|
|
|
|
|
|
|
|
query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius) |
|
|
|
grid_edge_indices = [] |
|
mesh_edge_indices = [] |
|
for grid_index, mesh_neighbors in enumerate(query_indices): |
|
grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors))) |
|
mesh_edge_indices.append(mesh_neighbors) |
|
|
|
|
|
grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int) |
|
mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int) |
|
|
|
return grid_edge_indices, mesh_edge_indices |
|
|
|
|
|
def in_mesh_triangle_indices( |
|
*, |
|
grid_latitude: np.ndarray, |
|
grid_longitude: np.ndarray, |
|
mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]: |
|
"""Returns mesh-grid edge indices for grid points contained in mesh triangles. |
|
|
|
Args: |
|
grid_latitude: Latitude values for the grid [num_lat_points] |
|
grid_longitude: Longitude values for the grid [num_lon_points] |
|
mesh: Mesh object. |
|
|
|
Returns: |
|
tuple with `grid_indices` and `mesh_indices` indicating edges between the |
|
grid and the mesh vertices of the triangle that contain each grid point. |
|
The number of edges is always num_lat_points * num_lon_points * 3 |
|
* grid_indices: Indices of shape [num_edges], that index into a |
|
[num_lat_points, num_lon_points] grid, after flattening the leading axes. |
|
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. |
|
""" |
|
|
|
|
|
grid_positions = _grid_lat_lon_to_coordinates( |
|
grid_latitude, grid_longitude).reshape([-1, 3]) |
|
|
|
mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) |
|
|
|
|
|
_, _, query_face_indices = trimesh.proximity.closest_point( |
|
mesh_trimesh, grid_positions) |
|
|
|
|
|
mesh_edge_indices = mesh.faces[query_face_indices] |
|
|
|
|
|
|
|
grid_indices = np.arange(grid_positions.shape[0]) |
|
grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3]) |
|
|
|
|
|
|
|
mesh_edge_indices = mesh_edge_indices.reshape([-1]) |
|
grid_edge_indices = grid_edge_indices.reshape([-1]) |
|
|
|
return grid_edge_indices, mesh_edge_indices |
|
|