From 933c391ac4fd698749ab71489e224ea6e7370d3a Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 5 Nov 2024 10:52:23 +0100 Subject: [PATCH 1/2] Add PYI lints --- benchmarks/benchmarks/_utils.py | 7 +++--- pyproject.toml | 3 +++ src/scanpy/_settings.py | 4 ++-- src/scanpy/_utils/__init__.py | 23 +++++++++++-------- src/scanpy/cli.py | 4 ++-- src/scanpy/get/_aggregated.py | 2 +- src/scanpy/plotting/_anndata.py | 8 +++---- src/scanpy/plotting/_stacked_violin.py | 8 +++---- src/scanpy/plotting/_tools/__init__.py | 4 ++-- src/scanpy/plotting/_tools/paga.py | 4 ++-- src/scanpy/plotting/_tools/scatterplots.py | 2 +- src/scanpy/preprocessing/_scale.py | 1 - .../preprocessing/_scrublet/sparse_utils.py | 2 +- src/scanpy/tools/_marker_gene_overlap.py | 4 ++-- src/scanpy/tools/_rank_genes_groups.py | 2 +- src/scanpy/tools/_tsne.py | 6 ++--- 16 files changed, 45 insertions(+), 39 deletions(-) diff --git a/benchmarks/benchmarks/_utils.py b/benchmarks/benchmarks/_utils.py index 810ace74fd..93bb4623f9 100644 --- a/benchmarks/benchmarks/_utils.py +++ b/benchmarks/benchmarks/_utils.py @@ -14,7 +14,8 @@ import scanpy as sc if TYPE_CHECKING: - from collections.abc import Callable, Sequence, Set + from collections.abc import Callable, Sequence + from collections.abc import Set as AbstractSet from typing import Literal, Protocol, TypeVar from anndata import AnnData @@ -22,7 +23,7 @@ C = TypeVar("C", bound=Callable) class ParamSkipper(Protocol): - def __call__(self, **skipped: Set) -> Callable[[C], C]: ... + def __call__(self, **skipped: AbstractSet) -> Callable[[C], C]: ... Dataset = Literal["pbmc68k_reduced", "pbmc3k", "bmmc", "lung93k"] KeyX = Literal[None, "off-axis"] @@ -195,7 +196,7 @@ def param_skipper( b 5 """ - def skip(**skipped: Set) -> Callable[[C], C]: + def skip(**skipped: AbstractSet) -> Callable[[C], C]: skipped_combs = [ tuple(record.values()) for record in ( diff --git a/pyproject.toml b/pyproject.toml index dda000d790..e983d04a97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -232,6 +232,7 @@ select = [ "TID251", # Banned imports "ICN", # Follow import conventions "PTH", # Pathlib instead of os.path + "PYI", # Typing "PLR0917", # Ban APIs with too many positional parameters "FBT", # No positional boolean parameters "PT", # Pytest style @@ -246,6 +247,8 @@ ignore = [ "E262", # allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation "E741", + # `Literal["..."] | str` is useful for autocompletion + "PYI051", ] [tool.ruff.lint.per-file-ignores] # Do not assign a lambda expression, use a def diff --git a/src/scanpy/_settings.py b/src/scanpy/_settings.py index fa44fc8492..54b51b6420 100644 --- a/src/scanpy/_settings.py +++ b/src/scanpy/_settings.py @@ -19,7 +19,7 @@ # Collected from the print_* functions in matplotlib.backends _Format = ( - Literal["png", "jpg", "tif", "tiff"] + Literal["png", "jpg", "tif", "tiff"] # noqa: PYI030 | Literal["pdf", "ps", "eps", "svg", "svgz", "pgf"] | Literal["raw", "rgba"] ) @@ -340,7 +340,7 @@ def max_memory(self) -> int | float: return self._max_memory @max_memory.setter - def max_memory(self, max_memory: int | float): + def max_memory(self, max_memory: float): _type_check(max_memory, "max_memory", (int, float)) self._max_memory = max_memory diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 5a8b0288b8..2b1b3efcc5 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -12,14 +12,13 @@ import re import sys import warnings -from collections import namedtuple from contextlib import contextmanager, suppress from enum import Enum from functools import partial, singledispatch, wraps from operator import mul, truediv from textwrap import dedent from types import MethodType, ModuleType -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Literal, NamedTuple, overload from weakref import WeakSet import h5py @@ -296,6 +295,11 @@ def get_igraph_from_adjacency(adjacency, directed=None): # -------------------------------------------------------------------------------- +class AssoResult(NamedTuple): + asso_names: list[str] + asso_matrix: NDArray[np.floating] + + def compute_association_matrix_of_groups( adata: AnnData, prediction: str, @@ -304,7 +308,7 @@ def compute_association_matrix_of_groups( normalization: Literal["prediction", "reference"] = "prediction", threshold: float = 0.01, max_n_names: int | None = 2, -): +) -> AssoResult: """Compute overlaps between groups. See ``identify_groups`` for identifying the groups. @@ -346,8 +350,8 @@ def compute_association_matrix_of_groups( f"Ignoring category {cat!r} " "as it’s in `settings.categories_to_ignore`." ) - asso_names = [] - asso_matrix = [] + asso_names: list[str] = [] + asso_matrix: list[list[float]] = [] for ipred_group, pred_group in enumerate(adata.obs[prediction].cat.categories): if "?" in pred_group: pred_group = str(ipred_group) @@ -380,13 +384,12 @@ def compute_association_matrix_of_groups( if asso_matrix[-1][i] > threshold ] asso_names += ["\n".join(name_list_pred[:max_n_names])] - Result = namedtuple( - "compute_association_matrix_of_groups", ["asso_names", "asso_matrix"] - ) - return Result(asso_names=asso_names, asso_matrix=np.array(asso_matrix)) + return AssoResult(asso_names=asso_names, asso_matrix=np.array(asso_matrix)) -def get_associated_colors_of_groups(reference_colors, asso_matrix): +def get_associated_colors_of_groups( + reference_colors: Mapping[int, str], asso_matrix: NDArray[np.floating] +) -> list[dict[str, float]]: return [ { reference_colors[i_ref]: asso_matrix[i_pred, i_ref] diff --git a/src/scanpy/cli.py b/src/scanpy/cli.py index 04b75c8b74..c934292dba 100644 --- a/src/scanpy/cli.py +++ b/src/scanpy/cli.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Generator, Mapping, Sequence + from collections.abc import Iterator, Mapping, Sequence from subprocess import CompletedProcess from typing import Any @@ -64,7 +64,7 @@ def __delitem__(self, k: str) -> None: # These methods retrieve the command list or help with doing it - def __iter__(self) -> Generator[str, None, None]: + def __iter__(self) -> Iterator[str]: yield from self.parser_map yield from self.commands diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index e95fedf9dc..dc1d5343f8 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -160,7 +160,7 @@ def median(self) -> Array: return np.array(medians) -def _power(X: Array, power: float | int) -> Array: +def _power(X: Array, power: float) -> Array: """\ Generate elementwise power of a matrix. diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index c1918878c8..b26715040a 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -101,7 +101,7 @@ def scatter( components: str | Collection[str] | None = None, projection: Literal["2d", "3d"] = "2d", legend_loc: _LegendLoc | None = "right margin", - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight | None = None, legend_fontoutline: float | None = None, color_map: str | Colormap | None = None, @@ -109,7 +109,7 @@ def scatter( frameon: bool | None = None, right_margin: float | None = None, left_margin: float | None = None, - size: int | float | None = None, + size: float | None = None, marker: str | Sequence[str] = ".", title: str | Collection[str] | None = None, show: bool | None = None, @@ -229,7 +229,7 @@ def _scatter_obs( components: str | Collection[str] | None = None, projection: Literal["2d", "3d"] = "2d", legend_loc: _LegendLoc | None = "right margin", - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight | None = None, legend_fontoutline: float | None = None, color_map: str | Colormap | None = None, @@ -237,7 +237,7 @@ def _scatter_obs( frameon: bool | None = None, right_margin: float | None = None, left_margin: float | None = None, - size: int | float | None = None, + size: float | None = None, marker: str | Sequence[str] = ".", title: str | Collection[str] | None = None, show: bool | None = None, diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 691dd863d0..e47680facc 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -273,7 +273,7 @@ def style( cmap: Colormap | str | None | Empty = _empty, stripplot: bool | Empty = _empty, jitter: float | bool | Empty = _empty, - jitter_size: int | float | Empty = _empty, + jitter_size: float | Empty = _empty, linewidth: float | None | Empty = _empty, row_palette: str | None | Empty = _empty, density_norm: DensityNorm | Empty = _empty, @@ -470,8 +470,8 @@ def _make_rows_of_violinplots( _matrix, colormap_array, _color_df, - x_spacer_size: float | int, - y_spacer_size: float | int, + x_spacer_size: float, + y_spacer_size: float, x_axis_order, ): import seaborn as sns # Slow import, only import if called @@ -699,7 +699,7 @@ def stacked_violin( cmap: Colormap | str | None = StackedViolin.DEFAULT_COLORMAP, stripplot: bool = StackedViolin.DEFAULT_STRIPPLOT, jitter: float | bool = StackedViolin.DEFAULT_JITTER, - size: int | float = StackedViolin.DEFAULT_JITTER_SIZE, + size: float = StackedViolin.DEFAULT_JITTER_SIZE, row_palette: str | None = StackedViolin.DEFAULT_ROW_PALETTE, density_norm: DensityNorm | Empty = _empty, yticklabels: bool = StackedViolin.DEFAULT_PLOT_YTICKLABELS, diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index 837d3791e8..a421f6b94a 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -1209,7 +1209,7 @@ def rank_genes_groups_violin( split: bool = True, density_norm: DensityNorm = "width", strip: bool = True, - jitter: int | float | bool = True, + jitter: float | bool = True, size: int = 1, ax: Axes | None = None, show: bool | None = None, @@ -1428,7 +1428,7 @@ def embedding_density( *, key: str | None = None, groupby: str | None = None, - group: str | Sequence[str] | None | None = "all", + group: str | Sequence[str] | None = "all", color_map: Colormap | str = "YlOrRd", bg_dotsize: int | None = 80, fg_dotsize: int | None = 180, diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index 29408735b6..7e62d46eac 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -73,7 +73,7 @@ def paga_compare( components=None, projection: Literal["2d", "3d"] = "2d", legend_loc: _LegendLoc | None = "on data", - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight = "bold", legend_fontoutline=None, color_map=None, @@ -1053,7 +1053,7 @@ def paga_path( show_node_names: bool = True, show_yticks: bool = True, show_colorbar: bool = True, - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight | None = None, normalize_to_zero_one: bool = False, as_heatmap: bool = True, diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 7f69a76025..4ce39f7211 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -94,7 +94,7 @@ def embedding( na_in_legend: bool = True, size: float | Sequence[float] | None = None, frameon: bool | None = None, - legend_fontsize: int | float | _FontSize | None = None, + legend_fontsize: float | _FontSize | None = None, legend_fontweight: int | _FontWeight = "bold", legend_loc: _LegendLoc | None = "right margin", legend_fontoutline: int | None = None, diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index be452c356d..a7a16bbcc4 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -148,7 +148,6 @@ def scale_array( | tuple[ np.ndarray | DaskArray, NDArray[np.float64] | DaskArray, NDArray[np.float64] ] - | DaskArray ): if copy: X = X.copy() diff --git a/src/scanpy/preprocessing/_scrublet/sparse_utils.py b/src/scanpy/preprocessing/_scrublet/sparse_utils.py index b4ff1a36b0..cc0b1bc815 100644 --- a/src/scanpy/preprocessing/_scrublet/sparse_utils.py +++ b/src/scanpy/preprocessing/_scrublet/sparse_utils.py @@ -17,7 +17,7 @@ def sparse_multiply( E: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.float64], - a: float | int | NDArray[np.float64], + a: float | NDArray[np.float64], ) -> sparse.csr_matrix | sparse.csc_matrix: """multiply each row of E by a scalar""" diff --git a/src/scanpy/tools/_marker_gene_overlap.py b/src/scanpy/tools/_marker_gene_overlap.py index 83a19c86a4..eb07b84885 100644 --- a/src/scanpy/tools/_marker_gene_overlap.py +++ b/src/scanpy/tools/_marker_gene_overlap.py @@ -4,7 +4,7 @@ from __future__ import annotations -from collections.abc import Set +from collections.abc import Set as AbstractSet from typing import TYPE_CHECKING import numpy as np @@ -187,7 +187,7 @@ def marker_gene_overlap( if normalize is not None and method != "overlap_count": raise ValueError("Can only normalize with method=`overlap_count`.") - if not all(isinstance(val, Set) for val in reference_markers.values()): + if not all(isinstance(val, AbstractSet) for val in reference_markers.values()): try: reference_markers = { key: set(val) for key, val in reference_markers.items() diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 3a737bb487..7773148702 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -749,7 +749,7 @@ def filter_rank_genes_groups( use_raw: bool | None = None, key_added: str = "rank_genes_groups_filtered", min_in_group_fraction: float = 0.25, - min_fold_change: int | float = 1, + min_fold_change: float = 1, max_out_group_fraction: float = 0.5, compare_abs: bool = False, ) -> None: diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 23d490218b..ac0e6a6317 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -34,10 +34,10 @@ def tsne( n_pcs: int | None = None, *, use_rep: str | None = None, - perplexity: float | int = 30, + perplexity: float = 30, metric: str = "euclidean", - early_exaggeration: float | int = 12, - learning_rate: float | int = 1000, + early_exaggeration: float = 12, + learning_rate: float = 1000, random_state: AnyRandom = 0, use_fast_tsne: bool = False, n_jobs: int | None = None, From 41666cffcc228a2ebe0c1837e87c074c5d097367 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:32:01 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/_utils/__init__.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index f0fb3aa89e..066e23f667 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -17,10 +17,16 @@ from functools import partial, reduce, singledispatch, wraps from operator import mul, or_, truediv from textwrap import dedent -from types import MethodType, ModuleType -from typing import TYPE_CHECKING, Literal, NamedTuple, overload from types import MethodType, ModuleType, UnionType -from typing import TYPE_CHECKING, Literal, Union, get_args, get_origin, overload +from typing import ( + TYPE_CHECKING, + Literal, + NamedTuple, + Union, + get_args, + get_origin, + overload, +) from weakref import WeakSet import h5py