diff --git a/benchmarks/benchmarks/_utils.py b/benchmarks/benchmarks/_utils.py index 810ace74fd..93bb4623f9 100644 --- a/benchmarks/benchmarks/_utils.py +++ b/benchmarks/benchmarks/_utils.py @@ -14,7 +14,8 @@ import scanpy as sc if TYPE_CHECKING: - from collections.abc import Callable, Sequence, Set + from collections.abc import Callable, Sequence + from collections.abc import Set as AbstractSet from typing import Literal, Protocol, TypeVar from anndata import AnnData @@ -22,7 +23,7 @@ C = TypeVar("C", bound=Callable) class ParamSkipper(Protocol): - def __call__(self, **skipped: Set) -> Callable[[C], C]: ... + def __call__(self, **skipped: AbstractSet) -> Callable[[C], C]: ... Dataset = Literal["pbmc68k_reduced", "pbmc3k", "bmmc", "lung93k"] KeyX = Literal[None, "off-axis"] @@ -195,7 +196,7 @@ def param_skipper( b 5 """ - def skip(**skipped: Set) -> Callable[[C], C]: + def skip(**skipped: AbstractSet) -> Callable[[C], C]: skipped_combs = [ tuple(record.values()) for record in ( diff --git a/pyproject.toml b/pyproject.toml index dda000d790..e983d04a97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -232,6 +232,7 @@ select = [ "TID251", # Banned imports "ICN", # Follow import conventions "PTH", # Pathlib instead of os.path + "PYI", # Typing "PLR0917", # Ban APIs with too many positional parameters "FBT", # No positional boolean parameters "PT", # Pytest style @@ -246,6 +247,8 @@ ignore = [ "E262", # allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation "E741", + # `Literal["..."] | str` is useful for autocompletion + "PYI051", ] [tool.ruff.lint.per-file-ignores] # Do not assign a lambda expression, use a def diff --git a/src/scanpy/_settings.py b/src/scanpy/_settings.py index fa44fc8492..54b51b6420 100644 --- a/src/scanpy/_settings.py +++ b/src/scanpy/_settings.py @@ -19,7 +19,7 @@ # Collected from the print_* functions in matplotlib.backends _Format = ( - Literal["png", "jpg", "tif", "tiff"] + Literal["png", "jpg", "tif", "tiff"] # noqa: PYI030 | Literal["pdf", "ps", "eps", "svg", "svgz", "pgf"] | Literal["raw", "rgba"] ) @@ -340,7 +340,7 @@ def max_memory(self) -> int | float: return self._max_memory @max_memory.setter - def max_memory(self, max_memory: int | float): + def max_memory(self, max_memory: float): _type_check(max_memory, "max_memory", (int, float)) self._max_memory = max_memory diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 8e886d1ff1..066e23f667 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -12,14 +12,21 @@ import re import sys import warnings -from collections import namedtuple from contextlib import contextmanager, suppress from enum import Enum from functools import partial, reduce, singledispatch, wraps from operator import mul, or_, truediv from textwrap import dedent from types import MethodType, ModuleType, UnionType -from typing import TYPE_CHECKING, Literal, Union, get_args, get_origin, overload +from typing import ( + TYPE_CHECKING, + Literal, + NamedTuple, + Union, + get_args, + get_origin, + overload, +) from weakref import WeakSet import h5py @@ -297,6 +304,11 @@ def get_igraph_from_adjacency(adjacency, directed=None): # -------------------------------------------------------------------------------- +class AssoResult(NamedTuple): + asso_names: list[str] + asso_matrix: NDArray[np.floating] + + def compute_association_matrix_of_groups( adata: AnnData, prediction: str, @@ -305,7 +317,7 @@ def compute_association_matrix_of_groups( normalization: Literal["prediction", "reference"] = "prediction", threshold: float = 0.01, max_n_names: int | None = 2, -): +) -> AssoResult: """Compute overlaps between groups. See ``identify_groups`` for identifying the groups. @@ -347,8 +359,8 @@ def compute_association_matrix_of_groups( f"Ignoring category {cat!r} " "as it’s in `settings.categories_to_ignore`." ) - asso_names = [] - asso_matrix = [] + asso_names: list[str] = [] + asso_matrix: list[list[float]] = [] for ipred_group, pred_group in enumerate(adata.obs[prediction].cat.categories): if "?" in pred_group: pred_group = str(ipred_group) @@ -381,13 +393,12 @@ def compute_association_matrix_of_groups( if asso_matrix[-1][i] > threshold ] asso_names += ["\n".join(name_list_pred[:max_n_names])] - Result = namedtuple( - "compute_association_matrix_of_groups", ["asso_names", "asso_matrix"] - ) - return Result(asso_names=asso_names, asso_matrix=np.array(asso_matrix)) + return AssoResult(asso_names=asso_names, asso_matrix=np.array(asso_matrix)) -def get_associated_colors_of_groups(reference_colors, asso_matrix): +def get_associated_colors_of_groups( + reference_colors: Mapping[int, str], asso_matrix: NDArray[np.floating] +) -> list[dict[str, float]]: return [ { reference_colors[i_ref]: asso_matrix[i_pred, i_ref] diff --git a/src/scanpy/cli.py b/src/scanpy/cli.py index 04b75c8b74..c934292dba 100644 --- a/src/scanpy/cli.py +++ b/src/scanpy/cli.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Generator, Mapping, Sequence + from collections.abc import Iterator, Mapping, Sequence from subprocess import CompletedProcess from typing import Any @@ -64,7 +64,7 @@ def __delitem__(self, k: str) -> None: # These methods retrieve the command list or help with doing it - def __iter__(self) -> Generator[str, None, None]: + def __iter__(self) -> Iterator[str]: yield from self.parser_map yield from self.commands diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 2d2739491e..13ca54b5c4 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -160,7 +160,7 @@ def median(self) -> Array: return np.array(medians) -def _power(X: Array, power: float | int) -> Array: +def _power(X: Array, power: float) -> Array: """\ Generate elementwise power of a matrix. diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index 0ae810b2c7..a93d55699b 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -104,7 +104,7 @@ def scatter( components: str | Collection[str] | None = None, projection: Literal["2d", "3d"] = "2d", legend_loc: _LegendLoc | None = "right margin", - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight | None = None, legend_fontoutline: float | None = None, color_map: str | Colormap | None = None, @@ -112,7 +112,7 @@ def scatter( frameon: bool | None = None, right_margin: float | None = None, left_margin: float | None = None, - size: int | float | None = None, + size: float | None = None, marker: str | Sequence[str] = ".", title: str | Collection[str] | None = None, show: bool | None = None, @@ -232,7 +232,7 @@ def _scatter_obs( components: str | Collection[str] | None = None, projection: Literal["2d", "3d"] = "2d", legend_loc: _LegendLoc | None = "right margin", - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight | None = None, legend_fontoutline: float | None = None, color_map: str | Colormap | None = None, @@ -240,7 +240,7 @@ def _scatter_obs( frameon: bool | None = None, right_margin: float | None = None, left_margin: float | None = None, - size: int | float | None = None, + size: float | None = None, marker: str | Sequence[str] = ".", title: str | Collection[str] | None = None, show: bool | None = None, diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 691dd863d0..e47680facc 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -273,7 +273,7 @@ def style( cmap: Colormap | str | None | Empty = _empty, stripplot: bool | Empty = _empty, jitter: float | bool | Empty = _empty, - jitter_size: int | float | Empty = _empty, + jitter_size: float | Empty = _empty, linewidth: float | None | Empty = _empty, row_palette: str | None | Empty = _empty, density_norm: DensityNorm | Empty = _empty, @@ -470,8 +470,8 @@ def _make_rows_of_violinplots( _matrix, colormap_array, _color_df, - x_spacer_size: float | int, - y_spacer_size: float | int, + x_spacer_size: float, + y_spacer_size: float, x_axis_order, ): import seaborn as sns # Slow import, only import if called @@ -699,7 +699,7 @@ def stacked_violin( cmap: Colormap | str | None = StackedViolin.DEFAULT_COLORMAP, stripplot: bool = StackedViolin.DEFAULT_STRIPPLOT, jitter: float | bool = StackedViolin.DEFAULT_JITTER, - size: int | float = StackedViolin.DEFAULT_JITTER_SIZE, + size: float = StackedViolin.DEFAULT_JITTER_SIZE, row_palette: str | None = StackedViolin.DEFAULT_ROW_PALETTE, density_norm: DensityNorm | Empty = _empty, yticklabels: bool = StackedViolin.DEFAULT_PLOT_YTICKLABELS, diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index 837d3791e8..a421f6b94a 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -1209,7 +1209,7 @@ def rank_genes_groups_violin( split: bool = True, density_norm: DensityNorm = "width", strip: bool = True, - jitter: int | float | bool = True, + jitter: float | bool = True, size: int = 1, ax: Axes | None = None, show: bool | None = None, @@ -1428,7 +1428,7 @@ def embedding_density( *, key: str | None = None, groupby: str | None = None, - group: str | Sequence[str] | None | None = "all", + group: str | Sequence[str] | None = "all", color_map: Colormap | str = "YlOrRd", bg_dotsize: int | None = 80, fg_dotsize: int | None = 180, diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index 29408735b6..7e62d46eac 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -73,7 +73,7 @@ def paga_compare( components=None, projection: Literal["2d", "3d"] = "2d", legend_loc: _LegendLoc | None = "on data", - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight = "bold", legend_fontoutline=None, color_map=None, @@ -1053,7 +1053,7 @@ def paga_path( show_node_names: bool = True, show_yticks: bool = True, show_colorbar: bool = True, - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight | None = None, normalize_to_zero_one: bool = False, as_heatmap: bool = True, diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 7f69a76025..4ce39f7211 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -94,7 +94,7 @@ def embedding( na_in_legend: bool = True, size: float | Sequence[float] | None = None, frameon: bool | None = None, - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight = "bold", legend_loc: _LegendLoc | None = "right margin", legend_fontoutline: int | None = None, diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index be452c356d..a7a16bbcc4 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -148,7 +148,6 @@ def scale_array( | tuple[ np.ndarray | DaskArray, NDArray[np.float64] | DaskArray, NDArray[np.float64] ] - | DaskArray ): if copy: X = X.copy() diff --git a/src/scanpy/preprocessing/_scrublet/sparse_utils.py b/src/scanpy/preprocessing/_scrublet/sparse_utils.py index b4ff1a36b0..cc0b1bc815 100644 --- a/src/scanpy/preprocessing/_scrublet/sparse_utils.py +++ b/src/scanpy/preprocessing/_scrublet/sparse_utils.py @@ -17,7 +17,7 @@ def sparse_multiply( E: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.float64], - a: float | int | NDArray[np.float64], + a: float | NDArray[np.float64], ) -> sparse.csr_matrix | sparse.csc_matrix: """multiply each row of E by a scalar""" diff --git a/src/scanpy/tools/_marker_gene_overlap.py b/src/scanpy/tools/_marker_gene_overlap.py index 83a19c86a4..eb07b84885 100644 --- a/src/scanpy/tools/_marker_gene_overlap.py +++ b/src/scanpy/tools/_marker_gene_overlap.py @@ -4,7 +4,7 @@ from __future__ import annotations -from collections.abc import Set +from collections.abc import Set as AbstractSet from typing import TYPE_CHECKING import numpy as np @@ -187,7 +187,7 @@ def marker_gene_overlap( if normalize is not None and method != "overlap_count": raise ValueError("Can only normalize with method=`overlap_count`.") - if not all(isinstance(val, Set) for val in reference_markers.values()): + if not all(isinstance(val, AbstractSet) for val in reference_markers.values()): try: reference_markers = { key: set(val) for key, val in reference_markers.items() diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index f8ab13e9fd..9a2896196a 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -749,7 +749,7 @@ def filter_rank_genes_groups( use_raw: bool | None = None, key_added: str = "rank_genes_groups_filtered", min_in_group_fraction: float = 0.25, - min_fold_change: int | float = 1, + min_fold_change: float = 1, max_out_group_fraction: float = 0.5, compare_abs: bool = False, ) -> None: diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 23d490218b..ac0e6a6317 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -34,10 +34,10 @@ def tsne( n_pcs: int | None = None, *, use_rep: str | None = None, - perplexity: float | int = 30, + perplexity: float = 30, metric: str = "euclidean", - early_exaggeration: float | int = 12, - learning_rate: float | int = 1000, + early_exaggeration: float = 12, + learning_rate: float = 1000, random_state: AnyRandom = 0, use_fast_tsne: bool = False, n_jobs: int | None = None,