From ef27db0de5f948df25fecc39e16058ad949cf70f Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 21 Nov 2024 13:41:00 +0100 Subject: [PATCH] Fix util --- src/scanpy/_compat.py | 13 +++++++------ tests/test_utils.py | 14 +++++++++----- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/scanpy/_compat.py b/src/scanpy/_compat.py index 9c89f94f2..bf13cd232 100644 --- a/src/scanpy/_compat.py +++ b/src/scanpy/_compat.py @@ -204,12 +204,13 @@ def _legacy_numpy_gen( random_state: _LegacyRandom | None = None, ) -> np.random.Generator: """Return a random generator that behaves like the legacy one.""" - if random_state is None: - return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) - if isinstance(random_state, np.random.RandomState): - np.random.set_state(random_state.get_state(legacy=False)) - return _FakeRandomGen(random_state) - np.random.seed(random_state) + + if random_state is not None: + if isinstance(random_state, np.random.RandomState): + np.random.set_state(random_state.get_state(legacy=False)) + return _FakeRandomGen(random_state) + np.random.seed(random_state) + return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) class _FakeRandomGen(np.random.Generator): diff --git a/tests/test_utils.py b/tests/test_utils.py index aba645608..2cee7a3e3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -250,16 +250,18 @@ def test_is_constant_dask(request: pytest.FixtureRequest, axis, expected, block_ @pytest.mark.parametrize("seed", [0, 1, 1256712675]) +@pytest.mark.parametrize("pass_seed", [True, False], ids=["pass_seed", "set_seed"]) @pytest.mark.parametrize("func", ["choice"]) -def test_legacy_numpy_gen(seed: int, func: str): +def test_legacy_numpy_gen(*, seed: int, pass_seed: bool, func: str): np.random.seed(seed) state_before = np.random.get_state(legacy=False) arrs = {} states_after = {} for direct in [True, False]: - np.random.seed(seed) - arrs[direct] = _mk_random(func, direct=direct) + if not pass_seed: + np.random.seed(seed) + arrs[direct] = _mk_random(func, direct=direct, seed=seed if pass_seed else None) states_after[direct] = np.random.get_state(legacy=False) np.testing.assert_array_equal(arrs[True], arrs[False]) @@ -271,8 +273,10 @@ def test_legacy_numpy_gen(seed: int, func: str): np.testing.assert_equal(states_after[True], state_before) -def _mk_random(func: str, *, direct: bool) -> np.ndarray: - gen = np.random if direct else _legacy_numpy_gen() +def _mk_random(func: str, *, direct: bool, seed: int | None) -> np.ndarray: + if direct and seed is not None: + np.random.seed(seed) + gen = np.random if direct else _legacy_numpy_gen(seed) match func: case "choice": arr = np.arange(1000)