Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwilby committed Jan 27, 2025
1 parent 1f0fb32 commit be883dc
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions deepsensor/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,10 @@ def overlap_index(
),
)

def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool) -> tuple:
"""
Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions.
def get_coordinate_extent(
ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool
) -> tuple:
"""Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions.
Parameters
----------
Expand All @@ -825,12 +826,11 @@ def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool,
x2_ascend : bool
Whether the x2 coordinates ascend (increase) from left to right.
Returns
Returns:
-------
tuple of tuples:
Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)).
"""

if x1_ascend:
ds_x1_coords = (
ds.coords[orig_x1_name].min().values,
Expand All @@ -853,7 +853,6 @@ def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool,
)
return ds_x1_coords, ds_x2_coords


def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]:
"""Convert coordinates into pixel row/column (index).
Expand Down Expand Up @@ -930,22 +929,27 @@ def stitch_clipped_predictions(
combined: dict
Dictionary object containing the stitched model predictions.
"""

# Get row/col index values of X_t.
data_x1_coords, data_x2_coords= get_coordinate_extent(X_t, x1_ascend, x2_ascend)
# Get row/col index values of X_t.
data_x1_coords, data_x2_coords = get_coordinate_extent(
X_t, x1_ascend, x2_ascend
)
data_x1_index, data_x2_index = get_index(data_x1_coords, data_x2_coords)

# Iterate through patchwise predictions and slice edges prior to stitchin.
patches_clipped = []
for i, patch_pred in enumerate(patch_preds):
# get one variable name to use for coordinates and extent
first_key = list(patch_pred.keys())[0]
first_key = list(patch_pred.keys())[0]
# Get row/col index values of each patch.
patch_x1_coords, patch_x2_coords= get_coordinate_extent(patch_pred[first_key], x1_ascend, x2_ascend)
patch_x1_index, patch_x2_index = get_index(patch_x1_coords, patch_x2_coords)
patch_x1_coords, patch_x2_coords = get_coordinate_extent(
patch_pred[first_key], x1_ascend, x2_ascend
)
patch_x1_index, patch_x2_index = get_index(
patch_x1_coords, patch_x2_coords
)

# Calculate size of border to slice of each edge of patchwise predictions.
# Initially set the size of all borders to the size of the overlap.
# Initially set the size of all borders to the size of the overlap.
b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0]
b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1]

Expand Down Expand Up @@ -983,10 +987,10 @@ def stitch_clipped_predictions(
else:
b_x2_max = b_x2_max

# Repeat process as above for x1 coordinates.
# Repeat process as above for x1 coordinates.
if patch_x1_index[0] == data_x1_index[0]:
b_x1_min = 0

elif abs(patch_x1_index[1] - data_x1_index[1]) < 2:
b_x1_max = 0
b_x1_max = b_x1_max
Expand Down Expand Up @@ -1033,12 +1037,12 @@ def stitch_clipped_predictions(
}

patches_clipped.append(patch_clip)

# Create blank prediction object to stitch prediction values onto.
stitched_prediction = copy.deepcopy(patch_preds[0])
# Set prediction object extent to the same as X_t.
for var_name, data_array in stitched_prediction.items():
blank_ds= xr.Dataset(
blank_ds = xr.Dataset(
coords={
orig_x1_name: X_t[orig_x1_name],
orig_x2_name: X_t[orig_x2_name],
Expand All @@ -1051,9 +1055,12 @@ def stitch_clipped_predictions(
blank_ds[data_var] = data_array[data_var]
blank_ds[data_var][:] = np.nan
stitched_prediction[var_name] = blank_ds

# Restructure prediction objects for merging
restructured_patches = {key: [item[key] for item in patches_clipped] for key in patches_clipped[0].keys()}
restructured_patches = {
key: [item[key] for item in patches_clipped]
for key in patches_clipped[0].keys()
}

# Merge patchwise predictions to create final stiched prediction.
# Iterate over each variable (key) in the prediction dictionary
Expand Down

0 comments on commit be883dc

Please sign in to comment.