Skip to content

Commit

Permalink
Change ArrayV3Metadata.data_type to DataType (#2278)
Browse files Browse the repository at this point in the history
* change v3.metadata.data_type type

* implement suggestions
  • Loading branch information
rabernat authored Oct 4, 2024
1 parent 8dd1f24 commit 01346dd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 31 deletions.
62 changes: 37 additions & 25 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
if TYPE_CHECKING:
from typing import Self

import numpy.typing as npt

from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.common import JSON, ChunkCoords
Expand All @@ -20,6 +18,7 @@

import numcodecs.abc
import numpy as np
import numpy.typing as npt

from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
from zarr.core.array_spec import ArraySpec
Expand All @@ -31,6 +30,8 @@
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
from zarr.registry import get_codec_class

DEFAULT_DTYPE = "float64"


def parse_zarr_format(data: object) -> Literal[3]:
if data == 3:
Expand Down Expand Up @@ -152,7 +153,7 @@ def _replace_special_floats(obj: object) -> Any:
@dataclass(frozen=True, kw_only=True)
class ArrayV3Metadata(ArrayMetadata):
shape: ChunkCoords
data_type: np.dtype[Any]
data_type: DataType
chunk_grid: ChunkGrid
chunk_key_encoding: ChunkKeyEncoding
fill_value: Any
Expand All @@ -167,7 +168,7 @@ def __init__(
self,
*,
shape: Iterable[int],
data_type: npt.DTypeLike,
data_type: npt.DTypeLike | DataType,
chunk_grid: dict[str, JSON] | ChunkGrid,
chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding,
fill_value: Any,
Expand All @@ -180,18 +181,18 @@ def __init__(
Because the class is a frozen dataclass, we set attributes using object.__setattr__
"""
shape_parsed = parse_shapelike(shape)
data_type_parsed = parse_dtype(data_type)
data_type_parsed = DataType.parse(data_type)
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
dimension_names_parsed = parse_dimension_names(dimension_names)
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed)
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy())
attributes_parsed = parse_attributes(attributes)
codecs_parsed_partial = parse_codecs(codecs)
storage_transformers_parsed = parse_storage_transformers(storage_transformers)

array_spec = ArraySpec(
shape=shape_parsed,
dtype=data_type_parsed,
dtype=data_type_parsed.to_numpy(),
fill_value=fill_value_parsed,
order="C", # TODO: order is not needed here.
prototype=default_buffer_prototype(), # TODO: prototype is not needed here.
Expand Down Expand Up @@ -224,11 +225,14 @@ def _validate_metadata(self) -> None:
if self.fill_value is None:
raise ValueError("`fill_value` is required.")
for codec in self.codecs:
codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid)
codec.validate(
shape=self.shape, dtype=self.data_type.to_numpy(), chunk_grid=self.chunk_grid
)

@property
def dtype(self) -> np.dtype[Any]:
return self.data_type
"""Interpret Zarr dtype as NumPy dtype"""
return self.data_type.to_numpy()

@property
def ndim(self) -> int:
Expand Down Expand Up @@ -266,13 +270,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
_ = parse_node_type_array(_data.pop("node_type"))

# check that the data_type attribute is valid
_ = DataType(_data["data_type"])
data_type = DataType.parse(_data.pop("data_type"))

# dimension_names key is optional, normalize missing to `None`
_data["dimension_names"] = _data.pop("dimension_names", None)
# attributes key is optional, normalize missing to `None`
_data["attributes"] = _data.pop("attributes", None)
return cls(**_data) # type: ignore[arg-type]
return cls(**_data, data_type=data_type) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
out_dict = super().to_dict()
Expand Down Expand Up @@ -490,8 +494,11 @@ def to_numpy_shortname(self) -> str:
}
return data_type_to_numpy[self]

def to_numpy(self) -> np.dtype[Any]:
return np.dtype(self.to_numpy_shortname())

@classmethod
def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
dtype_to_data_type = {
"|b1": "bool",
"bool": "bool",
Expand All @@ -511,16 +518,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
}
return DataType[dtype_to_data_type[dtype.str]]


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
try:
dtype = np.dtype(data)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid V3 data_type: {data}") from e
# check that this is a valid v3 data_type
try:
_ = DataType.from_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e

return dtype
@classmethod
def parse(cls, dtype: None | DataType | Any) -> DataType:
if dtype is None:
# the default dtype
return DataType[DEFAULT_DTYPE]
if isinstance(dtype, DataType):
return dtype
else:
try:
dtype = np.dtype(dtype)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
# check that this is a valid v3 data_type
try:
data_type = DataType.from_numpy(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
return data_type
11 changes: 5 additions & 6 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -22,7 +22,6 @@

from zarr.core.metadata.v3 import (
parse_dimension_names,
parse_dtype,
parse_fill_value,
parse_zarr_format,
)
Expand Down Expand Up @@ -209,7 +208,7 @@ def test_metadata_to_dict(
storage_transformers: None | tuple[dict[str, JSON]],
) -> None:
shape = (1, 2, 3)
data_type = "uint8"
data_type = DataType.uint8
if chunk_grid == "regular":
cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}}

Expand Down Expand Up @@ -290,7 +289,7 @@ def test_metadata_to_dict(
# assert result["fill_value"] == fill_value


async def test_invalid_dtype_raises() -> None:
def test_invalid_dtype_raises() -> None:
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
Expand All @@ -301,14 +300,14 @@ async def test_invalid_dtype_raises() -> None:
"codecs": (),
"fill_value": np.datetime64(0, "ns"),
}
with pytest.raises(ValueError, match=r".* is not a valid DataType"):
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
ArrayV3Metadata.from_dict(metadata_dict)


@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
def test_parse_invalid_dtype_raises(data):
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
parse_dtype(data)
DataType.parse(data)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 01346dd

Please sign in to comment.