Skip to content

Commit

Permalink
Add type annotations to DataTree.pipe tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chuckwondo committed Feb 7, 2025
1 parent 1541378 commit 56f9e4c
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import sys
import typing
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from copy import copy, deepcopy
from textwrap import dedent

Expand Down Expand Up @@ -1589,27 +1589,53 @@ def test_assign(self) -> None:


class TestPipe:
def test_noop(self, create_test_datatree) -> None:
def test_noop(self, create_test_datatree: Callable[[], DataTree]) -> None:
dt = create_test_datatree()

actual = dt.pipe(lambda tree: tree)
assert actual.identical(dt)

def test_params(self, create_test_datatree) -> None:
def test_args(self, create_test_datatree: Callable[[], DataTree]) -> None:
dt = create_test_datatree()

def f(tree, **attrs):
return tree.assign(arr_with_attrs=xr.Variable("dim0", [], attrs=attrs))
def f(tree: DataTree, x: int, y: int) -> DataTree:
return tree.assign(
arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y))
)

actual = dt.pipe(f, 1, 2)
assert actual["arr_with_attrs"].attrs == dict(x=1, y=2)

def test_kwargs(self, create_test_datatree: Callable[[], DataTree]) -> None:
dt = create_test_datatree()

def f(tree: DataTree, *, x: int, y: int, z: int) -> DataTree:
return tree.assign(
arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y, z=z))
)

attrs = {"x": 1, "y": 2, "z": 3}

actual = dt.pipe(f, **attrs)
assert actual["arr_with_attrs"].attrs == attrs

def test_named_self(self, create_test_datatree) -> None:
def test_args_kwargs(self, create_test_datatree: Callable[[], DataTree]) -> None:
dt = create_test_datatree()

def f(tree: DataTree, x: int, *, y: int, z: int) -> DataTree:
return tree.assign(
arr_with_attrs=xr.Variable("dim0", [], attrs=dict(x=x, y=y, z=z))
)

attrs = {"x": 1, "y": 2, "z": 3}

actual = dt.pipe(f, attrs["x"], y=attrs["y"], z=attrs["z"])
assert actual["arr_with_attrs"].attrs == attrs

def test_named_self(self, create_test_datatree: Callable[[], DataTree]) -> None:
dt = create_test_datatree()

def f(x, tree, y):
def f(x: int, tree: DataTree, y: int):
tree.attrs.update({"x": x, "y": y})
return tree

Expand Down

0 comments on commit 56f9e4c

Please sign in to comment.