Skip to content

Commit

Permalink
fix: dask failing for TTrees with duplicate TBranch names (#1189)
Browse files Browse the repository at this point in the history
* fix: dask failing for TTrees with duplicate TBranch names

* style: pre-commit fixes

* Update _dask.py

* Test file name update

* preserve order in `common_keys` while dropping duplicates

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jim Pivarski <[email protected]>
  • Loading branch information
3 people authored Mar 28, 2024
1 parent e47badf commit f137305
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/uproot/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def real_filter_branch(branch):
filter_typename=filter_typename,
filter_branch=real_filter_branch,
full_paths=full_paths,
ignore_duplicates=True,
)

if common_keys is None:
Expand Down Expand Up @@ -747,6 +748,7 @@ def _get_dask_array_delay_open(
filter_typename=filter_typename,
filter_branch=filter_branch,
full_paths=full_paths,
ignore_duplicates=True,
)

dask_dict = {}
Expand Down Expand Up @@ -1441,6 +1443,7 @@ def real_filter_branch(branch):
filter_typename=filter_typename,
filter_branch=real_filter_branch,
full_paths=full_paths,
ignore_duplicates=True,
)

if common_keys is None:
Expand Down Expand Up @@ -1586,7 +1589,7 @@ def _get_dak_array_delay_open(
ffile_path, fobject_path = files[0][0:2]

if known_base_form is not None:
common_keys = list(known_base_form.fields)
common_keys = list(dict.fromkeys(known_base_form.fields))
base_form = known_base_form
else:
obj = uproot._util.regularize_object_path(
Expand All @@ -1598,6 +1601,7 @@ def _get_dak_array_delay_open(
filter_typename=filter_typename,
filter_branch=filter_branch,
full_paths=full_paths,
ignore_duplicates=True,
)
base_form = _get_ttree_form(
awkward, obj, common_keys, interp_options.get("ak_add_doc")
Expand Down
24 changes: 22 additions & 2 deletions src/uproot/behaviors/TBranch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,7 @@ def keys(
filter_branch=no_filter,
recursive=True,
full_paths=True,
ignore_duplicates=False,
):
"""
Args:
Expand All @@ -1143,6 +1144,7 @@ def keys(
full_paths (bool): If True, include the full path to each subbranch
with slashes (``/``); otherwise, use the descendant's name as
the output name.
ignore_duplicates (bool): If True, return a set of the keys; otherwise, return the full list of keys.
Returns the names of the subbranches as a list of strings.
"""
Expand All @@ -1153,6 +1155,7 @@ def keys(
filter_branch=filter_branch,
recursive=recursive,
full_paths=full_paths,
ignore_duplicates=ignore_duplicates,
)
)

Expand Down Expand Up @@ -1279,6 +1282,7 @@ def iterkeys(
filter_branch=no_filter,
recursive=True,
full_paths=True,
ignore_duplicates=False,
):
"""
Args:
Expand All @@ -1296,6 +1300,8 @@ def iterkeys(
full_paths (bool): If True, include the full path to each subbranch
with slashes (``/``); otherwise, use the descendant's name as
the output name.
ignore_duplicates (bool): If True, return a set of the keys; otherwise, return the full list of keys.
Returns the names of the subbranches as an iterator over strings.
"""
Expand All @@ -1305,6 +1311,7 @@ def iterkeys(
filter_branch=filter_branch,
recursive=recursive,
full_paths=full_paths,
ignore_duplicates=ignore_duplicates,
):
yield k

Expand Down Expand Up @@ -1353,6 +1360,7 @@ def iteritems(
filter_branch=no_filter,
recursive=True,
full_paths=True,
ignore_duplicates=False,
):
"""
Args:
Expand All @@ -1370,6 +1378,8 @@ def iteritems(
full_paths (bool): If True, include the full path to each subbranch
with slashes (``/``) in the name; otherwise, use the descendant's
name as the name without modification.
ignore_duplicates (bool): If True, return a set of the keys; otherwise, return the full list of keys.
Returns (name, branch) pairs of the subbranches as an iterator over
2-tuples of (str, :doc:`uproot.behaviors.TBranch.TBranch`).
Expand All @@ -1385,6 +1395,8 @@ def iteritems(
f"filter_branch must be None or a function: TBranch -> bool, not {filter_branch!r}"
)

keys_set = set()

for branch in self.branches:
if (
(
Expand All @@ -1394,7 +1406,11 @@ def iteritems(
and (filter_typename is no_filter or filter_typename(branch.typename))
and (filter_branch is no_filter or filter_branch(branch))
):
yield branch.name, branch
if ignore_duplicates and branch.name in keys_set:
pass
else:
keys_set.add(branch.name)
yield branch.name, branch

if recursive:
for k1, v in branch.iteritems(
Expand All @@ -1408,7 +1424,11 @@ def iteritems(
if filter_name is no_filter or _filter_name_deep(
filter_name, self, v
):
yield k2, v
if ignore_duplicates and branch.name in keys_set:
pass
else:
keys_set.add(k2)
yield k2, v

def itertypenames(
self,
Expand Down
21 changes: 21 additions & 0 deletions tests/test_1189_dask_failing_on_duplicate_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE

import pytest
import uproot
import skhep_testdata


def test_dask_duplicated_keys():

lazy = uproot.dask(
skhep_testdata.data_path("uproot-metadata-performance.root") + ":Events"
)
materialized = lazy.FatJet_btagDDBvLV2.compute()

lazy = uproot.dask(skhep_testdata.data_path("uproot-issue513.root") + ":Delphes")
materialized = lazy.Particle.compute()

lazy = uproot.dask(
skhep_testdata.data_path("uproot-issue443.root") + ":muonDataTree"
)
materialized = lazy.hitEnd.compute()

0 comments on commit f137305

Please sign in to comment.