diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 5a8b0288b8..8e886d1ff1 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -15,11 +15,11 @@ from collections import namedtuple from contextlib import contextmanager, suppress from enum import Enum -from functools import partial, singledispatch, wraps -from operator import mul, truediv +from functools import partial, reduce, singledispatch, wraps +from operator import mul, or_, truediv from textwrap import dedent -from types import MethodType, ModuleType -from typing import TYPE_CHECKING, overload +from types import MethodType, ModuleType, UnionType +from typing import TYPE_CHECKING, Literal, Union, get_args, get_origin, overload from weakref import WeakSet import h5py @@ -42,9 +42,9 @@ from anndata._core.sparse_dataset import SparseDataset if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping + from collections.abc import Callable, Iterable, KeysView, Mapping from pathlib import Path - from typing import Any, Literal, TypeVar + from typing import Any, TypeVar from anndata import AnnData from numpy.typing import DTypeLike, NDArray @@ -55,6 +55,7 @@ # e.g. https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html # maybe in the future random.Generator AnyRandom = int | np.random.RandomState | None +LegacyUnionType = type(Union[int, str]) # noqa: UP007 class Empty(Enum): @@ -532,6 +533,19 @@ def update_params( return updated_params +# `get_args` returns `tuple[Any]` so I don’t think it’s possible to get the correct type here +def get_literal_vals(typ: UnionType | Any) -> KeysView[Any]: + """Get all literal values from a Literal or Union of … of Literal type.""" + if isinstance(typ, UnionType | LegacyUnionType): + return reduce( + or_, (dict.fromkeys(get_literal_vals(t)) for t in get_args(typ)) + ).keys() + if get_origin(typ) is Literal: + return dict.fromkeys(get_args(typ)).keys() + msg = f"{typ} is not a valid Literal" + raise TypeError(msg) + + # -------------------------------------------------------------------------------- # Others # -------------------------------------------------------------------------------- diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index e95fedf9dc..2d2739491e 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import singledispatch -from typing import TYPE_CHECKING, Literal, get_args +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd @@ -9,7 +9,7 @@ from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 -from .._utils import _resolve_axis +from .._utils import _resolve_axis, get_literal_vals from .get import _check_mask if TYPE_CHECKING: @@ -19,7 +19,7 @@ Array = np.ndarray | sparse.csc_matrix | sparse.csr_matrix -# Used with get_args +# Used with get_literal_vals AggType = Literal["count_nonzero", "mean", "sum", "var", "median"] @@ -347,8 +347,8 @@ def aggregate_array( result = {} funcs = set([func] if isinstance(func, str) else func) - if unknown := funcs - set(get_args(AggType)): - raise ValueError(f"func {unknown} is not one of {get_args(AggType)}") + if unknown := funcs - get_literal_vals(AggType): + raise ValueError(f"func {unknown} is not one of {get_literal_vals(AggType)}") if "sum" in funcs: # sum is calculated separately from the rest agg = groupby.sum() diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 7b1c3f2506..379f34227b 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from textwrap import indent from types import MappingProxyType -from typing import TYPE_CHECKING, NamedTuple, TypedDict, get_args +from typing import TYPE_CHECKING, NamedTuple, TypedDict from warnings import warn import numpy as np @@ -16,7 +16,7 @@ from .. import logging as logg from .._compat import old_positionals from .._settings import settings -from .._utils import NeighborsView, _doc_params +from .._utils import NeighborsView, _doc_params, get_literal_vals from . import _connectivity from ._common import ( _get_indices_distances_from_sparse_matrix, @@ -652,7 +652,9 @@ def _handle_transformer( raise ValueError(msg) method = "umap" transformer = "rapids" - elif method not in (methods := set(get_args(_Method))) and method is not None: + elif ( + method not in (methods := get_literal_vals(_Method)) and method is not None + ): msg = f"`method` needs to be one of {methods}." raise ValueError(msg) @@ -704,7 +706,7 @@ def _handle_transformer( elif isinstance(transformer, str): msg = ( f"Unknown transformer: {transformer}. " - f"Try passing a class or one of {set(get_args(_KnownTransformer))}" + f"Try passing a class or one of {get_literal_vals(_KnownTransformer)}" ) raise ValueError(msg) # else `transformer` is probably an instance diff --git a/src/scanpy/neighbors/_types.py b/src/scanpy/neighbors/_types.py index d98ec76af3..39f50284ec 100644 --- a/src/scanpy/neighbors/_types.py +++ b/src/scanpy/neighbors/_types.py @@ -11,7 +11,7 @@ from scipy.sparse import spmatrix -# These two are used with get_args elsewhere +# These two are used with get_literal_vals elsewhere _Method = Literal["umap", "gauss"] _KnownTransformer = Literal["pynndescent", "sklearn", "rapids"] diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index c1918878c8..0ae810b2c7 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -6,7 +6,7 @@ from collections.abc import Collection, Mapping, Sequence from itertools import product from types import NoneType -from typing import TYPE_CHECKING, cast, get_args +from typing import TYPE_CHECKING, cast import matplotlib as mpl import numpy as np @@ -22,7 +22,13 @@ from .. import logging as logg from .._compat import old_positionals from .._settings import settings -from .._utils import _check_use_raw, _doc_params, _empty, sanitize_anndata +from .._utils import ( + _check_use_raw, + _doc_params, + _empty, + get_literal_vals, + sanitize_anndata, +) from . import _utils from ._docs import ( doc_common_plot_args, @@ -65,9 +71,6 @@ _VarNames = str | Sequence[str] -VALID_LEGENDLOCS = frozenset(get_args(_utils._LegendLoc)) - - @old_positionals( "color", "use_raw", @@ -268,9 +271,9 @@ def _scatter_obs( if use_raw and layers not in [("X", "X", "X"), (None, None, None)]: ValueError("`use_raw` must be `False` if layers are used.") - if legend_loc not in VALID_LEGENDLOCS: + if legend_loc not in (valid_legend_locs := get_literal_vals(_utils._LegendLoc)): raise ValueError( - f"Invalid `legend_loc`, need to be one of: {VALID_LEGENDLOCS}." + f"Invalid `legend_loc`, need to be one of: {valid_legend_locs}." ) if components is None: components = "1,2" if "2d" in projection else "1,2,3" diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 354848ea7d..dba47d821c 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Literal, get_args, overload +from typing import TYPE_CHECKING, Literal, overload from warnings import warn import anndata as ad @@ -14,13 +14,14 @@ from ... import logging as logg from ..._compat import DaskArray, pkg_version from ..._settings import settings -from ..._utils import _doc_params, _empty, is_backed_type +from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg from ._compat import _pca_compat_sparse if TYPE_CHECKING: from collections.abc import Container + from collections.abc import Set as AbstractSet from typing import LiteralString, TypeVar import dask_ml.decomposition as dmld @@ -44,10 +45,11 @@ SvdSolvTruncatedSVDDaskML = Literal["tsqr", "randomized"] SvdSolvDaskML = SvdSolvPCADaskML | SvdSolvTruncatedSVDDaskML -SvdSolvPCADenseSklearn = Literal[ - "auto", "full", "arpack", "covariance_eigh", "randomized" -] -SvdSolvPCASparseSklearn = Literal["arpack", "covariance_eigh"] +if pkg_version("scikit-learn") >= Version("1.5") or TYPE_CHECKING: + SvdSolvPCASparseSklearn = Literal["arpack", "covariance_eigh"] +else: + SvdSolvPCASparseSklearn = Literal["arpack"] +SvdSolvPCADenseSklearn = Literal["auto", "full", "randomized"] | SvdSolvPCASparseSklearn SvdSolvTruncatedSVDSklearn = Literal["arpack", "randomized"] SvdSolvSkearn = ( SvdSolvPCADenseSklearn | SvdSolvPCASparseSklearn | SvdSolvTruncatedSVDSklearn @@ -299,7 +301,9 @@ def pca( if issparse(X) and ( pkg_version("scikit-learn") < Version("1.4") or svd_solver == "lobpcg" ): - if svd_solver not in {"lobpcg", "arpack"}: + if svd_solver not in ( + {"lobpcg"} | get_literal_vals(SvdSolvPCASparseSklearn) + ): if svd_solver is not None: msg = ( f"Ignoring {svd_solver=} and using 'arpack', " @@ -467,14 +471,14 @@ def _handle_dask_ml_args( def _handle_dask_ml_args(svd_solver: str | None, method: MethodDaskML) -> SvdSolvDaskML: import dask_ml.decomposition as dmld - args: tuple[SvdSolvDaskML, ...] + args: AbstractSet[SvdSolvDaskML] default: SvdSolvDaskML match method: case dmld.PCA | dmld.IncrementalPCA: - args = get_args(SvdSolvPCADaskML) + args = get_literal_vals(SvdSolvPCADaskML) default = "auto" case dmld.TruncatedSVD: - args = get_args(SvdSolvTruncatedSVDDaskML) + args = get_literal_vals(SvdSolvTruncatedSVDDaskML) default = "tsqr" case _: msg = f"Unknown {method=} in _handle_dask_ml_args" @@ -499,18 +503,18 @@ def _handle_sklearn_args( ) -> SvdSolvSkearn: import sklearn.decomposition as skld - args: tuple[SvdSolvSkearn, ...] + args: AbstractSet[SvdSolvSkearn] default: SvdSolvSkearn suffix = "" match (method, sparse): case (skld.TruncatedSVD, None): - args = get_args(SvdSolvTruncatedSVDSklearn) + args = get_literal_vals(SvdSolvTruncatedSVDSklearn) default = "randomized" case (skld.PCA, False): - args = get_args(SvdSolvPCADenseSklearn) + args = get_literal_vals(SvdSolvPCADenseSklearn) default = "arpack" case (skld.PCA, True): - args = get_args(SvdSolvPCASparseSklearn) + args = get_literal_vals(SvdSolvPCASparseSklearn) default = "arpack" suffix = " (with sparse input)" case _: diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 4e8c91fb1f..3f0e65c061 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -2,14 +2,14 @@ import random from importlib.util import find_spec -from typing import TYPE_CHECKING, Literal, get_args +from typing import TYPE_CHECKING, Literal import numpy as np from .. import _utils from .. import logging as logg from .._compat import old_positionals -from .._utils import _choose_graph +from .._utils import _choose_graph, get_literal_vals from ._utils import get_init_pos_from_paga if TYPE_CHECKING: @@ -24,7 +24,6 @@ _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"] -_LAYOUTS = get_args(_Layout) @old_positionals( @@ -124,8 +123,8 @@ def draw_graph( `draw_graph` parameters. """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") - if layout not in _LAYOUTS: - raise ValueError(f"Provide a valid layout, one of {_LAYOUTS}.") + if layout not in (layouts := get_literal_vals(_Layout)): + raise ValueError(f"Provide a valid layout, one of {layouts}.") adata = adata.copy() if copy else adata if adjacency is None: adjacency = _choose_graph(adata, obsp, neighbors_key) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 3a737bb487..f8ab13e9fd 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -3,7 +3,7 @@ from __future__ import annotations from math import floor -from typing import TYPE_CHECKING, Literal, get_args +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd @@ -14,6 +14,7 @@ from .._compat import old_positionals from .._utils import ( check_nonnegative_integers, + get_literal_vals, raise_not_implemented_error_if_backed_type, ) from ..get import _check_mask @@ -28,7 +29,7 @@ _CorrMethod = Literal["benjamini-hochberg", "bonferroni"] -# Used with get_args +# Used with get_literal_vals _Method = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"] @@ -607,8 +608,7 @@ def rank_genes_groups( rankby_abs = not kwds.pop("only_positive") # backwards compat start = logg.info("ranking genes") - avail_methods = set(get_args(_Method)) - if method not in avail_methods: + if method not in (avail_methods := get_literal_vals(_Method)): raise ValueError(f"Method must be one of {avail_methods}.") avail_corr = {"benjamini-hochberg", "bonferroni"} diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index ce680b8df5..5bd87e231d 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import get_args - import anndata as ad import numpy as np import pandas as pd @@ -10,14 +8,14 @@ from scipy import sparse import scanpy as sc -from scanpy._utils import _resolve_axis +from scanpy._utils import _resolve_axis, get_literal_vals from scanpy.get._aggregated import AggType from testing.scanpy._helpers import assert_equal from testing.scanpy._helpers.data import pbmc3k_processed from testing.scanpy._pytest.params import ARRAY_TYPES_MEM -@pytest.fixture(params=get_args(AggType)) +@pytest.fixture(params=get_literal_vals(AggType)) def metric(request: pytest.FixtureRequest) -> AggType: return request.param diff --git a/tests/test_pca.py b/tests/test_pca.py index 6fc8eafd43..0130b6ac35 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -3,7 +3,7 @@ import warnings from contextlib import nullcontext from functools import wraps -from typing import TYPE_CHECKING, Literal, get_args +from typing import TYPE_CHECKING, Literal import anndata as ad import numpy as np @@ -17,6 +17,7 @@ import scanpy as sc from scanpy._compat import DaskArray, pkg_version +from scanpy._utils import get_literal_vals from scanpy.preprocessing._pca import SvdSolver as SvdSolverSupported from scanpy.preprocessing._pca._dask_sparse import _cov_sparse_dask from testing.scanpy import _helpers @@ -125,6 +126,10 @@ def array_type(request: pytest.FixtureRequest) -> ArrayType: SVDSolverDeprecated = Literal["lobpcg"] SVDSolver = SvdSolverSupported | SVDSolverDeprecated +SKLEARN_ADDITIONAL: frozenset[SvdSolverSupported] = frozenset( + {"covariance_eigh"} if pkg_version("scikit-learn") >= Version("1.5") else () +) + def gen_pca_params( *, @@ -140,7 +145,7 @@ def gen_pca_params( yield None, None, None return - all_svd_solvers = set(get_args(SVDSolver)) + all_svd_solvers = get_literal_vals(SVDSolver) svd_solvers: set[SVDSolver] match array_type, zero_center: case (dc, True) if dc is DASK_CONVERTERS[_helpers.as_dense_dask_array]: @@ -150,11 +155,11 @@ def gen_pca_params( case (dc, True) if dc is DASK_CONVERTERS[_helpers.as_sparse_dask_array]: svd_solvers = {"covariance_eigh"} case ((sparse.csr_matrix | sparse.csc_matrix), True): - svd_solvers = {"arpack"} + svd_solvers = {"arpack"} | SKLEARN_ADDITIONAL case ((sparse.csr_matrix | sparse.csc_matrix), False): svd_solvers = {"arpack", "randomized"} case (helpers.asarray, True): - svd_solvers = {"auto", "full", "arpack", "randomized"} + svd_solvers = {"auto", "full", "arpack", "randomized"} | SKLEARN_ADDITIONAL case (helpers.asarray, False): svd_solvers = {"arpack", "randomized"} case _: @@ -168,7 +173,8 @@ def gen_pca_params( else: pytest.fail(f"Unknown {svd_solver_type=}") - for svd_solver in svd_solvers: + # sorted to prevent https://github.com/pytest-dev/pytest-xdist/issues/432 + for svd_solver in sorted(svd_solvers): # explicit check for special case if ( array_type in {sparse.csr_matrix, sparse.csc_matrix}