Skip to content

Commit

Permalink
Merge pull request #602 from ameraner/fix_bucket_sum_for_fillvalue
Browse files Browse the repository at this point in the history
Add support for `fill_value` and `set_empty_bucket_to` in BucketResampler `get_sum`
  • Loading branch information
djhoese authored Jul 24, 2024
2 parents 3651ce9 + 05920a8 commit 60629b7
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 322 deletions.
53 changes: 37 additions & 16 deletions pyresample/bucket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _expand_bin_statistics(bins, unique_bin, unique_idx, weights_sorted):
# assign the valid index to array
weight_idx[unique_bin[~unique_bin.mask].data] = unique_idx[~unique_bin.mask]

return weights_sorted[weight_idx] # last value of weigths_sorted always nan
return weights_sorted[weight_idx] # last value of weights_sorted always nan


@dask.delayed(pure=True)
Expand Down Expand Up @@ -202,20 +202,29 @@ def _get_indices(self):
target_shape = self.target_area.shape
self.idxs = self.y_idxs * target_shape[1] + self.x_idxs

def get_sum(self, data, skipna=True):
def get_sum(self, data, fill_value=np.nan, skipna=True, empty_bucket_value=0):
"""Calculate sums for each bin with drop-in-a-bucket resampling.
Parameters
----------
data : Numpy or Dask array
Data to be binned and summed.
fill_value : float
Fill value of the input data marking missing/invalid values.
Default: np.nan
skipna : boolean (optional)
If True, skips NaN values for the sum calculation
(similarly to Numpy's `nansum`). Buckets containing only NaN are set to zero.
If False, sets the bucket to NaN if one or more NaN values are present in the bucket
(similarly to Numpy's `sum`).
In both cases, empty buckets are set to 0.
Default: True
If True, skips missing values (as marked by NaN or `fill_value`) for the sum calculation
(similarly to Numpy's `nansum`). Buckets containing only missing values are set to `empty_bucket_value`.
If False, sets the bucket to fill_value if one or more missing values are present in the bucket
(similarly to Numpy's `sum`).
In both cases, empty buckets are set to `empty_bucket_value`.
Default: True
empty_bucket_value : float
Set empty buckets to the given value. Empty buckets are considered as the buckets with value 0.
Note that a bucket could become 0 as the result of a sum
of positive and negative values. If the user needs to identify these zero-buckets reliably,
`get_count()` can be used for this purpose.
Default: 0
Returns
-------
Expand All @@ -228,8 +237,9 @@ def get_sum(self, data, skipna=True):
data = data.data
data = data.ravel()

# Remove NaN values from the data when used as weights
weights = da.where(np.isnan(data), 0, data)
# Remove fill_values values from the data when used as weights
invalid_mask = _get_invalid_mask(data, fill_value)
weights = da.where(invalid_mask, 0, data)

# Rechunk indices to match the data chunking
if weights.chunks != self.idxs.chunks:
Expand All @@ -241,16 +251,19 @@ def get_sum(self, data, skipna=True):
weights=weights, density=False)

# TODO remove following line in favour of weights = data when dask histogram bug (issue #6935) is fixed
sums = self._mask_bins_with_nan_if_not_skipna(skipna, data, out_size, sums)
sums = self._mask_bins_with_nan_if_not_skipna(skipna, data, out_size, sums, fill_value)

if empty_bucket_value != 0:
sums = da.where(sums == 0, empty_bucket_value, sums)

return sums.reshape(self.target_area.shape)

def _mask_bins_with_nan_if_not_skipna(self, skipna, data, out_size, statistic):
def _mask_bins_with_nan_if_not_skipna(self, skipna, data, out_size, statistic, fill_value):
if not skipna:
nans = np.isnan(data)
nan_bins, _ = da.histogram(self.idxs[nans], bins=out_size,
range=(0, out_size))
statistic = da.where(nan_bins > 0, np.nan, statistic)
missing_val = _get_invalid_mask(data, fill_value)
missing_val_bins, _ = da.histogram(self.idxs[missing_val], bins=out_size,
range=(0, out_size))
statistic = da.where(missing_val_bins > 0, fill_value, statistic)
return statistic

def _call_bin_statistic(self, statistic_method, data, fill_value=None, skipna=None):
Expand Down Expand Up @@ -456,6 +469,14 @@ def get_fractions(self, data, categories=None, fill_value=np.nan):
return results


def _get_invalid_mask(data, fill_value):
"""Get a boolean array where values equal to fill_value in data are True."""
if np.isnan(fill_value):
return np.isnan(data)
else:
return data == fill_value


def round_to_resolution(arr, resolution):
"""Round the values in *arr* to closest resolution element.
Expand Down
Loading

0 comments on commit 60629b7

Please sign in to comment.