Skip to content

Commit

Permalink
Fix util
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Nov 21, 2024
1 parent 59adc76 commit ef27db0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
13 changes: 7 additions & 6 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,13 @@ def _legacy_numpy_gen(
random_state: _LegacyRandom | None = None,
) -> np.random.Generator:
"""Return a random generator that behaves like the legacy one."""
if random_state is None:
return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator()))
if isinstance(random_state, np.random.RandomState):
np.random.set_state(random_state.get_state(legacy=False))
return _FakeRandomGen(random_state)
np.random.seed(random_state)

if random_state is not None:
if isinstance(random_state, np.random.RandomState):
np.random.set_state(random_state.get_state(legacy=False))
return _FakeRandomGen(random_state)
np.random.seed(random_state)
return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator()))


class _FakeRandomGen(np.random.Generator):
Expand Down
14 changes: 9 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,18 @@ def test_is_constant_dask(request: pytest.FixtureRequest, axis, expected, block_


@pytest.mark.parametrize("seed", [0, 1, 1256712675])
@pytest.mark.parametrize("pass_seed", [True, False], ids=["pass_seed", "set_seed"])
@pytest.mark.parametrize("func", ["choice"])
def test_legacy_numpy_gen(seed: int, func: str):
def test_legacy_numpy_gen(*, seed: int, pass_seed: bool, func: str):
np.random.seed(seed)
state_before = np.random.get_state(legacy=False)

arrs = {}
states_after = {}
for direct in [True, False]:
np.random.seed(seed)
arrs[direct] = _mk_random(func, direct=direct)
if not pass_seed:
np.random.seed(seed)
arrs[direct] = _mk_random(func, direct=direct, seed=seed if pass_seed else None)
states_after[direct] = np.random.get_state(legacy=False)

np.testing.assert_array_equal(arrs[True], arrs[False])
Expand All @@ -271,8 +273,10 @@ def test_legacy_numpy_gen(seed: int, func: str):
np.testing.assert_equal(states_after[True], state_before)


def _mk_random(func: str, *, direct: bool) -> np.ndarray:
gen = np.random if direct else _legacy_numpy_gen()
def _mk_random(func: str, *, direct: bool, seed: int | None) -> np.ndarray:
if direct and seed is not None:
np.random.seed(seed)
gen = np.random if direct else _legacy_numpy_gen(seed)
match func:
case "choice":
arr = np.arange(1000)
Expand Down

0 comments on commit ef27db0

Please sign in to comment.