@@ -810,6 +810,51 @@ def overlap_index(
810810 ),
811811 )
812812
813+ def get_coordinate_extent (ds : Union [xr .DataArray , xr .Dataset ], x1_ascend : bool , x2_ascend : bool ) -> tuple :
814+ """
815+ Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions.
816+
817+ Parameters
818+ ----------
819+ ds : Data object
820+ The dataset or data array to determine coordinate extent for.
821+ Refer to `X_t` in `predict_patchwise()` for supported types.
822+
823+ x1_ascend : bool
824+ Whether the x1 coordinates ascend (increase) from top to bottom.
825+
826+ x2_ascend : bool
827+ Whether the x2 coordinates ascend (increase) from left to right.
828+
829+ Returns
830+ -------
831+ tuple of tuples:
832+ Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)).
833+ """
834+
835+ if x1_ascend :
836+ ds_x1_coords = (
837+ ds .coords [orig_x1_name ].min ().values ,
838+ ds .coords [orig_x1_name ].max ().values ,
839+ )
840+ else :
841+ ds_x1_coords = (
842+ ds .coords [orig_x1_name ].max ().values ,
843+ ds .coords [orig_x1_name ].min ().values ,
844+ )
845+ if x2_ascend :
846+ ds_x2_coords = (
847+ ds .coords [orig_x2_name ].min ().values ,
848+ ds .coords [orig_x2_name ].max ().values ,
849+ )
850+ else :
851+ ds_x2_coords = (
852+ ds .coords [orig_x2_name ].max ().values ,
853+ ds .coords [orig_x2_name ].min ().values ,
854+ )
855+ return ds_x1_coords , ds_x2_coords
856+
857+
813858 def get_index (* args , x1 = True ) -> Union [int , Tuple [List [int ], List [int ]]]:
814859 """Convert coordinates into pixel row/column (index).
815860
@@ -886,74 +931,38 @@ def stitch_clipped_predictions(
886931 combined: dict
887932 Dictionary object containing the stitched model predictions.
888933 """
889- # Get row/col index values of X_t. Order depends on whether coordinate is ascending or descending.
890- if x1_ascend :
891- data_x1 = (
892- X_t .coords [orig_x1_name ].min ().values ,
893- X_t .coords [orig_x1_name ].max ().values ,
894- )
895- else :
896- data_x1 = (
897- X_t .coords [orig_x1_name ].max ().values ,
898- X_t .coords [orig_x1_name ].min ().values ,
899- )
900- if x2_ascend :
901- data_x2 = (
902- X_t .coords [orig_x2_name ].min ().values ,
903- X_t .coords [orig_x2_name ].max ().values ,
904- )
905- else :
906- data_x2 = (
907- X_t .coords [orig_x2_name ].max ().values ,
908- X_t .coords [orig_x2_name ].min ().values ,
909- )
910934
911- data_x1_index , data_x2_index = get_index (data_x1 , data_x2 )
912- patches_clipped = {var_name : [] for var_name in patch_preds [0 ].keys ()}
935+ # Get row/col index values of X_t.
936+ data_x1_coords , data_x2_coords = get_coordinate_extent (X_t , x1_ascend , x2_ascend )
937+ data_x1_index , data_x2_index = get_index (data_x1_coords , data_x2_coords )
913938
939+ # Iterate through patchwise predictions and slice edges prior to stitchin.
940+ patches_clipped = {var_name : [] for var_name in patch_preds [0 ].keys ()}
914941 for i , patch_pred in enumerate (patch_preds ):
915942 for var_name , data_array in patch_pred .items ():
916943 if var_name in patch_pred :
917- # Get row/col index values of each patch. Order depends on whether coordinate is ascending or descending.
918- if x1_ascend :
919- patch_x1 = (
920- data_array .coords [orig_x1_name ].min ().values ,
921- data_array .coords [orig_x1_name ].max ().values ,
922- )
923- else :
924- patch_x1 = (
925- data_array .coords [orig_x1_name ].max ().values ,
926- data_array .coords [orig_x1_name ].min ().values ,
927- )
928- if x2_ascend :
929- patch_x2 = (
930- data_array .coords [orig_x2_name ].min ().values ,
931- data_array .coords [orig_x2_name ].max ().values ,
932- )
933- else :
934- patch_x2 = (
935- data_array .coords [orig_x2_name ].max ().values ,
936- data_array .coords [orig_x2_name ].min ().values ,
937- )
938- patch_x1_index , patch_x2_index = get_index (patch_x1 , patch_x2 )
944+
945+ # Get row/col index values of each patch.
946+ patch_x1_coords , patch_x2_coords = get_coordinate_extent (data_array , x1_ascend , x2_ascend )
947+ patch_x1_index , patch_x2_index = get_index (patch_x1_coords , patch_x2_coords )
939948
949+ # Calculate size of border to slice of each edge of patchwise predictions.
950+ # Initially set the size of all borders to the size of the overlap.
940951 b_x1_min , b_x1_max = patch_overlap [0 ], patch_overlap [0 ]
941952 b_x2_min , b_x2_max = patch_overlap [1 ], patch_overlap [1 ]
942953
943954 # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
944955 if patch_x2_index [0 ] == data_x2_index [0 ]:
945956 b_x2_min = 0
946- # TODO: Try to resolve this issue in data/loader.py by ensuring patches are perfectly square.
947957 b_x2_max = b_x2_max
948958
949- # At end of row (when patch_x2_index = data_x2_index), to calculate the number of pixels to remove from left hand side of patch:
959+ # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
950960 elif patch_x2_index [1 ] == data_x2_index [1 ]:
951961 b_x2_max = 0
952962 patch_row_prev = preds [i - 1 ]
953963
954964 # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
955965 # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
956- # to get the number of pixels to remove from left hand side of patch.
957966 if x2_ascend :
958967 prev_patch_x2_max = get_index (
959968 patch_row_prev [var_name ].coords [orig_x2_name ].max (),
@@ -963,9 +972,8 @@ def stitch_clipped_predictions(
963972 prev_patch_x2_max - patch_x2_index [0 ]
964973 ) - patch_overlap [1 ]
965974
966- # If x2 is descending. Subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
975+ # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
967976 # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
968- # to get the number of pixels to remove from left hand side of patch.
969977 else :
970978 prev_patch_x2_min = get_index (
971979 patch_row_prev [var_name ].coords [orig_x2_name ].min (),
@@ -977,9 +985,10 @@ def stitch_clipped_predictions(
977985 else :
978986 b_x2_max = b_x2_max
979987
988+ # Repeat process as above for x1 coordinates.
980989 if patch_x1_index [0 ] == data_x1_index [0 ]:
981990 b_x1_min = 0
982- # TODO: ensure this elif statement is robust to multiple patch sizes.
991+
983992 elif abs (patch_x1_index [1 ] - data_x1_index [1 ]) < 2 :
984993 b_x1_max = 0
985994 b_x1_max = b_x1_max
@@ -1013,6 +1022,7 @@ def stitch_clipped_predictions(
10131022 data_array .sizes [orig_x2_name ] - b_x2_max
10141023 )
10151024
1025+ # Slice patchwise predictions
10161026 patch_clip = data_array .isel (
10171027 ** {
10181028 orig_x1_name : slice (
@@ -1025,51 +1035,47 @@ def stitch_clipped_predictions(
10251035 )
10261036
10271037 patches_clipped [var_name ].append (patch_clip )
1028-
1029- # Create blank prediction
1030- combined_dataset = copy .deepcopy (patches_clipped )
1031-
1032- # Generate new blank DeepSensor.prediction object with same extent and coordinate system as X_t.
1033- for var , data_array_list in combined_dataset .items ():
1034- first_patchwise_pred = data_array_list [0 ]
1035-
1036- # Define coordinate extent and time
1037- blank_pred = xr .Dataset (
1038+
1039+ # Create blank prediction object to stitch prediction values onto.
1040+ stitched_prediction = copy .deepcopy (patch_preds [0 ])
1041+ # Set prediction object extent to the same as X_t.
1042+ for var_name , data_array in stitched_prediction .items ():
1043+ blank_ds = xr .Dataset (
10381044 coords = {
10391045 orig_x1_name : X_t [orig_x1_name ],
10401046 orig_x2_name : X_t [orig_x2_name ],
1041- "time" : first_patchwise_pred ["time" ],
1047+ "time" : stitched_prediction [ 0 ] ["time" ],
10421048 }
10431049 )
10441050
1045- # Set variable names to those in patched predictions, set values to Nan.
1046- for param in first_patchwise_pred .data_vars :
1047- blank_pred [ param ] = first_patchwise_pred [ param ]
1048- blank_pred [ param ][:] = np .nan
1049- combined_dataset [ var ] = blank_pred
1050-
1051- # Merge patchwise predictions to create final combined dataset .
1051+ # Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan.
1052+ for data_var in data_array .data_vars :
1053+ blank_ds [ data_var ] = data_array [ data_var ]
1054+ blank_ds [ data_var ][:] = np .nan
1055+ stitched_prediction [ var_name ] = blank_ds
1056+
1057+ # Merge patchwise predictions to create final stiched prediction .
10521058 # Iterate over each variable (key) in the prediction dictionary
10531059 for var_name , patches in patches_clipped .items ():
10541060 # Retrieve the blank dataset for the current variable
1055- combined_array = combined_dataset [var_name ]
1061+ prediction_array = stitched_prediction [var_name ]
10561062
10571063 # Merge each patch into the combined dataset
10581064 for patch in patches :
10591065 for var in patch .data_vars :
10601066 # Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset
10611067 reindexed_patch = patch [var ].reindex_like (
1062- combined_array [var ], method = "nearest" , tolerance = 1e-6
1068+ prediction_array [var ], method = "nearest" , tolerance = 1e-6
10631069 )
10641070
10651071 # Combine data, prioritizing non-NaN values from patches
1066- combined_array [var ] = combined_array [var ].where (
1072+ prediction_array [var ] = prediction_array [var ].where (
10671073 np .isnan (reindexed_patch ), reindexed_patch
10681074 )
10691075
10701076 # Update the dictionary with the merged dataset
1071- combined_dataset [var_name ] = combined_array
1072- return combined_dataset
1077+ stitched_prediction [var_name ] = prediction_array
1078+ return stitched_prediction
10731079
10741080 # load patch_size and stride from task
10751081 patch_size = tasks [0 ]["patch_size" ]
@@ -1164,34 +1170,10 @@ def stitch_clipped_predictions(
11641170 )
11651171
11661172 patches_per_row = get_patches_per_row (preds )
1167- stitched_prediction = stitch_clipped_predictions (
1173+ prediction = stitch_clipped_predictions (
11681174 preds , patch_overlap_unnorm , patches_per_row , x1_ascending , x2_ascending
11691175 )
11701176
1171- ## Cast prediction into DeepSensor.Prediction object.
1172- # TODO make this into seperate method.
1173- prediction = copy .deepcopy (preds [0 ])
1174-
1175- # Generate new blank DeepSensor.prediction object in original coordinate system.
1176- for var_name_copy , data_array_copy in prediction .items ():
1177- # set x and y coords
1178- stitched_preds = xr .Dataset (
1179- coords = {
1180- orig_x1_name : X_t [orig_x1_name ],
1181- orig_x2_name : X_t [orig_x2_name ],
1182- }
1183- )
1184-
1185- # Set time to same as patched prediction
1186- stitched_preds ["time" ] = data_array_copy ["time" ]
1187-
1188- # set variable names to those in patched prediction, make values blank
1189- for var_name_i in data_array_copy .data_vars :
1190- stitched_preds [var_name_i ] = data_array_copy [var_name_i ]
1191- stitched_preds [var_name_i ][:] = np .nan
1192- prediction [var_name_copy ] = stitched_preds
1193- prediction [var_name_copy ] = stitched_prediction [var_name_copy ]
1194-
11951177 return prediction
11961178
11971179
0 commit comments