Skip to content

Commit 120db93

Browse files
gokceneraslanUbuntu
authored andcommitted
Add replace option to subsample and rename function to sample (scverse#943)
1 parent b02b1ce commit 120db93

File tree

11 files changed

+391
-74
lines changed

11 files changed

+391
-74
lines changed

docs/api/deprecated.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
1212
pp.filter_genes_dispersion
1313
pp.normalize_per_cell
14+
pp.subsample
1415
```

docs/api/preprocessing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ For visual quality control, see {func}`~scanpy.pl.highest_expr_genes` and
3131
pp.normalize_total
3232
pp.regress_out
3333
pp.scale
34-
pp.subsample
34+
pp.sample
3535
pp.downsample_counts
3636
```
3737

docs/release-notes/943.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{func}`~scanpy.pp.sample` supports both upsampling and downsampling of observations and variables. {func}`~scanpy.pp.subsample` is now deprecated. {smaller}`G Eraslan` & {smaller}`P Angerer`

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ classifiers = [
4747
]
4848
dependencies = [
4949
"anndata>=0.8",
50-
"numpy>=1.23",
50+
"numpy>=1.24",
5151
"matplotlib>=3.6",
5252
"pandas >=1.5",
5353
"scipy>=1.8",
@@ -60,7 +60,7 @@ dependencies = [
6060
"networkx>=2.7",
6161
"natsort",
6262
"joblib",
63-
"numba>=0.56",
63+
"numba>=0.57",
6464
"umap-learn>=0.5,!=0.5.0",
6565
"pynndescent>=0.5",
6666
"packaging>=21.3",

src/scanpy/_compat.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
import warnings
66
from dataclasses import dataclass, field
7-
from functools import cache, partial, wraps
7+
from functools import WRAPPER_ASSIGNMENTS, cache, partial, wraps
88
from importlib.util import find_spec
99
from pathlib import Path
1010
from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload
@@ -224,3 +224,42 @@ def _numba_threading_layer() -> Layer:
224224
f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})"
225225
)
226226
raise ValueError(msg)
227+
228+
229+
def _legacy_numpy_gen(
230+
random_state: _LegacyRandom | None = None,
231+
) -> np.random.Generator:
232+
"""Return a random generator that behaves like the legacy one."""
233+
234+
if random_state is not None:
235+
if isinstance(random_state, np.random.RandomState):
236+
np.random.set_state(random_state.get_state(legacy=False))
237+
return _FakeRandomGen(random_state)
238+
np.random.seed(random_state)
239+
return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator()))
240+
241+
242+
class _FakeRandomGen(np.random.Generator):
243+
_state: np.random.RandomState
244+
245+
def __init__(self, random_state: np.random.RandomState) -> None:
246+
self._state = random_state
247+
248+
@classmethod
249+
def _delegate(cls) -> None:
250+
for name, meth in np.random.Generator.__dict__.items():
251+
if name.startswith("_") or not callable(meth):
252+
continue
253+
254+
def mk_wrapper(name: str):
255+
# Old pytest versions try to run the doctests
256+
@wraps(meth, assigned=set(WRAPPER_ASSIGNMENTS) - {"__doc__"})
257+
def wrapper(self: _FakeRandomGen, *args, **kwargs):
258+
return getattr(self._state, name)(*args, **kwargs)
259+
260+
return wrapper
261+
262+
setattr(cls, name, mk_wrapper(name))
263+
264+
265+
_FakeRandomGen._delegate()

src/scanpy/preprocessing/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..neighbors import neighbors
44
from ._combat import combat
55
from ._deprecated.highly_variable_genes import filter_genes_dispersion
6+
from ._deprecated.sampling import subsample
67
from ._highly_variable_genes import highly_variable_genes
78
from ._normalization import normalize_total
89
from ._pca import pca
@@ -17,8 +18,8 @@
1718
log1p,
1819
normalize_per_cell,
1920
regress_out,
21+
sample,
2022
sqrt,
21-
subsample,
2223
)
2324

2425
__all__ = [
@@ -40,6 +41,7 @@
4041
"log1p",
4142
"normalize_per_cell",
4243
"regress_out",
44+
"sample",
4345
"scale",
4446
"sqrt",
4547
"subsample",
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from ..._compat import _legacy_numpy_gen, old_positionals
6+
from .._simple import sample
7+
8+
if TYPE_CHECKING:
9+
import numpy as np
10+
from anndata import AnnData
11+
from numpy.typing import NDArray
12+
from scipy.sparse import csc_matrix, csr_matrix
13+
14+
from ..._compat import _LegacyRandom
15+
16+
CSMatrix = csr_matrix | csc_matrix
17+
18+
19+
@old_positionals("n_obs", "random_state", "copy")
20+
def subsample(
21+
data: AnnData | np.ndarray | CSMatrix,
22+
fraction: float | None = None,
23+
*,
24+
n_obs: int | None = None,
25+
random_state: _LegacyRandom = 0,
26+
copy: bool = False,
27+
) -> AnnData | tuple[np.ndarray | CSMatrix, NDArray[np.int64]] | None:
28+
"""\
29+
Subsample to a fraction of the number of observations.
30+
31+
.. deprecated:: 1.11.0
32+
33+
Use :func:`~scanpy.pp.sample` instead.
34+
35+
Parameters
36+
----------
37+
data
38+
The (annotated) data matrix of shape `n_obs` × `n_vars`.
39+
Rows correspond to cells and columns to genes.
40+
fraction
41+
Subsample to this `fraction` of the number of observations.
42+
n_obs
43+
Subsample to this number of observations.
44+
random_state
45+
Random seed to change subsampling.
46+
copy
47+
If an :class:`~anndata.AnnData` is passed,
48+
determines whether a copy is returned.
49+
50+
Returns
51+
-------
52+
Returns `X[obs_indices], obs_indices` if data is array-like, otherwise
53+
subsamples the passed :class:`~anndata.AnnData` (`copy == False`) or
54+
returns a subsampled copy of it (`copy == True`).
55+
"""
56+
57+
rng = _legacy_numpy_gen(random_state)
58+
return sample(
59+
data=data, fraction=fraction, n=n_obs, rng=rng, copy=copy, replace=False, axis=0
60+
)

0 commit comments

Comments
 (0)