Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More precisely type pipe methods #10038

Merged
merged 12 commits into from
Feb 19, 2025
16 changes: 15 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ jobs:
- env: "flaky"
python-version: "3.13"
os: ubuntu-latest
# The mypy tests must be executed using only 1 process in order to guarantee
# predictable mypy output messages for comparison to expectations.
- env: "mypy"
python-version: "3.10"
numprocesses: 1
os: ubuntu-latest
- env: "mypy"
python-version: "3.13"
numprocesses: 1
os: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -88,6 +98,10 @@ jobs:
then
echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV
echo "PYTEST_ADDOPTS=-m 'flaky or network' --run-flaky --run-network-tests -W default" >> $GITHUB_ENV
elif [[ "${{ matrix.env }}" == "mypy" ]] ;
then
echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV
echo "PYTEST_ADDOPTS=-n 1 -m 'mypy' --run-mypy -W default" >> $GITHUB_ENV
else
echo "CONDA_ENV_FILE=ci/requirements/${{ matrix.env }}.yml" >> $GITHUB_ENV
fi
Expand Down Expand Up @@ -144,7 +158,7 @@ jobs:
save-always: true

- name: Run tests
run: python -m pytest -n 4
run: python -m pytest -n ${{ matrix.numprocesses || 4 }}
--timeout 180
--cov=xarray
--cov-report=xml
Expand Down
3 changes: 2 additions & 1 deletion ci/minimum_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
"pytest",
"pytest-cov",
"pytest-env",
"pytest-xdist",
"pytest-mypy-plugins",
"pytest-timeout",
"pytest-xdist",
"hypothesis",
]

Expand Down
3 changes: 2 additions & 1 deletion ci/requirements/all-but-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ dependencies:
- pytest
- pytest-cov
- pytest-env
- pytest-xdist
- pytest-mypy-plugins
- pytest-timeout
- pytest-xdist
- rasterio
- scipy
- seaborn
Expand Down
3 changes: 2 additions & 1 deletion ci/requirements/all-but-numba.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ dependencies:
- pytest
- pytest-cov
- pytest-env
- pytest-xdist
- pytest-mypy-plugins
- pytest-timeout
- pytest-xdist
- rasterio
- scipy
- seaborn
Expand Down
3 changes: 2 additions & 1 deletion ci/requirements/bare-minimum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ dependencies:
- pytest
- pytest-cov
- pytest-env
- pytest-xdist
- pytest-mypy-plugins
- pytest-timeout
- pytest-xdist
- numpy=1.24
- packaging=23.1
- pandas=2.1
14 changes: 13 additions & 1 deletion ci/requirements/environment-3.14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- opt_einsum
- packaging
- pandas
- pandas-stubs
# - pint>=0.22
- pip
- pooch
Expand All @@ -38,14 +39,25 @@ dependencies:
- pytest
- pytest-cov
- pytest-env
- pytest-xdist
- pytest-mypy-plugins
- pytest-timeout
- pytest-xdist
- rasterio
- scipy
- seaborn
# - sparse
- toolz
- types-colorama
- types-docutils
- types-psutil
- types-Pygments
- types-python-dateutil
- types-pytz
- types-PyYAML
- types-setuptools
- typing_extensions
- zarr
- pip:
- jax # no way to get cpu-only jaxlib from conda if gpu is present
- types-defusedxml
- types-pexpect
15 changes: 14 additions & 1 deletion ci/requirements/environment-windows-3.14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- numpy
- packaging
- pandas
- pandas-stubs
# - pint>=0.22
- pip
- pre-commit
Expand All @@ -33,12 +34,24 @@ dependencies:
- pytest
- pytest-cov
- pytest-env
- pytest-xdist
- pytest-mypy-plugins
- pytest-timeout
- pytest-xdist
- rasterio
- scipy
- seaborn
# - sparse
- toolz
- types-colorama
- types-docutils
- types-psutil
- types-Pygments
- types-python-dateutil
- types-pytz
- types-PyYAML
- types-setuptools
- typing_extensions
- zarr
- pip:
- types-defusedxml
- types-pexpect
15 changes: 14 additions & 1 deletion ci/requirements/environment-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- numpy
- packaging
- pandas
- pandas-stubs
# - pint>=0.22
- pip
- pre-commit
Expand All @@ -33,12 +34,24 @@ dependencies:
- pytest
- pytest-cov
- pytest-env
- pytest-xdist
- pytest-mypy-plugins
- pytest-timeout
- pytest-xdist
- rasterio
- scipy
- seaborn
- sparse
- toolz
- types-colorama
- types-docutils
- types-psutil
- types-Pygments
- types-python-dateutil
- types-pytz
- types-PyYAML
- types-setuptools
- typing_extensions
- zarr
- pip:
- types-defusedxml
- types-pexpect
14 changes: 13 additions & 1 deletion ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- opt_einsum
- packaging
- pandas
- pandas-stubs
# - pint>=0.22
- pip
- pooch
Expand All @@ -39,14 +40,25 @@ dependencies:
- pytest
- pytest-cov
- pytest-env
- pytest-xdist
- pytest-mypy-plugins
- pytest-timeout
- pytest-xdist
- rasterio
- scipy
- seaborn
- sparse
- toolz
- types-colorama
- types-docutils
- types-psutil
- types-Pygments
- types-python-dateutil
- types-pytz
- types-PyYAML
- types-setuptools
- typing_extensions
- zarr
- pip:
- jax # no way to get cpu-only jaxlib from conda if gpu is present
- types-defusedxml
- types-pexpect
3 changes: 2 additions & 1 deletion ci/requirements/min-all-deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ dependencies:
- pytest
- pytest-cov
- pytest-env
- pytest-xdist
- pytest-mypy-plugins
- pytest-timeout
- pytest-xdist
- rasterio=1.3
- scipy=1.11
- seaborn=0.13
Expand Down
18 changes: 17 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import pytest


def pytest_addoption(parser):
def pytest_addoption(parser: pytest.Parser):
"""Add command-line flags for pytest."""
parser.addoption("--run-flaky", action="store_true", help="runs flaky tests")
parser.addoption(
"--run-network-tests",
action="store_true",
help="runs tests requiring a network connection",
)
parser.addoption("--run-mypy", action="store_true", help="runs mypy tests")


def pytest_runtest_setup(item):
Expand All @@ -21,6 +22,21 @@ def pytest_runtest_setup(item):
pytest.skip(
"set --run-network-tests to run test requiring an internet connection"
)
if "mypy" in item.keywords and not item.config.getoption("--run-mypy"):
pytest.skip("set --run-mypy option to run mypy tests")


# See https://docs.pytest.org/en/stable/example/markers.html#automatically-adding-markers-based-on-test-names
def pytest_collection_modifyitems(items):
for item in items:
if "mypy" in item.nodeid:
# IMPORTANT: mypy type annotation tests leverage the pytest-mypy-plugins
# plugin, and are thus written in test_*.yml files. As such, there are
# no explicit test functions on which we can apply a pytest.mark.mypy
# decorator. Therefore, we mark them via this name-based, automatic
# marking approach, meaning that each test case must contain "mypy" in the
# name.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, not ideal to use the test name, but v nice comment

item.add_marker(pytest.mark.mypy)


@pytest.fixture(autouse=True)
Expand Down
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ dev = [
"pytest",
"pytest-cov",
"pytest-env",
"pytest-xdist",
"pytest-mypy-plugins",
"pytest-timeout",
"pytest-xdist",
"ruff>=0.8.0",
"sphinx",
"sphinx_autosummary_accessors",
Expand Down Expand Up @@ -304,7 +305,12 @@ known-first-party = ["xarray"]
ban-relative-imports = "all"

[tool.pytest.ini_options]
addopts = ["--strict-config", "--strict-markers"]
addopts = [
"--strict-config",
"--strict-markers",
"--mypy-only-local-stub",
"--mypy-pyproject-toml-file=pyproject.toml",
]

# We want to forbid warnings from within xarray in our tests — instead we should
# fix our own code, or mark the test itself as expecting a warning. So this:
Expand Down Expand Up @@ -361,6 +367,7 @@ filterwarnings = [
log_cli_level = "INFO"
markers = [
"flaky: flaky tests",
"mypy: type annotation tests",
"network: tests requiring a network connection",
"slow: slow tests",
"slow_hypothesis: slow hypothesis tests",
Expand Down
33 changes: 27 additions & 6 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import suppress
from html import escape
from textwrap import dedent
from typing import TYPE_CHECKING, Any, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, Union, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,6 +60,7 @@
T_Resample = TypeVar("T_Resample", bound="Resample")
C = TypeVar("C")
T = TypeVar("T")
P = ParamSpec("P")


class ImplementsArrayReduce:
Expand Down Expand Up @@ -718,11 +719,27 @@ def assign_attrs(self, *args: Any, **kwargs: Any) -> Self:
out.attrs.update(*args, **kwargs)
return out

@overload
def pipe(
self,
func: Callable[Concatenate[Self, P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> T: ...

@overload
def pipe(
self,
func: Callable[..., T] | tuple[Callable[..., T], str],
func: tuple[Callable[..., T], str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not this as well?

Suggested change
func: tuple[Callable[..., T], str],
func: tuple[Callable[P, T], str]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot do that because when we pass the function as part of a tuple we don't have enough information to precisely type the function's parameters.

In your suggested change, the ParamSpec P represents all of the function's parameters, but we need to know all of the parameters excluding one (the one that takes the data value, as identified by the name given in the second value of the tuple).

If that's as clear as mud, let me try to clarify with more detail.

In the first form, where we pass only a function as the first argument to pipe, we expect the function to take the data value as the first argument, meaning that we know exactly where in the function's list of parameters the data parameter is: the first position.

This means we can type the function more precisely, like so, indicating that the data parameter is first, concatenated with zero or more positional and keyword parameters (and returning a value of some type T):

# Self is a DataWithCoords (DataArray, Dataset, others?), or DataTree
# P represents all parameters to func *excluding* Self
func: Callable[Concatenate[Self, P], T]

Therefore, after passing func to pipe, we must pass all arguments except for the data (self) argument, and this is represented by P, which excludes Self:

# pipe expects a function followed by all arguments to pass to the function,
# *except* for the data argument, which pipe will *implicitly* pass.
def pipe(f: Callable[Concatenate[Self, P], T], *args: P.args, **kwargs: P.kwargs) -> T: ...

However, when we pass a function/keyword 2-tuple as the first argument to pipe, we have no idea what position the keyword parameter is in func's type signature. We only know that it's not the first parameter.

This means, we cannot do the following, as you suggest, because in this case P includes the data parameter, but it must not, and there's no way of omitting it without knowing precisely what position it's in:

# We don't know where Self is within P, so we have no way of defining P
# as meaning all of func's parameters *except* for Self. Therefore, this
# signature indicates that we can *explicity* pass a data argument, but that's
# not correct.
def pipe(func: tuple[Callable[P, T], str], *args: P.args, **kwargs: P.kwargs) -> T: ...

To clarify, although Callable[P, T] is valid for the function itself within the tuple, using *args: P.args, **kwargs: P.kwargs for the rest of the parameters to pipe in this case is incorrect, because it means that we can explicitly pass the data argument in the mix there (because P includes Self), but the whole point of pipe, of course, is to implicitly pass the data argument, and thus not allow it to be passed explicitly.

Technically speaking, as far as mypy is concerned, your suggestion probably make no difference from what I propose, but in terms of the information it conveys to the reader, it is incorrect.

This is why I discourage the use of the tuple form, and instead recommend the use of a lambda (or another function def with args reordered such that the data arg is first). Even using a lambda to reorder things allows mypy to be more helpful than is possible with the tuple form.

For example (taken from the test cases in xarray/tests/test_dataset_typing.yml in this PR):

    from xarray import Dataset

    def f(arg: int, ds: Dataset) -> Dataset:
        return ds

    # Since we cannot provide a precise type annotation when passing a tuple to
    # pipe, there's not enough information for type analysis to indicate that
    # we are missing an argument for parameter `arg`, so we get no error here.

    ds = Dataset().pipe((f, "ds"))
    reveal_type(ds)  # N: Revealed type is "xarray.core.dataset.Dataset"

    # Rather than passing a tuple, passing a lambda that calls `f` with args in
    # the correct order allows for proper type analysis, indicating (perhaps
    # somewhat cryptically) that we failed to pass an argument for `arg`.

    ds = Dataset().pipe(lambda data, arg: f(arg, data))

    # mypy produces the following output for the line above, as it should,
    # indicating that we forgot to pass an argument to pipe, which pipe needs
    # to pass to `f` in the 2nd position.
    error: No overload variant of "pipe" of "DataWithCoords" matches argument type "Callable[[Any, Any], Dataset]"  [call-overload]
    note: Possible overload variants:
    note:     def [P`9, T] pipe(self, func: Callable[[Dataset, **P], T], *args: P.args, **kwargs: P.kwargs) -> T
    note:     def [T] pipe(self, func: tuple[Callable[..., T], str], *args: Any, **kwargs: Any) -> T

IMHO, the tuple form should not be supported, for this very reason, but I don't expect that deprecating that form would get much traction from anybody else.

*args: Any,
**kwargs: Any,
) -> T: ...

def pipe(
self,
func: Callable[Concatenate[Self, P], T] | tuple[Callable[P, T], str],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""
Apply ``func(self, *args, **kwargs)``
Expand Down Expand Up @@ -840,15 +857,19 @@ def pipe(
pandas.DataFrame.pipe
"""
if isinstance(func, tuple):
func, target = func
# Use different var when unpacking function from tuple because the type
# signature of the unpacked function differs from the expected type
# signature in the case where only a function is given, rather than a tuple.
# This makes type checkers happy at both call sites below.
f, target = func
if target in kwargs:
raise ValueError(
f"{target} is both the pipe target and a keyword argument"
)
kwargs[target] = self
return func(*args, **kwargs)
else:
return func(self, *args, **kwargs)
return f(*args, **kwargs)

return func(self, *args, **kwargs)

def rolling_exp(
self: T_DataWithCoords,
Expand Down
Loading
Loading