Skip to content

Commit 7c4f77b

Browse files
authored
Merge branch 'main' into cftime_deprecation
2 parents 2fa10f6 + 1189240 commit 7c4f77b

File tree

13 files changed

+265
-156
lines changed

13 files changed

+265
-156
lines changed

.github/workflows/ci-additional.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ jobs:
116116
python xarray/util/print_versions.py
117117
- name: Install mypy
118118
run: |
119-
python -m pip install "mypy==1.13" --force-reinstall
119+
python -m pip install "mypy==1.15" --force-reinstall
120120
121121
- name: Run mypy
122122
run: |
@@ -167,7 +167,7 @@ jobs:
167167
python xarray/util/print_versions.py
168168
- name: Install mypy
169169
run: |
170-
python -m pip install "mypy==1.13" --force-reinstall
170+
python -m pip install "mypy==1.15" --force-reinstall
171171
172172
- name: Run mypy
173173
run: |

doc/whats-new.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ v2025.02.0 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24-
24+
- Allow kwargs in :py:meth:`DataTree.map_over_datasets` and :py:func:`map_over_datasets` (:issue:`10009`, :pull:`10012`).
25+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
2526

2627
Breaking changes
2728
~~~~~~~~~~~~~~~~
@@ -45,6 +46,9 @@ Bug fixes
4546
- Deprecate xr.cftime_range() in favor of xr.date_range(use_cftime=True)
4647
(:issue:`9886`, :pull:`10024`).
4748
By `Josh Kihm <https://github.com/maddogghoek>`_.
49+
- Fix DataArray().drop_attrs(deep=False) and add support for attrs to
50+
DataArray()._replace(). (:issue:`10027`, :pull:`10030`). By `Jan
51+
Haacker <https://github.com/j-haacker>`_.
4852

4953

5054
Documentation

xarray/core/dataarray.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import copy
34
import datetime
45
import warnings
56
from collections.abc import (
@@ -522,6 +523,7 @@ def _replace(
522523
variable: Variable | None = None,
523524
coords=None,
524525
name: Hashable | None | Default = _default,
526+
attrs=_default,
525527
indexes=None,
526528
) -> Self:
527529
if variable is None:
@@ -532,6 +534,11 @@ def _replace(
532534
indexes = self._indexes
533535
if name is _default:
534536
name = self.name
537+
if attrs is _default:
538+
attrs = copy.copy(self.attrs)
539+
else:
540+
variable = variable.copy()
541+
variable.attrs = attrs
535542
return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True)
536543

537544
def _replace_maybe_drop_dims(
@@ -886,7 +893,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]:
886893
return dict(zip(self.dims, key, strict=True))
887894

888895
def _getitem_coord(self, key: Any) -> Self:
889-
from xarray.core.dataset import _get_virtual_variable
896+
from xarray.core.dataset_utils import _get_virtual_variable
890897

891898
try:
892899
var = self._coords[key]
@@ -7575,6 +7582,11 @@ def drop_attrs(self, *, deep: bool = True) -> Self:
75757582
-------
75767583
DataArray
75777584
"""
7578-
return (
7579-
self._to_temp_dataset().drop_attrs(deep=deep).pipe(self._from_temp_dataset)
7580-
)
7585+
if not deep:
7586+
return self._replace(attrs={})
7587+
else:
7588+
return (
7589+
self._to_temp_dataset()
7590+
.drop_attrs(deep=deep)
7591+
.pipe(self._from_temp_dataset)
7592+
)

xarray/core/dataset.py

Lines changed: 4 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@
2424
from operator import methodcaller
2525
from os import PathLike
2626
from types import EllipsisType
27-
from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast, overload
27+
from typing import IO, TYPE_CHECKING, Any, Literal, cast, overload
2828

2929
import numpy as np
3030
from pandas.api.types import is_extension_array_dtype
3131

32+
from xarray.core.dataset_utils import _get_virtual_variable, _LocIndexer
33+
from xarray.core.dataset_variables import DataVariables
34+
3235
# remove once numpy 2.0 is the oldest supported version
3336
try:
3437
from numpy.exceptions import RankWarning
@@ -98,7 +101,6 @@
98101
T_ChunksFreq,
99102
T_DataArray,
100103
T_DataArrayOrSet,
101-
T_Dataset,
102104
ZarrWriteModes,
103105
)
104106
from xarray.core.utils import (
@@ -196,43 +198,6 @@
196198
]
197199

198200

199-
def _get_virtual_variable(
200-
variables, key: Hashable, dim_sizes: Mapping | None = None
201-
) -> tuple[Hashable, Hashable, Variable]:
202-
"""Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
203-
objects (if possible)
204-
205-
"""
206-
from xarray.core.dataarray import DataArray
207-
208-
if dim_sizes is None:
209-
dim_sizes = {}
210-
211-
if key in dim_sizes:
212-
data = pd.Index(range(dim_sizes[key]), name=key)
213-
variable = IndexVariable((key,), data)
214-
return key, key, variable
215-
216-
if not isinstance(key, str):
217-
raise KeyError(key)
218-
219-
split_key = key.split(".", 1)
220-
if len(split_key) != 2:
221-
raise KeyError(key)
222-
223-
ref_name, var_name = split_key
224-
ref_var = variables[ref_name]
225-
226-
if _contains_datetime_like_objects(ref_var):
227-
ref_var = DataArray(ref_var)
228-
data = getattr(ref_var.dt, var_name).data
229-
else:
230-
data = getattr(ref_var, var_name).data
231-
virtual_var = Variable(ref_var.dims, data)
232-
233-
return ref_name, var_name, virtual_var
234-
235-
236201
def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
237202
"""
238203
Return map from each dim to chunk sizes, accounting for backend's preferred chunks.
@@ -367,19 +332,6 @@ def _maybe_chunk(
367332
return var
368333

369334

370-
def as_dataset(obj: Any) -> Dataset:
371-
"""Cast the given object to a Dataset.
372-
373-
Handles Datasets, DataArrays and dictionaries of variables. A new Dataset
374-
object is only created if the provided object is not already one.
375-
"""
376-
if hasattr(obj, "to_dataset"):
377-
obj = obj.to_dataset()
378-
if not isinstance(obj, Dataset):
379-
obj = Dataset(obj)
380-
return obj
381-
382-
383335
def _get_func_args(func, param_names):
384336
"""Use `inspect.signature` to try accessing `func` args. Otherwise, ensure
385337
they are provided by user.
@@ -468,84 +420,6 @@ def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult:
468420
)
469421

470422

471-
class DataVariables(Mapping[Any, "DataArray"]):
472-
__slots__ = ("_dataset",)
473-
474-
def __init__(self, dataset: Dataset):
475-
self._dataset = dataset
476-
477-
def __iter__(self) -> Iterator[Hashable]:
478-
return (
479-
key
480-
for key in self._dataset._variables
481-
if key not in self._dataset._coord_names
482-
)
483-
484-
def __len__(self) -> int:
485-
length = len(self._dataset._variables) - len(self._dataset._coord_names)
486-
assert length >= 0, "something is wrong with Dataset._coord_names"
487-
return length
488-
489-
def __contains__(self, key: Hashable) -> bool:
490-
return key in self._dataset._variables and key not in self._dataset._coord_names
491-
492-
def __getitem__(self, key: Hashable) -> DataArray:
493-
if key not in self._dataset._coord_names:
494-
return self._dataset[key]
495-
raise KeyError(key)
496-
497-
def __repr__(self) -> str:
498-
return formatting.data_vars_repr(self)
499-
500-
@property
501-
def variables(self) -> Mapping[Hashable, Variable]:
502-
all_variables = self._dataset.variables
503-
return Frozen({k: all_variables[k] for k in self})
504-
505-
@property
506-
def dtypes(self) -> Frozen[Hashable, np.dtype]:
507-
"""Mapping from data variable names to dtypes.
508-
509-
Cannot be modified directly, but is updated when adding new variables.
510-
511-
See Also
512-
--------
513-
Dataset.dtype
514-
"""
515-
return self._dataset.dtypes
516-
517-
def _ipython_key_completions_(self):
518-
"""Provide method for the key-autocompletions in IPython."""
519-
return [
520-
key
521-
for key in self._dataset._ipython_key_completions_()
522-
if key not in self._dataset._coord_names
523-
]
524-
525-
526-
class _LocIndexer(Generic[T_Dataset]):
527-
__slots__ = ("dataset",)
528-
529-
def __init__(self, dataset: T_Dataset):
530-
self.dataset = dataset
531-
532-
def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset:
533-
if not utils.is_dict_like(key):
534-
raise TypeError("can only lookup dictionaries from Dataset.loc")
535-
return self.dataset.sel(key)
536-
537-
def __setitem__(self, key, value) -> None:
538-
if not utils.is_dict_like(key):
539-
raise TypeError(
540-
"can only set locations defined by dictionaries from Dataset.loc."
541-
f" Got: {key}"
542-
)
543-
544-
# set new values
545-
dim_indexers = map_index_queries(self.dataset, key).dim_indexers
546-
self.dataset[dim_indexers] = value
547-
548-
549423
class Dataset(
550424
DataWithCoords,
551425
DatasetAggregations,

xarray/core/dataset_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
import typing
4+
from collections.abc import Hashable, Mapping
5+
from typing import Any, Generic
6+
7+
import pandas as pd
8+
9+
from xarray.core import utils
10+
from xarray.core.common import _contains_datetime_like_objects
11+
from xarray.core.indexing import map_index_queries
12+
from xarray.core.types import T_Dataset
13+
from xarray.core.variable import IndexVariable, Variable
14+
15+
if typing.TYPE_CHECKING:
16+
from xarray.core.dataset import Dataset
17+
18+
19+
class _LocIndexer(Generic[T_Dataset]):
20+
__slots__ = ("dataset",)
21+
22+
def __init__(self, dataset: T_Dataset):
23+
self.dataset = dataset
24+
25+
def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset:
26+
if not utils.is_dict_like(key):
27+
raise TypeError("can only lookup dictionaries from Dataset.loc")
28+
return self.dataset.sel(key)
29+
30+
def __setitem__(self, key, value) -> None:
31+
if not utils.is_dict_like(key):
32+
raise TypeError(
33+
"can only set locations defined by dictionaries from Dataset.loc."
34+
f" Got: {key}"
35+
)
36+
37+
# set new values
38+
dim_indexers = map_index_queries(self.dataset, key).dim_indexers
39+
self.dataset[dim_indexers] = value
40+
41+
42+
def as_dataset(obj: Any) -> Dataset:
43+
"""Cast the given object to a Dataset.
44+
45+
Handles Datasets, DataArrays and dictionaries of variables. A new Dataset
46+
object is only created if the provided object is not already one.
47+
"""
48+
from xarray.core.dataset import Dataset
49+
50+
if hasattr(obj, "to_dataset"):
51+
obj = obj.to_dataset()
52+
if not isinstance(obj, Dataset):
53+
obj = Dataset(obj)
54+
return obj
55+
56+
57+
def _get_virtual_variable(
58+
variables, key: Hashable, dim_sizes: Mapping | None = None
59+
) -> tuple[Hashable, Hashable, Variable]:
60+
"""Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
61+
objects (if possible)
62+
63+
"""
64+
from xarray.core.dataarray import DataArray
65+
66+
if dim_sizes is None:
67+
dim_sizes = {}
68+
69+
if key in dim_sizes:
70+
data = pd.Index(range(dim_sizes[key]), name=key)
71+
variable = IndexVariable((key,), data)
72+
return key, key, variable
73+
74+
if not isinstance(key, str):
75+
raise KeyError(key)
76+
77+
split_key = key.split(".", 1)
78+
if len(split_key) != 2:
79+
raise KeyError(key)
80+
81+
ref_name, var_name = split_key
82+
ref_var = variables[ref_name]
83+
84+
if _contains_datetime_like_objects(ref_var):
85+
ref_var = DataArray(ref_var)
86+
data = getattr(ref_var.dt, var_name).data
87+
else:
88+
data = getattr(ref_var, var_name).data
89+
virtual_var = Variable(ref_var.dims, data)
90+
91+
return ref_name, var_name, virtual_var

0 commit comments

Comments
 (0)