Skip to content

Commit

Permalink
relabel block (#664)
Browse files Browse the repository at this point in the history
* relabel block

* dtype check array

* fix mypy

* type IntArray

* add sequential relabeling helper function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reintroduce DataArray

* revert import removal

* Update src/spatialdata/_core/operations/map.py

Co-authored-by: Wouter-Michiel Vierdag <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adjusted error message

* fix test

* make relabel_sequential public

* fix test

* adjust docstring

* adjust docstring, add comment

* IntArray fix type

* add comment

* add edge cases to test

* remove default arg

* change docstring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor renaming

---------

Co-authored-by: Giovanni Palla <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Wouter-Michiel Vierdag <[email protected]>
Co-authored-by: Wouter-Michiel Vierdag <[email protected]>
Co-authored-by: LucaMarconato <[email protected]>
Co-authored-by: Luca Marconato <[email protected]>
  • Loading branch information
7 people authored Nov 27, 2024
1 parent c332b18 commit 72dbffd
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Operations on `SpatialData` objects.
to_polygons
aggregate
map_raster
relabel_sequential
```

### Operations Utilities
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"save_transformations",
"get_dask_backing_files",
"are_extents_equal",
"relabel_sequential",
"map_raster",
"deepcopy",
]
Expand All @@ -58,7 +59,7 @@
from spatialdata._core.concatenate import concatenate
from spatialdata._core.data_extent import are_extents_equal, get_extent
from spatialdata._core.operations.aggregate import aggregate
from spatialdata._core.operations.map import map_raster
from spatialdata._core.operations.map import map_raster, relabel_sequential
from spatialdata._core.operations.rasterize import rasterize
from spatialdata._core.operations.rasterize_bins import rasterize_bins
from spatialdata._core.operations.transform import transform
Expand Down
115 changes: 114 additions & 1 deletion src/spatialdata/_core/operations/map.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import annotations

import math
import operator
from collections.abc import Callable, Iterable, Mapping
from functools import reduce
from types import MappingProxyType
from typing import TYPE_CHECKING, Any

import dask.array as da
import numpy as np
from dask.array.overlap import coerce_depth
from xarray import DataArray, DataTree

from spatialdata._types import IntArrayLike
from spatialdata.models._utils import get_axes_names, get_channel_names, get_raster_model_from_data_dims
from spatialdata.transformations import get_transformation

__all__ = ["map_raster"]
__all__ = ["map_raster", "relabel_sequential"]


def map_raster(
Expand All @@ -24,6 +29,7 @@ def map_raster(
c_coords: Iterable[int] | Iterable[str] | None = None,
dims: tuple[str, ...] | None = None,
transformations: dict[str, Any] | None = None,
relabel: bool = True,
**kwargs: Any,
) -> DataArray:
"""
Expand Down Expand Up @@ -68,6 +74,13 @@ def map_raster(
transformations
The transformations of the output data. If not provided, the transformations of the input data are copied to the
output data. It should be specified if the callable changes the data transformations.
relabel
Whether to relabel the blocks of the output data.
This option is ignored when the output data is not a labels layer (i.e., when `dims` does not contain `c`).
It is recommended to enable relabeling if `func` returns labels that are not unique across chunks.
Relabeling will be done by performing a bit shift. When a cell or entity to be labeled is split between two
adjacent chunks, the current implementation does not assign the same label across blocks.
See https://github.com/scverse/spatialdata/pull/664 for discussion.
kwargs
Additional keyword arguments to pass to :func:`dask.array.map_overlap` or :func:`dask.array.map_blocks`.
Ignored if `blockwise` is set to `False`.
Expand Down Expand Up @@ -130,6 +143,9 @@ def map_raster(
assert isinstance(d, dict)
transformations = d

if "c" not in dims and relabel:
arr = _relabel(arr)

model_kwargs = {
"chunks": arr.chunksize,
"c_coords": c_coords,
Expand All @@ -138,3 +154,100 @@ def map_raster(
}
model = get_raster_model_from_data_dims(dims)
return model.parse(arr, **model_kwargs)


def _relabel(arr: da.Array) -> da.Array:
if not np.issubdtype(arr.dtype, np.integer):
raise ValueError(f"Relabeling is only supported for arrays of type {np.integer}.")
num_blocks = arr.numblocks

shift = (math.prod(num_blocks) - 1).bit_length()

meta = np.empty((0,) * arr.ndim, dtype=arr.dtype)

def _relabel_block(
block: IntArrayLike, block_id: tuple[int, ...], num_blocks: tuple[int, ...], shift: int
) -> IntArrayLike:
def _calculate_block_num(block_id: tuple[int, ...], num_blocks: tuple[int, ...]) -> int:
if len(num_blocks) != len(block_id):
raise ValueError("num_blocks and block_id must have the same length")
block_num = 0
for i in range(len(num_blocks)):
multiplier = reduce(operator.mul, num_blocks[i + 1 :], 1)
block_num += block_id[i] * multiplier
return block_num

available_bits = np.iinfo(block.dtype).max.bit_length()
max_bits_block = int(block.max()).bit_length()

if max_bits_block + shift > available_bits:
# Note: because of no harmonization across blocks, adjusting number of chunks lowers the required bits.
raise ValueError(
f"Relabel was set to True, but "
f"the number of bits required to represent the labels in the block ({max_bits_block}) "
f"+ required shift ({shift}) exceeds the available_bits ({available_bits}). In other words"
f"the number of labels exceeds the number of integers that can be represented by the dtype"
"of the individual blocks."
"To solve this issue, please consider the following solutions:"
" 1. Rechunking using a larger chunk size, lowering the number of blocks and thereby"
" lowering the value of required shift."
" 2. Cast to a data type with a higher maximum value "
" 3. Perform sequential relabeling of the dask array using `relabel_sequential` in `spatialdata`,"
" potentially lowering the maximum value of a label (though number of distinct labels values "
" stays the same). For example if the unique labels values are `[0, 1, 1000]`, after the "
" sequential relabeling the unique labels value will be `[0, 1, 2]`, thus requiring less bits "
" to store the labels."
)

block_num = _calculate_block_num(block_id=block_id, num_blocks=num_blocks)

mask = block > 0
block[mask] = (block[mask] << shift) | block_num

return block

return da.map_blocks(
_relabel_block,
arr,
dtype=arr.dtype,
num_blocks=num_blocks,
shift=shift,
meta=meta,
)


def relabel_sequential(arr: da.Array) -> da.Array:
"""
Relabels integers in a Dask array sequentially.
This function assigns sequential labels to the integers in a Dask array starting from 1.
For example, if the unique values in the input array are [0, 9, 5],
they will be relabeled to [0, 1, 2] respectively.
Note that currently if a cell or entity to be labeled is split across adjacent chunks the same label is not
assigned to the cell across blocks. See discussion https://github.com/scverse/spatialdata/pull/664.
Parameters
----------
arr
input array.
Returns
-------
The relabeled array.
"""
if not np.issubdtype(arr.dtype, np.integer):
raise ValueError(f"Sequential relabeling is only supported for arrays of type {np.integer}.")

unique_labels = da.unique(arr).compute()
if 0 not in unique_labels:
# otherwise first non zero label would be relabeled to 0
unique_labels = np.insert(unique_labels, 0, 0)

max_label = unique_labels[-1]

new_labeling = da.full(max_label + 1, -1, dtype=arr.dtype)

# Note that both sides are ordered as da.unique returns an ordered array.
new_labeling[unique_labels] = da.arange(len(unique_labels), dtype=arr.dtype)

return da.map_blocks(operator.getitem, new_labeling, arr, dtype=arr.dtype, chunks=arr.chunks)
3 changes: 3 additions & 0 deletions src/spatialdata/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from numpy.typing import DTypeLike, NDArray

ArrayLike = NDArray[np.float64]
IntArrayLike = NDArray[np.int64] # or any np.integer

except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
IntArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc, assignment]

Raster_T = DataArray | DataTree
Expand Down
125 changes: 124 additions & 1 deletion tests/core/operations/test_map.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import math
import re

import dask.array as da
import numpy as np
import pytest
from xarray import DataArray

from spatialdata._core.operations.map import map_raster
from spatialdata._core.operations.map import map_raster, relabel_sequential
from spatialdata.transformations import Translation, get_transformation, set_transformation


Expand All @@ -28,6 +30,11 @@ def _multiply_to_labels(arr, parameter=10):
return arr[0].astype(np.int32)


def _to_constant(arr, constant):
arr[arr > 0] = constant
return arr


@pytest.mark.parametrize(
"depth",
[
Expand All @@ -47,6 +54,7 @@ def test_map_raster(sdata_blobs, depth, element_name):
func_kwargs=func_kwargs,
c_coords=None,
depth=depth,
relabel=False,
)

assert isinstance(se, DataArray)
Expand Down Expand Up @@ -162,6 +170,7 @@ def test_map_to_labels_(sdata_blobs, blockwise, chunks, drop_axis):
chunks=chunks,
drop_axis=drop_axis,
dims=("y", "x"),
relabel=False,
)

data = sdata_blobs[img_layer].data.compute()
Expand Down Expand Up @@ -249,3 +258,117 @@ def test_invalid_map_raster(sdata_blobs):
c_coords=["c"],
depth=(0, 60, 60),
)


def test_map_raster_relabel(sdata_blobs):
constant = 2047
func_kwargs = {"constant": constant}

element_name = "blobs_labels"
se = map_raster(
sdata_blobs[element_name].chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)

# check if labels in different blocks are all mapped to a different value
assert isinstance(se, DataArray)
se.data.compute()
a = set()
for chunk in se.data.to_delayed().flatten():
chunk = chunk.compute()
b = set(np.unique(chunk))
b.remove(0)
assert not b.intersection(a)
a.update(b)
# 9 blocks, each block contains 'constant' left shifted by (9-1).bit_length() + block_num.
shift = (math.prod(se.data.numblocks) - 1).bit_length()
assert a == set(range(constant << shift, (constant << shift) + math.prod(se.data.numblocks)))


def test_map_raster_relabel_fail(sdata_blobs):
constant = 2048
func_kwargs = {"constant": constant}

element_name = "blobs_labels"

# Testing the case of having insufficient number of bits.
with pytest.raises(
ValueError,
match=re.escape("Relabel was set to True, but"),
):
se = map_raster(
sdata_blobs[element_name].chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)

se.data.compute()

constant = 2047
func_kwargs = {"constant": constant}

element_name = "blobs_labels"
with pytest.raises(
ValueError,
match=re.escape(f"Relabeling is only supported for arrays of type {np.integer}."),
):
se = map_raster(
sdata_blobs[element_name].astype(float).chunk((100, 100)),
func=_to_constant,
func_kwargs=func_kwargs,
c_coords=None,
depth=None,
relabel=True,
)


def test_relabel_sequential(sdata_blobs):
def _is_sequential(arr):
if arr.ndim != 1:
raise ValueError("Input array must be one-dimensional")
sorted_arr = np.sort(arr)
expected_sequence = np.arange(sorted_arr[0], sorted_arr[0] + len(sorted_arr))
return np.array_equal(sorted_arr, expected_sequence)

arr = sdata_blobs["blobs_labels"].data.rechunk(100)

arr_relabeled = relabel_sequential(arr)

labels_relabeled = da.unique(arr_relabeled).compute()
labels_original = da.unique(arr).compute()

assert labels_relabeled.shape == labels_original.shape
assert _is_sequential(labels_relabeled)

# test some edge cases
arr = da.asarray(np.array([0]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([0]))

arr = da.asarray(np.array([1]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([1]))

arr = da.asarray(np.array([2]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([1]))

arr = da.asarray(np.array([2, 0]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([1, 0]))

arr = da.asarray(np.array([0, 9, 5]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([0, 2, 1]))

arr = da.asarray(np.array([4, 1, 3]))
assert np.array_equal(relabel_sequential(arr).compute(), np.array([3, 1, 2]))


def test_relabel_sequential_fails(sdata_blobs):
with pytest.raises(
ValueError, match=re.escape(f"Sequential relabeling is only supported for arrays of type {np.integer}.")
):
relabel_sequential(sdata_blobs["blobs_labels"].data.astype(float))

0 comments on commit 72dbffd

Please sign in to comment.