Skip to content

Commit

Permalink
adds init from scalar to Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
itcarroll committed Jan 23, 2025
1 parent 70997ef commit ee9f35c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 8 deletions.
8 changes: 7 additions & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ def collect_variables_and_indexes(
indexes = {}

grouped: dict[Hashable, list[MergeElement]] = defaultdict(list)
sizes: dict[Hashable, int] = {
k: v
for i in list_of_mappings
for j in i.values()
for k, v in getattr(j, "sizes", {}).items()
}

def append(name, variable, index):
grouped[name].append((variable, index))
Expand All @@ -355,7 +361,7 @@ def append_all(variables, indexes):
indexes_.pop(name, None)
append_all(coords_, indexes_)

variable = as_variable(variable, name=name, auto_convert=False)
variable = as_variable(variable, name=name, auto_convert=False, sizes=sizes)
if name in indexes:
append(name, variable, indexes[name])
elif variable.dims == (name,):
Expand Down
33 changes: 28 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ class MissingDimensionsError(ValueError):


def as_variable(
obj: T_DuckArray | Any, name=None, auto_convert: bool = True
obj: T_DuckArray | Any,
name=None,
auto_convert: bool = True,
sizes: Mapping | None = None,
) -> Variable | IndexVariable:
"""Convert an object into a Variable.
Expand Down Expand Up @@ -127,24 +130,44 @@ def as_variable(
if isinstance(obj, Variable):
obj = obj.copy(deep=False)
elif isinstance(obj, tuple):
if len(obj) < 2:
obj += (np.nan,)
try:
dims_, data_, *attrs = obj
dims_, data_, *attrs_ = obj
except ValueError as err:
raise ValueError(
f"Tuple {obj} is not in the form (dims, data[, attrs])"
f"Tuple {obj} is not in the form (dims, [data[, attrs[, encoding]]])"
) from err

if isinstance(data_, DataArray):
raise TypeError(
f"Variable {name!r}: Using a DataArray object to construct a variable is"
" ambiguous, please extract the data using the .data property."
)

if utils.is_scalar(data_, include_0d=True):
try:
shape_ = tuple(sizes[i] for i in dims_)
except TypeError as err:
message = (
f"Variable {name!r}: Could not convert tuple of form "
f"(dims, [data, [attrs, [encoding]]]): {obj} to Variable."
)
raise ValueError(message) from err
except KeyError as err:
message = (
f"Variable {name!r}: Provide `coords` with dimension(s) {dims_} to "
f"initialize with `np.full({dims_}, {data_!r})`."
)
raise ValueError(message) from err
data_ = np.full(shape_, data_)

try:
obj = Variable(dims_, data_, *attrs)
obj = Variable(dims_, data_, *attrs_)
except (TypeError, ValueError) as error:
raise error.__class__(
f"Variable {name!r}: Could not convert tuple of form "
f"(dims, data[, attrs, encoding]): {obj} to Variable."
f"(dims, [data, [attrs, [encoding]]]): {obj} to Variable."
) from error
elif utils.is_scalar(obj):
obj = Variable([], obj)
Expand Down
47 changes: 46 additions & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_constructor(self) -> None:

with pytest.raises(ValueError, match=r"conflicting sizes"):
Dataset({"a": x1, "b": x2})
with pytest.raises(TypeError, match=r"tuple of form"):
with pytest.raises(ValueError, match=r"tuple of form"):
Dataset({"x": (1, 2, 3, 4, 5, 6, 7)})
with pytest.raises(ValueError, match=r"already exists as a scalar"):
Dataset({"x": 0, "y": ("x", [1, 2, 3])})
Expand Down Expand Up @@ -527,6 +527,51 @@ class Arbitrary:
actual = Dataset({"x": arg})
assert_identical(expected, actual)

def test_constructor_scalar(self) -> None:
fill_value = np.nan
x = np.arange(2)
a = {"foo": "bar"}

# a suitable `coords`` argument is required
with pytest.raises(ValueError):
Dataset({"f": (["x"], fill_value), "x": x})

# 1d coordinates
expected = Dataset(
{
"f": DataArray(fill_value, dims=["x"], coords={"x": x}),
},
)
for actual in (
Dataset({"f": (["x"], fill_value)}, coords=expected.coords),
Dataset({"f": (["x"], fill_value)}, coords={"x": x}),
Dataset({"f": (["x"],)}, coords=expected.coords),
Dataset({"f": (["x"],)}, coords={"x": x}),
):
assert_identical(expected, actual)
expected["f"].attrs.update(a)
actual = Dataset({"f": (["x"], fill_value, a)}, coords={"x": x})
assert_identical(expected, actual)

# 2d coordinates
yx = np.arange(6).reshape(2, -1)
try:
# TODO(itcarroll): aux coords broken in DataArray from scalar
array = DataArray(
fill_value, dims=["y", "x"], coords={"lat": (["y", "x"], yx)}
)
expected = Dataset({"f": array})
except ValueError:
expected = Dataset(
data_vars={"f": (["y", "x"], np.full(yx.shape, fill_value))},
coords={"lat": (["y", "x"], yx)},
)
actual = Dataset(
{"f": (["y", "x"], fill_value)},
coords=expected.coords,
)
assert_identical(expected, actual)

def test_constructor_auto_align(self) -> None:
a = DataArray([1, 2], [("x", [0, 1])])
b = DataArray([3, 4], [("x", [1, 2])])
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ def test_as_variable(self):
)
assert_identical(expected_extra, as_variable(xarray_tuple))

with pytest.raises(TypeError, match=r"tuple of form"):
with pytest.raises(ValueError, match=r"tuple of form"):
as_variable(tuple(data))
with pytest.raises(ValueError, match=r"tuple of form"): # GH1016
as_variable(("five", "six", "seven"))
Expand Down

0 comments on commit ee9f35c

Please sign in to comment.