diff --git a/uxarray/grid/dual.py b/uxarray/grid/dual.py index faf784853..15eda87e0 100644 --- a/uxarray/grid/dual.py +++ b/uxarray/grid/dual.py @@ -1,5 +1,5 @@ import numpy as np -from numba import njit +from numba import njit, prange from uxarray.constants import INT_DTYPE, INT_FILL_VALUE @@ -26,7 +26,6 @@ def construct_dual(grid): dual_node_z = grid.face_z.values # Get other information from the grid needed - n_node = grid.n_node node_x = grid.node_x.values node_y = grid.node_y.values node_z = grid.node_z.values @@ -35,10 +34,18 @@ def construct_dual(grid): # Get an array with the number of edges for each face n_edges_mask = node_face_connectivity != INT_FILL_VALUE n_edges = np.sum(n_edges_mask, axis=1) + max_edges = node_face_connectivity.shape[1] + + # Only nodes with 3+ edges can form valid dual faces + valid_node_indices = np.where(n_edges >= 3)[0] + + construct_node_face_connectivity = np.full( + (len(valid_node_indices), max_edges), INT_FILL_VALUE, dtype=INT_DTYPE + ) # Construct and return the faces new_node_face_connectivity = construct_faces( - n_node, + valid_node_indices, n_edges, dual_node_x, dual_node_y, @@ -47,14 +54,16 @@ def construct_dual(grid): node_x, node_y, node_z, + construct_node_face_connectivity, + max_edges, ) return new_node_face_connectivity -@njit(cache=True) +@njit(cache=True, parallel=True) def construct_faces( - n_node, + valid_node_indices, n_edges, dual_node_x, dual_node_y, @@ -63,61 +72,67 @@ def construct_faces( node_x, node_y, node_z, + construct_node_face_connectivity, + max_edges, ): """Construct the faces of the dual mesh based on a given node_face_connectivity. Parameters ---------- - n_node: np.ndarray - number of nodes in the primal mesh + valid_node_indices: np.ndarray + Array of node indices with at least 3 connections in the primal mesh n_edges: np.ndarray - array of the number of edges for each dual face + Array of the number of edges for each node in the primal mesh dual_node_x: np.ndarray - x node coordinates for the dual mesh + x coordinates for the dual mesh nodes (face centers of primal mesh) dual_node_y: np.ndarray - y node coordinates for the dual mesh + y coordinates for the dual mesh nodes (face centers of primal mesh) dual_node_z: np.ndarray - z node coordinates for the dual mesh + z coordinates for the dual mesh nodes (face centers of primal mesh) node_face_connectivity: np.ndarray - `node_face_connectivity` of the primal mesh + Node-to-face connectivity of the primal mesh node_x: np.ndarray - x node coordinates from the primal mesh + x coordinates of nodes from the primal mesh node_y: np.ndarray - y node coordinates from the primal mesh + y coordinates of nodes from the primal mesh node_z: np.ndarray - z node coordinates from the primal mesh + z coordinates of nodes from the primal mesh + construct_node_face_connectivity: np.ndarray + Pre-allocated array to store the dual mesh connectivity + max_edges: int + The max number of edges in a face Returns -------- - node_face_connectivity : ndarray + construct_node_face_connectivity : ndarray Constructed node_face_connectivity for the dual mesh + + Notes + ----- + In dual mesh construction, the "valid node indices" are face indices from + the primal mesh's node_face_connectivity that are not fill values. These + represent the actual faces that each primal node connects to, which become + the nodes of the dual mesh faces. """ - correction = 0 - max_edges = len(node_face_connectivity[0]) - construct_node_face_connectivity = np.full( - (np.sum(n_edges > 2), max_edges), INT_FILL_VALUE, dtype=INT_DTYPE - ) - for i in range(n_node): - # If we have less than 3 edges, we can't construct anything but a line - if n_edges[i] < 3: - correction += 1 - continue + n_valid = valid_node_indices.shape[0] + + for out_idx in prange(n_valid): + i = valid_node_indices[out_idx] # Construct temporary face to hold unordered face nodes temp_face = np.array( [INT_FILL_VALUE for _ in range(n_edges[i])], dtype=INT_DTYPE ) - # Get a list of the valid non fill value nodes - valid_node_indices = node_face_connectivity[i][0 : n_edges[i]] - index = 0 + # Get the face indices this node connects to (these become dual face nodes) + connected_faces = node_face_connectivity[i][0 : n_edges[i]] # Connect the face centers around the node to make dual face - for node_idx in valid_node_indices: - temp_face[index] = node_idx - index += 1 + for index, node_idx in enumerate(connected_faces): + if node_idx != INT_FILL_VALUE: + temp_face[index] = node_idx # Order the nodes using the angles so the faces have nodes in counter-clockwise sequence node_central = np.array([node_x[i], node_y[i], node_z[i]]) @@ -130,7 +145,7 @@ def construct_faces( ) # Order the face nodes properly in a counter-clockwise fashion - if temp_face[0] is not INT_FILL_VALUE: + if temp_face[0] != INT_FILL_VALUE: _face = _order_nodes( temp_face, node_0, @@ -141,7 +156,8 @@ def construct_faces( dual_node_z, max_edges, ) - construct_node_face_connectivity[i - correction] = _face + construct_node_face_connectivity[out_idx] = _face + return construct_node_face_connectivity @@ -183,10 +199,18 @@ def _order_nodes( final_face : np.ndarray The face in proper counter-clockwise order """ + # Add numerical stability check for degenerate cases + if n_edges < 3: + return np.full(max_edges, INT_FILL_VALUE, dtype=INT_DTYPE) + node_zero = node_0 - node_central + node_zero_mag = np.linalg.norm(node_zero) + + # Check for numerical stability + if node_zero_mag < 1e-15: + return np.full(max_edges, INT_FILL_VALUE, dtype=INT_DTYPE) node_cross = np.cross(node_0, node_central) - node_zero_mag = np.linalg.norm(node_zero) d_angles = np.zeros(n_edges, dtype=np.float64) d_angles[0] = 0.0 @@ -205,11 +229,16 @@ def _order_nodes( node_diff = sub - node_central node_diff_mag = np.linalg.norm(node_diff) + # Skip if node difference is too small (numerical stability) + if node_diff_mag < 1e-15: + d_angles[j] = 0.0 + continue + d_side = np.dot(node_cross, node_diff) d_dot_norm = np.dot(node_zero, node_diff) / (node_zero_mag * node_diff_mag) - if d_dot_norm > 1.0: - d_dot_norm = 1.0 + # Clamp to valid range for arccos to avoid numerical errors + d_dot_norm = max(-1.0, min(1.0, d_dot_norm)) d_angles[j] = np.arccos(d_dot_norm) diff --git a/uxarray/remap/bilinear.py b/uxarray/remap/bilinear.py index 7655a55c2..34ec94a15 100644 --- a/uxarray/remap/bilinear.py +++ b/uxarray/remap/bilinear.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from numba import njit +from numba import njit, prange if TYPE_CHECKING: from uxarray.core.dataarray import UxDataArray @@ -140,6 +140,7 @@ def _barycentric_weights(point_xyz, dual, data_size, source_grid): dual.node_z.values, dual.face_node_connectivity.values, dual.n_nodes_per_face.values, + dual.n_face, all_weights, all_indices, ) @@ -147,7 +148,7 @@ def _barycentric_weights(point_xyz, dual, data_size, source_grid): return all_weights, all_indices -@njit(cache=True) +@njit(cache=True, parallel=True) def _calculate_weights( valid_idxs, point_xyz, @@ -157,13 +158,16 @@ def _calculate_weights( z, face_node_conn, n_nodes_per_face, + n_faces, all_weights, all_indices, ): - for idx in valid_idxs: - fidx = int(face_indices[idx, 0]) + for idx in prange(len(valid_idxs)): + fidx = int(face_indices[valid_idxs[idx], 0]) + # bounds check: ensure face index is within valid range (0 to n_faces-1) + if fidx < 0 or fidx >= n_faces: + continue nverts = int(n_nodes_per_face[fidx]) - polygon_xyz = np.zeros((nverts, 3), dtype=np.float64) polygon_face_indices = np.empty(nverts, dtype=np.int32) for j in range(nverts): @@ -174,18 +178,18 @@ def _calculate_weights( polygon_face_indices[j] = node # snap check - match = _find_matching_node_index(polygon_xyz, point_xyz[idx]) + match = _find_matching_node_index(polygon_xyz, point_xyz[valid_idxs[idx]]) if match[0] != -1: - all_weights[idx, 0] = 1.0 - all_indices[idx, 0] = polygon_face_indices[match[0]] + all_weights[valid_idxs[idx], 0] = 1.0 + all_indices[valid_idxs[idx], 0] = polygon_face_indices[match[0]] continue weights, node_idxs = barycentric_coordinates_cartesian( - polygon_xyz, point_xyz[idx] + polygon_xyz, point_xyz[valid_idxs[idx]] ) for k in range(len(weights)): - all_weights[idx, k] = weights[k] - all_indices[idx, k] = polygon_face_indices[node_idxs[k]] + all_weights[valid_idxs[idx], k] = weights[k] + all_indices[valid_idxs[idx], k] = polygon_face_indices[node_idxs[k]] @njit(cache=True)