Skip to content

Commit

Permalink
Add PYI lints (#3339)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
flying-sheep and pre-commit-ci[bot] authored Nov 5, 2024
1 parent 0d04447 commit 5c0e89e
Show file tree
Hide file tree
Showing 16 changed files with 53 additions and 39 deletions.
7 changes: 4 additions & 3 deletions benchmarks/benchmarks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
import scanpy as sc

if TYPE_CHECKING:
from collections.abc import Callable, Sequence, Set
from collections.abc import Callable, Sequence
from collections.abc import Set as AbstractSet
from typing import Literal, Protocol, TypeVar

from anndata import AnnData

C = TypeVar("C", bound=Callable)

class ParamSkipper(Protocol):
def __call__(self, **skipped: Set) -> Callable[[C], C]: ...
def __call__(self, **skipped: AbstractSet) -> Callable[[C], C]: ...

Dataset = Literal["pbmc68k_reduced", "pbmc3k", "bmmc", "lung93k"]
KeyX = Literal[None, "off-axis"]
Expand Down Expand Up @@ -195,7 +196,7 @@ def param_skipper(
b 5
"""

def skip(**skipped: Set) -> Callable[[C], C]:
def skip(**skipped: AbstractSet) -> Callable[[C], C]:
skipped_combs = [
tuple(record.values())
for record in (
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ select = [
"TID251", # Banned imports
"ICN", # Follow import conventions
"PTH", # Pathlib instead of os.path
"PYI", # Typing
"PLR0917", # Ban APIs with too many positional parameters
"FBT", # No positional boolean parameters
"PT", # Pytest style
Expand All @@ -246,6 +247,8 @@ ignore = [
"E262",
# allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation
"E741",
# `Literal["..."] | str` is useful for autocompletion
"PYI051",
]
[tool.ruff.lint.per-file-ignores]
# Do not assign a lambda expression, use a def
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Collected from the print_* functions in matplotlib.backends
_Format = (
Literal["png", "jpg", "tif", "tiff"]
Literal["png", "jpg", "tif", "tiff"] # noqa: PYI030
| Literal["pdf", "ps", "eps", "svg", "svgz", "pgf"]
| Literal["raw", "rgba"]
)
Expand Down Expand Up @@ -340,7 +340,7 @@ def max_memory(self) -> int | float:
return self._max_memory

@max_memory.setter
def max_memory(self, max_memory: int | float):
def max_memory(self, max_memory: float):
_type_check(max_memory, "max_memory", (int, float))
self._max_memory = max_memory

Expand Down
31 changes: 21 additions & 10 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
import re
import sys
import warnings
from collections import namedtuple
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, reduce, singledispatch, wraps
from operator import mul, or_, truediv
from textwrap import dedent
from types import MethodType, ModuleType, UnionType
from typing import TYPE_CHECKING, Literal, Union, get_args, get_origin, overload
from typing import (
TYPE_CHECKING,
Literal,
NamedTuple,
Union,
get_args,
get_origin,
overload,
)
from weakref import WeakSet

import h5py
Expand Down Expand Up @@ -297,6 +304,11 @@ def get_igraph_from_adjacency(adjacency, directed=None):
# --------------------------------------------------------------------------------


class AssoResult(NamedTuple):
asso_names: list[str]
asso_matrix: NDArray[np.floating]


def compute_association_matrix_of_groups(
adata: AnnData,
prediction: str,
Expand All @@ -305,7 +317,7 @@ def compute_association_matrix_of_groups(
normalization: Literal["prediction", "reference"] = "prediction",
threshold: float = 0.01,
max_n_names: int | None = 2,
):
) -> AssoResult:
"""Compute overlaps between groups.
See ``identify_groups`` for identifying the groups.
Expand Down Expand Up @@ -347,8 +359,8 @@ def compute_association_matrix_of_groups(
f"Ignoring category {cat!r} "
"as it’s in `settings.categories_to_ignore`."
)
asso_names = []
asso_matrix = []
asso_names: list[str] = []
asso_matrix: list[list[float]] = []
for ipred_group, pred_group in enumerate(adata.obs[prediction].cat.categories):
if "?" in pred_group:
pred_group = str(ipred_group)
Expand Down Expand Up @@ -381,13 +393,12 @@ def compute_association_matrix_of_groups(
if asso_matrix[-1][i] > threshold
]
asso_names += ["\n".join(name_list_pred[:max_n_names])]
Result = namedtuple(
"compute_association_matrix_of_groups", ["asso_names", "asso_matrix"]
)
return Result(asso_names=asso_names, asso_matrix=np.array(asso_matrix))
return AssoResult(asso_names=asso_names, asso_matrix=np.array(asso_matrix))


def get_associated_colors_of_groups(reference_colors, asso_matrix):
def get_associated_colors_of_groups(
reference_colors: Mapping[int, str], asso_matrix: NDArray[np.floating]
) -> list[dict[str, float]]:
return [
{
reference_colors[i_ref]: asso_matrix[i_pred, i_ref]
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Iterator, Mapping, Sequence
from subprocess import CompletedProcess
from typing import Any

Expand Down Expand Up @@ -64,7 +64,7 @@ def __delitem__(self, k: str) -> None:

# These methods retrieve the command list or help with doing it

def __iter__(self) -> Generator[str, None, None]:
def __iter__(self) -> Iterator[str]:
yield from self.parser_map
yield from self.commands

Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def median(self) -> Array:
return np.array(medians)


def _power(X: Array, power: float | int) -> Array:
def _power(X: Array, power: float) -> Array:
"""\
Generate elementwise power of a matrix.
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ def scatter(
components: str | Collection[str] | None = None,
projection: Literal["2d", "3d"] = "2d",
legend_loc: _LegendLoc | None = "right margin",
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight | None = None,
legend_fontoutline: float | None = None,
color_map: str | Colormap | None = None,
palette: Cycler | ListedColormap | ColorLike | Sequence[ColorLike] | None = None,
frameon: bool | None = None,
right_margin: float | None = None,
left_margin: float | None = None,
size: int | float | None = None,
size: float | None = None,
marker: str | Sequence[str] = ".",
title: str | Collection[str] | None = None,
show: bool | None = None,
Expand Down Expand Up @@ -232,15 +232,15 @@ def _scatter_obs(
components: str | Collection[str] | None = None,
projection: Literal["2d", "3d"] = "2d",
legend_loc: _LegendLoc | None = "right margin",
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight | None = None,
legend_fontoutline: float | None = None,
color_map: str | Colormap | None = None,
palette: Cycler | ListedColormap | ColorLike | Sequence[ColorLike] | None = None,
frameon: bool | None = None,
right_margin: float | None = None,
left_margin: float | None = None,
size: int | float | None = None,
size: float | None = None,
marker: str | Sequence[str] = ".",
title: str | Collection[str] | None = None,
show: bool | None = None,
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/plotting/_stacked_violin.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def style(
cmap: Colormap | str | None | Empty = _empty,
stripplot: bool | Empty = _empty,
jitter: float | bool | Empty = _empty,
jitter_size: int | float | Empty = _empty,
jitter_size: float | Empty = _empty,
linewidth: float | None | Empty = _empty,
row_palette: str | None | Empty = _empty,
density_norm: DensityNorm | Empty = _empty,
Expand Down Expand Up @@ -470,8 +470,8 @@ def _make_rows_of_violinplots(
_matrix,
colormap_array,
_color_df,
x_spacer_size: float | int,
y_spacer_size: float | int,
x_spacer_size: float,
y_spacer_size: float,
x_axis_order,
):
import seaborn as sns # Slow import, only import if called
Expand Down Expand Up @@ -699,7 +699,7 @@ def stacked_violin(
cmap: Colormap | str | None = StackedViolin.DEFAULT_COLORMAP,
stripplot: bool = StackedViolin.DEFAULT_STRIPPLOT,
jitter: float | bool = StackedViolin.DEFAULT_JITTER,
size: int | float = StackedViolin.DEFAULT_JITTER_SIZE,
size: float = StackedViolin.DEFAULT_JITTER_SIZE,
row_palette: str | None = StackedViolin.DEFAULT_ROW_PALETTE,
density_norm: DensityNorm | Empty = _empty,
yticklabels: bool = StackedViolin.DEFAULT_PLOT_YTICKLABELS,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ def rank_genes_groups_violin(
split: bool = True,
density_norm: DensityNorm = "width",
strip: bool = True,
jitter: int | float | bool = True,
jitter: float | bool = True,
size: int = 1,
ax: Axes | None = None,
show: bool | None = None,
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def embedding_density(
*,
key: str | None = None,
groupby: str | None = None,
group: str | Sequence[str] | None | None = "all",
group: str | Sequence[str] | None = "all",
color_map: Colormap | str = "YlOrRd",
bg_dotsize: int | None = 80,
fg_dotsize: int | None = 180,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/plotting/_tools/paga.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def paga_compare(
components=None,
projection: Literal["2d", "3d"] = "2d",
legend_loc: _LegendLoc | None = "on data",
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight = "bold",
legend_fontoutline=None,
color_map=None,
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def paga_path(
show_node_names: bool = True,
show_yticks: bool = True,
show_colorbar: bool = True,
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight | None = None,
normalize_to_zero_one: bool = False,
as_heatmap: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def embedding(
na_in_legend: bool = True,
size: float | Sequence[float] | None = None,
frameon: bool | None = None,
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight = "bold",
legend_loc: _LegendLoc | None = "right margin",
legend_fontoutline: int | None = None,
Expand Down
1 change: 0 additions & 1 deletion src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def scale_array(
| tuple[
np.ndarray | DaskArray, NDArray[np.float64] | DaskArray, NDArray[np.float64]
]
| DaskArray
):
if copy:
X = X.copy()
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/preprocessing/_scrublet/sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def sparse_multiply(
E: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.float64],
a: float | int | NDArray[np.float64],
a: float | NDArray[np.float64],
) -> sparse.csr_matrix | sparse.csc_matrix:
"""multiply each row of E by a scalar"""

Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/tools/_marker_gene_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from collections.abc import Set
from collections.abc import Set as AbstractSet
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -187,7 +187,7 @@ def marker_gene_overlap(
if normalize is not None and method != "overlap_count":
raise ValueError("Can only normalize with method=`overlap_count`.")

if not all(isinstance(val, Set) for val in reference_markers.values()):
if not all(isinstance(val, AbstractSet) for val in reference_markers.values()):
try:
reference_markers = {
key: set(val) for key, val in reference_markers.items()
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/tools/_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def filter_rank_genes_groups(
use_raw: bool | None = None,
key_added: str = "rank_genes_groups_filtered",
min_in_group_fraction: float = 0.25,
min_fold_change: int | float = 1,
min_fold_change: float = 1,
max_out_group_fraction: float = 0.5,
compare_abs: bool = False,
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/tools/_tsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def tsne(
n_pcs: int | None = None,
*,
use_rep: str | None = None,
perplexity: float | int = 30,
perplexity: float = 30,
metric: str = "euclidean",
early_exaggeration: float | int = 12,
learning_rate: float | int = 1000,
early_exaggeration: float = 12,
learning_rate: float = 1000,
random_state: AnyRandom = 0,
use_fast_tsne: bool = False,
n_jobs: int | None = None,
Expand Down

0 comments on commit 5c0e89e

Please sign in to comment.