Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding the variables parameter to the View.update function. #14

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion zcollection/collection/__init__.py
Original file line number Diff line number Diff line change
@@ -517,7 +517,7 @@ def update(
the variables are inferred by calling the function on the first
partition. In this case, it is important to ensure that the
function can be called twice on the same partition without
side-effects. Default is None.
side effects. Default is None.
**kwargs: The keyword arguments to pass to the function.

Raises:
13 changes: 7 additions & 6 deletions zcollection/merging/tests/test_merging.py
Original file line number Diff line number Diff line change
@@ -45,12 +45,13 @@ def test_update_fs(
"""Test the _update_fs function."""
generator = data.create_test_dataset(delayed=False)
zds = next(generator)
zds_sc = dask_client.scatter(zds)

partition_folder = local_fs.root.joinpath('variable=1')

zattrs = str(partition_folder.joinpath('.zattrs'))
future = dask_client.submit(_update_fs, str(partition_folder),
dask_client.scatter(zds), local_fs.fs)
future = dask_client.submit(_update_fs, str(partition_folder), zds_sc,
local_fs.fs)
dask_client.gather(future)
assert local_fs.exists(zattrs)

@@ -60,7 +61,7 @@ def test_update_fs(
try:
future = dask_client.submit(_update_fs,
str(partition_folder),
dask_client.scatter(zds),
zds_sc,
local_fs.fs,
synchronizer=ThrowError())
dask_client.gather(future)
@@ -83,13 +84,13 @@ def test_perform(
zds = next(generator)

path = str(local_fs.root.joinpath('variable=1'))
zds_sc = dask_client.scatter(zds)

future = dask_client.submit(_update_fs, path, dask_client.scatter(zds),
local_fs.fs)
future = dask_client.submit(_update_fs, path, zds_sc, local_fs.fs)
dask_client.gather(future)

future = dask_client.submit(perform,
dask_client.scatter(zds),
zds_sc,
path,
'time',
local_fs.fs,
20 changes: 13 additions & 7 deletions zcollection/view/__init__.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,6 @@
from ..collection.callable_objects import MapCallable, PartitionCallable
from ..collection.detail import _try_infer_callable
from ..convenience import collection as convenience
from ..type_hints import ArrayLike
from .detail import (
ViewReference,
ViewUpdateCallable,
@@ -418,6 +417,7 @@ def update(
npartitions: int | None = None,
selected_variables: Iterable[str] | None = None,
trim: bool = True,
variables: Sequence[str] | None = None,
**kwargs,
) -> None:
"""Update a variable stored int the view.
@@ -446,6 +446,11 @@ def update(
trim: Whether to trim ``depth`` items from each partition after
calling ``func``. Set it to ``False`` if your function does
this for you.
variables: The list of variables updated by the function. If None,
the variables are inferred by calling the function on the first
partition. In this case, it is important to ensure that the
function can be called twice on the same partition without
side effects. Default is None.
args: The positional arguments to pass to the function.
kwargs: The keyword arguments to pass to the function.

@@ -485,16 +490,17 @@ def update(
'data is selected with the given filters.')
return

func_result: dict[str, ArrayLike] = _try_infer_callable(
func, datasets_list[0][0], self.view_ref.partition_properties.dim,
*args, **kwargs)
variables = variables or tuple(
_try_infer_callable(func, datasets_list[0][0],
self.view_ref.partition_properties.dim, *args,
**kwargs))
tuple(
map(
lambda varname: _assert_variable_handled(
self.view_ref.metadata, self.metadata, varname),
func_result))
variables))
_LOGGER.info('Updating variable %s',
', '.join(repr(item) for item in func_result))
', '.join(repr(item) for item in variables))

# Function to apply to each partition.
wrap_function: ViewUpdateCallable
@@ -509,7 +515,7 @@ def update(
)
else:
if selected_variables is not None and len(
set(func_result) & set(selected_variables)) == 0:
set(variables) & set(selected_variables)) == 0:
raise ValueError(
'If the depth is greater than 0, the selected variables '
'must contain the variables updated by the function.')
69 changes: 67 additions & 2 deletions zcollection/view/tests/test_view.py
Original file line number Diff line number Diff line change
@@ -8,12 +8,13 @@
"""
from __future__ import annotations

import logging
import pathlib

import numpy
import pytest

from ... import collection, convenience, meta, partitioning, view
from ... import collection, convenience, dataset, meta, partitioning, view
# pylint: disable=unused-import # Need to import for fixtures
from ...tests.cluster import dask_client, dask_cluster
from ...tests.data import (
@@ -23,6 +24,7 @@
)
from ...tests.fixture import dask_arrays, numpy_arrays
from ...tests.fs import local_fs, s3, s3_base, s3_fs
from ...type_hints import ArrayLike
from ...view.detail import _calculate_axis_reference

# pylint: enable=unused-import
@@ -136,7 +138,7 @@ def update(zds, varname):

zds = instance.load(delayed=delayed)
assert zds is not None
numpy.all(zds.variables['var3'].values == 5)
assert numpy.all(zds.variables['var3'].values == 5)

indexers = instance.map(
lambda x: slice(0, x.dimensions['num_lines']) # type: ignore
@@ -161,6 +163,69 @@ def update(zds, varname):
filesystem=tested_fs.fs)


@pytest.mark.parametrize('fs', ['local_fs', 's3_fs'])
def test_view_update(
dask_client, # pylint: disable=redefined-outer-name,unused-argument
fs,
request,
caplog):
"""Test the creation of a view."""
tested_fs = request.getfixturevalue(fs)

create_test_collection(tested_fs, delayed=False)
instance = convenience.create_view(path=str(tested_fs.view),
view_ref=view.ViewReference(
str(tested_fs.collection),
tested_fs.fs),
filesystem=tested_fs.fs)

var_name = 'var3'
log_msg = 'Update called'

var = meta.Variable(name=var_name,
dtype=numpy.float64,
dimensions=('num_lines', 'num_pixels'))

instance.add_variable(var)

def to_zero(zds: dataset.Dataset, varname):
"""Update function used to set a variable to 0."""
logging.info(log_msg)
return {varname: zds.variables['var1'].values * 0}

instance.update(to_zero, var_name) # type: ignore

data = instance.load(delayed=False)
assert numpy.all(data.variables[var_name].values == 0)

def plus_one_with_log(zds: dataset.Dataset, varname):
"""Update function increasing a variable by 1."""
logging.info(log_msg)
return {varname: zds.variables[var_name].values + 1}

caplog.set_level(logging.INFO)
caplog.clear()

instance.update(plus_one_with_log, var_name) # type: ignore

# One log per partition + 1 log for the initial call
assert caplog.text.count(log_msg) == len(list(instance.partitions())) + 1

data = instance.load(delayed=False)
assert numpy.all(data.variables[var_name].values == 1)

caplog.clear()
instance.update(
plus_one_with_log, # type: ignore
var_name,
variables=[var_name])

assert caplog.text.count(log_msg) == len(list(instance.partitions()))

data = instance.load(delayed=False)
assert numpy.all(data.variables[var_name].values == 2)


@pytest.mark.parametrize('arg', ['local_fs', 's3_fs'])
def test_view_overlap(
dask_client, # pylint: disable=redefined-outer-name,unused-argument