diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 1354777ed5..36c1726333 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -144,7 +144,7 @@ def test_normalize_per_cell(): @pytest.mark.parametrize("array_type", ARRAY_TYPES) -@pytest.mark.parametrize("copy", [True, False], ids=["copy", "inplace"]) +@pytest.mark.parametrize("which", ["copy", "inplace", "array"]) @pytest.mark.parametrize( ("axis", "fraction", "n", "replace", "expected"), [ @@ -160,7 +160,7 @@ def test_normalize_per_cell(): def test_sample( *, array_type: Callable[[np.ndarray], np.ndarray | CSMatrix], - copy: bool, + which: Literal["copy", "inplace", "array"], axis: Literal[0, 1], fraction: float | None, n: int | None, @@ -174,14 +174,30 @@ def test_sample( warnings.filterwarnings( "ignore" if replace else "error", r".*names are not unique", UserWarning ) - rv = sc.pp.sample(adata, fraction, n=n, replace=replace, axis=axis, copy=copy) + rv = sc.pp.sample( + adata.X if which == "array" else adata, + fraction, + n=n, + replace=replace, + axis=axis, + # `copy` only effects AnnData inputs + copy=dict(copy=True, inplace=False, array=False)[which], + ) - if copy: - assert adata.shape == (200, 10) - subset = rv - else: - assert rv is None - subset = adata + match which: + case "copy": + subset = rv + assert rv is not adata + assert adata.shape == (200, 10) + case "inplace": + subset = adata + assert rv is None + case "array": + subset, indices = rv + assert len(indices) == expected + assert adata.shape == (200, 10) + case _: + pytest.fail(f"Unknown `{which=}`") assert subset.shape == ((expected, 10) if axis == 0 else (200, expected))