Skip to content

Commit

Permalink
Merge branch 'main' into create_cat_regressor
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 authored Nov 11, 2024
2 parents 36858d9 + 6dd0a7a commit be1bccc
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,13 +607,13 @@ def check_op(op):

@singledispatch
def axis_mul_or_truediv(
X: np.ndarray,
X: ArrayLike,
scaling_array: np.ndarray,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: np.ndarray | None = None,
out: ArrayLike | None = None,
) -> np.ndarray:
check_op(op)
scaling_array = broadcast_axis(scaling_array, axis)
Expand Down
42 changes: 22 additions & 20 deletions src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

if TYPE_CHECKING:
from numpy.typing import NDArray
from scipy import sparse as sp

CSMatrix = sp.csr_matrix | sp.csc_matrix


@njit
Expand All @@ -44,7 +47,9 @@ def _scale_sparse_numba(indptr, indices, data, *, std, mask_obs, clip):


@njit
def clip_array(X: np.ndarray, *, max_value: float = 10, zero_center: bool = True):
def clip_array(
X: NDArray[np.floating], *, max_value: float, zero_center: bool
) -> NDArray[np.floating]:
a_min, a_max = -max_value, max_value
if X.ndim > 1:
for r, c in numba.pndindex(X.shape):
Expand All @@ -61,6 +66,14 @@ def clip_array(X: np.ndarray, *, max_value: float = 10, zero_center: bool = True
return X


def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMatrix:
x = x.copy()
x[x > max_value] = max_value
if zero_center:
x[x < -max_value] = -max_value
return x


@renamed_arg("X", "data", pos_0=True)
@old_positionals("zero_center", "max_value", "copy", "layer", "obsm")
@singledispatch
Expand Down Expand Up @@ -187,7 +200,8 @@ def scale_array(
if zero_center:
if isinstance(X, DaskArray) and issparse(X._meta):
warnings.warn(
"zero-center being used with `DaskArray` sparse chunks. This can be bad if you have large chunks or intend to eventually read the whole data into memory.",
"zero-center being used with `DaskArray` sparse chunks. "
"This can be bad if you have large chunks or intend to eventually read the whole data into memory.",
UserWarning,
)
X -= mean
Expand All @@ -203,25 +217,13 @@ def scale_array(
# do the clipping
if max_value is not None:
logg.debug(f"... clipping at max_value {max_value}")
if isinstance(X, DaskArray) and issparse(X._meta):

def clip_set(x):
x = x.copy()
x[x > max_value] = max_value
if zero_center:
x[x < -max_value] = -max_value
return x

X = da.map_blocks(clip_set, X)
if isinstance(X, DaskArray):
clip = clip_set if issparse(X._meta) else clip_array
X = X.map_blocks(clip, max_value=max_value, zero_center=zero_center)
elif issparse(X):
X.data = clip_array(X.data, max_value=max_value, zero_center=False)
else:
if isinstance(X, DaskArray):
X = X.map_blocks(
clip_array, max_value=max_value, zero_center=zero_center
)
elif issparse(X):
X.data = clip_array(X.data, max_value=max_value, zero_center=False)
else:
X = clip_array(X, max_value=max_value, zero_center=zero_center)
X = clip_array(X, max_value=max_value, zero_center=zero_center)
if return_mean_std:
return X, mean, std
else:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
anndata_v0_8_constructor_compat,
check_rep_mutation,
check_rep_results,
maybe_dask_process_context,
)
from testing.scanpy._helpers.data import pbmc3k, pbmc68k_reduced
from testing.scanpy._pytest.params import ARRAY_TYPES
Expand Down Expand Up @@ -174,7 +175,8 @@ def test_scale_matrix_types(array_type, zero_center, max_value):
adata_casted = adata.copy()
adata_casted.X = array_type(adata_casted.raw.X)
sc.pp.scale(adata, zero_center=zero_center, max_value=max_value)
sc.pp.scale(adata_casted, zero_center=zero_center, max_value=max_value)
with maybe_dask_process_context():
sc.pp.scale(adata_casted, zero_center=zero_center, max_value=max_value)
X = adata_casted.X
if "dask" in array_type.__name__:
X = X.compute()
Expand Down

0 comments on commit be1bccc

Please sign in to comment.