Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PYI lints #3339

Merged
merged 3 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions benchmarks/benchmarks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
import scanpy as sc

if TYPE_CHECKING:
from collections.abc import Callable, Sequence, Set
from collections.abc import Callable, Sequence
from collections.abc import Set as AbstractSet
from typing import Literal, Protocol, TypeVar

from anndata import AnnData

C = TypeVar("C", bound=Callable)

class ParamSkipper(Protocol):
def __call__(self, **skipped: Set) -> Callable[[C], C]: ...
def __call__(self, **skipped: AbstractSet) -> Callable[[C], C]: ...

Dataset = Literal["pbmc68k_reduced", "pbmc3k", "bmmc", "lung93k"]
KeyX = Literal[None, "off-axis"]
Expand Down Expand Up @@ -195,7 +196,7 @@ def param_skipper(
b 5
"""

def skip(**skipped: Set) -> Callable[[C], C]:
def skip(**skipped: AbstractSet) -> Callable[[C], C]:
skipped_combs = [
tuple(record.values())
for record in (
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ select = [
"TID251", # Banned imports
"ICN", # Follow import conventions
"PTH", # Pathlib instead of os.path
"PYI", # Typing
"PLR0917", # Ban APIs with too many positional parameters
"FBT", # No positional boolean parameters
"PT", # Pytest style
Expand All @@ -246,6 +247,8 @@ ignore = [
"E262",
# allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation
"E741",
# `Literal["..."] | str` is useful for autocompletion
"PYI051",
]
[tool.ruff.lint.per-file-ignores]
# Do not assign a lambda expression, use a def
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Collected from the print_* functions in matplotlib.backends
_Format = (
Literal["png", "jpg", "tif", "tiff"]
Literal["png", "jpg", "tif", "tiff"] # noqa: PYI030
| Literal["pdf", "ps", "eps", "svg", "svgz", "pgf"]
| Literal["raw", "rgba"]
)
Expand Down Expand Up @@ -340,7 +340,7 @@ def max_memory(self) -> int | float:
return self._max_memory

@max_memory.setter
def max_memory(self, max_memory: int | float):
def max_memory(self, max_memory: float):
_type_check(max_memory, "max_memory", (int, float))
self._max_memory = max_memory

Expand Down
23 changes: 13 additions & 10 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
import re
import sys
import warnings
from collections import namedtuple
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
from textwrap import dedent
from types import MethodType, ModuleType
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, Literal, NamedTuple, overload
from weakref import WeakSet

import h5py
Expand Down Expand Up @@ -296,6 +295,11 @@
# --------------------------------------------------------------------------------


class AssoResult(NamedTuple):
asso_names: list[str]
asso_matrix: NDArray[np.floating]


def compute_association_matrix_of_groups(
adata: AnnData,
prediction: str,
Expand All @@ -304,7 +308,7 @@
normalization: Literal["prediction", "reference"] = "prediction",
threshold: float = 0.01,
max_n_names: int | None = 2,
):
) -> AssoResult:
"""Compute overlaps between groups.

See ``identify_groups`` for identifying the groups.
Expand Down Expand Up @@ -346,8 +350,8 @@
f"Ignoring category {cat!r} "
"as it’s in `settings.categories_to_ignore`."
)
asso_names = []
asso_matrix = []
asso_names: list[str] = []
asso_matrix: list[list[float]] = []

Check warning on line 354 in src/scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/_utils/__init__.py#L353-L354

Added lines #L353 - L354 were not covered by tests
for ipred_group, pred_group in enumerate(adata.obs[prediction].cat.categories):
if "?" in pred_group:
pred_group = str(ipred_group)
Expand Down Expand Up @@ -380,13 +384,12 @@
if asso_matrix[-1][i] > threshold
]
asso_names += ["\n".join(name_list_pred[:max_n_names])]
Result = namedtuple(
"compute_association_matrix_of_groups", ["asso_names", "asso_matrix"]
)
return Result(asso_names=asso_names, asso_matrix=np.array(asso_matrix))
return AssoResult(asso_names=asso_names, asso_matrix=np.array(asso_matrix))

Check warning on line 387 in src/scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/_utils/__init__.py#L387

Added line #L387 was not covered by tests


def get_associated_colors_of_groups(reference_colors, asso_matrix):
def get_associated_colors_of_groups(
reference_colors: Mapping[int, str], asso_matrix: NDArray[np.floating]
) -> list[dict[str, float]]:
return [
{
reference_colors[i_ref]: asso_matrix[i_pred, i_ref]
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Iterator, Mapping, Sequence
from subprocess import CompletedProcess
from typing import Any

Expand Down Expand Up @@ -64,7 +64,7 @@ def __delitem__(self, k: str) -> None:

# These methods retrieve the command list or help with doing it

def __iter__(self) -> Generator[str, None, None]:
def __iter__(self) -> Iterator[str]:
yield from self.parser_map
yield from self.commands

Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def median(self) -> Array:
return np.array(medians)


def _power(X: Array, power: float | int) -> Array:
def _power(X: Array, power: float) -> Array:
"""\
Generate elementwise power of a matrix.

Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def scatter(
components: str | Collection[str] | None = None,
projection: Literal["2d", "3d"] = "2d",
legend_loc: _LegendLoc | None = "right margin",
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight | None = None,
legend_fontoutline: float | None = None,
color_map: str | Colormap | None = None,
palette: Cycler | ListedColormap | ColorLike | Sequence[ColorLike] | None = None,
frameon: bool | None = None,
right_margin: float | None = None,
left_margin: float | None = None,
size: int | float | None = None,
size: float | None = None,
marker: str | Sequence[str] = ".",
title: str | Collection[str] | None = None,
show: bool | None = None,
Expand Down Expand Up @@ -229,15 +229,15 @@ def _scatter_obs(
components: str | Collection[str] | None = None,
projection: Literal["2d", "3d"] = "2d",
legend_loc: _LegendLoc | None = "right margin",
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight | None = None,
legend_fontoutline: float | None = None,
color_map: str | Colormap | None = None,
palette: Cycler | ListedColormap | ColorLike | Sequence[ColorLike] | None = None,
frameon: bool | None = None,
right_margin: float | None = None,
left_margin: float | None = None,
size: int | float | None = None,
size: float | None = None,
marker: str | Sequence[str] = ".",
title: str | Collection[str] | None = None,
show: bool | None = None,
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/plotting/_stacked_violin.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def style(
cmap: Colormap | str | None | Empty = _empty,
stripplot: bool | Empty = _empty,
jitter: float | bool | Empty = _empty,
jitter_size: int | float | Empty = _empty,
jitter_size: float | Empty = _empty,
linewidth: float | None | Empty = _empty,
row_palette: str | None | Empty = _empty,
density_norm: DensityNorm | Empty = _empty,
Expand Down Expand Up @@ -470,8 +470,8 @@ def _make_rows_of_violinplots(
_matrix,
colormap_array,
_color_df,
x_spacer_size: float | int,
y_spacer_size: float | int,
x_spacer_size: float,
y_spacer_size: float,
x_axis_order,
):
import seaborn as sns # Slow import, only import if called
Expand Down Expand Up @@ -699,7 +699,7 @@ def stacked_violin(
cmap: Colormap | str | None = StackedViolin.DEFAULT_COLORMAP,
stripplot: bool = StackedViolin.DEFAULT_STRIPPLOT,
jitter: float | bool = StackedViolin.DEFAULT_JITTER,
size: int | float = StackedViolin.DEFAULT_JITTER_SIZE,
size: float = StackedViolin.DEFAULT_JITTER_SIZE,
row_palette: str | None = StackedViolin.DEFAULT_ROW_PALETTE,
density_norm: DensityNorm | Empty = _empty,
yticklabels: bool = StackedViolin.DEFAULT_PLOT_YTICKLABELS,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ def rank_genes_groups_violin(
split: bool = True,
density_norm: DensityNorm = "width",
strip: bool = True,
jitter: int | float | bool = True,
jitter: float | bool = True,
size: int = 1,
ax: Axes | None = None,
show: bool | None = None,
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def embedding_density(
*,
key: str | None = None,
groupby: str | None = None,
group: str | Sequence[str] | None | None = "all",
group: str | Sequence[str] | None = "all",
color_map: Colormap | str = "YlOrRd",
bg_dotsize: int | None = 80,
fg_dotsize: int | None = 180,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/plotting/_tools/paga.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def paga_compare(
components=None,
projection: Literal["2d", "3d"] = "2d",
legend_loc: _LegendLoc | None = "on data",
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight = "bold",
legend_fontoutline=None,
color_map=None,
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def paga_path(
show_node_names: bool = True,
show_yticks: bool = True,
show_colorbar: bool = True,
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight | None = None,
normalize_to_zero_one: bool = False,
as_heatmap: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def embedding(
na_in_legend: bool = True,
size: float | Sequence[float] | None = None,
frameon: bool | None = None,
legend_fontsize: int | float | _FontSize | None = None,
legend_fontsize: float | _FontSize | None = None,
legend_fontweight: int | _FontWeight = "bold",
legend_loc: _LegendLoc | None = "right margin",
legend_fontoutline: int | None = None,
Expand Down
1 change: 0 additions & 1 deletion src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def scale_array(
| tuple[
np.ndarray | DaskArray, NDArray[np.float64] | DaskArray, NDArray[np.float64]
]
| DaskArray
):
if copy:
X = X.copy()
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/preprocessing/_scrublet/sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def sparse_multiply(
E: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.float64],
a: float | int | NDArray[np.float64],
a: float | NDArray[np.float64],
) -> sparse.csr_matrix | sparse.csc_matrix:
"""multiply each row of E by a scalar"""

Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/tools/_marker_gene_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from collections.abc import Set
from collections.abc import Set as AbstractSet
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -187,7 +187,7 @@ def marker_gene_overlap(
if normalize is not None and method != "overlap_count":
raise ValueError("Can only normalize with method=`overlap_count`.")

if not all(isinstance(val, Set) for val in reference_markers.values()):
if not all(isinstance(val, AbstractSet) for val in reference_markers.values()):
try:
reference_markers = {
key: set(val) for key, val in reference_markers.items()
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/tools/_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def filter_rank_genes_groups(
use_raw: bool | None = None,
key_added: str = "rank_genes_groups_filtered",
min_in_group_fraction: float = 0.25,
min_fold_change: int | float = 1,
min_fold_change: float = 1,
max_out_group_fraction: float = 0.5,
compare_abs: bool = False,
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/tools/_tsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def tsne(
n_pcs: int | None = None,
*,
use_rep: str | None = None,
perplexity: float | int = 30,
perplexity: float = 30,
metric: str = "euclidean",
early_exaggeration: float | int = 12,
learning_rate: float | int = 1000,
early_exaggeration: float = 12,
learning_rate: float = 1000,
random_state: AnyRandom = 0,
use_fast_tsne: bool = False,
n_jobs: int | None = None,
Expand Down
Loading