Skip to content

Commit

Permalink
correctly deal with non-collection kwargs for already flat deps
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Dec 17, 2024
1 parent 8a77db7 commit be4e7ba
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
17 changes: 13 additions & 4 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,7 @@ def map_partitions(
message += f"- {type(arg)}"
raise TypeError(message)

if len(kwargs) == 0:
if len(kwarg_flat_deps) == 0:
non_traversed_deps, _ = unpack_collections(*args, traverse=False)
if len(flat_deps) == len(non_traversed_deps) and all(
id(traversed_dep) == id(non_traversed_dep)
Expand All @@ -2184,6 +2184,7 @@ def map_partitions(
token=token,
meta=meta,
output_divisions=output_divisions,
**kwargs,
)

arg_flat_deps_expanded = []
Expand Down Expand Up @@ -2556,11 +2557,19 @@ def to_length_zero_arrays(objects: Sequence[Any]) -> tuple[Any, ...]:
return tuple(map(length_zero_array_or_identity, objects))


def map_meta(fn: Callable | ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None:
# NOTE: fn is assumed to be a *packed* function
def map_meta(
fn: Callable | ArgsKwargsPackedFunction, *deps: Any, **kwargs: Any
) -> ak.Array | None:
# NOTE: fn to be a *packed* function (so flat deps or ArgsKwargsPackedFunction)
# if ArgsKwargsPackedFunction we do not allow kwargs
# as defined up in map_partitions. be careful!
if isinstance(fn, ArgsKwargsPackedFunction) and len(kwargs) > 0:
raise ValueError("ArgsKwargsPackedFunctions may not have additional kwargs!")
try:
meta = fn(*to_meta(deps))
if isinstance(fn, ArgsKwargsPackedFunction):
meta = fn(*to_meta(deps))
else:
meta = fn(*to_meta(deps), **kwargs)
return meta
except Exception as err:
# if compute-unknown-meta is False then we don't care about
Expand Down
13 changes: 12 additions & 1 deletion src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,20 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
arg.key if isinstance(arg, GraphNode) else arg
for arg in layer.task.args
]
kwargs = {
k: v.key if isinstance(v, GraphNode) else v
for k, v in layer.task.kwargs.items()
}
# how to do this with `.substitute(...)`?
args2 = _recursive_replace(args, layer, parent, indices)
all_tasks.append(Task(chain_member, func, *args2))
kwargs2 = {
k: v
for k, v in zip(
kwargs.keys(),
_recursive_replace(kwargs.values(), layer, parent, indices),
)
}
all_tasks.append(Task(chain_member, func, *args2, **kwargs2))
else:
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,11 @@ def test_map_partitions_args_and_kwargs_have_collection():
zc = my_power(xc, kwarg_y=yc)
zl = dak.map_partitions(my_power, xl, kwarg_y=yl)

# kwargs that contain collections should be wrapped
assert isinstance(
zl.dask.layers[zl.name].task.func, dak.lib.core.ArgsKwargsPackedFunction
)

assert_eq(zc, zl)

zd = structured_function(inputs={"x": xc, "y": xc, "z": yc})
Expand All @@ -830,6 +835,9 @@ def test_map_partitions_args_and_kwargs_have_collection():
zg = my_power(xc, kwarg_y=2.0)
zp = dak.map_partitions(my_power, xl, kwarg_y=2.0)

# this invocation of my_power shouldn't be wrapped, no collections
assert zp.dask.layers[zp.name].task.func is my_power

assert_eq(zg, zp)

a = ak.Array(
Expand Down Expand Up @@ -860,6 +868,7 @@ def test_map_partitions_args_and_kwargs_have_collection():
ccc=cc,
ddd=dd,
)

assert_eq(res1, res2)


Expand Down

0 comments on commit be4e7ba

Please sign in to comment.