Skip to content

Commit

Permalink
account for non-gridded data correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwilby committed Nov 29, 2024
1 parent 325de6d commit 9d79b34
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions deepsensor/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,41 +821,53 @@ def _compute_global_coordinate_bounds(self) -> List[float]:

return [x1_min, x1_max, x2_min, x2_max]

def _compute_x1x2_direction(self) -> str:
def _compute_x1x2_direction(self) -> dict:
"""Compute whether the x1 and x2 coords are ascending or descending.
Returns:
-------
coord_directions: dict(str)
Dictionary containing two keys: x1 and x2, with boolean values
defining if these coordings increase or decrease from top left corner.
dict(bool)
Dictionary containing two keys: x1 and x2, with boolean values
defining if these coordings increase or decrease from top left corner.
Raises:
ValueError:
If all datasets are non-gridded or if direction of ascending
coordinates does not match across non-gridded datasets.
"""
non_gridded = {"x1": None, "x2": None} # value to use for non-gridded data
ascending = []
for var in itertools.chain(self.context, self.target):
if isinstance(var, (xr.Dataset, xr.DataArray)):
coord_x1_left = var.x1[0]
coord_x1_right = var.x1[-1]
coord_x2_top = var.x2[0]
coord_x2_bottom = var.x2[-1]

x1_ascend = True if coord_x1_left <= coord_x1_right else False
x2_ascend = True if coord_x2_top <= coord_x2_bottom else False

coord_directions = {
"x1": x1_ascend,
"x2": x2_ascend,
}
ascending.append(
{
"x1": True if coord_x1_left <= coord_x1_right else False,
"x2": True if coord_x2_top <= coord_x2_bottom else False,
}
)

# TODO- what to input for pd.dataframe
elif isinstance(var, (pd.DataFrame, pd.Series)):
# var_x1_min = var.index.get_level_values("x1").min()
# var_x1_max = var.index.get_level_values("x1").max()
# var_x2_min = var.index.get_level_values("x2").min()
# var_x2_max = var.index.get_level_values("x2").max()
ascending.append(non_gridded)

coord_directions = {"x1": None, "x2": None}
if len(list(filter(lambda x: x != non_gridded, ascending))) == 0:
raise ValueError(
"All data is non gridded, can not proceed with sliding window sampling."
)

# get the directions for only the gridded data
gridded = list(filter(lambda x: x != non_gridded, ascending))
# raise error if directions don't match across gridded data
if gridded.count(gridded[0]) != len(gridded):
raise ValueError(
"Direction of ascending coordinates does not match across all gridded datasets."
)

return coord_directions
return gridded[0]

def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]:
"""Sample random window uniformly from global coordinates to slice data.
Expand Down

0 comments on commit 9d79b34

Please sign in to comment.