Skip to content

Commit

Permalink
Basic Zarr-python 2.x compatibility changes (#2098)
Browse files Browse the repository at this point in the history
* WIP - backwards compat

* fixup put

* rm consolidated

* typing fixup

* revert unneded change

* fixup

* deprecate positional args

* attribute

* Fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* ci

* fixup

* fixup

---------

Co-authored-by: Joe Hamman <[email protected]>
  • Loading branch information
TomAugspurger and jhamman authored Sep 20, 2024
1 parent c878da2 commit 6900754
Show file tree
Hide file tree
Showing 15 changed files with 303 additions and 61 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ test = [
"flask",
"requests",
"mypy",
"hypothesis"
"hypothesis",
"universal-pathlib",
]

jupyter = [
Expand Down
68 changes: 68 additions & 0 deletions src/zarr/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import warnings
from collections.abc import Callable
from functools import wraps
from inspect import Parameter, signature
from typing import Any, TypeVar

T = TypeVar("T")

# Based off https://github.com/scikit-learn/scikit-learn/blob/e87b32a81c70abed8f2e97483758eb64df8255e9/sklearn/utils/validation.py#L63


def _deprecate_positional_args(
func: Callable[..., T] | None = None, *, version: str = "3.1.0"
) -> Callable[..., T]:
"""Decorator for methods that issues warnings for positional arguments.
Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument.
Parameters
----------
func : callable, default=None
Function to check arguments on.
version : callable, default="3.1.0"
The version when positional arguments will result in error.
"""

def _inner_deprecate_positional_args(f: Callable[..., T]) -> Callable[..., T]:
sig = signature(f)
kwonly_args = []
all_args = []

for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)

@wraps(f)
def inner_f(*args: Any, **kwargs: Any) -> T:
extra_args = len(args) - len(all_args)
if extra_args <= 0:
return f(*args, **kwargs)

# extra_args > 0
args_msg = [
f"{name}={arg}"
for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:], strict=False)
]
formatted_args_msg = ", ".join(args_msg)
warnings.warn(
(
f"Pass {formatted_args_msg} as keyword args. From version "
f"{version} passing these as positional arguments "
"will result in an error"
),
FutureWarning,
stacklevel=2,
)
kwargs.update(zip(sig.parameters, args, strict=False))
return f(**kwargs)

return inner_f

if func is not None:
return _inner_deprecate_positional_args(func)

return _inner_deprecate_positional_args # type: ignore[return-value]
6 changes: 3 additions & 3 deletions src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async def group(
try:
return await AsyncGroup.open(store=store_path, zarr_format=zarr_format)
except (KeyError, FileNotFoundError):
return await AsyncGroup.create(
return await AsyncGroup.from_store(
store=store_path,
zarr_format=zarr_format or _default_zarr_version(),
exists_ok=overwrite,
Expand All @@ -512,8 +512,8 @@ async def group(


async def open_group(
*, # Note: this is a change from v2
store: StoreLike | None = None,
*, # Note: this is a change from v2
mode: AccessModeLiteral | None = None,
cache_attrs: bool | None = None, # not used, default changed
synchronizer: Any = None, # not used
Expand Down Expand Up @@ -590,7 +590,7 @@ async def open_group(
try:
return await AsyncGroup.open(store_path, zarr_format=zarr_format)
except (KeyError, FileNotFoundError):
return await AsyncGroup.create(
return await AsyncGroup.from_store(
store_path,
zarr_format=zarr_format or _default_zarr_version(),
exists_ok=True,
Expand Down
11 changes: 8 additions & 3 deletions src/zarr/api/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any

import zarr.api.asynchronous as async_api
from zarr._compat import _deprecate_positional_args
from zarr.core.array import Array, AsyncArray
from zarr.core.group import Group
from zarr.core.sync import sync
Expand Down Expand Up @@ -63,9 +64,10 @@ def load(
return sync(async_api.load(store=store, zarr_version=zarr_version, path=path))


@_deprecate_positional_args
def open(
*,
store: StoreLike | None = None,
*,
mode: AccessModeLiteral | None = None, # type and value changed
zarr_version: ZarrFormat | None = None, # deprecated
zarr_format: ZarrFormat | None = None,
Expand Down Expand Up @@ -107,6 +109,7 @@ def save(
)


@_deprecate_positional_args
def save_array(
store: StoreLike,
arr: NDArrayLike,
Expand Down Expand Up @@ -159,9 +162,10 @@ def array(data: NDArrayLike, **kwargs: Any) -> Array:
return Array(sync(async_api.array(data=data, **kwargs)))


@_deprecate_positional_args
def group(
*, # Note: this is a change from v2
store: StoreLike | None = None,
*, # Note: this is a change from v2
overwrite: bool = False,
chunk_store: StoreLike | None = None, # not used in async_api
cache_attrs: bool | None = None, # default changed, not used in async_api
Expand Down Expand Up @@ -190,9 +194,10 @@ def group(
)


@_deprecate_positional_args
def open_group(
*, # Note: this is a change from v2
store: StoreLike | None = None,
*, # Note: this is a change from v2
mode: AccessModeLiteral | None = None, # not used in async api
cache_attrs: bool | None = None, # default changed, not used in async api
synchronizer: Any = None, # not used in async api
Expand Down
13 changes: 13 additions & 0 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
import numpy.typing as npt

from zarr._compat import _deprecate_positional_args
from zarr.abc.codec import Codec, CodecPipeline
from zarr.abc.store import set_or_delete
from zarr.codecs import BytesCodec
from zarr.codecs._v2 import V2Compressor, V2Filters
Expand Down Expand Up @@ -621,6 +623,7 @@ class Array:
_async_array: AsyncArray

@classmethod
@_deprecate_positional_args
def create(
cls,
store: StoreLike,
Expand Down Expand Up @@ -1016,6 +1019,7 @@ def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None:
else:
self.set_basic_selection(cast(BasicSelection, pure_selection), value, fields=fields)

@_deprecate_positional_args
def get_basic_selection(
self,
selection: BasicSelection = Ellipsis,
Expand Down Expand Up @@ -1139,6 +1143,7 @@ def get_basic_selection(
)
)

@_deprecate_positional_args
def set_basic_selection(
self,
selection: BasicSelection,
Expand Down Expand Up @@ -1234,6 +1239,7 @@ def set_basic_selection(
indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid)
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

@_deprecate_positional_args
def get_orthogonal_selection(
self,
selection: OrthogonalSelection,
Expand Down Expand Up @@ -1358,6 +1364,7 @@ def get_orthogonal_selection(
)
)

@_deprecate_positional_args
def set_orthogonal_selection(
self,
selection: OrthogonalSelection,
Expand Down Expand Up @@ -1468,6 +1475,7 @@ def set_orthogonal_selection(
self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
)

@_deprecate_positional_args
def get_mask_selection(
self,
mask: MaskSelection,
Expand Down Expand Up @@ -1550,6 +1558,7 @@ def get_mask_selection(
)
)

@_deprecate_positional_args
def set_mask_selection(
self,
mask: MaskSelection,
Expand Down Expand Up @@ -1628,6 +1637,7 @@ def set_mask_selection(
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

@_deprecate_positional_args
def get_coordinate_selection(
self,
selection: CoordinateSelection,
Expand Down Expand Up @@ -1717,6 +1727,7 @@ def get_coordinate_selection(
out_array = np.array(out_array).reshape(indexer.sel_shape)
return out_array

@_deprecate_positional_args
def set_coordinate_selection(
self,
selection: CoordinateSelection,
Expand Down Expand Up @@ -1806,6 +1817,7 @@ def set_coordinate_selection(

sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

@_deprecate_positional_args
def get_block_selection(
self,
selection: BasicSelection,
Expand Down Expand Up @@ -1904,6 +1916,7 @@ def get_block_selection(
)
)

@_deprecate_positional_args
def set_block_selection(
self,
selection: BasicSelection,
Expand Down
16 changes: 16 additions & 0 deletions src/zarr/core/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ def __iter__(self) -> Iterator[str]:

def __len__(self) -> int:
return len(self._obj.metadata.attributes)

def put(self, d: dict[str, JSON]) -> None:
"""
Overwrite all attributes with the values from `d`.
Equivalent to the following pseudo-code, but performed atomically.
.. code-block:: python
>>> attrs = {"a": 1, "b": 2}
>>> attrs.clear()
>>> attrs.update({"a": 3", "c": 4})
>>> attrs
{'a': 3, 'c': 4}
"""
self._obj = self._obj.update_attributes(d)
45 changes: 40 additions & 5 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import zarr.api.asynchronous as async_api
from zarr.abc.metadata import Metadata
from zarr.abc.store import set_or_delete
from zarr.abc.store import Store, set_or_delete
from zarr.core.array import Array, AsyncArray
from zarr.core.attributes import Attributes
from zarr.core.buffer import default_buffer_prototype
Expand Down Expand Up @@ -126,7 +126,7 @@ class AsyncGroup:
store_path: StorePath

@classmethod
async def create(
async def from_store(
cls,
store: StoreLike,
*,
Expand Down Expand Up @@ -312,6 +312,21 @@ def attrs(self) -> dict[str, Any]:
def info(self) -> None:
raise NotImplementedError

@property
def store(self) -> Store:
return self.store_path.store

@property
def read_only(self) -> bool:
# Backwards compatibility for 2.x
return self.store_path.store.mode.readonly

@property
def synchronizer(self) -> None:
# Backwards compatibility for 2.x
# Not implemented in 3.x yet.
return None

async def create_group(
self,
name: str,
Expand All @@ -320,7 +335,7 @@ async def create_group(
attributes: dict[str, Any] | None = None,
) -> AsyncGroup:
attributes = attributes or {}
return await type(self).create(
return await type(self).from_store(
self.store_path / name,
attributes=attributes,
exists_ok=exists_ok,
Expand Down Expand Up @@ -752,7 +767,7 @@ class Group(SyncMixin):
_async_group: AsyncGroup

@classmethod
def create(
def from_store(
cls,
store: StoreLike,
*,
Expand All @@ -762,7 +777,7 @@ def create(
) -> Group:
attributes = attributes or {}
obj = sync(
AsyncGroup.create(
AsyncGroup.from_store(
store,
attributes=attributes,
exists_ok=exists_ok,
Expand Down Expand Up @@ -843,6 +858,22 @@ def attrs(self) -> Attributes:
def info(self) -> None:
raise NotImplementedError

@property
def store(self) -> Store:
# Backwards compatibility for 2.x
return self._async_group.store

@property
def read_only(self) -> bool:
# Backwards compatibility for 2.x
return self._async_group.read_only

@property
def synchronizer(self) -> None:
# Backwards compatibility for 2.x
# Not implemented in 3.x yet.
return self._async_group.synchronizer

def update_attributes(self, new_attributes: dict[str, Any]) -> Group:
self._sync(self._async_group.update_attributes(new_attributes))
return self
Expand Down Expand Up @@ -913,6 +944,10 @@ def require_groups(self, *names: str) -> tuple[Group, ...]:
"""Convenience method to require multiple groups in a single call."""
return tuple(map(Group, self._sync(self._async_group.require_groups(*names))))

def create(self, *args: Any, **kwargs: Any) -> Array:
# Backwards compatibility for 2.x
return self.create_array(*args, **kwargs)

def create_array(
self,
name: str,
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def arrays(
expected_attrs = {} if attributes is None else attributes

array_path = path + ("/" if not path.endswith("/") else "") + name
root = Group.create(store)
root = Group.from_store(store)
fill_value_args: tuple[Any, ...] = tuple()
if nparray.dtype.kind == "M":
m = re.search(r"\[(.+)\]", nparray.dtype.str)
Expand Down
2 changes: 1 addition & 1 deletion tests/v3/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def async_group(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> As
param: AsyncGroupRequest = request.param

store = await parse_store(param.store, str(tmpdir))
agroup = await AsyncGroup.create(
agroup = await AsyncGroup.from_store(
store,
attributes=param.attributes,
zarr_format=param.zarr_format,
Expand Down
Loading

0 comments on commit 6900754

Please sign in to comment.