Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tokenization of ArgsKwargsPackedFunction #555

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
24 changes: 22 additions & 2 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,7 +1965,18 @@ def __init__(self, the_fn, arg_repackers, kwarg_repacker, arg_lens_for_repackers
self.kwarg_repacker = kwarg_repacker
self.arg_lens_for_repackers = arg_lens_for_repackers

def __call__(self, *args_deps_expanded):
def __repr__(self):
if hasattr(self.fn, "__qualname__"):
return f"{self.fn.__qualname__}-repacked"
return (
repr(self.fn)
.replace("<", "")
.replace(">", "")
.replace("function ", "")
.replace("built-in ", "")
lgray marked this conversation as resolved.
Show resolved Hide resolved
) + "-repacked"

def _repack(self, *args_deps_expanded):
"""This packing function receives a list of strictly
ordered arguments. The first range of arguments,
[0:sum(self.arg_lens_for_repackers)], corresponding to
Expand All @@ -1989,6 +2000,10 @@ def __call__(self, *args_deps_expanded):
)
len_args += n_args
kwargs = self.kwarg_repacker(args_deps_expanded[len_args:])[0]
return args, kwargs

def __call__(self, *args_deps_expanded):
args, kwargs = self._repack(*args_deps_expanded)
return self.fn(*args, **kwargs)


Expand All @@ -2010,7 +2025,12 @@ def _map_partitions(
will not be traversed to extract all dask collections, except those in
the first dimension of args or kwargs.
"""
token = token or tokenize(fn, *args, output_divisions, **kwargs)
if isinstance(fn, ArgsKwargsPackedFunction):
token_args, token_kwargs = fn._repack(*args)
token = token or tokenize(fn.fn, *token_args, output_divisions, **token_kwargs)
else:
token = token or tokenize(fn, *args, output_divisions, **kwargs)

label = hyphenize(label or funcname(fn))
name = f"{label}-{token}"
deps = [a for a in args if is_dask_collection(a)] + [
Expand Down
11 changes: 11 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,3 +980,14 @@ def test_array__bool_nonzero_long_int_float_complex_index():
match=r"A dask_awkward.Array is encountered in a computation where a concrete value is expected. If you intend to convert the dask_awkward.Array to a concrete value, use the `.compute\(\)` method. The .+ method was called on .+.",
):
fun(dask_arr)


def test_map_partitions_deterministic_token():
dask_arr = dak.from_awkward(ak.Array([1]), npartitions=1)

def f(x):
return x[0] + 1

assert (
map_partitions(f, {0: dask_arr}).name == map_partitions(f, {0: dask_arr}).name
)
Loading