Skip to content

Commit 53ee50f

Browse files
authored
Merge pull request #19 from davidwilby/simplify_stitching
Simplify stitching process
2 parents b4e9ff5 + 9943e99 commit 53ee50f

File tree

1 file changed

+81
-99
lines changed

1 file changed

+81
-99
lines changed

deepsensor/model/model.py

Lines changed: 81 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,51 @@ def overlap_index(
810810
),
811811
)
812812

813+
def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool) -> tuple:
814+
"""
815+
Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions.
816+
817+
Parameters
818+
----------
819+
ds : Data object
820+
The dataset or data array to determine coordinate extent for.
821+
Refer to `X_t` in `predict_patchwise()` for supported types.
822+
823+
x1_ascend : bool
824+
Whether the x1 coordinates ascend (increase) from top to bottom.
825+
826+
x2_ascend : bool
827+
Whether the x2 coordinates ascend (increase) from left to right.
828+
829+
Returns
830+
-------
831+
tuple of tuples:
832+
Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)).
833+
"""
834+
835+
if x1_ascend:
836+
ds_x1_coords = (
837+
ds.coords[orig_x1_name].min().values,
838+
ds.coords[orig_x1_name].max().values,
839+
)
840+
else:
841+
ds_x1_coords = (
842+
ds.coords[orig_x1_name].max().values,
843+
ds.coords[orig_x1_name].min().values,
844+
)
845+
if x2_ascend:
846+
ds_x2_coords = (
847+
ds.coords[orig_x2_name].min().values,
848+
ds.coords[orig_x2_name].max().values,
849+
)
850+
else:
851+
ds_x2_coords = (
852+
ds.coords[orig_x2_name].max().values,
853+
ds.coords[orig_x2_name].min().values,
854+
)
855+
return ds_x1_coords, ds_x2_coords
856+
857+
813858
def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]:
814859
"""Convert coordinates into pixel row/column (index).
815860
@@ -886,74 +931,38 @@ def stitch_clipped_predictions(
886931
combined: dict
887932
Dictionary object containing the stitched model predictions.
888933
"""
889-
# Get row/col index values of X_t. Order depends on whether coordinate is ascending or descending.
890-
if x1_ascend:
891-
data_x1 = (
892-
X_t.coords[orig_x1_name].min().values,
893-
X_t.coords[orig_x1_name].max().values,
894-
)
895-
else:
896-
data_x1 = (
897-
X_t.coords[orig_x1_name].max().values,
898-
X_t.coords[orig_x1_name].min().values,
899-
)
900-
if x2_ascend:
901-
data_x2 = (
902-
X_t.coords[orig_x2_name].min().values,
903-
X_t.coords[orig_x2_name].max().values,
904-
)
905-
else:
906-
data_x2 = (
907-
X_t.coords[orig_x2_name].max().values,
908-
X_t.coords[orig_x2_name].min().values,
909-
)
910934

911-
data_x1_index, data_x2_index = get_index(data_x1, data_x2)
912-
patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()}
935+
# Get row/col index values of X_t.
936+
data_x1_coords, data_x2_coords= get_coordinate_extent(X_t, x1_ascend, x2_ascend)
937+
data_x1_index, data_x2_index = get_index(data_x1_coords, data_x2_coords)
913938

939+
# Iterate through patchwise predictions and slice edges prior to stitchin.
940+
patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()}
914941
for i, patch_pred in enumerate(patch_preds):
915942
for var_name, data_array in patch_pred.items():
916943
if var_name in patch_pred:
917-
# Get row/col index values of each patch. Order depends on whether coordinate is ascending or descending.
918-
if x1_ascend:
919-
patch_x1 = (
920-
data_array.coords[orig_x1_name].min().values,
921-
data_array.coords[orig_x1_name].max().values,
922-
)
923-
else:
924-
patch_x1 = (
925-
data_array.coords[orig_x1_name].max().values,
926-
data_array.coords[orig_x1_name].min().values,
927-
)
928-
if x2_ascend:
929-
patch_x2 = (
930-
data_array.coords[orig_x2_name].min().values,
931-
data_array.coords[orig_x2_name].max().values,
932-
)
933-
else:
934-
patch_x2 = (
935-
data_array.coords[orig_x2_name].max().values,
936-
data_array.coords[orig_x2_name].min().values,
937-
)
938-
patch_x1_index, patch_x2_index = get_index(patch_x1, patch_x2)
944+
945+
# Get row/col index values of each patch.
946+
patch_x1_coords, patch_x2_coords= get_coordinate_extent(data_array, x1_ascend, x2_ascend)
947+
patch_x1_index, patch_x2_index = get_index(patch_x1_coords, patch_x2_coords)
939948

949+
# Calculate size of border to slice of each edge of patchwise predictions.
950+
# Initially set the size of all borders to the size of the overlap.
940951
b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0]
941952
b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1]
942953

943954
# Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
944955
if patch_x2_index[0] == data_x2_index[0]:
945956
b_x2_min = 0
946-
# TODO: Try to resolve this issue in data/loader.py by ensuring patches are perfectly square.
947957
b_x2_max = b_x2_max
948958

949-
# At end of row (when patch_x2_index = data_x2_index), to calculate the number of pixels to remove from left hand side of patch:
959+
# At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
950960
elif patch_x2_index[1] == data_x2_index[1]:
951961
b_x2_max = 0
952962
patch_row_prev = preds[i - 1]
953963

954964
# If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
955965
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
956-
# to get the number of pixels to remove from left hand side of patch.
957966
if x2_ascend:
958967
prev_patch_x2_max = get_index(
959968
patch_row_prev[var_name].coords[orig_x2_name].max(),
@@ -963,9 +972,8 @@ def stitch_clipped_predictions(
963972
prev_patch_x2_max - patch_x2_index[0]
964973
) - patch_overlap[1]
965974

966-
# If x2 is descending. Subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
975+
# If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
967976
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
968-
# to get the number of pixels to remove from left hand side of patch.
969977
else:
970978
prev_patch_x2_min = get_index(
971979
patch_row_prev[var_name].coords[orig_x2_name].min(),
@@ -977,9 +985,10 @@ def stitch_clipped_predictions(
977985
else:
978986
b_x2_max = b_x2_max
979987

988+
# Repeat process as above for x1 coordinates.
980989
if patch_x1_index[0] == data_x1_index[0]:
981990
b_x1_min = 0
982-
# TODO: ensure this elif statement is robust to multiple patch sizes.
991+
983992
elif abs(patch_x1_index[1] - data_x1_index[1]) < 2:
984993
b_x1_max = 0
985994
b_x1_max = b_x1_max
@@ -1013,6 +1022,7 @@ def stitch_clipped_predictions(
10131022
data_array.sizes[orig_x2_name] - b_x2_max
10141023
)
10151024

1025+
# Slice patchwise predictions
10161026
patch_clip = data_array.isel(
10171027
**{
10181028
orig_x1_name: slice(
@@ -1025,51 +1035,47 @@ def stitch_clipped_predictions(
10251035
)
10261036

10271037
patches_clipped[var_name].append(patch_clip)
1028-
1029-
# Create blank prediction
1030-
combined_dataset = copy.deepcopy(patches_clipped)
1031-
1032-
# Generate new blank DeepSensor.prediction object with same extent and coordinate system as X_t.
1033-
for var, data_array_list in combined_dataset.items():
1034-
first_patchwise_pred = data_array_list[0]
1035-
1036-
# Define coordinate extent and time
1037-
blank_pred = xr.Dataset(
1038+
1039+
# Create blank prediction object to stitch prediction values onto.
1040+
stitched_prediction = copy.deepcopy(patch_preds[0])
1041+
# Set prediction object extent to the same as X_t.
1042+
for var_name, data_array in stitched_prediction.items():
1043+
blank_ds= xr.Dataset(
10381044
coords={
10391045
orig_x1_name: X_t[orig_x1_name],
10401046
orig_x2_name: X_t[orig_x2_name],
1041-
"time": first_patchwise_pred["time"],
1047+
"time": stitched_prediction[0]["time"],
10421048
}
10431049
)
10441050

1045-
# Set variable names to those in patched predictions, set values to Nan.
1046-
for param in first_patchwise_pred.data_vars:
1047-
blank_pred[param] = first_patchwise_pred[param]
1048-
blank_pred[param][:] = np.nan
1049-
combined_dataset[var] = blank_pred
1050-
1051-
# Merge patchwise predictions to create final combined dataset.
1051+
# Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan.
1052+
for data_var in data_array.data_vars:
1053+
blank_ds[data_var] = data_array[data_var]
1054+
blank_ds[data_var][:] = np.nan
1055+
stitched_prediction[var_name] = blank_ds
1056+
1057+
# Merge patchwise predictions to create final stiched prediction.
10521058
# Iterate over each variable (key) in the prediction dictionary
10531059
for var_name, patches in patches_clipped.items():
10541060
# Retrieve the blank dataset for the current variable
1055-
combined_array = combined_dataset[var_name]
1061+
prediction_array = stitched_prediction[var_name]
10561062

10571063
# Merge each patch into the combined dataset
10581064
for patch in patches:
10591065
for var in patch.data_vars:
10601066
# Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset
10611067
reindexed_patch = patch[var].reindex_like(
1062-
combined_array[var], method="nearest", tolerance=1e-6
1068+
prediction_array[var], method="nearest", tolerance=1e-6
10631069
)
10641070

10651071
# Combine data, prioritizing non-NaN values from patches
1066-
combined_array[var] = combined_array[var].where(
1072+
prediction_array[var] = prediction_array[var].where(
10671073
np.isnan(reindexed_patch), reindexed_patch
10681074
)
10691075

10701076
# Update the dictionary with the merged dataset
1071-
combined_dataset[var_name] = combined_array
1072-
return combined_dataset
1077+
stitched_prediction[var_name] = prediction_array
1078+
return stitched_prediction
10731079

10741080
# load patch_size and stride from task
10751081
patch_size = tasks[0]["patch_size"]
@@ -1164,34 +1170,10 @@ def stitch_clipped_predictions(
11641170
)
11651171

11661172
patches_per_row = get_patches_per_row(preds)
1167-
stitched_prediction = stitch_clipped_predictions(
1173+
prediction = stitch_clipped_predictions(
11681174
preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending
11691175
)
11701176

1171-
## Cast prediction into DeepSensor.Prediction object.
1172-
# TODO make this into seperate method.
1173-
prediction = copy.deepcopy(preds[0])
1174-
1175-
# Generate new blank DeepSensor.prediction object in original coordinate system.
1176-
for var_name_copy, data_array_copy in prediction.items():
1177-
# set x and y coords
1178-
stitched_preds = xr.Dataset(
1179-
coords={
1180-
orig_x1_name: X_t[orig_x1_name],
1181-
orig_x2_name: X_t[orig_x2_name],
1182-
}
1183-
)
1184-
1185-
# Set time to same as patched prediction
1186-
stitched_preds["time"] = data_array_copy["time"]
1187-
1188-
# set variable names to those in patched prediction, make values blank
1189-
for var_name_i in data_array_copy.data_vars:
1190-
stitched_preds[var_name_i] = data_array_copy[var_name_i]
1191-
stitched_preds[var_name_i][:] = np.nan
1192-
prediction[var_name_copy] = stitched_preds
1193-
prediction[var_name_copy] = stitched_prediction[var_name_copy]
1194-
11951177
return prediction
11961178

11971179

0 commit comments

Comments
 (0)