Skip to content

Commit

Permalink
handle array case in test
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Nov 14, 2024
1 parent 8528f2d commit 06d4280
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_normalize_per_cell():


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("copy", [True, False], ids=["copy", "inplace"])
@pytest.mark.parametrize("which", ["copy", "inplace", "array"])
@pytest.mark.parametrize(
("axis", "fraction", "n", "replace", "expected"),
[
Expand All @@ -160,7 +160,7 @@ def test_normalize_per_cell():
def test_sample(
*,
array_type: Callable[[np.ndarray], np.ndarray | CSMatrix],
copy: bool,
which: Literal["copy", "inplace", "array"],
axis: Literal[0, 1],
fraction: float | None,
n: int | None,
Expand All @@ -174,14 +174,30 @@ def test_sample(
warnings.filterwarnings(
"ignore" if replace else "error", r".*names are not unique", UserWarning
)
rv = sc.pp.sample(adata, fraction, n=n, replace=replace, axis=axis, copy=copy)
rv = sc.pp.sample(
adata.X if which == "array" else adata,
fraction,
n=n,
replace=replace,
axis=axis,
# `copy` only effects AnnData inputs
copy=dict(copy=True, inplace=False, array=False)[which],
)

if copy:
assert adata.shape == (200, 10)
subset = rv
else:
assert rv is None
subset = adata
match which:
case "copy":
subset = rv
assert rv is not adata
assert adata.shape == (200, 10)
case "inplace":
subset = adata
assert rv is None
case "array":
subset, indices = rv
assert len(indices) == expected
assert adata.shape == (200, 10)
case _:
pytest.fail(f"Unknown `{which=}`")

assert subset.shape == ((expected, 10) if axis == 0 else (200, expected))

Expand Down

0 comments on commit 06d4280

Please sign in to comment.