@@ -810,6 +810,51 @@ def overlap_index(
810
810
),
811
811
)
812
812
813
+ def get_coordinate_extent (ds : Union [xr .DataArray , xr .Dataset ], x1_ascend : bool , x2_ascend : bool ) -> tuple :
814
+ """
815
+ Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions.
816
+
817
+ Parameters
818
+ ----------
819
+ ds : Data object
820
+ The dataset or data array to determine coordinate extent for.
821
+ Refer to `X_t` in `predict_patchwise()` for supported types.
822
+
823
+ x1_ascend : bool
824
+ Whether the x1 coordinates ascend (increase) from top to bottom.
825
+
826
+ x2_ascend : bool
827
+ Whether the x2 coordinates ascend (increase) from left to right.
828
+
829
+ Returns
830
+ -------
831
+ tuple of tuples:
832
+ Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)).
833
+ """
834
+
835
+ if x1_ascend :
836
+ ds_x1_coords = (
837
+ ds .coords [orig_x1_name ].min ().values ,
838
+ ds .coords [orig_x1_name ].max ().values ,
839
+ )
840
+ else :
841
+ ds_x1_coords = (
842
+ ds .coords [orig_x1_name ].max ().values ,
843
+ ds .coords [orig_x1_name ].min ().values ,
844
+ )
845
+ if x2_ascend :
846
+ ds_x2_coords = (
847
+ ds .coords [orig_x2_name ].min ().values ,
848
+ ds .coords [orig_x2_name ].max ().values ,
849
+ )
850
+ else :
851
+ ds_x2_coords = (
852
+ ds .coords [orig_x2_name ].max ().values ,
853
+ ds .coords [orig_x2_name ].min ().values ,
854
+ )
855
+ return ds_x1_coords , ds_x2_coords
856
+
857
+
813
858
def get_index (* args , x1 = True ) -> Union [int , Tuple [List [int ], List [int ]]]:
814
859
"""Convert coordinates into pixel row/column (index).
815
860
@@ -886,74 +931,38 @@ def stitch_clipped_predictions(
886
931
combined: dict
887
932
Dictionary object containing the stitched model predictions.
888
933
"""
889
- # Get row/col index values of X_t. Order depends on whether coordinate is ascending or descending.
890
- if x1_ascend :
891
- data_x1 = (
892
- X_t .coords [orig_x1_name ].min ().values ,
893
- X_t .coords [orig_x1_name ].max ().values ,
894
- )
895
- else :
896
- data_x1 = (
897
- X_t .coords [orig_x1_name ].max ().values ,
898
- X_t .coords [orig_x1_name ].min ().values ,
899
- )
900
- if x2_ascend :
901
- data_x2 = (
902
- X_t .coords [orig_x2_name ].min ().values ,
903
- X_t .coords [orig_x2_name ].max ().values ,
904
- )
905
- else :
906
- data_x2 = (
907
- X_t .coords [orig_x2_name ].max ().values ,
908
- X_t .coords [orig_x2_name ].min ().values ,
909
- )
910
934
911
- data_x1_index , data_x2_index = get_index (data_x1 , data_x2 )
912
- patches_clipped = {var_name : [] for var_name in patch_preds [0 ].keys ()}
935
+ # Get row/col index values of X_t.
936
+ data_x1_coords , data_x2_coords = get_coordinate_extent (X_t , x1_ascend , x2_ascend )
937
+ data_x1_index , data_x2_index = get_index (data_x1_coords , data_x2_coords )
913
938
939
+ # Iterate through patchwise predictions and slice edges prior to stitchin.
940
+ patches_clipped = {var_name : [] for var_name in patch_preds [0 ].keys ()}
914
941
for i , patch_pred in enumerate (patch_preds ):
915
942
for var_name , data_array in patch_pred .items ():
916
943
if var_name in patch_pred :
917
- # Get row/col index values of each patch. Order depends on whether coordinate is ascending or descending.
918
- if x1_ascend :
919
- patch_x1 = (
920
- data_array .coords [orig_x1_name ].min ().values ,
921
- data_array .coords [orig_x1_name ].max ().values ,
922
- )
923
- else :
924
- patch_x1 = (
925
- data_array .coords [orig_x1_name ].max ().values ,
926
- data_array .coords [orig_x1_name ].min ().values ,
927
- )
928
- if x2_ascend :
929
- patch_x2 = (
930
- data_array .coords [orig_x2_name ].min ().values ,
931
- data_array .coords [orig_x2_name ].max ().values ,
932
- )
933
- else :
934
- patch_x2 = (
935
- data_array .coords [orig_x2_name ].max ().values ,
936
- data_array .coords [orig_x2_name ].min ().values ,
937
- )
938
- patch_x1_index , patch_x2_index = get_index (patch_x1 , patch_x2 )
944
+
945
+ # Get row/col index values of each patch.
946
+ patch_x1_coords , patch_x2_coords = get_coordinate_extent (data_array , x1_ascend , x2_ascend )
947
+ patch_x1_index , patch_x2_index = get_index (patch_x1_coords , patch_x2_coords )
939
948
949
+ # Calculate size of border to slice of each edge of patchwise predictions.
950
+ # Initially set the size of all borders to the size of the overlap.
940
951
b_x1_min , b_x1_max = patch_overlap [0 ], patch_overlap [0 ]
941
952
b_x2_min , b_x2_max = patch_overlap [1 ], patch_overlap [1 ]
942
953
943
954
# Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
944
955
if patch_x2_index [0 ] == data_x2_index [0 ]:
945
956
b_x2_min = 0
946
- # TODO: Try to resolve this issue in data/loader.py by ensuring patches are perfectly square.
947
957
b_x2_max = b_x2_max
948
958
949
- # At end of row (when patch_x2_index = data_x2_index), to calculate the number of pixels to remove from left hand side of patch:
959
+ # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
950
960
elif patch_x2_index [1 ] == data_x2_index [1 ]:
951
961
b_x2_max = 0
952
962
patch_row_prev = preds [i - 1 ]
953
963
954
964
# If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
955
965
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
956
- # to get the number of pixels to remove from left hand side of patch.
957
966
if x2_ascend :
958
967
prev_patch_x2_max = get_index (
959
968
patch_row_prev [var_name ].coords [orig_x2_name ].max (),
@@ -963,9 +972,8 @@ def stitch_clipped_predictions(
963
972
prev_patch_x2_max - patch_x2_index [0 ]
964
973
) - patch_overlap [1 ]
965
974
966
- # If x2 is descending. Subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
975
+ # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
967
976
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
968
- # to get the number of pixels to remove from left hand side of patch.
969
977
else :
970
978
prev_patch_x2_min = get_index (
971
979
patch_row_prev [var_name ].coords [orig_x2_name ].min (),
@@ -977,9 +985,10 @@ def stitch_clipped_predictions(
977
985
else :
978
986
b_x2_max = b_x2_max
979
987
988
+ # Repeat process as above for x1 coordinates.
980
989
if patch_x1_index [0 ] == data_x1_index [0 ]:
981
990
b_x1_min = 0
982
- # TODO: ensure this elif statement is robust to multiple patch sizes.
991
+
983
992
elif abs (patch_x1_index [1 ] - data_x1_index [1 ]) < 2 :
984
993
b_x1_max = 0
985
994
b_x1_max = b_x1_max
@@ -1013,6 +1022,7 @@ def stitch_clipped_predictions(
1013
1022
data_array .sizes [orig_x2_name ] - b_x2_max
1014
1023
)
1015
1024
1025
+ # Slice patchwise predictions
1016
1026
patch_clip = data_array .isel (
1017
1027
** {
1018
1028
orig_x1_name : slice (
@@ -1025,51 +1035,47 @@ def stitch_clipped_predictions(
1025
1035
)
1026
1036
1027
1037
patches_clipped [var_name ].append (patch_clip )
1028
-
1029
- # Create blank prediction
1030
- combined_dataset = copy .deepcopy (patches_clipped )
1031
-
1032
- # Generate new blank DeepSensor.prediction object with same extent and coordinate system as X_t.
1033
- for var , data_array_list in combined_dataset .items ():
1034
- first_patchwise_pred = data_array_list [0 ]
1035
-
1036
- # Define coordinate extent and time
1037
- blank_pred = xr .Dataset (
1038
+
1039
+ # Create blank prediction object to stitch prediction values onto.
1040
+ stitched_prediction = copy .deepcopy (patch_preds [0 ])
1041
+ # Set prediction object extent to the same as X_t.
1042
+ for var_name , data_array in stitched_prediction .items ():
1043
+ blank_ds = xr .Dataset (
1038
1044
coords = {
1039
1045
orig_x1_name : X_t [orig_x1_name ],
1040
1046
orig_x2_name : X_t [orig_x2_name ],
1041
- "time" : first_patchwise_pred ["time" ],
1047
+ "time" : stitched_prediction [ 0 ] ["time" ],
1042
1048
}
1043
1049
)
1044
1050
1045
- # Set variable names to those in patched predictions, set values to Nan.
1046
- for param in first_patchwise_pred .data_vars :
1047
- blank_pred [ param ] = first_patchwise_pred [ param ]
1048
- blank_pred [ param ][:] = np .nan
1049
- combined_dataset [ var ] = blank_pred
1050
-
1051
- # Merge patchwise predictions to create final combined dataset .
1051
+ # Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan.
1052
+ for data_var in data_array .data_vars :
1053
+ blank_ds [ data_var ] = data_array [ data_var ]
1054
+ blank_ds [ data_var ][:] = np .nan
1055
+ stitched_prediction [ var_name ] = blank_ds
1056
+
1057
+ # Merge patchwise predictions to create final stiched prediction .
1052
1058
# Iterate over each variable (key) in the prediction dictionary
1053
1059
for var_name , patches in patches_clipped .items ():
1054
1060
# Retrieve the blank dataset for the current variable
1055
- combined_array = combined_dataset [var_name ]
1061
+ prediction_array = stitched_prediction [var_name ]
1056
1062
1057
1063
# Merge each patch into the combined dataset
1058
1064
for patch in patches :
1059
1065
for var in patch .data_vars :
1060
1066
# Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset
1061
1067
reindexed_patch = patch [var ].reindex_like (
1062
- combined_array [var ], method = "nearest" , tolerance = 1e-6
1068
+ prediction_array [var ], method = "nearest" , tolerance = 1e-6
1063
1069
)
1064
1070
1065
1071
# Combine data, prioritizing non-NaN values from patches
1066
- combined_array [var ] = combined_array [var ].where (
1072
+ prediction_array [var ] = prediction_array [var ].where (
1067
1073
np .isnan (reindexed_patch ), reindexed_patch
1068
1074
)
1069
1075
1070
1076
# Update the dictionary with the merged dataset
1071
- combined_dataset [var_name ] = combined_array
1072
- return combined_dataset
1077
+ stitched_prediction [var_name ] = prediction_array
1078
+ return stitched_prediction
1073
1079
1074
1080
# load patch_size and stride from task
1075
1081
patch_size = tasks [0 ]["patch_size" ]
@@ -1164,34 +1170,10 @@ def stitch_clipped_predictions(
1164
1170
)
1165
1171
1166
1172
patches_per_row = get_patches_per_row (preds )
1167
- stitched_prediction = stitch_clipped_predictions (
1173
+ prediction = stitch_clipped_predictions (
1168
1174
preds , patch_overlap_unnorm , patches_per_row , x1_ascending , x2_ascending
1169
1175
)
1170
1176
1171
- ## Cast prediction into DeepSensor.Prediction object.
1172
- # TODO make this into seperate method.
1173
- prediction = copy .deepcopy (preds [0 ])
1174
-
1175
- # Generate new blank DeepSensor.prediction object in original coordinate system.
1176
- for var_name_copy , data_array_copy in prediction .items ():
1177
- # set x and y coords
1178
- stitched_preds = xr .Dataset (
1179
- coords = {
1180
- orig_x1_name : X_t [orig_x1_name ],
1181
- orig_x2_name : X_t [orig_x2_name ],
1182
- }
1183
- )
1184
-
1185
- # Set time to same as patched prediction
1186
- stitched_preds ["time" ] = data_array_copy ["time" ]
1187
-
1188
- # set variable names to those in patched prediction, make values blank
1189
- for var_name_i in data_array_copy .data_vars :
1190
- stitched_preds [var_name_i ] = data_array_copy [var_name_i ]
1191
- stitched_preds [var_name_i ][:] = np .nan
1192
- prediction [var_name_copy ] = stitched_preds
1193
- prediction [var_name_copy ] = stitched_prediction [var_name_copy ]
1194
-
1195
1177
return prediction
1196
1178
1197
1179
0 commit comments