Skip to content

Commit

Permalink
fix: uproot was exposed in one place to dask's _task_spec overhaul (#…
Browse files Browse the repository at this point in the history
…1352)

* fix: uproot was exposed in one plat to dask's _task_spec overhaul

* style: pre-commit fixes

* require fixed dask-awkward

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lgray and pre-commit-ci[bot] authored Dec 16, 2024
1 parent 8a71b73 commit 2ba58f2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ requires-python = ">=3.9"
[project.optional-dependencies]
dev = [
"boost_histogram>=0.13",
"dask-awkward>=2023.12.1",
"dask-awkward>=2024.12.1",
"dask[array,distributed]",
"hist>=1.2",
"pandas",
Expand Down
26 changes: 18 additions & 8 deletions src/uproot/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def _dask_array_from_map(
**kwargs,
):
dask = uproot.extras.dask()
_dask_uses_tasks = hasattr(dask, "_task_spec")

da = uproot.extras.dask_array()
if not callable(func):
raise ValueError("`func` argument must be `callable`")
Expand Down Expand Up @@ -446,14 +448,22 @@ def _dask_array_from_map(
produces_tasks=produces_tasks,
)

dsk = dask.blockwise.Blockwise(
output=name,
output_indices="i",
dsk={name: (io_func, dask.blockwise.blockwise_token(0))},
indices=[(io_arg_map, "i")],
numblocks={},
annotations=None,
)
blockwise_kwargs = {
"output": name,
"output_indices": "i",
"indices": [(io_arg_map, "i")],
"numblocks": {},
"annotations": None,
}

if _dask_uses_tasks:
blockwise_kwargs["task"] = dask._task_spec.Task(
name, io_func, dask._task_spec.TaskRef(dask.blockwise.blockwise_token(0))
)
else:
blockwise_kwargs["dsk"] = {name: (io_func, dask.blockwise.blockwise_token(0))}

dsk = dask.blockwise.Blockwise(**blockwise_kwargs)

hlg = dask.highlevelgraph.HighLevelGraph.from_collections(name, dsk)
return da.core.Array(hlg, name, chunks, dtype=dtype)
Expand Down

0 comments on commit 2ba58f2

Please sign in to comment.