Skip to content

Commit

Permalink
Allow mode casting for Stores (#2249)
Browse files Browse the repository at this point in the history
* Allow mode casting

* fixup

* fixup

* fixup

* fixup

* match message

* Update src/zarr/testing/store.py

Co-authored-by: Davis Bennett <[email protected]>

* fixup

* fixup

* fixup

* fixup

* pre-commit

* log methods

* style: pre-commit fixes

---------

Co-authored-by: Davis Bennett <[email protected]>
Co-authored-by: Joe Hamman <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 27, 2024
1 parent 2edc548 commit 73b884b
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 19 deletions.
25 changes: 25 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,31 @@ async def empty(self) -> bool: ...
@abstractmethod
async def clear(self) -> None: ...

@abstractmethod
def with_mode(self, mode: AccessModeLiteral) -> Self:
"""
Return a new store of the same type pointing to the same location with a new mode.
The returned Store is not automatically opened. Call :meth:`Store.open` before
using.
Parameters
----------
mode: AccessModeLiteral
The new mode to use.
Returns
-------
store:
A new store of the same type with the new mode.
Examples
--------
>>> writer = zarr.store.MemoryStore(mode="w")
>>> reader = writer.with_mode("r")
"""
...

@property
def mode(self) -> AccessMode:
"""Access mode of the store."""
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/store/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ async def make_store_path(
assert AccessMode.from_literal(mode) == store_like.store.mode
result = store_like
elif isinstance(store_like, Store):
if mode is not None:
assert AccessMode.from_literal(mode) == store_like.mode
if mode is not None and mode != store_like.mode.str:
store_like = store_like.with_mode(mode)
await store_like._ensure_open()
result = StorePath(store_like)
elif store_like is None:
Expand Down
5 changes: 4 additions & 1 deletion src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer
Expand Down Expand Up @@ -110,6 +110,9 @@ async def empty(self) -> bool:
else:
return True

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(root=self.root, mode=mode)

def __str__(self) -> str:
return f"file://{self.root}"

Expand Down
21 changes: 20 additions & 1 deletion src/zarr/store/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from zarr.abc.store import AccessMode, ByteRangeRequest, Store
from zarr.core.buffer import Buffer
Expand All @@ -14,6 +14,7 @@
from collections.abc import AsyncGenerator, Generator, Iterable

from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.common import AccessModeLiteral


class LoggingStore(Store):
Expand All @@ -28,6 +29,8 @@ def __init__(
) -> None:
self._store = store
self.counter = defaultdict(int)
self.log_level = log_level
self.log_handler = log_handler

self._configure_logger(log_level, log_handler)

Expand Down Expand Up @@ -96,6 +99,14 @@ def _is_open(self) -> bool: # type: ignore[override]
with self.log():
return self._store._is_open

async def _open(self) -> None:
with self.log():
return await self._store._open()

async def _ensure_open(self) -> None:
with self.log():
return await self._store._ensure_open()

async def empty(self) -> bool:
with self.log():
return await self._store.empty()
Expand Down Expand Up @@ -167,3 +178,11 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
with self.log():
async for key in self._store.list_dir(prefix=prefix):
yield key

def with_mode(self, mode: AccessModeLiteral) -> Self:
with self.log():
return type(self)(
self._store.with_mode(mode),
log_level=self.log_level,
log_handler=self.log_handler,
)
50 changes: 41 additions & 9 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer, gpu
Expand Down Expand Up @@ -41,6 +41,9 @@ async def empty(self) -> bool:
async def clear(self) -> None:
self._store_dict.clear()

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(store_dict=self._store_dict, mode=mode)

def __str__(self) -> str:
return f"memory://{id(self._store_dict)}"

Expand Down Expand Up @@ -156,29 +159,58 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:

class GpuMemoryStore(MemoryStore):
"""A GPU only memory store that stores every chunk in GPU memory irrespective
of the original location. This guarantees that chunks will always be in GPU
memory for downstream processing. For location agnostic use cases, it would
be better to use `MemoryStore` instead.
of the original location.
The dictionary of buffers to initialize this memory store with *must* be
GPU Buffers.
Writing data to this store through ``.set`` will move the buffer to the GPU
if necessary.
Parameters
----------
store_dict: MutableMapping, optional
A mutable mapping with string keys and :class:`zarr.core.buffer.gpu.Buffer`
values.
"""

_store_dict: MutableMapping[str, Buffer]
_store_dict: MutableMapping[str, gpu.Buffer] # type: ignore[assignment]

def __init__(
self,
store_dict: MutableMapping[str, Buffer] | None = None,
store_dict: MutableMapping[str, gpu.Buffer] | None = None,
*,
mode: AccessModeLiteral = "r",
) -> None:
super().__init__(mode=mode)
if store_dict:
self._store_dict = {k: gpu.Buffer.from_buffer(store_dict[k]) for k in iter(store_dict)}
super().__init__(store_dict=store_dict, mode=mode) # type: ignore[arg-type]

def __str__(self) -> str:
return f"gpumemory://{id(self._store_dict)}"

def __repr__(self) -> str:
return f"GpuMemoryStore({str(self)!r})"

@classmethod
def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self:
"""
Create a GpuMemoryStore from a dictionary of buffers at any location.
The dictionary backing the newly created ``GpuMemoryStore`` will not be
the same as ``store_dict``.
Parameters
----------
store_dict: mapping
A mapping of strings keys to arbitrary Buffers. The buffer data
will be moved into a :class:`gpu.Buffer`.
Returns
-------
GpuMemoryStore
"""
gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()}
return cls(gpu_store_dict)

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
assert isinstance(key, str)
Expand Down
10 changes: 9 additions & 1 deletion src/zarr/store/remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Self

import fsspec

Expand Down Expand Up @@ -96,6 +96,14 @@ async def clear(self) -> None:
async def empty(self) -> bool:
return not await self.fs._find(self.path, withdirs=True)

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(
fs=self.fs,
mode=mode,
path=self.path,
allowed_exceptions=self.allowed_exceptions,
)

def __repr__(self) -> str:
return f"<RemoteStore({type(self.fs).__name__}, {self.path})>"

Expand Down
5 changes: 4 additions & 1 deletion src/zarr/store/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, Self

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer, BufferPrototype
Expand Down Expand Up @@ -112,6 +112,9 @@ async def empty(self) -> bool:
with self._lock:
return not self._zf.namelist()

def with_mode(self, mode: ZipStoreAccessModeLiteral) -> Self: # type: ignore[override]
raise NotImplementedError("ZipStore cannot be reopened with a new mode.")

def __str__(self) -> str:
return f"zip://{self.path}"

Expand Down
38 changes: 37 additions & 1 deletion src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
from typing import Any, Generic, TypeVar
from typing import Any, Generic, TypeVar, cast

import pytest

from zarr.abc.store import AccessMode, Store
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.common import AccessModeLiteral
from zarr.core.sync import _collect_aiterator, collect_aiterator
from zarr.store._utils import _normalize_interval_index
from zarr.testing.utils import assert_bytes_equal
Expand Down Expand Up @@ -274,6 +275,41 @@ async def test_list_dir(self, store: S) -> None:
keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
assert sorted(keys_expected) == sorted(keys_observed)

async def test_with_mode(self, store: S) -> None:
data = b"0000"
self.set(store, "key", self.buffer_cls.from_bytes(data))
assert self.get(store, "key").to_bytes() == data

for mode in ["r", "a"]:
mode = cast(AccessModeLiteral, mode)
clone = store.with_mode(mode)
# await store.close()
await clone._ensure_open()
assert clone.mode == AccessMode.from_literal(mode)
assert isinstance(clone, type(store))

# earlier writes are visible
result = await clone.get("key", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

# writes to original after with_mode is visible
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
result = await clone.get("key-2", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

if mode == "a":
# writes to clone is visible in the original
await clone.set("key-3", self.buffer_cls.from_bytes(data))
result = await clone.get("key-3", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

else:
with pytest.raises(ValueError, match="store mode"):
await clone.set("key-3", self.buffer_cls.from_bytes(data))

async def test_set_if_not_exists(self, store: S) -> None:
key = "k"
data_buf = self.buffer_cls.from_bytes(b"0000")
Expand Down
8 changes: 8 additions & 0 deletions tests/v3/test_store/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import zarr
import zarr.store
from zarr.core.buffer import default_buffer_prototype
from zarr.store.logging import LoggingStore

Expand Down Expand Up @@ -48,3 +49,10 @@ async def test_logging_store_counter(store: Store) -> None:
assert wrapped.counter["list"] == 0
assert wrapped.counter["list_dir"] == 0
assert wrapped.counter["list_prefix"] == 0


async def test_with_mode():
wrapped = LoggingStore(store=zarr.store.MemoryStore(mode="w"), log_level="INFO")
new = wrapped.with_mode(mode="r")
assert new.mode.str == "r"
assert new.log_level == "INFO"
25 changes: 22 additions & 3 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,14 @@ def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(params=[None, {}])
def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]:
return {"store_dict": request.param, "mode": "r+"}
@pytest.fixture(params=[None, True])
def store_kwargs(
self, request: pytest.FixtureRequest
) -> dict[str, str | None | dict[str, Buffer]]:
kwargs = {"store_dict": None, "mode": "r+"}
if request.param is True:
kwargs["store_dict"] = {}
return kwargs

@pytest.fixture
def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore:
Expand All @@ -80,3 +85,17 @@ def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None:

def test_list_prefix(self, store: GpuMemoryStore) -> None:
assert True

def test_dict_reference(self, store: GpuMemoryStore) -> None:
store_dict = {}
result = GpuMemoryStore(store_dict=store_dict)
assert result._store_dict is store_dict

def test_from_dict(self):
d = {
"a": gpu.Buffer.from_bytes(b"aaaa"),
"b": cpu.Buffer.from_bytes(b"bbbb"),
}
result = GpuMemoryStore.from_dict(d)
for v in result._store_dict.values():
assert type(v) is gpu.Buffer
4 changes: 4 additions & 0 deletions tests/v3/test_store/test_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ def test_api_integration(self, store: ZipStore) -> None:
del root["bar"]

store.close()

async def test_with_mode(self, store: ZipStore) -> None:
with pytest.raises(NotImplementedError, match="new mode"):
await super().test_with_mode(store)

0 comments on commit 73b884b

Please sign in to comment.