Skip to content

Commit

Permalink
Merge pull request #19 from davidwilby/simplify_stitching
Browse files Browse the repository at this point in the history
Simplify stitching process
  • Loading branch information
davidwilby authored Jan 21, 2025
2 parents b4e9ff5 + 9943e99 commit 53ee50f
Showing 1 changed file with 81 additions and 99 deletions.
180 changes: 81 additions & 99 deletions deepsensor/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,51 @@ def overlap_index(
),
)

def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool) -> tuple:
"""
Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions.
Parameters
----------
ds : Data object
The dataset or data array to determine coordinate extent for.
Refer to `X_t` in `predict_patchwise()` for supported types.
x1_ascend : bool
Whether the x1 coordinates ascend (increase) from top to bottom.
x2_ascend : bool
Whether the x2 coordinates ascend (increase) from left to right.
Returns
-------
tuple of tuples:
Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)).
"""

if x1_ascend:
ds_x1_coords = (
ds.coords[orig_x1_name].min().values,
ds.coords[orig_x1_name].max().values,
)
else:
ds_x1_coords = (
ds.coords[orig_x1_name].max().values,
ds.coords[orig_x1_name].min().values,
)
if x2_ascend:
ds_x2_coords = (
ds.coords[orig_x2_name].min().values,
ds.coords[orig_x2_name].max().values,
)
else:
ds_x2_coords = (
ds.coords[orig_x2_name].max().values,
ds.coords[orig_x2_name].min().values,
)
return ds_x1_coords, ds_x2_coords


def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]:
"""Convert coordinates into pixel row/column (index).
Expand Down Expand Up @@ -886,74 +931,38 @@ def stitch_clipped_predictions(
combined: dict
Dictionary object containing the stitched model predictions.
"""
# Get row/col index values of X_t. Order depends on whether coordinate is ascending or descending.
if x1_ascend:
data_x1 = (
X_t.coords[orig_x1_name].min().values,
X_t.coords[orig_x1_name].max().values,
)
else:
data_x1 = (
X_t.coords[orig_x1_name].max().values,
X_t.coords[orig_x1_name].min().values,
)
if x2_ascend:
data_x2 = (
X_t.coords[orig_x2_name].min().values,
X_t.coords[orig_x2_name].max().values,
)
else:
data_x2 = (
X_t.coords[orig_x2_name].max().values,
X_t.coords[orig_x2_name].min().values,
)

data_x1_index, data_x2_index = get_index(data_x1, data_x2)
patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()}
# Get row/col index values of X_t.
data_x1_coords, data_x2_coords= get_coordinate_extent(X_t, x1_ascend, x2_ascend)
data_x1_index, data_x2_index = get_index(data_x1_coords, data_x2_coords)

# Iterate through patchwise predictions and slice edges prior to stitchin.
patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()}
for i, patch_pred in enumerate(patch_preds):
for var_name, data_array in patch_pred.items():
if var_name in patch_pred:
# Get row/col index values of each patch. Order depends on whether coordinate is ascending or descending.
if x1_ascend:
patch_x1 = (
data_array.coords[orig_x1_name].min().values,
data_array.coords[orig_x1_name].max().values,
)
else:
patch_x1 = (
data_array.coords[orig_x1_name].max().values,
data_array.coords[orig_x1_name].min().values,
)
if x2_ascend:
patch_x2 = (
data_array.coords[orig_x2_name].min().values,
data_array.coords[orig_x2_name].max().values,
)
else:
patch_x2 = (
data_array.coords[orig_x2_name].max().values,
data_array.coords[orig_x2_name].min().values,
)
patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2)

# Get row/col index values of each patch.
patch_x1_coords, patch_x2_coords= get_coordinate_extent(data_array, x1_ascend, x2_ascend)
patch_x1_index, patch_x2_index = get_index(patch_x1_coords, patch_x2_coords)

# Calculate size of border to slice of each edge of patchwise predictions.
# Initially set the size of all borders to the size of the overlap.
b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0]
b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1]

# Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
if patch_x2_index[0] == data_x2_index[0]:
b_x2_min = 0
# TODO: Try to resolve this issue in data/loader.py by ensuring patches are perfectly square.
b_x2_max = b_x2_max

# At end of row (when patch_x2_index = data_x2_index), to calculate the number of pixels to remove from left hand side of patch:
# At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
elif patch_x2_index[1] == data_x2_index[1]:
b_x2_max = 0
patch_row_prev = preds[i - 1]

# If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
# to get the number of pixels to remove from left hand side of patch.
if x2_ascend:
prev_patch_x2_max = get_index(
patch_row_prev[var_name].coords[orig_x2_name].max(),
Expand All @@ -963,9 +972,8 @@ def stitch_clipped_predictions(
prev_patch_x2_max - patch_x2_index[0]
) - patch_overlap[1]

# If x2 is descending. Subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
# If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
# to get the number of pixels to remove from left hand side of patch.
else:
prev_patch_x2_min = get_index(
patch_row_prev[var_name].coords[orig_x2_name].min(),
Expand All @@ -977,9 +985,10 @@ def stitch_clipped_predictions(
else:
b_x2_max = b_x2_max

# Repeat process as above for x1 coordinates.
if patch_x1_index[0] == data_x1_index[0]:
b_x1_min = 0
# TODO: ensure this elif statement is robust to multiple patch sizes.

elif abs(patch_x1_index[1] - data_x1_index[1]) < 2:
b_x1_max = 0
b_x1_max = b_x1_max
Expand Down Expand Up @@ -1013,6 +1022,7 @@ def stitch_clipped_predictions(
data_array.sizes[orig_x2_name] - b_x2_max
)

# Slice patchwise predictions
patch_clip = data_array.isel(
**{
orig_x1_name: slice(
Expand All @@ -1025,51 +1035,47 @@ def stitch_clipped_predictions(
)

patches_clipped[var_name].append(patch_clip)

# Create blank prediction
combined_dataset = copy.deepcopy(patches_clipped)

# Generate new blank DeepSensor.prediction object with same extent and coordinate system as X_t.
for var, data_array_list in combined_dataset.items():
first_patchwise_pred = data_array_list[0]

# Define coordinate extent and time
blank_pred = xr.Dataset(

# Create blank prediction object to stitch prediction values onto.
stitched_prediction = copy.deepcopy(patch_preds[0])
# Set prediction object extent to the same as X_t.
for var_name, data_array in stitched_prediction.items():
blank_ds= xr.Dataset(
coords={
orig_x1_name: X_t[orig_x1_name],
orig_x2_name: X_t[orig_x2_name],
"time": first_patchwise_pred["time"],
"time": stitched_prediction[0]["time"],
}
)

# Set variable names to those in patched predictions, set values to Nan.
for param in first_patchwise_pred.data_vars:
blank_pred[param] = first_patchwise_pred[param]
blank_pred[param][:] = np.nan
combined_dataset[var] = blank_pred

# Merge patchwise predictions to create final combined dataset.
# Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan.
for data_var in data_array.data_vars:
blank_ds[data_var] = data_array[data_var]
blank_ds[data_var][:] = np.nan
stitched_prediction[var_name] = blank_ds
# Merge patchwise predictions to create final stiched prediction.
# Iterate over each variable (key) in the prediction dictionary
for var_name, patches in patches_clipped.items():
# Retrieve the blank dataset for the current variable
combined_array = combined_dataset[var_name]
prediction_array = stitched_prediction[var_name]

# Merge each patch into the combined dataset
for patch in patches:
for var in patch.data_vars:
# Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset
reindexed_patch = patch[var].reindex_like(
combined_array[var], method="nearest", tolerance=1e-6
prediction_array[var], method="nearest", tolerance=1e-6
)

# Combine data, prioritizing non-NaN values from patches
combined_array[var] = combined_array[var].where(
prediction_array[var] = prediction_array[var].where(
np.isnan(reindexed_patch), reindexed_patch
)

# Update the dictionary with the merged dataset
combined_dataset[var_name] = combined_array
return combined_dataset
stitched_prediction[var_name] = prediction_array
return stitched_prediction

# load patch_size and stride from task
patch_size = tasks[0]["patch_size"]
Expand Down Expand Up @@ -1164,34 +1170,10 @@ def stitch_clipped_predictions(
)

patches_per_row = get_patches_per_row(preds)
stitched_prediction = stitch_clipped_predictions(
prediction = stitch_clipped_predictions(
preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending
)

## Cast prediction into DeepSensor.Prediction object.
# TODO make this into seperate method.
prediction = copy.deepcopy(preds[0])

# Generate new blank DeepSensor.prediction object in original coordinate system.
for var_name_copy, data_array_copy in prediction.items():
# set x and y coords
stitched_preds = xr.Dataset(
coords={
orig_x1_name: X_t[orig_x1_name],
orig_x2_name: X_t[orig_x2_name],
}
)

# Set time to same as patched prediction
stitched_preds["time"] = data_array_copy["time"]

# set variable names to those in patched prediction, make values blank
for var_name_i in data_array_copy.data_vars:
stitched_preds[var_name_i] = data_array_copy[var_name_i]
stitched_preds[var_name_i][:] = np.nan
prediction[var_name_copy] = stitched_preds
prediction[var_name_copy] = stitched_prediction[var_name_copy]

return prediction


Expand Down

0 comments on commit 53ee50f

Please sign in to comment.