diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index 23722278..2879af45 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -4,7 +4,7 @@ from xarray import DataArray, DataTree -from spatialdata.models import SpatialElement +from spatialdata.models import SpatialElement, get_axes_names, get_spatial_axes if TYPE_CHECKING: from spatialdata._core.spatialdata import SpatialData @@ -114,12 +114,13 @@ def transform_to_data_extent( } for _, element_name, element in sdata_raster.gen_spatial_elements(): + element_axes = get_spatial_axes(get_axes_names(element)) if isinstance(element, DataArray | DataTree): rasterized = rasterize( element, - axes=data_extent_axes, - min_coordinate=[data_extent[ax][0] for ax in data_extent_axes], - max_coordinate=[data_extent[ax][1] for ax in data_extent_axes], + axes=element_axes, + min_coordinate=[data_extent[ax][0] for ax in element_axes], + max_coordinate=[data_extent[ax][1] for ax in element_axes], target_coordinate_system=coordinate_system, target_unit_to_pixels=None, target_width=target_width, diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index 969e685a..5acb5e4e 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -551,7 +551,10 @@ def rasterize_images_labels( target_coordinate_system=target_coordinate_system, ) - half_pixel_offset = Translation([0.5, 0.5, 0.5], axes=("z", "y", "x")) + if "z" in spatial_axes: + half_pixel_offset = Translation([0.5, 0.5, 0.5], axes=("z", "y", "x")) + else: + half_pixel_offset = Translation([0.5, 0.5], axes=("y", "x")) sequence = Sequence( [ # half_pixel_offset.inverse(), diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index a5949ab5..d4644d4a 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -10,6 +10,7 @@ from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData +from spatialdata._types import ArrayLike from spatialdata.datasets import blobs from spatialdata.models import ( Image2DModel, @@ -17,7 +18,6 @@ PointsModel, ShapesModel, TableModel, - get_model, get_table_keys, ) from spatialdata.testing import assert_elements_dict_are_identical, assert_spatial_data_objects_are_identical @@ -490,32 +490,46 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: "poly", ] full_sdata = full_sdata.subset(elements) + points = full_sdata["points_0"].compute() + points["z"] = points["x"] + points = PointsModel.parse(points) + full_sdata["points_0_3d"] = points sdata = transform_to_data_extent(full_sdata, "global", target_width=1000, maintain_positioning=maintain_positioning) - matrices = [] - for el in sdata._gen_spatial_element_values(): + first_a: ArrayLike | None = None + for _, name, el in sdata.gen_spatial_elements(): t = get_transformation(el, to_coordinate_system="global") assert isinstance(t, BaseTransformation) - a = t.to_affine_matrix(input_axes=("x", "y", "z"), output_axes=("x", "y", "z")) - matrices.append(a) + a = t.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + if first_a is None: + first_a = a + else: + # we are not pixel perfect because of this bug: https://github.com/scverse/spatialdata/issues/165 + if maintain_positioning and name in ["points_0_3d", "points_0", "poly", "circles", "multipoly"]: + # Again, due to the "pixel perfect" bug, the 0.5 translation forth and back in the z axis that is added + # by rasterize() (like the one in the example belows), amplifies the error also for x and y beyond the + # rtol threshold below. So, let's skip that check and to an absolute check up to 0.5 (due to the + # half-pixel offset). + # Sequence + # Translation (z, y, x) + # [-0.5 -0.5 -0.5] + # Scale (y, x) + # [0.17482681 0.17485125] + # Translation (y, x) + # [ -3.13652607 -164. ] + # Translation (z, y, x) + # [0.5 0.5 0.5] + assert np.allclose(a, first_a, atol=0.5) + else: + assert np.allclose(a, first_a, rtol=0.005) - first_a = matrices[0] - for a in matrices[1:]: - # we are not pixel perfect because of this bug: https://github.com/scverse/spatialdata/issues/165 - assert np.allclose(a, first_a, rtol=0.005) if not maintain_positioning: - assert np.allclose(first_a, np.eye(4)) + assert np.allclose(first_a, np.eye(3)) else: - for element in elements: - before = full_sdata[element] - after = sdata[element] - assert get_model(after) == get_model(before) - data_extent_before = get_extent(before, coordinate_system="global") - data_extent_after = get_extent(after, coordinate_system="global") - # huge tolerance because of the bug with pixel perfectness - assert are_extents_equal( - data_extent_before, data_extent_after, atol=4 - ), f"data_extent_before: {data_extent_before}, data_extent_after: {data_extent_after} for element {element}" + data_extent_before = get_extent(full_sdata, coordinate_system="global") + data_extent_after = get_extent(sdata, coordinate_system="global") + # again, due to the "pixel perfect" bug, we use an absolute tolerance of 0.5 + assert are_extents_equal(data_extent_before, data_extent_after, atol=0.5) def test_validate_table_in_spatialdata(full_sdata):