Skip to content

Commit

Permalink
(fix): sort pca test args (#3333)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp A. <[email protected]>
  • Loading branch information
ilan-gold and flying-sheep authored Nov 5, 2024
1 parent 6440515 commit 0d04447
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 55 deletions.
26 changes: 20 additions & 6 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from collections import namedtuple
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
from functools import partial, reduce, singledispatch, wraps
from operator import mul, or_, truediv
from textwrap import dedent
from types import MethodType, ModuleType
from typing import TYPE_CHECKING, overload
from types import MethodType, ModuleType, UnionType
from typing import TYPE_CHECKING, Literal, Union, get_args, get_origin, overload
from weakref import WeakSet

import h5py
Expand All @@ -42,9 +42,9 @@
from anndata._core.sparse_dataset import SparseDataset

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Mapping
from collections.abc import Callable, Iterable, KeysView, Mapping
from pathlib import Path
from typing import Any, Literal, TypeVar
from typing import Any, TypeVar

from anndata import AnnData
from numpy.typing import DTypeLike, NDArray
Expand All @@ -55,6 +55,7 @@
# e.g. https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
# maybe in the future random.Generator
AnyRandom = int | np.random.RandomState | None
LegacyUnionType = type(Union[int, str]) # noqa: UP007


class Empty(Enum):
Expand Down Expand Up @@ -532,6 +533,19 @@ def update_params(
return updated_params


# `get_args` returns `tuple[Any]` so I don’t think it’s possible to get the correct type here
def get_literal_vals(typ: UnionType | Any) -> KeysView[Any]:
"""Get all literal values from a Literal or Union of … of Literal type."""
if isinstance(typ, UnionType | LegacyUnionType):
return reduce(
or_, (dict.fromkeys(get_literal_vals(t)) for t in get_args(typ))
).keys()
if get_origin(typ) is Literal:
return dict.fromkeys(get_args(typ)).keys()
msg = f"{typ} is not a valid Literal"
raise TypeError(msg)


# --------------------------------------------------------------------------------
# Others
# --------------------------------------------------------------------------------
Expand Down
10 changes: 5 additions & 5 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING, Literal, get_args
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd
from anndata import AnnData, utils
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from .._utils import _resolve_axis
from .._utils import _resolve_axis, get_literal_vals
from .get import _check_mask

if TYPE_CHECKING:
Expand All @@ -19,7 +19,7 @@

Array = np.ndarray | sparse.csc_matrix | sparse.csr_matrix

# Used with get_args
# Used with get_literal_vals
AggType = Literal["count_nonzero", "mean", "sum", "var", "median"]


Expand Down Expand Up @@ -347,8 +347,8 @@ def aggregate_array(
result = {}

funcs = set([func] if isinstance(func, str) else func)
if unknown := funcs - set(get_args(AggType)):
raise ValueError(f"func {unknown} is not one of {get_args(AggType)}")
if unknown := funcs - get_literal_vals(AggType):
raise ValueError(f"func {unknown} is not one of {get_literal_vals(AggType)}")

if "sum" in funcs: # sum is calculated separately from the rest
agg = groupby.sum()
Expand Down
10 changes: 6 additions & 4 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Mapping
from textwrap import indent
from types import MappingProxyType
from typing import TYPE_CHECKING, NamedTuple, TypedDict, get_args
from typing import TYPE_CHECKING, NamedTuple, TypedDict
from warnings import warn

import numpy as np
Expand All @@ -16,7 +16,7 @@
from .. import logging as logg
from .._compat import old_positionals
from .._settings import settings
from .._utils import NeighborsView, _doc_params
from .._utils import NeighborsView, _doc_params, get_literal_vals
from . import _connectivity
from ._common import (
_get_indices_distances_from_sparse_matrix,
Expand Down Expand Up @@ -652,7 +652,9 @@ def _handle_transformer(
raise ValueError(msg)
method = "umap"
transformer = "rapids"
elif method not in (methods := set(get_args(_Method))) and method is not None:
elif (
method not in (methods := get_literal_vals(_Method)) and method is not None
):
msg = f"`method` needs to be one of {methods}."
raise ValueError(msg)

Expand Down Expand Up @@ -704,7 +706,7 @@ def _handle_transformer(
elif isinstance(transformer, str):
msg = (
f"Unknown transformer: {transformer}. "
f"Try passing a class or one of {set(get_args(_KnownTransformer))}"
f"Try passing a class or one of {get_literal_vals(_KnownTransformer)}"
)
raise ValueError(msg)
# else `transformer` is probably an instance
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/neighbors/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from scipy.sparse import spmatrix


# These two are used with get_args elsewhere
# These two are used with get_literal_vals elsewhere
_Method = Literal["umap", "gauss"]
_KnownTransformer = Literal["pynndescent", "sklearn", "rapids"]

Expand Down
17 changes: 10 additions & 7 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Collection, Mapping, Sequence
from itertools import product
from types import NoneType
from typing import TYPE_CHECKING, cast, get_args
from typing import TYPE_CHECKING, cast

import matplotlib as mpl
import numpy as np
Expand All @@ -22,7 +22,13 @@
from .. import logging as logg
from .._compat import old_positionals
from .._settings import settings
from .._utils import _check_use_raw, _doc_params, _empty, sanitize_anndata
from .._utils import (
_check_use_raw,
_doc_params,
_empty,
get_literal_vals,
sanitize_anndata,
)
from . import _utils
from ._docs import (
doc_common_plot_args,
Expand Down Expand Up @@ -65,9 +71,6 @@
_VarNames = str | Sequence[str]


VALID_LEGENDLOCS = frozenset(get_args(_utils._LegendLoc))


@old_positionals(
"color",
"use_raw",
Expand Down Expand Up @@ -268,9 +271,9 @@ def _scatter_obs(
if use_raw and layers not in [("X", "X", "X"), (None, None, None)]:
ValueError("`use_raw` must be `False` if layers are used.")

if legend_loc not in VALID_LEGENDLOCS:
if legend_loc not in (valid_legend_locs := get_literal_vals(_utils._LegendLoc)):
raise ValueError(
f"Invalid `legend_loc`, need to be one of: {VALID_LEGENDLOCS}."
f"Invalid `legend_loc`, need to be one of: {valid_legend_locs}."
)
if components is None:
components = "1,2" if "2d" in projection else "1,2,3"
Expand Down
32 changes: 18 additions & 14 deletions src/scanpy/preprocessing/_pca/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal, get_args, overload
from typing import TYPE_CHECKING, Literal, overload
from warnings import warn

import anndata as ad
Expand All @@ -14,13 +14,14 @@
from ... import logging as logg
from ..._compat import DaskArray, pkg_version
from ..._settings import settings
from ..._utils import _doc_params, _empty, is_backed_type
from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type
from ...get import _check_mask, _get_obs_rep
from .._docs import doc_mask_var_hvg
from ._compat import _pca_compat_sparse

if TYPE_CHECKING:
from collections.abc import Container
from collections.abc import Set as AbstractSet
from typing import LiteralString, TypeVar

import dask_ml.decomposition as dmld
Expand All @@ -44,10 +45,11 @@
SvdSolvTruncatedSVDDaskML = Literal["tsqr", "randomized"]
SvdSolvDaskML = SvdSolvPCADaskML | SvdSolvTruncatedSVDDaskML

SvdSolvPCADenseSklearn = Literal[
"auto", "full", "arpack", "covariance_eigh", "randomized"
]
SvdSolvPCASparseSklearn = Literal["arpack", "covariance_eigh"]
if pkg_version("scikit-learn") >= Version("1.5") or TYPE_CHECKING:
SvdSolvPCASparseSklearn = Literal["arpack", "covariance_eigh"]
else:
SvdSolvPCASparseSklearn = Literal["arpack"]
SvdSolvPCADenseSklearn = Literal["auto", "full", "randomized"] | SvdSolvPCASparseSklearn
SvdSolvTruncatedSVDSklearn = Literal["arpack", "randomized"]
SvdSolvSkearn = (
SvdSolvPCADenseSklearn | SvdSolvPCASparseSklearn | SvdSolvTruncatedSVDSklearn
Expand Down Expand Up @@ -299,7 +301,9 @@ def pca(
if issparse(X) and (
pkg_version("scikit-learn") < Version("1.4") or svd_solver == "lobpcg"
):
if svd_solver not in {"lobpcg", "arpack"}:
if svd_solver not in (
{"lobpcg"} | get_literal_vals(SvdSolvPCASparseSklearn)
):
if svd_solver is not None:
msg = (
f"Ignoring {svd_solver=} and using 'arpack', "
Expand Down Expand Up @@ -467,14 +471,14 @@ def _handle_dask_ml_args(
def _handle_dask_ml_args(svd_solver: str | None, method: MethodDaskML) -> SvdSolvDaskML:
import dask_ml.decomposition as dmld

args: tuple[SvdSolvDaskML, ...]
args: AbstractSet[SvdSolvDaskML]
default: SvdSolvDaskML
match method:
case dmld.PCA | dmld.IncrementalPCA:
args = get_args(SvdSolvPCADaskML)
args = get_literal_vals(SvdSolvPCADaskML)
default = "auto"
case dmld.TruncatedSVD:
args = get_args(SvdSolvTruncatedSVDDaskML)
args = get_literal_vals(SvdSolvTruncatedSVDDaskML)
default = "tsqr"
case _:
msg = f"Unknown {method=} in _handle_dask_ml_args"
Expand All @@ -499,18 +503,18 @@ def _handle_sklearn_args(
) -> SvdSolvSkearn:
import sklearn.decomposition as skld

args: tuple[SvdSolvSkearn, ...]
args: AbstractSet[SvdSolvSkearn]
default: SvdSolvSkearn
suffix = ""
match (method, sparse):
case (skld.TruncatedSVD, None):
args = get_args(SvdSolvTruncatedSVDSklearn)
args = get_literal_vals(SvdSolvTruncatedSVDSklearn)
default = "randomized"
case (skld.PCA, False):
args = get_args(SvdSolvPCADenseSklearn)
args = get_literal_vals(SvdSolvPCADenseSklearn)
default = "arpack"
case (skld.PCA, True):
args = get_args(SvdSolvPCASparseSklearn)
args = get_literal_vals(SvdSolvPCASparseSklearn)
default = "arpack"
suffix = " (with sparse input)"
case _:
Expand Down
9 changes: 4 additions & 5 deletions src/scanpy/tools/_draw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import random
from importlib.util import find_spec
from typing import TYPE_CHECKING, Literal, get_args
from typing import TYPE_CHECKING, Literal

import numpy as np

from .. import _utils
from .. import logging as logg
from .._compat import old_positionals
from .._utils import _choose_graph
from .._utils import _choose_graph, get_literal_vals
from ._utils import get_init_pos_from_paga

if TYPE_CHECKING:
Expand All @@ -24,7 +24,6 @@


_Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"]
_LAYOUTS = get_args(_Layout)


@old_positionals(
Expand Down Expand Up @@ -124,8 +123,8 @@ def draw_graph(
`draw_graph` parameters.
"""
start = logg.info(f"drawing single-cell graph using layout {layout!r}")
if layout not in _LAYOUTS:
raise ValueError(f"Provide a valid layout, one of {_LAYOUTS}.")
if layout not in (layouts := get_literal_vals(_Layout)):
raise ValueError(f"Provide a valid layout, one of {layouts}.")
adata = adata.copy() if copy else adata
if adjacency is None:
adjacency = _choose_graph(adata, obsp, neighbors_key)
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/tools/_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from math import floor
from typing import TYPE_CHECKING, Literal, get_args
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd
Expand All @@ -14,6 +14,7 @@
from .._compat import old_positionals
from .._utils import (
check_nonnegative_integers,
get_literal_vals,
raise_not_implemented_error_if_backed_type,
)
from ..get import _check_mask
Expand All @@ -28,7 +29,7 @@

_CorrMethod = Literal["benjamini-hochberg", "bonferroni"]

# Used with get_args
# Used with get_literal_vals
_Method = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"]


Expand Down Expand Up @@ -607,8 +608,7 @@ def rank_genes_groups(
rankby_abs = not kwds.pop("only_positive") # backwards compat

start = logg.info("ranking genes")
avail_methods = set(get_args(_Method))
if method not in avail_methods:
if method not in (avail_methods := get_literal_vals(_Method)):
raise ValueError(f"Method must be one of {avail_methods}.")

avail_corr = {"benjamini-hochberg", "bonferroni"}
Expand Down
6 changes: 2 additions & 4 deletions tests/test_aggregated.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import get_args

import anndata as ad
import numpy as np
import pandas as pd
Expand All @@ -10,14 +8,14 @@
from scipy import sparse

import scanpy as sc
from scanpy._utils import _resolve_axis
from scanpy._utils import _resolve_axis, get_literal_vals
from scanpy.get._aggregated import AggType
from testing.scanpy._helpers import assert_equal
from testing.scanpy._helpers.data import pbmc3k_processed
from testing.scanpy._pytest.params import ARRAY_TYPES_MEM


@pytest.fixture(params=get_args(AggType))
@pytest.fixture(params=get_literal_vals(AggType))
def metric(request: pytest.FixtureRequest) -> AggType:
return request.param

Expand Down
Loading

0 comments on commit 0d04447

Please sign in to comment.