diff --git a/.gitignore b/.gitignore index 46ab9b74..819a1c34 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,4 @@ dist/* .tox/ _build *.png -deepsensor.egg-info/ +deepsensor.egg-info/ \ No newline at end of file diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index a0c566b4..a65b85cf 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -1,15 +1,16 @@ -from deepsensor.data.task import Task, flatten_X - -import os -import json import copy +import itertools +import json +import operator +import os +import random +from typing import List, Optional, Sequence, Tuple, Union import numpy as np -import xarray as xr import pandas as pd +import xarray as xr -from typing import List, Tuple, Union, Optional - +from deepsensor.data.task import Task, flatten_X from deepsensor.errors import InvalidSamplingStrategyError @@ -189,6 +190,8 @@ def __init__( self.aux_at_target_var_IDs, ) = self.infer_context_and_target_var_IDs() + self.coord_bounds = self._compute_global_coordinate_bounds() + def _set_config(self): """Instantiate a config dictionary for the TaskLoader object.""" # Take deepcopy to avoid modifying the original config @@ -770,6 +773,7 @@ def sample_offgrid_aux( xt2 = xt2.ravel() else: xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1]) + Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest") if isinstance(Y_t_aux, xr.Dataset): Y_t_aux = Y_t_aux.to_array() @@ -781,6 +785,124 @@ def sample_offgrid_aux( Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape) return Y_t_aux + def _compute_global_coordinate_bounds(self) -> List[float]: + """Compute global coordinate bounds in order to sample spatial bounds if desired. + + Returns: + ------- + bbox: List[float] + sequence of global spatial extent as [x1_min, x1_max, x2_min, x2_max] + """ + x1_min, x1_max, x2_min, x2_max = np.inf, -np.inf, np.inf, -np.inf + + for var in itertools.chain(self.context, self.target): + if isinstance(var, (xr.Dataset, xr.DataArray)): + var_x1_min = var.x1.min().item() + var_x1_max = var.x1.max().item() + var_x2_min = var.x2.min().item() + var_x2_max = var.x2.max().item() + elif isinstance(var, (pd.DataFrame, pd.Series)): + var_x1_min = var.index.get_level_values("x1").min() + var_x1_max = var.index.get_level_values("x1").max() + var_x2_min = var.index.get_level_values("x2").min() + var_x2_max = var.index.get_level_values("x2").max() + + if var_x1_min < x1_min: + x1_min = var_x1_min + + if var_x1_max > x1_max: + x1_max = var_x1_max + + if var_x2_min < x2_min: + x2_min = var_x2_min + + if var_x2_max > x2_max: + x2_max = var_x2_max + + return [x1_min, x1_max, x2_min, x2_max] + + def _compute_x1x2_direction(self) -> dict: + """Compute whether the x1 and x2 coords are ascending or descending. + + Returns: + dict(bool) + Dictionary containing two keys: x1 and x2, with boolean values + defining if these coordings increase or decrease from top left corner. + + Raises: + ValueError: + If all datasets are non-gridded or if direction of ascending + coordinates does not match across non-gridded datasets. + + """ + non_gridded = {"x1": None, "x2": None} # value to use for non-gridded data + ascending = [] + for var in itertools.chain(self.context, self.target): + if isinstance(var, (xr.Dataset, xr.DataArray)): + coord_x1_left = var.x1[0] + coord_x1_right = var.x1[-1] + coord_x2_top = var.x2[0] + coord_x2_bottom = var.x2[-1] + + ascending.append( + { + "x1": True if coord_x1_left <= coord_x1_right else False, + "x2": True if coord_x2_top <= coord_x2_bottom else False, + } + ) + + elif isinstance(var, (pd.DataFrame, pd.Series)): + ascending.append(non_gridded) + + if len(list(filter(lambda x: x != non_gridded, ascending))) == 0: + raise ValueError( + "All data is non gridded, can not proceed with sliding window sampling." + ) + + # get the directions for only the gridded data + gridded = list(filter(lambda x: x != non_gridded, ascending)) + # raise error if directions don't match across gridded data + if gridded.count(gridded[0]) != len(gridded): + raise ValueError( + "Direction of ascending coordinates does not match across all gridded datasets." + ) + + return gridded[0] + + def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]: + """Sample random window uniformly from global coordinates to slice data. + + Parameters + ---------- + patch_size : Tuple[float] + Tuple of window extent + + Returns: + ------- + bbox: List[float] + sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max] + """ + x1_extend, x2_extend = patch_size + + x1_side = x1_extend / 2 + x2_side = x2_extend / 2 + + # sample a point that satisfies the context and target global bounds + x1_min, x1_max, x2_min, x2_max = self.coord_bounds + + x1_point = random.uniform(x1_min + x1_side, x1_max - x1_side) + x2_point = random.uniform(x2_min + x2_side, x2_max - x2_side) + + # bbox of x1_min, x1_max, x2_min, x2_max + bbox = [ + x1_point - x1_side, + x1_point + x1_side, + x2_point - x2_side, + x2_point + x2_side, + ] + + return bbox + def time_slice_variable(self, var, date, delta_t=0): """Slice a variable by a given time delta. @@ -810,6 +932,48 @@ def time_slice_variable(self, var, date, delta_t=0): raise ValueError(f"Unknown variable type {type(var)}") return var + def spatial_slice_variable(self, var, window: List[float]): + """Slice a variable by a given window size. + + Args: + var (...): + Variable to slice. + window (List[float]): + List of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max]. + + Returns: + var (...) + Sliced variable. + + Raises: + ValueError + If the variable is of an unknown type. + """ + x1_min, x1_max, x2_min, x2_max = window + if isinstance(var, (xr.Dataset, xr.DataArray)): + # we cannot assume that the coordinates are sorted from small to large + if var.x1[0] > var.x1[-1]: + x1_slice = slice(x1_max, x1_min) + else: + x1_slice = slice(x1_min, x1_max) + if var.x2[0] > var.x2[-1]: + x2_slice = slice(x2_max, x2_min) + else: + x2_slice = slice(x2_min, x2_max) + var = var.sel(x1=x1_slice, x2=x2_slice) + elif isinstance(var, (pd.DataFrame, pd.Series)): + # retrieve desired patch size + var = var[ + (var.index.get_level_values("x1") >= x1_min) + & (var.index.get_level_values("x1") <= x1_max) + & (var.index.get_level_values("x2") >= x2_min) + & (var.index.get_level_values("x2") <= x2_max) + ] + else: + raise ValueError(f"Unknown variable type {type(var)}") + + return var + def task_generation( # noqa: D102 self, date: pd.Timestamp, @@ -830,9 +994,61 @@ def task_generation( # noqa: D102 ] ] = None, split_frac: float = 0.5, + bbox: Sequence[float] = None, + patch_size: Union[float, Tuple[float]] = None, + stride: Union[float, Tuple[float]] = None, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Task: + """Generate a task for a given date. + + There are several sampling strategies available for the context and + target data: + + - "all": Sample all observations. + - int: Sample N observations uniformly at random. + - float: Sample a fraction of observations uniformly at random. + - :class:`numpy:numpy.ndarray`, shape (2, N): Sample N observations + at the given x1, x2 coordinates. Coords are assumed to be + unnormalised. + + Parameters + ---------- + date : :class:`pandas.Timestamp` + Date for which to generate the task. + context_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`] + Sampling strategy for the context data, either a list of sampling + strategies for each context set, or a single strategy applied to + all context sets. Default is ``"all"``. + target_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`] + Sampling strategy for the target data, either a list of sampling + strategies for each target set, or a single strategy applied to all + target sets. Default is ``"all"``. + split_frac : float + The fraction of observations to use for the context set with the + "split" sampling strategy for linked context and target set pairs. + The remaining observations are used for the target set. Default is + 0.5. + bbox : Sequence[float], optional + Bounding box to spatially slice the data, should be of the form [x1_min, x1_max, x2_min, x2_max]. + Useful when considering the entire available region is computationally prohibitive for model forward pass. + patch_size : Union(Tuple|float), optional + Only used by patchwise inference. Height and width of patch in x1/x2 normalised coordinates. + stride: Union(Tuple|float), optional + Only used by patchwise inference. Length of stride between adjacent patches in x1/x2 normalised coordinates. + datewise_deterministic : bool + Whether random sampling is datewise_deterministic based on the + date. Default is ``False``. + seed_override : Optional[int] + Override the seed for random sampling. This can be used to use the + same random sampling at different ``date``. Default is None. + + Returns: + ------- + task : :class:`~.data.task.Task` + Task object containing the context and target data. + """ + def check_sampling_strat(sampling_strat, set): """Check the sampling strategy. @@ -877,6 +1093,13 @@ def check_sampling_strat(sampling_strat, set): raise InvalidSamplingStrategyError( f"Unknown sampling strategy {strat} of type {type(strat)}" ) + elif isinstance(strat, str) and strat == "gapfill": + assert all( + isinstance(item, (xr.Dataset, xr.DataArray)) for item in set + ), ( + "Gapfill sampling strategy can only be used with xarray " + "datasets or data arrays" + ) elif isinstance(strat, str) and strat not in [ "all", "split", @@ -1001,6 +1224,11 @@ def sample_variable(var, sampling_strat, seed): task["time"] = date task["ops"] = [] + task["bbox"] = bbox + task["patch_size"] = ( + patch_size # store patch_size and stride in task for use in stitching in prediction + ) + task["stride"] = stride task["X_c"] = [] task["Y_c"] = [] if target_sampling is not None: @@ -1010,6 +1238,7 @@ def sample_variable(var, sampling_strat, seed): task["X_t"] = None task["Y_t"] = None + # temporal slices context_slices = [ self.time_slice_variable(var, date, delta_t) for var, delta_t in zip(self.context, self.context_delta_t) @@ -1139,6 +1368,20 @@ def sample_variable(var, sampling_strat, seed): context_slices[context_idx] = context_var target_slices[target_idx] = target_var + # check bbox size + if bbox is not None: + assert ( + len(bbox) == 4 + ), "bbox must be a list of length 4 with [x1_min, x1_max, x2_min, x2_max]" + + # spatial slices + context_slices = [ + self.spatial_slice_variable(var, bbox) for var in context_slices + ] + target_slices = [ + self.spatial_slice_variable(var, bbox) for var in target_slices + ] + for i, (var, sampling_strat) in enumerate( zip(context_slices, context_sampling) ): @@ -1188,9 +1431,100 @@ def sample_variable(var, sampling_strat, seed): return Task(task) + def sample_sliding_window( + self, patch_size: Tuple[float], stride: Tuple[int] + ) -> Sequence[float]: + """Sample data using sliding window from global coordinates to slice data. + Parameters. + ---------- + patch_size : Tuple[float] + Tuple of window extent + + stride : Tuple[float] + Tuple of step size between each patch along x1 and x2 axis. + + Returns: + ------- + List[float] + Sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max]. + """ + self.coord_directions = self._compute_x1x2_direction() + # define patch size in x1/x2 + size = {} + size["x1"], size["x2"] = patch_size + + # define stride length in x1/x2 or set to patch_size if undefined + if stride is None: + stride = patch_size + + step = {} + step["x1"], step["x2"] = stride + + # Calculate the global bounds of context and target set. + coord_min = {} + coord_max = {} + coord_min["x1"], coord_max["x1"], coord_min["x2"], coord_max["x2"] = ( + self.coord_bounds + ) + + ## start with first patch top left hand corner at coord_min["x1"], coord_min["x2"] + patch_list = [] + + # define some lambda functions for use below + # round to 12 figures to avoid floating point error but reduce likelihood of unintentional rounding + r = lambda x: round(x, 12) + bbox_coords_ascend = lambda a, b: [r(a), r(a + b)] + bbox_coords_descend = lambda a, b: bbox_coords_ascend(a, b)[::-1] + + compare = {} + bbox_coords = {} + # for each coordinate direction specify the correct operations for patching + for c in ("x1", "x2"): + if self.coord_directions[c]: + compare[c] = operator.gt + bbox_coords[c] = bbox_coords_ascend + else: + step[c] = -step[c] + coord_min[c], coord_max[c] = coord_max[c], coord_min[c] + size[c] = -size[c] + compare[c] = operator.lt + bbox_coords[c] = bbox_coords_descend + + # Define the bounding boxes for all patches, starting in top left corner of dataArray + for y, x in itertools.product( + np.arange(coord_min["x1"], coord_max["x1"], step["x1"]), + np.arange(coord_min["x2"], coord_max["x2"], step["x2"]), + ): + y0 = ( + coord_max["x1"] - size["x1"] + if compare["x1"](y + size["x1"], coord_max["x1"]) + else y + ) + x0 = ( + coord_max["x2"] - size["x2"] + if compare["x2"](x + size["x2"], coord_max["x2"]) + else x + ) + + # bbox of x1_min, x1_max, x2_min, x2_max per patch + bbox = bbox_coords["x1"](y0, size["x1"]) + bbox_coords["x2"](x0, size["x2"]) + patch_list.append(bbox) + + # Remove duplicate patches while preserving order + seen = set() + unique_patch_list = [] + for lst in patch_list: + # Convert list to tuple for immutability + tuple_lst = tuple(lst) + if tuple_lst not in seen: + seen.add(tuple_lst) + unique_patch_list.append(lst) + + return unique_patch_list + def __call__( self, - date: pd.Timestamp, + date: Union[pd.Timestamp, Sequence[pd.Timestamp]], context_sampling: Union[ str, int, @@ -1208,6 +1542,10 @@ def __call__( ] ] = None, split_frac: float = 0.5, + patch_size: Union[float, Tuple[float]] = None, + patch_strategy: Optional[str] = None, + stride: Union[float, Tuple[float]] = None, + num_samples_per_date: int = 1, datewise_deterministic: bool = False, seed_override: Optional[int] = None, ) -> Union[Task, List[Task]]: @@ -1253,6 +1591,16 @@ def __call__( the "split" sampling strategy for linked context and target set pairs. The remaining observations are used for the target set. Default is 0.5. + patch_size : Union[float, tuple[float]], optional + Desired patch size in x1/x2 used for patchwise task generation. Useful when considering + the entire available region is computationally prohibitive for model forward pass. + If passed a single float, will use value for both x1 & x2. + patch_strategy: + Patch strategy to use for patchwise task generation. Default is None. + Possible options are 'random' or 'sliding'. + stride: Union[float, tuple[float]], optional + Step size between each sliding window patch along x1 and x2 axis. Default is None. + If passed a single float, will use value for both x1 & x2. datewise_deterministic (bool, optional): Whether random sampling is datewise deterministic based on the date. Default is ``False``. @@ -1266,24 +1614,155 @@ def __call__( Task object or list of task objects for each date containing the context and target data. """ - if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): - return [ - self.task_generation( - d, - context_sampling, - target_sampling, - split_frac, - datewise_deterministic, - seed_override, + if patch_strategy not in [None, "random", "sliding"]: + raise ValueError( + f"Invalid patch strategy {patch_strategy}. " + f"Must be one of [None, 'random', 'sliding']." + ) + + if isinstance(patch_size, float) and patch_size is not None: + patch_size = (patch_size, patch_size) + + if isinstance(stride, float) and stride is not None: + stride = (stride, stride) + + if patch_strategy is None: + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + tasks = [ + self.task_generation( + d, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for d in date + ] + else: + tasks = self.task_generation( + date=date, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, ) - for d in date - ] + + elif patch_strategy == "random": + if patch_size is None: + raise ValueError( + "Patch size must be specified for random patch sampling" + ) + + coord_bounds = [self.coord_bounds[0:2], self.coord_bounds[2:]] + for i, val in enumerate(patch_size): + if val < coord_bounds[i][0] or val > coord_bounds[i][1]: + raise ValueError( + f"Values of stride must be between the normalised coordinate bounds of: {self.coord_bounds}. \ + Got: patch_size: {patch_size}." + ) + + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + for d in date: + bboxes = [ + self.sample_random_window(patch_size) + for _ in range(num_samples_per_date) + ] + tasks = [ + self.task_generation( + d, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for bbox in bboxes + ] + + else: + bboxes = [ + self.sample_random_window(patch_size) + for _ in range(num_samples_per_date) + ] + tasks = [ + self.task_generation( + date, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + ) + for bbox in bboxes + ] + + elif patch_strategy == "sliding": + # sliding window sampling of patch + + for val in (patch_size, stride): + if val is None: + raise ValueError( + f"patch_size and stride must be specified for sliding window sampling, got patch_size: {patch_size} and stride: {stride}." + ) + + if stride[0] > patch_size[0] or stride[1] > patch_size[1]: + raise Warning( + f"stride should generally be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}" + ) + + coord_bounds = [self.coord_bounds[0:2], self.coord_bounds[2:]] + for i in (0, 1): + for val in (patch_size[i], stride[i]): + if val < coord_bounds[i][0] or val > coord_bounds[i][1]: + raise ValueError( + f"Values of stride and patch_size must be between the normalised coordinate bounds of: {self.coord_bounds}. \ + Got: patch_size: {patch_size}, stride: {stride}" + ) + + if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + tasks = [] + for d in date: + bboxes = self.sample_sliding_window(patch_size, stride) + tasks.extend( + [ + self.task_generation( + d, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + patch_size=patch_size, + stride=stride, + ) + for bbox in bboxes + ] + ) + else: + bboxes = self.sample_sliding_window(patch_size, stride) + tasks = [ + self.task_generation( + date, + bbox=bbox, + context_sampling=context_sampling, + target_sampling=target_sampling, + split_frac=split_frac, + datewise_deterministic=datewise_deterministic, + seed_override=seed_override, + patch_size=patch_size, + stride=stride, + ) + for bbox in bboxes + ] else: - return self.task_generation( - date, - context_sampling, - target_sampling, - split_frac, - datewise_deterministic, - seed_override, + raise ValueError( + f"Invalid patch strategy {patch_strategy}. " + f"Must be one of [None, 'random', 'sliding']." ) + + return tasks diff --git a/deepsensor/data/task.py b/deepsensor/data/task.py index 353df4db..33710c89 100644 --- a/deepsensor/data/task.py +++ b/deepsensor/data/task.py @@ -31,7 +31,9 @@ def __init__(self, task_dict: dict) -> None: @classmethod def summarise_str(cls, k, v): """Return string summaries for the _str__ method.""" - if plum.isinstance(v, B.Numeric): + if isinstance(v, float): + return v + elif plum.isinstance(v, B.Numeric): return v.shape elif plum.isinstance(v, tuple): return tuple(vi.shape for vi in v) @@ -57,6 +59,8 @@ def summarise_repr(cls, k, v) -> str: """ if v is None: return "None" + elif isinstance(v, float): + return f"{type(v).__name__}" elif plum.isinstance(v, B.Numeric): return f"{type(v).__name__}/{v.dtype}/{v.shape}" if plum.isinstance(v, deepsensor.backend.nps.mask.Masked): diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 74a0dd58..07426968 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -648,6 +648,542 @@ def unnormalise_pred_array(arr, **kwargs): return pred + def predict_patchwise( + self, + tasks: Union[List[Task], Task], + X_t: Union[ + xr.Dataset, + xr.DataArray, + pd.DataFrame, + pd.Series, + pd.Index, + np.ndarray, + ], + X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, + **kwargs, + ) -> Prediction: + """Predict using tasks loaded using a sliding window patching strategy. Uses the `predict` method. + + .. versionadded:: 0.4.3 + :py:func:`predict_patchwise()` method. + + Args: + tasks (List[Task] | Task): + List of tasks containing context data. Tasks for patchwise prediction must be generated by a task loader using the "sliding" patching strategy. + X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): + Target locations to predict at. Can be an xarray object + containingon-grid locations or a pandas object containing off-grid locations. + X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional + 2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated + to the same grid as ``X_t`` and patched in the same way. Default None (no mask). + **kwargs: + Keyword arguments as per ``predict``. + + Returns: + :class:`~.model.pred.Prediction`): + A `dict`-like object mapping from target variable IDs to xarray or pandas objects + containing model predictions. + - If ``X_t`` is a pandas object, returns pandas objects + containing off-grid predictions. + - If ``X_t`` is an xarray object, returns xarray object + containing on-grid predictions. + - If ``n_samples`` == 0, returns only mean and std predictions. + - If ``n_samples`` > 0, returns mean, std and samples + predictions. + + Raises: + AttributeError + If ``tasks`` are not generated using the "sliding" patching strategy of TaskLoader, + i.e. if they do not have a ``bbox`` attribute. + Errors + See `~.model.model.DeepSensorModel.predict` + """ + # Get coordinate names of original unnormalised dataset. + orig_x1_name = self.data_processor.x1_name + orig_x2_name = self.data_processor.x2_name + + def get_patches_per_row(preds) -> int: + """Calculate number of patches per row. + Required to stitch patches back together. + + Args: + preds (List[class:`~.model.pred.Prediction`]): + A list of `dict`-like objects containing patchwise predictions. + + Returns: + patches_per_row: int + Number of patches per row. + """ + patches_per_row = 0 + vars = list(preds[0][0].data_vars) + var = vars[0] + x1_val = preds[0][0][var].coords[orig_x1_name].min() + + for pred in preds: + if pred[0][var].coords[orig_x1_name].min() == x1_val: + patches_per_row = patches_per_row + 1 + + return patches_per_row + + def get_patch_overlap( + overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend + ) -> int: + """Calculate overlap between adjacent patches in pixels. + + Parameters + ---------- + overlap_norm : tuple[float]. + Normalised size of overlap in x1/x2. + + data_processor (:class:`~.data.processor.DataProcessor`): + Used for unnormalising the coordinates of the bounding boxes of patches. + + X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): + Data array containing target locations to predict at. + + x1_ascend : str: + Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. + + x2_ascend : str: + Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. + + Returns: + ------- + patch_overlap : tuple (int) + Unnormalised size of overlap between adjacent patches. + """ + # Todo- check if there is simplier and more robust way to convert overlap into pixels. + # Place x1/x2 overlap values in Xarray to pass into unnormalise() + overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]] + x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims="x1", name="x1") + x2 = xr.DataArray([overlap_list[2], overlap_list[3]], dims="x2", name="x2") + overlap_norm_xr = xr.Dataset(coords={"x1": x1, "x2": x2}) + + # Unnormalise coordinates of bounding boxes + overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr) + + unnorm_overlap_x1 = overlap_unnorm_xr.coords[orig_x1_name].values[1] + unnorm_overlap_x2 = overlap_unnorm_xr.coords[orig_x2_name].values[1] + + def overlap_index( + coords: np.ndarray, ascend: bool, unnorm_overlap: float + ) -> int: + """Find size of overlap in a single coordinate direction, in units of pixels. + + Parameters + ---------- + coords : np.ndarray + + ascend : bool + Boolean defining whether coords ascend (increase) from top to bottom or left to right. + + unnorm_overlap : float + The patch overlap in unnormalised coordinates. + + Returns: + ------- + int : The number of pixels in the overlap. + """ + pixel_coords_overlap_diffs = np.abs(coords - unnorm_overlap) + if ascend: + trim_size = np.argmin(pixel_coords_overlap_diffs) / 2 + trim_size_rounded = int( + np.floor(trim_size) + ) # Always round down trim slide as stitching method can handle slight overlaps + return trim_size_rounded + + else: + overlap_pixel_size = np.argmin(pixel_coords_overlap_diffs) + overlap_pixel_size_rounded = np.ceil(overlap_pixel_size) + trim_size = ( + (coords.size - int(overlap_pixel_size_rounded)) / 2 + ) # this extra step is so we get the overlap with respect to the largest value (i.e. is the number of pixels = 360, coords.size = 360) + trim_size_rounded = int(np.floor(trim_size)) + return trim_size_rounded + + return ( + overlap_index( + X_t_ds.coords[orig_x1_name].values, x1_ascend, unnorm_overlap_x1 + ), + overlap_index( + X_t_ds.coords[orig_x2_name].values, x2_ascend, unnorm_overlap_x2 + ), + ) + + def get_coordinate_extent( + ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool, x2_ascend: bool + ) -> tuple: + """Get coordinate extent of dataset. This method is applied to either X_t or patchwise predictions. + + Parameters + ---------- + ds : Data object + The dataset or data array to determine coordinate extent for. + + x1_ascend : bool + Whether the x1 coordinates ascend (increase) from top to bottom. + + x2_ascend : bool + Whether the x2 coordinates ascend (increase) from left to right. + + Returns: + ------- + tuple of tuples: + Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)). + """ + if x1_ascend: + ds_x1_coords = ( + ds.coords[orig_x1_name].min().values, + ds.coords[orig_x1_name].max().values, + ) + else: + ds_x1_coords = ( + ds.coords[orig_x1_name].max().values, + ds.coords[orig_x1_name].min().values, + ) + if x2_ascend: + ds_x2_coords = ( + ds.coords[orig_x2_name].min().values, + ds.coords[orig_x2_name].max().values, + ) + else: + ds_x2_coords = ( + ds.coords[orig_x2_name].max().values, + ds.coords[orig_x2_name].min().values, + ) + return ds_x1_coords, ds_x2_coords + + def get_index(*args, x1=True) -> Union[int, Tuple[List[int], List[int]]]: + """Convert coordinates into pixel row/column (index). + + Parameters + ---------- + args : tuple + If one argument (numeric), it represents the coordinate value. + If two arguments (lists), they represent lists of coordinate values. + + x1 : bool, optional + If True, compute index for x1 (default is True). + + Returns: + ------- + Union[int, Tuple[List[int], List[int]]] + If one argument is provided and x1 is True or False, returns the index position. + If two arguments are provided, returns a tuple containing two lists: + - First list: indices corresponding to x1 coordinates. + - Second list: indices corresponding to x2 coordinates. + + """ + if len(args) == 1: + patch_coord = args + if x1: + coord_index = np.argmin( + np.abs(X_t.coords[orig_x1_name].values - patch_coord) + ) + else: + coord_index = np.argmin( + np.abs(X_t.coords[orig_x2_name].values - patch_coord) + ) + return coord_index + + elif len(args) == 2: + patch_x1, patch_x2 = args + x1_index = [ + np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1)) + for target_x1 in patch_x1 + ] + x2_index = [ + np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2)) + for target_x2 in patch_x2 + ] + return (x1_index, x2_index) + + def stitch_clipped_predictions( + patch_preds: List[Prediction], + patch_overlap: int, + patches_per_row: int, + x1_ascend: bool = True, + x2_ascend: bool = True, + ) -> Prediction: + """Stitch patchwise predictions to form prediction at original extent. + + Parameters + ---------- + patch_preds : list (class:`~.model.pred.Prediction`) + List of patchwise predictions + + patch_overlap: int + Overlap between adjacent patches in pixels. + + patches_per_row: int + Number of patchwise predictions in each row. + + x1_ascend : bool + Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True. + + x2_ascend : bool + Boolean defining whether the x2 coords ascend (increase) from left to right, default = True. + + Returns: + ------- + combined: dict + Dictionary object containing the stitched model predictions. + """ + # Get row/col index values of X_t. + data_x1_coords, data_x2_coords = get_coordinate_extent( + X_t, x1_ascend, x2_ascend + ) + data_x1_index, data_x2_index = get_index(data_x1_coords, data_x2_coords) + + # Iterate through patchwise predictions and slice edges prior to stitchin. + patches_clipped = [] + for i, patch_pred in enumerate(patch_preds): + # get one variable name to use for coordinates and extent + first_key = list(patch_pred.keys())[0] + # Get row/col index values of each patch. + patch_x1_coords, patch_x2_coords = get_coordinate_extent( + patch_pred[first_key], x1_ascend, x2_ascend + ) + patch_x1_index, patch_x2_index = get_index( + patch_x1_coords, patch_x2_coords + ) + + # Calculate size of border to slice of each edge of patchwise predictions. + # Initially set the size of all borders to the size of the overlap. + b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0] + b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1] + + # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column. + if patch_x2_index[0] == data_x2_index[0]: + b_x2_min = 0 + b_x2_max = b_x2_max + + # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch. + elif patch_x2_index[1] == data_x2_index[1]: + b_x2_max = 0 + patch_row_prev = patch_preds[i - 1] + + # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels. + # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels + if x2_ascend: + prev_patch_x2_max = get_index( + patch_row_prev[first_key].coords[orig_x2_name].max(), + x1=False, + ) + b_x2_min = ( + prev_patch_x2_max - patch_x2_index[0] + ) - patch_overlap[1] + + # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels. + # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels + else: + prev_patch_x2_min = get_index( + patch_row_prev[first_key].coords[orig_x2_name].min(), + x1=False, + ) + b_x2_min = ( + patch_x2_index[0] - prev_patch_x2_min + ) - patch_overlap[1] + else: + b_x2_max = b_x2_max + + # Repeat process as above for x1 coordinates. + if patch_x1_index[0] == data_x1_index[0]: + b_x1_min = 0 + + elif abs(patch_x1_index[1] - data_x1_index[1]) < 2: + b_x1_max = 0 + b_x1_max = b_x1_max + patch_prev = patch_preds[i - patches_per_row] + if x1_ascend: + prev_patch_x1_max = get_index( + patch_prev[first_key].coords[orig_x1_name].max(), + x1=True, + ) + b_x1_min = ( + prev_patch_x1_max - patch_x1_index[0] + ) - patch_overlap[0] + else: + prev_patch_x1_min = get_index( + patch_prev[first_key].coords[orig_x1_name].min(), + x1=True, + ) + + b_x1_min = ( + prev_patch_x1_min - patch_x1_index[0] + ) - patch_overlap[0] + else: + b_x1_max = b_x1_max + + patch_clip_x1_min = int(b_x1_min) + patch_clip_x1_max = int( + patch_pred[first_key].sizes[orig_x1_name] - b_x1_max + ) + patch_clip_x2_min = int(b_x2_min) + patch_clip_x2_max = int( + patch_pred[first_key].sizes[orig_x2_name] - b_x2_max + ) + + # Define slicing parameters + slicing_params = { + orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max), + orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max), + } + + # Slice patchwise predictions + patch_clip = { + key: dataset.isel(**slicing_params) + for key, dataset in patch_pred.items() + } + + patches_clipped.append(patch_clip) + + # Create blank prediction object to stitch prediction values onto. + stitched_prediction = copy.deepcopy(patch_preds[0]) + # Set prediction object extent to the same as X_t. + for var_name, data_array in stitched_prediction.items(): + blank_ds = xr.Dataset( + coords={ + orig_x1_name: X_t[orig_x1_name], + orig_x2_name: X_t[orig_x2_name], + "time": stitched_prediction[0]["time"], + } + ) + + # Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan. + for data_var in data_array.data_vars: + blank_ds[data_var] = data_array[data_var] + blank_ds[data_var][:] = np.nan + stitched_prediction[var_name] = blank_ds + + # Restructure prediction objects for merging + restructured_patches = { + key: [item[key] for item in patches_clipped] + for key in patches_clipped[0].keys() + } + + # Merge patchwise predictions to create final stiched prediction. + # Iterate over each variable (key) in the prediction dictionary + for var_name, patches in restructured_patches.items(): + # Retrieve the blank dataset for the current variable + prediction_array = stitched_prediction[var_name] + + # Merge each patch into the combined dataset + for patch in patches: + for var in patch.data_vars: + # Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset + reindexed_patch = patch[var].reindex_like( + prediction_array[var], method="nearest", tolerance=1e-6 + ) + + # Combine data, prioritizing non-NaN values from patches + prediction_array[var] = prediction_array[var].where( + np.isnan(reindexed_patch), reindexed_patch + ) + + # Update the dictionary with the merged dataset + stitched_prediction[var_name] = prediction_array + return stitched_prediction + + # load patch_size and stride from task + patch_size = tasks[0]["patch_size"] + stride = tasks[0]["stride"] + + # sanitise patch_size and stride arguments + if isinstance(patch_size, float) and patch_size is not None: + patch_size = (patch_size, patch_size) + + if isinstance(stride, float) and stride is not None: + stride = (stride, stride) + + if stride[0] > patch_size[0] or stride[1] > patch_size[1]: + raise ValueError( + f"stride must be smaller than patch_size in the corresponding dimensions for patchwise prediction. Got: patch_size: {patch_size}, stride: {stride}" + ) + + # patchwise prediction does not yet support more than a single date + num_task_dates = len(set([t["time"] for t in tasks])) + if num_task_dates > 1: + raise NotImplementedError( + f"Patchwise prediction does not yet support more than a single date at a time, got {num_task_dates}." + ) + + # tasks should be iterable, if only one is provided, make it a list + if type(tasks) is Task: + tasks = [tasks] + + # Perform patchwise predictions + preds = [] + for task in tasks: + bbox = task["bbox"] + + if bbox is None: + raise AttributeError( + "For patchwise prediction, only tasks generated using a patch_strategy of 'sliding' are valid. \ + This task has a bbox value of None, indicating that it was generated with a patch_strategy of \ + 'random' or None." + ) + + # Unnormalise coordinates of bounding box of patch + x1 = xr.DataArray([bbox[0], bbox[1]], dims="x1", name="x1") + x2 = xr.DataArray([bbox[2], bbox[3]], dims="x2", name="x2") + bbox_norm = xr.Dataset(coords={"x1": x1, "x2": x2}) + bbox_unnorm = self.data_processor.unnormalise(bbox_norm) + unnorm_bbox_x1 = ( + bbox_unnorm[orig_x1_name].values.min(), + bbox_unnorm[orig_x1_name].values.max(), + ) + unnorm_bbox_x2 = ( + bbox_unnorm[orig_x2_name].values.min(), + bbox_unnorm[orig_x2_name].values.max(), + ) + + # Determine X_t for patch, however, cannot assume min/max ordering of slice coordinates + # Check the order of coordinates in X_t, sometimes they are increasing or decreasing in order. + x1_coords = X_t.coords[orig_x1_name].values + x2_coords = X_t.coords[orig_x2_name].values + + if x1_coords[0] < x1_coords[-1]: + x1_slice = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1]) + x1_ascending = True + else: + x1_slice = slice(unnorm_bbox_x1[1], unnorm_bbox_x1[0]) + x1_ascending = False + + if x2_coords[0] < x2_coords[-1]: + x2_slice = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1]) + x2_ascending = True + else: + x2_slice = slice(unnorm_bbox_x2[1], unnorm_bbox_x2[0]) + x2_ascending = False + + # Determine X_t for patch with correct slice direction + task_X_t = X_t.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice}) + task_X_t_mask = ( + X_t_mask.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice}) + if X_t_mask + else None + ) + + # Patchwise prediction + pred = self.predict(task, task_X_t, task_X_t_mask, **kwargs) + # Append patchwise DeepSensor prediction object to list + preds.append(pred) + + overlap_norm = tuple( + patch - stride for patch, stride in zip(patch_size, stride) + ) + patch_overlap_unnorm = get_patch_overlap( + overlap_norm, self.data_processor, X_t, x1_ascending, x2_ascending + ) + + patches_per_row = get_patches_per_row(preds) + prediction = stitch_clipped_predictions( + preds, patch_overlap_unnorm, patches_per_row, x1_ascending, x2_ascending + ) + + return prediction + def add_valid_time_coord_to_pred_and_move_time_dims(pred: Prediction) -> Prediction: """Add a valid time coordinate "time" to a Prediction object based on the diff --git a/docs/user-guide/patchwise_training_and_prediction.ipynb b/docs/user-guide/patchwise_training_and_prediction.ipynb new file mode 100644 index 00000000..d0345566 --- /dev/null +++ b/docs/user-guide/patchwise_training_and_prediction.ipynb @@ -0,0 +1,588 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Patchwise Training & Prediction\n", + "\n", + "Environmental data can sometimes span large spatial areas. For example:\n", + "\n", + "- Modelling tasks based on data that span the entire globe\n", + "- Modelling tasks with high-resolution data\n", + "\n", + "In such cases, training and inference with a ConvNP over the entire region of data may be computationally prohibitive. However, we can resort to patchwise training, where the `TaskLoader` does not provide data of the entire region but instead creates smaller patches that are computationally feasible.\n", + "\n", + "The goal of the notebook is to demonstrate patchwise training and inference." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logging.captureWarnings(True)\n", + "\n", + "import deepsensor.torch\n", + "from deepsensor.model import ConvNP\n", + "from deepsensor.train import Trainer, set_gpu_default_device\n", + "from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds\n", + "from deepsensor.data.sources import (\n", + " get_era5_reanalysis_data,\n", + " get_earthenv_auxiliary_data,\n", + " get_gldas_land_mask,\n", + ")\n", + "\n", + "import xarray as xr\n", + "import cartopy.crs as ccrs\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Training/data config\n", + "data_range = (\"2010-01-01\", \"2019-12-31\")\n", + "train_range = (\"2010-01-01\", \"2018-12-31\")\n", + "val_range = (\"2019-01-01\", \"2019-12-31\")\n", + "date_subsample_factor = 2\n", + "extent = \"north_america\"\n", + "era5_var_IDs = [\"2m_temperature\"]\n", + "lowres_auxiliary_var_IDs = [\"elevation\"]\n", + "cache_dir = \"../../.datacache\"\n", + "deepsensor_folder = \"../deepsensor_config/\"\n", + "verbose_download = True" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ERA5 data from Google Cloud Storage... Using 8 CPUs out of 48... \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 120/120 [00:05<00:00, 21.64it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.41 GB loaded in 8.46 s\n" + ] + } + ], + "source": [ + "era5_raw_ds = get_era5_reanalysis_data(\n", + " era5_var_IDs,\n", + " extent,\n", + " date_range=data_range,\n", + " cache=True,\n", + " cache_dir=cache_dir,\n", + " verbose=verbose_download,\n", + " num_processes=8,\n", + ")\n", + "lowres_aux_raw_ds = get_earthenv_auxiliary_data(\n", + " lowres_auxiliary_var_IDs,\n", + " extent,\n", + " \"100KM\",\n", + " cache=True,\n", + " cache_dir=cache_dir,\n", + " verbose=verbose_download,\n", + ")\n", + "land_mask_raw_ds = get_gldas_land_mask(\n", + " extent, cache=True, cache_dir=cache_dir, verbose=verbose_download\n", + ")\n", + "\n", + "data_processor = DataProcessor(x1_name=\"lat\", x2_name=\"lon\")\n", + "era5_ds = data_processor(era5_raw_ds)\n", + "lowres_aux_ds, land_mask_ds = data_processor(\n", + " [lowres_aux_raw_ds, land_mask_raw_ds], method=\"min_max\"\n", + ")\n", + "\n", + "dates = pd.date_range(era5_ds.time.values.min(), era5_ds.time.values.max(), freq=\"D\")\n", + "doy_ds = construct_circ_time_ds(dates, freq=\"D\")\n", + "lowres_aux_ds[\"cos_D\"] = doy_ds[\"cos_D\"]\n", + "lowres_aux_ds[\"sin_D\"] = doy_ds[\"sin_D\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "set_gpu_default_device()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialise TaskLoader and ConvNP model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TaskLoader(3 context sets, 1 target sets)\n", + "Context variable IDs: (('2m_temperature',), ('GLDAS_mask',), ('elevation', 'cos_D', 'sin_D'))\n", + "Target variable IDs: (('2m_temperature',),)\n" + ] + } + ], + "source": [ + "task_loader = TaskLoader(\n", + " context=[era5_ds, land_mask_ds, lowres_aux_ds],\n", + " target=era5_ds,\n", + ")\n", + "task_loader.load_dask()\n", + "print(task_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dim_yc inferred from TaskLoader: (1, 1, 3)\n", + "dim_yt inferred from TaskLoader: 1\n", + "dim_aux_t inferred from TaskLoader: 0\n", + "internal_density inferred from TaskLoader: 400\n", + "encoder_scales inferred from TaskLoader: [0.0012499999720603228, 0.0012499999720603228, 0.00416666641831398]\n", + "decoder_scale inferred from TaskLoader: 0.0025\n" + ] + } + ], + "source": [ + "# Set up model\n", + "model = ConvNP(data_processor, task_loader, unet_channels=(32, 32, 32, 32, 32))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define how Tasks are generated\n", + "\n", + "For the purpose of this notebook, we will use a random patchwise training strategy for our training tasks and a sliding window patch strategy for validation and testing to make sure we cover the entire region of interest.\n", + "\n", + "There are two possible arguments for patch_strategy: \n", + "- `random`: where the centroid of the patches are randomly selected;\n", + "- `sliding_window`: where the patch is first produced in the top-left corner, and the patch is convolved from left to right and top to bottom over the whole image. \n", + "\n", + "If no patching strategy is defined, the default is for no patching to take place during training or inference. \n", + "\n", + "Additional arguments to define when running patchwise training: \n", + "- `patch_size`: In x1 and x2 coordinate. This is required for both patching stategies\n", + "- `stride`: the distance in x1 and x2 between each patch. It is commonplace to use a stride size equal to half the patch size. This is only required when using `sliding_window`.\n", + "- `num_sample_per_date`: the number of patches to generate when using the random patching strategy. " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def gen_training_tasks(dates, progress=True):\n", + " tasks = []\n", + " for date in tqdm(dates, disable=not progress):\n", + " tasks_per_date = task_loader(\n", + " date,\n", + " context_sampling=[\"all\", \"all\", \"all\"],\n", + " target_sampling=\"all\",\n", + " patch_strategy=\"random\",\n", + " patch_size=(0.4, 0.4),\n", + " num_samples_per_date=2,\n", + " )\n", + " tasks.extend(tasks_per_date)\n", + " return tasks\n", + "\n", + "\n", + "def gen_validation_tasks(dates, progress=True):\n", + " tasks = []\n", + " for date in tqdm(dates, disable=not progress):\n", + " tasks_per_date = task_loader(\n", + " date,\n", + " context_sampling=[\"all\", \"all\", \"all\"],\n", + " target_sampling=\"all\",\n", + " patch_strategy=\"sliding\",\n", + " patch_size=(0.5, 0.5),\n", + " stride=(0.25,0.25)\n", + " )\n", + " tasks.extend(tasks_per_date)\n", + " return tasks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate validation tasks for testing generalisation" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "68e80805a6a94960a101bd7b39b05e4f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/183 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", + "axes[0].plot(losses)\n", + "axes[1].plot(val_rmses)\n", + "_ = axes[0].set_xlabel(\"Epoch\")\n", + "_ = axes[1].set_xlabel(\"Epoch\")\n", + "_ = axes[0].set_title(\"Training loss\")\n", + "_ = axes[1].set_title(\"Validation RMSE\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Patching during inference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In many circumstances, patching is only required during training. If required during inference, use the `model.predict_patchwise()` function rather than `model.predict()`. \n", + "\n", + "Firstly, make the test tasks, defining the patch and stride size. The `sliding_window` strategy is the only strategy that can be used during inference. \n", + "You must also pass in the `data_processor` when calling `model.predict_patchwise()`, alongside the `test_task` and `X_t`.\n", + "\n", + "The `predict_patchwise()` function stitches the patchwise predictions together, to generate a prediction with the same original extent as X_t. Currently patches are stiched together by clipping the overlapping edges of the patches and concatenating them. We welcome contributions to add additional stitching strategies into the DeepSensor package. \n", + "\n", + "The output prediction object is identical to the object generated when running `model.predict()`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### Make prediction\n", + "test_date =\"2019-01-01\"\n", + "test_task = task_loader(test_date, context_sampling=\"all\", target_sampling=\"all\",\n", + " patch_strategy=\"sliding\", patch_size=(0.5, 0.5), stride=(0.25, 0.25))\n", + "prediction = model.predict_patchwise(test_task, X_t=era5_raw_ds)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plotting is similar to the usual case, but since `task_loader` returns a list when patching we need to select a single task from the list to pass into `deepsensor.plot.prediction()` as below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = deepsensor.plot.prediction(prediction, test_date, data_processor, task_loader, test_task[0], crs=ccrs.PlateCarree())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deepsensor", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_model.py b/tests/test_model.py index 3e1d1343..d755a786 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -704,6 +704,54 @@ def test_forecasting_model_predict_return_valid_times(self): np.testing.assert_array_equal(pred_var.time.values, expected_valid_times) +def test_patchwise_prediction(): + """Test that ``.predict_patchwise`` runs correctly.""" + + patch_size = 0.5 + stride = 0.15 + + da = _gen_data_xr(dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(0, 1, 30), + x2=np.linspace(0, 1, 60), + ), + data_vars=["var"]) + + dp = DataProcessor() + ds = dp(da) # Compute normalisation parameters + + tl = TaskLoader(context=da, target=da) + + tasks = tl( + "2020-01-01", + context_sampling="all", + target_sampling="all", + patch_strategy="sliding", + patch_size=patch_size, + stride=stride, + ) + + model = ConvNP(dp, tl) + + pred = model.predict_patchwise( + tasks=tasks, + X_t=da, + ) + + # gridded predictions + assert [isinstance(ds, xr.Dataset) for ds in pred.values()] + for var_ID in pred: + assert_shape( + pred[var_ID]["mean"], + (1, da.x1.size, da.x2.size), + ) + assert_shape( + pred[var_ID]["std"], + (1, da.x1.size, da.x2.size), + ) + assert da.x1.size == pred[var_ID].x1.size + assert da.x2.size == pred[var_ID].x2.size + def assert_shape(x, shape: tuple): """Assert that the shape of ``x`` matches ``shape``.""" # TODO put this in a utils module? diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index fbb532d2..c8b9c6a6 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -1,28 +1,29 @@ +import copy import itertools +import math +import os +import shutil +import tempfile +import unittest +from typing import Sequence -from parameterized import parameterized - -import xarray as xr import dask.array import numpy as np import pandas as pd -import unittest - -import os -import shutil -import tempfile -import copy +import pytest +import xarray as xr +from _pytest.fixtures import SubRequest +from parameterized import parameterized +from deepsensor.data.loader import TaskLoader from deepsensor.errors import InvalidSamplingStrategyError from tests.utils import ( - gen_random_data_xr, - gen_random_data_pandas, assert_allclose_pd, assert_allclose_xr, + gen_random_data_pandas, + gen_random_data_xr, ) -from deepsensor.data.loader import TaskLoader - def _gen_data_xr(coords=None, dims=None, data_vars=None, use_dask=False): """Gen random normalised data""" @@ -287,6 +288,135 @@ def test_links(self) -> None: tl = TaskLoader(context=self.df, target=self.df, links=[(0, 0)]) task = tl("2020-01-01", "gapfill", "gapfill") + @parameterized.expand([[(0.3, 0.3)], [(0.6, 0.4)]]) + def test_patch_size(self, patch_size) -> None: + """Test patch size sampling.""" + # need to redefine the data generators because the patch size samplin + # where we want to test that context and or target have different + # spatial extents + da_data_0_1 = self.da + + # smaller normalized coord + da_data_smaller = _gen_data_xr( + coords=dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(0.1, 0.9, 25), + x2=np.linspace(0.1, 0.9, 10), + ) + ) + # larger normalized coord + da_data_larger = _gen_data_xr( + coords=dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(-0.1, 1.1, 50), + x2=np.linspace(-0.1, 1.1, 50), + ) + ) + + context = [da_data_0_1, da_data_smaller, da_data_larger] + tl = TaskLoader( + context=context, # gridded xarray and off-grid pandas contexts + target=self.df, # off-grid pandas targets + ) + + # TODO it would be better to do this with pytest.fixtures + # but could not get to work so far + task = tl( + "2020-01-01", "all", "all", patch_size=patch_size, patch_strategy="random" + ) + + # test date range + tasks = tl( + ["2020-01-01", "2020-01-02"], + "all", + "all", + patch_size=patch_size, + patch_strategy="random", + ) + + # test date range with num_samples per date + tasks = tl( + ["2020-01-01", "2020-01-02"], + context_sampling="all", + target_sampling="all", + patch_size=patch_size, + patch_strategy="random", + num_samples_per_date=2, + ) + + @parameterized.expand([[0.5, 0.45], [(0.3, 0.4), (0.3, 0.35)]]) + def test_sliding_window(self, patch_size, stride) -> None: + """Test sliding window sampling.""" + # need to redefine the data generators because the patch size samplin + # where we want to test that context and or target have different + # spatial extents + da_data_0_1 = self.da + + # smaller normalized coord + da_data_smaller = _gen_data_xr( + coords=dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(0.1, 0.9, 25), + x2=np.linspace(0.1, 0.9, 10), + ) + ) + # larger normalized coord + da_data_larger = _gen_data_xr( + coords=dict( + time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), + x1=np.linspace(-0.1, 1.1, 50), + x2=np.linspace(-0.1, 1.1, 50), + ) + ) + + context = [da_data_0_1, da_data_smaller, da_data_larger] + tl = TaskLoader( + context=context, # gridded xarray and off-grid pandas contexts + target=self.df, # off-grid pandas targets + ) + + # test date range + tasks = tl( + ["2020-01-01", "2020-01-02"], + "all", + "all", + patch_size=patch_size, + patch_strategy="sliding", + stride=stride, + ) + + # test patch sizes are correct + for task in tasks: + assert math.isclose(task['bbox'][1] - task['bbox'][0], task['patch_size'][0]) + assert math.isclose(task['bbox'][3] - task['bbox'][2], task['patch_size'][1]) + + # test stride sizes are correct + assert math.isclose(abs(tasks[0]['bbox'][2] - tasks[1]['bbox'][2]), tasks[0]['stride'][1]) + + @parameterized.expand( + [ + ("sliding", (0.5, 0.5), (0.6, 0.6), Warning), # patch_size and stride as tuples + ("sliding", 0.5, 0.6, Warning), # as floats + ("sliding", 1.0, 1.2, Warning), # one argument above allowed range + ("sliding", -0.1, 0.6, Warning), # and below allowed range + ("random", 1.1, None, ValueError) # for sliding window as well + ] + ) + def test_patchwise_task_loader_parameter_handling(self, patch_strategy, patch_size, stride, raised): + """Test that correct errors and warnings are raised""" + + tl = TaskLoader(context=self.da, target=self.da) + + with self.assertRaises(raised): + tl( + "2020-01-01", + context_sampling="all", + target_sampling="all", + patch_strategy=patch_strategy, + patch_size=patch_size, + stride=stride, + ) + def test_saving_and_loading(self): """Test saving and loading TaskLoader""" with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tests/test_training.py b/tests/test_training.py index 351b7d9f..b408b62d 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -6,7 +6,8 @@ from tqdm import tqdm -import deepsensor.tensorflow as deepsensor +# import deepsensor.tensorflow as deepsensor +import deepsensor.torch from deepsensor.train.train import Trainer from deepsensor.data.processor import DataProcessor @@ -115,6 +116,72 @@ def test_training(self): loss = np.mean(epoch_losses) self.assertFalse(np.isnan(loss)) + def test_patchwise_training(self): + """ + Test model training with patchwise tasks. + """ + tl = TaskLoader(context=self.da, target=self.da) + model = ConvNP(self.data_processor, tl, unet_channels=(5, 5, 5), verbose=False) + + # generate training tasks + n_train_dates = 10 + dates = [np.random.choice(self.da.time.values) for i in range(n_train_dates)] + train_tasks = tl( + dates, + context_sampling="all", + target_sampling="all", + patch_strategy="random", + patch_size=(0.4, 0.8), + ) + + # TODO pytest can also be more succinct with pytest.fixtures + # Train + trainer = Trainer(model, lr=5e-5) + batch_size = None + # TODO check with batch_size > 1 + # batch_size = 5 + n_epochs = 5 + epoch_losses = [] + for epoch in tqdm(range(n_epochs)): + batch_losses = trainer(train_tasks, batch_size=batch_size) + epoch_losses.append(np.mean(batch_losses)) + + # Check for NaNs in the loss + loss = np.mean(epoch_losses) + self.assertFalse(np.isnan(loss)) + + def test_sliding_window_training(self): + """ + Test model training with sliding window tasks. + """ + tl = TaskLoader(context=self.da, target=self.da) + model = ConvNP(self.data_processor, tl, unet_channels=(5, 5, 5), verbose=False) + + # generate training tasks + n_train_dates = 3 + dates = [np.random.choice(self.da.time.values) for i in range(n_train_dates)] + train_tasks = tl( + dates, + context_sampling="all", + target_sampling="all", + patch_strategy="sliding", + patch_size=(0.4, 0.4), + stride=(0.1, 0.1), + ) + + # Train + trainer = Trainer(model, lr=5e-5) + batch_size = None + n_epochs = 2 + epoch_losses = [] + for epoch in tqdm(range(n_epochs)): + batch_losses = trainer(train_tasks, batch_size=batch_size) + epoch_losses.append(np.mean(batch_losses)) + + # Check for NaNs in the loss + loss = np.mean(epoch_losses) + self.assertFalse(np.isnan(loss)) + def test_training_multidim(self): """A basic test of the training loop with multidimensional context sets""" # Load raw data @@ -149,6 +216,7 @@ def test_training_multidim(self): # batch_size = None batch_size = 5 n_epochs = 10 + epoch_losses = [] for epoch in tqdm(range(n_epochs)): batch_losses = trainer(train_tasks, batch_size=batch_size)