From 20b5f14c0f34c5c1202241b322874fb2d608d298 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 24 Oct 2024 17:28:43 +0200 Subject: [PATCH 1/6] Avoid parallel numba within dask --- src/scanpy/_utils/__init__.py | 13 ++++++ src/scanpy/preprocessing/_scale.py | 63 +++++++++++++++++++----------- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 883f8b97b9..f262f17aeb 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -582,6 +582,19 @@ def check_op(op): @singledispatch def axis_mul_or_truediv( + X: Any, + scaling_array: Any, + axis: Literal[0, 1], + op: Callable[[Any, Any], Any], + *, + allow_divide_by_zero: bool = True, + out: Any | None = None, +) -> Any: + raise NotImplementedError + + +@axis_mul_or_truediv.register(np.ndarray) +def _( X: np.ndarray, scaling_array: np.ndarray, axis: Literal[0, 1], diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index be452c356d..7ca1f27943 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from functools import singledispatch +from functools import partial, singledispatch from operator import truediv from typing import TYPE_CHECKING @@ -30,6 +30,9 @@ if TYPE_CHECKING: from numpy.typing import NDArray + from scipy import sparse as sp + + CSMatrix = sp.csr_matrix | sp.csc_matrix @numba.njit(cache=True, parallel=True) @@ -43,8 +46,9 @@ def _scale_sparse_numba(indptr, indices, data, *, std, mask_obs, clip): data[j] /= std[indices[j]] -@numba.njit(parallel=True, cache=True) -def clip_array(X: np.ndarray, *, max_value: float = 10, zero_center: bool = True): +def _clip_array( + X: NDArray[np.floating], *, max_value: float, zero_center: bool +) -> NDArray[np.floating]: a_min, a_max = -max_value, max_value if X.ndim > 1: for r, c in numba.pndindex(X.shape): @@ -61,6 +65,30 @@ def clip_array(X: np.ndarray, *, max_value: float = 10, zero_center: bool = True return X +_clip_array_fns = { + parallel: numba.njit(_clip_array, parallel=parallel, cache=True) + for parallel in (True, False) +} + + +def clip_array( + x: NDArray[np.floating], + *, + max_value: float = 10, + zero_center: bool = True, + parallel: bool = True, +) -> NDArray[np.floating]: + return _clip_array_fns[parallel](x, max_value=max_value, zero_center=zero_center) + + +def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMatrix: + x = x.copy() + x[x > max_value] = max_value + if zero_center: + x[x < -max_value] = -max_value + return x + + @renamed_arg("X", "data", pos_0=True) @old_positionals("zero_center", "max_value", "copy", "layer", "obsm") @singledispatch @@ -188,7 +216,8 @@ def scale_array( if zero_center: if isinstance(X, DaskArray) and issparse(X._meta): warnings.warn( - "zero-center being used with `DaskArray` sparse chunks. This can be bad if you have large chunks or intend to eventually read the whole data into memory.", + "zero-center being used with `DaskArray` sparse chunks. " + "This can be bad if you have large chunks or intend to eventually read the whole data into memory.", UserWarning, ) X -= mean @@ -204,25 +233,15 @@ def scale_array( # do the clipping if max_value is not None: logg.debug(f"... clipping at max_value {max_value}") - if isinstance(X, DaskArray) and issparse(X._meta): - - def clip_set(x): - x = x.copy() - x[x > max_value] = max_value - if zero_center: - x[x < -max_value] = -max_value - return x - - X = da.map_blocks(clip_set, X) + if isinstance(X, DaskArray): + clip = ( + clip_set if issparse(X._meta) else partial(clip_array, parallel=False) + ) + X = X.map_blocks(clip, max_value=max_value, zero_center=zero_center) + elif issparse(X): + X.data = clip_array(X.data, max_value=max_value, zero_center=False) else: - if isinstance(X, DaskArray): - X = X.map_blocks( - clip_array, max_value=max_value, zero_center=zero_center - ) - elif issparse(X): - X.data = clip_array(X.data, max_value=max_value, zero_center=False) - else: - X = clip_array(X, max_value=max_value, zero_center=zero_center) + X = clip_array(X, max_value=max_value, zero_center=zero_center) if return_mean_std: return X, mean, std else: From 4ba3b21f5495102303182f9e40fadb1b0a09c24a Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 24 Oct 2024 17:58:03 +0200 Subject: [PATCH 2/6] restore zappy compat --- src/scanpy/_utils/__init__.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index f262f17aeb..12b6fe454d 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -47,7 +47,7 @@ from typing import Any, Literal, TypeVar from anndata import AnnData - from numpy.typing import DTypeLike, NDArray + from numpy.typing import ArrayLike, DTypeLike, NDArray from ..neighbors import NeighborsParams, RPForestDict @@ -582,26 +582,13 @@ def check_op(op): @singledispatch def axis_mul_or_truediv( - X: Any, - scaling_array: Any, - axis: Literal[0, 1], - op: Callable[[Any, Any], Any], - *, - allow_divide_by_zero: bool = True, - out: Any | None = None, -) -> Any: - raise NotImplementedError - - -@axis_mul_or_truediv.register(np.ndarray) -def _( - X: np.ndarray, + X: ArrayLike, scaling_array: np.ndarray, axis: Literal[0, 1], op: Callable[[Any, Any], Any], *, allow_divide_by_zero: bool = True, - out: np.ndarray | None = None, + out: ArrayLike | None = None, ) -> np.ndarray: check_op(op) scaling_array = broadcast_axis(scaling_array, axis) From 7fdeda1f70297e30c1990232a35a53bbfd33940c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 25 Oct 2024 10:19:20 +0200 Subject: [PATCH 3/6] only do it in tests --- src/scanpy/preprocessing/_scale.py | 25 ++++--------------------- src/testing/scanpy/_helpers/__init__.py | 24 +++++++++++++++++++++++- tests/test_preprocessing.py | 4 +++- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index 7ca1f27943..10c791af3f 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from functools import partial, singledispatch +from functools import singledispatch from operator import truediv from typing import TYPE_CHECKING @@ -46,7 +46,8 @@ def _scale_sparse_numba(indptr, indices, data, *, std, mask_obs, clip): data[j] /= std[indices[j]] -def _clip_array( +@numba.njit(cache=True, parallel=True) +def clip_array( X: NDArray[np.floating], *, max_value: float, zero_center: bool ) -> NDArray[np.floating]: a_min, a_max = -max_value, max_value @@ -65,22 +66,6 @@ def _clip_array( return X -_clip_array_fns = { - parallel: numba.njit(_clip_array, parallel=parallel, cache=True) - for parallel in (True, False) -} - - -def clip_array( - x: NDArray[np.floating], - *, - max_value: float = 10, - zero_center: bool = True, - parallel: bool = True, -) -> NDArray[np.floating]: - return _clip_array_fns[parallel](x, max_value=max_value, zero_center=zero_center) - - def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMatrix: x = x.copy() x[x > max_value] = max_value @@ -234,9 +219,7 @@ def scale_array( if max_value is not None: logg.debug(f"... clipping at max_value {max_value}") if isinstance(X, DaskArray): - clip = ( - clip_set if issparse(X._meta) else partial(clip_array, parallel=False) - ) + clip = clip_set if issparse(X._meta) else clip_array X = X.map_blocks(clip, max_value=max_value, zero_center=zero_center) elif issparse(X): X.data = clip_array(X.data, max_value=max_value, zero_center=False) diff --git a/src/testing/scanpy/_helpers/__init__.py b/src/testing/scanpy/_helpers/__init__.py index 0c59eb592f..3cff738132 100644 --- a/src/testing/scanpy/_helpers/__init__.py +++ b/src/testing/scanpy/_helpers/__init__.py @@ -5,8 +5,9 @@ from __future__ import annotations import warnings -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, contextmanager from dataclasses import dataclass +from importlib.util import find_spec from itertools import permutations from typing import TYPE_CHECKING @@ -158,3 +159,24 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): for ctx in reversed(self.contexts): ctx.__exit__(exc_type, exc_value, traceback) + + +@contextmanager +def maybe_dask_process_context(): + """ + Running numba with dask's threaded scheduler causes crashes, + so we need to switch to single-threaded (or processes, which is slower) + scheduler for tests that use numba. + """ + if not find_spec("dask"): + yield + return + + import dask.config + + prev_scheduler = dask.config.get("scheduler", "threads") + dask.config.set(scheduler="single-threaded") + try: + yield + finally: + dask.config.set(scheduler=prev_scheduler) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index a63011bff4..85bbedfe10 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -16,6 +16,7 @@ anndata_v0_8_constructor_compat, check_rep_mutation, check_rep_results, + maybe_dask_process_context, ) from testing.scanpy._helpers.data import pbmc3k, pbmc68k_reduced from testing.scanpy._pytest.params import ARRAY_TYPES @@ -168,7 +169,8 @@ def test_scale_matrix_types(array_type, zero_center, max_value): adata_casted = adata.copy() adata_casted.X = array_type(adata_casted.raw.X) sc.pp.scale(adata, zero_center=zero_center, max_value=max_value) - sc.pp.scale(adata_casted, zero_center=zero_center, max_value=max_value) + with maybe_dask_process_context(): + sc.pp.scale(adata_casted, zero_center=zero_center, max_value=max_value) X = adata_casted.X if "dask" in array_type.__name__: X = X.compute() From f5eaa1266c52350d3a5f843b8d06b47a2f1b7256 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 7 Nov 2024 13:17:10 +0100 Subject: [PATCH 4/6] relnote --- docs/release-notes/3317.bugfix.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/3317.bugfix.md diff --git a/docs/release-notes/3317.bugfix.md b/docs/release-notes/3317.bugfix.md new file mode 100644 index 0000000000..ba7d435bff --- /dev/null +++ b/docs/release-notes/3317.bugfix.md @@ -0,0 +1 @@ +Fix zappy compatibility for clip_array {smaller}`P Angerer` From bbbf3f443f011c79a61ec5094cca9c6ccf4d9d84 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 11 Nov 2024 14:37:08 +0100 Subject: [PATCH 5/6] Simplify scale implementation --- src/scanpy/preprocessing/_scale.py | 129 +++++++---------------------- 1 file changed, 29 insertions(+), 100 deletions(-) diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index d7123d5f65..07f55089c3 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -8,7 +8,7 @@ import numba import numpy as np from anndata import AnnData -from scipy.sparse import issparse, isspmatrix_csc, spmatrix +from scipy.sparse import csc_matrix, csr_matrix, issparse from .. import logging as logg from .._compat import DaskArray, njit, old_positionals @@ -29,21 +29,30 @@ da = None if TYPE_CHECKING: - from numpy.typing import NDArray - from scipy import sparse as sp + from numpy.typing import ArrayLike, NDArray - CSMatrix = sp.csr_matrix | sp.csc_matrix + CSMatrix = csr_matrix | csc_matrix -@njit -def _scale_sparse_numba(indptr, indices, data, *, std, mask_obs, clip): - for i in numba.prange(len(indptr) - 1): - if mask_obs[i]: - for j in range(indptr[i], indptr[i + 1]): - if clip: - data[j] = min(clip, data[j] / std[indices[j]]) - else: - data[j] /= std[indices[j]] +@singledispatch +def clip( + x: ArrayLike, *, max_value: float, zero_center: bool = True +) -> NDArray[np.floating]: + return clip_array(x, max_value=max_value, zero_center=zero_center) + + +@clip.register(csr_matrix) +@clip.register(csc_matrix) +def _(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMatrix: + x.data = clip(x.data, max_value=max_value, zero_center=zero_center) + return x + + +@clip.register(DaskArray) +def _(x: DaskArray, *, max_value: float, zero_center: bool = True): + return x.map_blocks( + clip, max_value=max_value, zero_center=zero_center, dtype=x.dtype, meta=x._meta + ) @njit @@ -66,19 +75,11 @@ def clip_array( return X -def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMatrix: - x = x.copy() - x[x > max_value] = max_value - if zero_center: - x[x < -max_value] = -max_value - return x - - @renamed_arg("X", "data", pos_0=True) @old_positionals("zero_center", "max_value", "copy", "layer", "obsm") @singledispatch def scale( - data: AnnData | spmatrix | np.ndarray | DaskArray, + data: AnnData | CSMatrix | np.ndarray | DaskArray, *, zero_center: bool = True, max_value: float | None = None, @@ -86,7 +87,7 @@ def scale( layer: str | None = None, obsm: str | None = None, mask_obs: NDArray[np.bool_] | str | None = None, -) -> AnnData | spmatrix | np.ndarray | DaskArray | None: +) -> AnnData | CSMatrix | np.ndarray | DaskArray | None: """\ Scale data to unit variance and zero mean. @@ -147,8 +148,10 @@ def scale( @scale.register(np.ndarray) @scale.register(DaskArray) +@scale.register(csc_matrix) +@scale.register(csr_matrix) def scale_array( - X: np.ndarray | DaskArray, + X: np.ndarray | DaskArray | CSMatrix, *, zero_center: bool = True, max_value: float | None = None, @@ -210,87 +213,13 @@ def scale_array( X, std, op=truediv, - out=X if isinstance(X, np.ndarray) or issparse(X) else None, + out=X if isinstance(X, np.ndarray | csr_matrix | csc_matrix) else None, axis=1, ) # do the clipping if max_value is not None: - logg.debug(f"... clipping at max_value {max_value}") - if isinstance(X, DaskArray): - clip = clip_set if issparse(X._meta) else clip_array - X = X.map_blocks(clip, max_value=max_value, zero_center=zero_center) - elif issparse(X): - X.data = clip_array(X.data, max_value=max_value, zero_center=False) - else: - X = clip_array(X, max_value=max_value, zero_center=zero_center) - if return_mean_std: - return X, mean, std - else: - return X - - -@scale.register(spmatrix) -def scale_sparse( - X: spmatrix, - *, - zero_center: bool = True, - max_value: float | None = None, - copy: bool = False, - return_mean_std: bool = False, - mask_obs: NDArray[np.bool_] | None = None, -) -> np.ndarray | tuple[np.ndarray, NDArray[np.float64], NDArray[np.float64]]: - # need to add the following here to make inplace logic work - if zero_center: - logg.info( - "... as `zero_center=True`, sparse input is " - "densified and may lead to large memory consumption" - ) - X = X.toarray() - copy = False # Since the data has been copied - return scale_array( - X, - zero_center=zero_center, - copy=copy, - max_value=max_value, - return_mean_std=return_mean_std, - mask_obs=mask_obs, - ) - elif mask_obs is None: - return scale_array( - X, - zero_center=zero_center, - copy=copy, - max_value=max_value, - return_mean_std=return_mean_std, - mask_obs=mask_obs, - ) - else: - if isspmatrix_csc(X): - X = X.tocsr() - elif copy: - X = X.copy() - - if mask_obs is not None: - mask_obs = _check_mask(X, mask_obs, "obs") - - mean, var = _get_mean_var(X[mask_obs, :]) - - std = np.sqrt(var) - std[std == 0] = 1 - - if max_value is None: - max_value = 0 - - _scale_sparse_numba( - X.indptr, - X.indices, - X.data, - std=std.astype(X.dtype), - mask_obs=mask_obs, - clip=max_value, - ) - + X = clip(X, max_value=max_value, zero_center=zero_center) if return_mean_std: return X, mean, std else: From f57a9b0b4c3a60876ca2682f89729063ab0ca18b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 11 Nov 2024 14:38:41 +0100 Subject: [PATCH 6/6] Fix merge --- src/scanpy/preprocessing/_scale.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index 99ef971eb2..07f55089c3 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -30,10 +30,6 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray - from numpy.typing import NDArray - from scipy import sparse as sp - - CSMatrix = sp.csr_matrix | sp.csc_matrix CSMatrix = csr_matrix | csc_matrix @@ -79,14 +75,6 @@ def clip_array( return X -def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMatrix: - x = x.copy() - x[x > max_value] = max_value - if zero_center: - x[x < -max_value] = -max_value - return x - - @renamed_arg("X", "data", pos_0=True) @old_positionals("zero_center", "max_value", "copy", "layer", "obsm") @singledispatch