From f13730589142bacf528f2f9aefd1d3548540fdca Mon Sep 17 00:00:00 2001 From: ioanaif Date: Thu, 28 Mar 2024 15:48:37 +0200 Subject: [PATCH] fix: dask failing for TTrees with duplicate TBranch names (#1189) * fix: dask failing for TTrees with duplicate TBranch names * style: pre-commit fixes * Update _dask.py * Test file name update * preserve order in `common_keys` while dropping duplicates --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jim Pivarski --- src/uproot/_dask.py | 6 ++++- src/uproot/behaviors/TBranch.py | 24 +++++++++++++++++-- ...est_1189_dask_failing_on_duplicate_keys.py | 21 ++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 tests/test_1189_dask_failing_on_duplicate_keys.py diff --git a/src/uproot/_dask.py b/src/uproot/_dask.py index a08df8a95..c1c0100be 100644 --- a/src/uproot/_dask.py +++ b/src/uproot/_dask.py @@ -606,6 +606,7 @@ def real_filter_branch(branch): filter_typename=filter_typename, filter_branch=real_filter_branch, full_paths=full_paths, + ignore_duplicates=True, ) if common_keys is None: @@ -747,6 +748,7 @@ def _get_dask_array_delay_open( filter_typename=filter_typename, filter_branch=filter_branch, full_paths=full_paths, + ignore_duplicates=True, ) dask_dict = {} @@ -1441,6 +1443,7 @@ def real_filter_branch(branch): filter_typename=filter_typename, filter_branch=real_filter_branch, full_paths=full_paths, + ignore_duplicates=True, ) if common_keys is None: @@ -1586,7 +1589,7 @@ def _get_dak_array_delay_open( ffile_path, fobject_path = files[0][0:2] if known_base_form is not None: - common_keys = list(known_base_form.fields) + common_keys = list(dict.fromkeys(known_base_form.fields)) base_form = known_base_form else: obj = uproot._util.regularize_object_path( @@ -1598,6 +1601,7 @@ def _get_dak_array_delay_open( filter_typename=filter_typename, filter_branch=filter_branch, full_paths=full_paths, + ignore_duplicates=True, ) base_form = _get_ttree_form( awkward, obj, common_keys, interp_options.get("ak_add_doc") diff --git a/src/uproot/behaviors/TBranch.py b/src/uproot/behaviors/TBranch.py index 2e5d0d38f..83b3c8fc4 100644 --- a/src/uproot/behaviors/TBranch.py +++ b/src/uproot/behaviors/TBranch.py @@ -1126,6 +1126,7 @@ def keys( filter_branch=no_filter, recursive=True, full_paths=True, + ignore_duplicates=False, ): """ Args: @@ -1143,6 +1144,7 @@ def keys( full_paths (bool): If True, include the full path to each subbranch with slashes (``/``); otherwise, use the descendant's name as the output name. + ignore_duplicates (bool): If True, return a set of the keys; otherwise, return the full list of keys. Returns the names of the subbranches as a list of strings. """ @@ -1153,6 +1155,7 @@ def keys( filter_branch=filter_branch, recursive=recursive, full_paths=full_paths, + ignore_duplicates=ignore_duplicates, ) ) @@ -1279,6 +1282,7 @@ def iterkeys( filter_branch=no_filter, recursive=True, full_paths=True, + ignore_duplicates=False, ): """ Args: @@ -1296,6 +1300,8 @@ def iterkeys( full_paths (bool): If True, include the full path to each subbranch with slashes (``/``); otherwise, use the descendant's name as the output name. + ignore_duplicates (bool): If True, return a set of the keys; otherwise, return the full list of keys. + Returns the names of the subbranches as an iterator over strings. """ @@ -1305,6 +1311,7 @@ def iterkeys( filter_branch=filter_branch, recursive=recursive, full_paths=full_paths, + ignore_duplicates=ignore_duplicates, ): yield k @@ -1353,6 +1360,7 @@ def iteritems( filter_branch=no_filter, recursive=True, full_paths=True, + ignore_duplicates=False, ): """ Args: @@ -1370,6 +1378,8 @@ def iteritems( full_paths (bool): If True, include the full path to each subbranch with slashes (``/``) in the name; otherwise, use the descendant's name as the name without modification. + ignore_duplicates (bool): If True, return a set of the keys; otherwise, return the full list of keys. + Returns (name, branch) pairs of the subbranches as an iterator over 2-tuples of (str, :doc:`uproot.behaviors.TBranch.TBranch`). @@ -1385,6 +1395,8 @@ def iteritems( f"filter_branch must be None or a function: TBranch -> bool, not {filter_branch!r}" ) + keys_set = set() + for branch in self.branches: if ( ( @@ -1394,7 +1406,11 @@ def iteritems( and (filter_typename is no_filter or filter_typename(branch.typename)) and (filter_branch is no_filter or filter_branch(branch)) ): - yield branch.name, branch + if ignore_duplicates and branch.name in keys_set: + pass + else: + keys_set.add(branch.name) + yield branch.name, branch if recursive: for k1, v in branch.iteritems( @@ -1408,7 +1424,11 @@ def iteritems( if filter_name is no_filter or _filter_name_deep( filter_name, self, v ): - yield k2, v + if ignore_duplicates and branch.name in keys_set: + pass + else: + keys_set.add(k2) + yield k2, v def itertypenames( self, diff --git a/tests/test_1189_dask_failing_on_duplicate_keys.py b/tests/test_1189_dask_failing_on_duplicate_keys.py new file mode 100644 index 000000000..81c90ad74 --- /dev/null +++ b/tests/test_1189_dask_failing_on_duplicate_keys.py @@ -0,0 +1,21 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE + +import pytest +import uproot +import skhep_testdata + + +def test_dask_duplicated_keys(): + + lazy = uproot.dask( + skhep_testdata.data_path("uproot-metadata-performance.root") + ":Events" + ) + materialized = lazy.FatJet_btagDDBvLV2.compute() + + lazy = uproot.dask(skhep_testdata.data_path("uproot-issue513.root") + ":Delphes") + materialized = lazy.Particle.compute() + + lazy = uproot.dask( + skhep_testdata.data_path("uproot-issue443.root") + ":muonDataTree" + ) + materialized = lazy.hitEnd.compute()