weather / graphcast /grid_mesh_connectivity.py
Gary0205's picture
Upload 25 files
6d70ed4 verified
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for converting from regular grids on a sphere, to triangular meshes."""
from graphcast import icosahedral_mesh
import numpy as np
import scipy
import trimesh
def _grid_lat_lon_to_coordinates(
grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
"""Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
# Convert to spherical coordinates phi and theta defined in the grid.
# Each [num_latitude_points, num_longitude_points]
phi_grid, theta_grid = np.meshgrid(
np.deg2rad(grid_longitude),
np.deg2rad(90 - grid_latitude))
# [num_latitude_points, num_longitude_points, 3]
# Note this assumes unit radius, since for now we model the earth as a
# sphere of unit radius, and keep any vertical dimension as a regular grid.
return np.stack(
[np.cos(phi_grid)*np.sin(theta_grid),
np.sin(phi_grid)*np.sin(theta_grid),
np.cos(theta_grid)], axis=-1)
def radius_query_indices(
*,
grid_latitude: np.ndarray,
grid_longitude: np.ndarray,
mesh: icosahedral_mesh.TriangularMesh,
radius: float) -> tuple[np.ndarray, np.ndarray]:
"""Returns mesh-grid edge indices for radius query.
Args:
grid_latitude: Latitude values for the grid [num_lat_points]
grid_longitude: Longitude values for the grid [num_lon_points]
mesh: Mesh object.
radius: Radius of connectivity in R3. for a sphere of unit radius.
Returns:
tuple with `grid_indices` and `mesh_indices` indicating edges between the
grid and the mesh such that the distances in a straight line (not geodesic)
are smaller than or equal to `radius`.
* grid_indices: Indices of shape [num_edges], that index into a
[num_lat_points, num_lon_points] grid, after flattening the leading axes.
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
"""
# [num_grid_points=num_lat_points * num_lon_points, 3]
grid_positions = _grid_lat_lon_to_coordinates(
grid_latitude, grid_longitude).reshape([-1, 3])
# [num_mesh_points, 3]
mesh_positions = mesh.vertices
kd_tree = scipy.spatial.cKDTree(mesh_positions)
# [num_grid_points, num_mesh_points_per_grid_point]
# Note `num_mesh_points_per_grid_point` is not constant, so this is a list
# of arrays, rather than a 2d array.
query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
grid_edge_indices = []
mesh_edge_indices = []
for grid_index, mesh_neighbors in enumerate(query_indices):
grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
mesh_edge_indices.append(mesh_neighbors)
# [num_edges]
grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
return grid_edge_indices, mesh_edge_indices
def in_mesh_triangle_indices(
*,
grid_latitude: np.ndarray,
grid_longitude: np.ndarray,
mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
"""Returns mesh-grid edge indices for grid points contained in mesh triangles.
Args:
grid_latitude: Latitude values for the grid [num_lat_points]
grid_longitude: Longitude values for the grid [num_lon_points]
mesh: Mesh object.
Returns:
tuple with `grid_indices` and `mesh_indices` indicating edges between the
grid and the mesh vertices of the triangle that contain each grid point.
The number of edges is always num_lat_points * num_lon_points * 3
* grid_indices: Indices of shape [num_edges], that index into a
[num_lat_points, num_lon_points] grid, after flattening the leading axes.
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
"""
# [num_grid_points=num_lat_points * num_lon_points, 3]
grid_positions = _grid_lat_lon_to_coordinates(
grid_latitude, grid_longitude).reshape([-1, 3])
mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
# [num_grid_points] with mesh face indices for each grid point.
_, _, query_face_indices = trimesh.proximity.closest_point(
mesh_trimesh, grid_positions)
# [num_grid_points, 3] with mesh node indices for each grid point.
mesh_edge_indices = mesh.faces[query_face_indices]
# [num_grid_points, 3] with grid node indices, where every row simply contains
# the row (grid_point) index.
grid_indices = np.arange(grid_positions.shape[0])
grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
# Flatten to get a regular list.
# [num_edges=num_grid_points*3]
mesh_edge_indices = mesh_edge_indices.reshape([-1])
grid_edge_indices = grid_edge_indices.reshape([-1])
return grid_edge_indices, mesh_edge_indices