From be883dcf73109807fb11a32dd064bb30128390e3 Mon Sep 17 00:00:00 2001 From: davidwilby <24752124+davidwilby@users.noreply.github.com> Date: Mon, 27 Jan 2025 11:09:17 +0000 Subject: [PATCH] lint --- deepsensor/model/model.py | 45 ++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 6ddf4dc2..e297dabf 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -810,9 +810,10 @@ def overlap_index( ), ) - def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool) -> tuple: - """ - Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions. + def get_coordinate_extent( + ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool + ) -> tuple: + """Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions. Parameters ---------- @@ -825,12 +826,11 @@ def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend : bool Whether the x2 coordinates ascend (increase) from left to right. - Returns + Returns: ------- tuple of tuples: Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)). """ - if x1_ascend: ds_x1_coords = ( ds.coords[orig_x1_name].min().values, @@ -853,7 +853,6 @@ def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, ) return ds_x1_coords, ds_x2_coords - def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]: """Convert coordinates into pixel row/column (index). @@ -930,22 +929,27 @@ def stitch_clipped_predictions( combined: dict Dictionary object containing the stitched model predictions. """ - - # Get row/col index values of X_t. - data_x1_coords, data_x2_coords= get_coordinate_extent(X_t, x1_ascend, x2_ascend) + # Get row/col index values of X_t. + data_x1_coords, data_x2_coords = get_coordinate_extent( + X_t, x1_ascend, x2_ascend + ) data_x1_index, data_x2_index = get_index(data_x1_coords, data_x2_coords) # Iterate through patchwise predictions and slice edges prior to stitchin. patches_clipped = [] for i, patch_pred in enumerate(patch_preds): # get one variable name to use for coordinates and extent - first_key = list(patch_pred.keys())[0] + first_key = list(patch_pred.keys())[0] # Get row/col index values of each patch. - patch_x1_coords, patch_x2_coords= get_coordinate_extent(patch_pred[first_key], x1_ascend, x2_ascend) - patch_x1_index, patch_x2_index = get_index(patch_x1_coords, patch_x2_coords) + patch_x1_coords, patch_x2_coords = get_coordinate_extent( + patch_pred[first_key], x1_ascend, x2_ascend + ) + patch_x1_index, patch_x2_index = get_index( + patch_x1_coords, patch_x2_coords + ) # Calculate size of border to slice of each edge of patchwise predictions. - # Initially set the size of all borders to the size of the overlap. + # Initially set the size of all borders to the size of the overlap. b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] @@ -983,10 +987,10 @@ def stitch_clipped_predictions( else: b_x2_max = b_x2_max - # Repeat process as above for x1 coordinates. + # Repeat process as above for x1 coordinates. if patch_x1_index[0] == data_x1_index[0]: b_x1_min = 0 - + elif abs(patch_x1_index[1] - data_x1_index[1]) < 2: b_x1_max = 0 b_x1_max = b_x1_max @@ -1033,12 +1037,12 @@ def stitch_clipped_predictions( } patches_clipped.append(patch_clip) - + # Create blank prediction object to stitch prediction values onto. stitched_prediction = copy.deepcopy(patch_preds[0]) # Set prediction object extent to the same as X_t. for var_name, data_array in stitched_prediction.items(): - blank_ds= xr.Dataset( + blank_ds = xr.Dataset( coords={ orig_x1_name: X_t[orig_x1_name], orig_x2_name: X_t[orig_x2_name], @@ -1051,9 +1055,12 @@ def stitch_clipped_predictions( blank_ds[data_var] = data_array[data_var] blank_ds[data_var][:] = np.nan stitched_prediction[var_name] = blank_ds - + # Restructure prediction objects for merging - restructured_patches = {key: [item[key] for item in patches_clipped] for key in patches_clipped[0].keys()} + restructured_patches = { + key: [item[key] for item in patches_clipped] + for key in patches_clipped[0].keys() + } # Merge patchwise predictions to create final stiched prediction. # Iterate over each variable (key) in the prediction dictionary