diff --git a/awkward-cpp/src/cpu-kernels/awkward_UnionArray_nestedfill_tags_index.cpp b/awkward-cpp/src/cpu-kernels/awkward_UnionArray_nestedfill_tags_index.cpp index 39152d87c0..77b38eca8b 100644 --- a/awkward-cpp/src/cpu-kernels/awkward_UnionArray_nestedfill_tags_index.cpp +++ b/awkward-cpp/src/cpu-kernels/awkward_UnionArray_nestedfill_tags_index.cpp @@ -55,3 +55,13 @@ ERROR awkward_UnionArray8_64_nestedfill_tags_index_64( return awkward_UnionArray_nestedfill_tags_index( totags, toindex, tmpstarts, tag, fromcounts, length); } +ERROR awkward_UnionArray64_64_nestedfill_tags_index_64( + int64_t* totags, + int64_t* toindex, + int64_t* tmpstarts, + int64_t tag, + const int64_t* fromcounts, + int64_t length) { + return awkward_UnionArray_nestedfill_tags_index( + totags, toindex, tmpstarts, tag, fromcounts, length); +} diff --git a/awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify.cpp b/awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify.cpp index 8dd8f86c75..184e9483d7 100644 --- a/awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify.cpp +++ b/awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify.cpp @@ -258,3 +258,28 @@ ERROR awkward_UnionArray8_64_simplify8_64_to8_64( length, base); } +ERROR awkward_UnionArray64_64_simplify8_64_to8_64( + int8_t* totags, + int64_t* toindex, + const int64_t* outertags, + const int64_t* outerindex, + const int8_t* innertags, + const int64_t* innerindex, + int64_t towhich, + int64_t innerwhich, + int64_t outerwhich, + int64_t length, + int64_t base) { + return awkward_UnionArray_simplify( + totags, + toindex, + outertags, + outerindex, + innertags, + innerindex, + towhich, + innerwhich, + outerwhich, + length, + base); +} diff --git a/awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify_one.cpp b/awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify_one.cpp index 227f9614f5..be29d352d0 100644 --- a/awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify_one.cpp +++ b/awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify_one.cpp @@ -82,3 +82,22 @@ ERROR awkward_UnionArray8_64_simplify_one_to8_64( length, base); } +ERROR awkward_UnionArray64_64_simplify_one_to8_64( + int8_t* totags, + int64_t* toindex, + const int64_t* fromtags, + const int64_t* fromindex, + int64_t towhich, + int64_t fromwhich, + int64_t length, + int64_t base) { + return awkward_UnionArray_simplify_one( + totags, + toindex, + fromtags, + fromindex, + towhich, + fromwhich, + length, + base); +} diff --git a/kernel-specification.yml b/kernel-specification.yml index 9252f1fd1f..2838b8db5c 100644 --- a/kernel-specification.yml +++ b/kernel-specification.yml @@ -3372,6 +3372,14 @@ kernels: - name: awkward_UnionArray_nestedfill_tags_index specializations: + - name: awkward_UnionArray64_64_nestedfill_tags_index_64 + args: + - {name: totags, type: "List[int64_t]", dir: out} + - {name: toindex, type: "List[int64_t]", dir: out} + - {name: tmpstarts, type: "List[int64_t]", dir: out} + - {name: tag, type: "int64_t", dir: in, role: default} + - {name: fromcounts, type: "Const[List[int64_t]]", dir: in, role: default} + - {name: length, type: "int64_t", dir: in, role: default} - name: awkward_UnionArray8_32_nestedfill_tags_index_64 args: - {name: totags, type: "List[int8_t]", dir: out} @@ -3503,6 +3511,19 @@ kernels: - name: awkward_UnionArray_simplify specializations: + - name: awkward_UnionArray64_64_simplify8_64_to8_64 + args: + - {name: totags, type: "List[int8_t]", dir: out} + - {name: toindex, type: "List[int64_t]", dir: out} + - {name: outertags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags} + - {name: outerindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray-index} + - {name: innertags, type: "Const[List[int8_t]]", dir: in, role: UnionArray2-tags} + - {name: innerindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray2-index} + - {name: towhich, type: "int64_t", dir: in, role: default} + - {name: innerwhich, type: "int64_t", dir: in, role: UnionArray1-which} + - {name: outerwhich, type: "int64_t", dir: in, role: UnionArray2-which} + - {name: length, type: "int64_t", dir: in, role: default} + - {name: base, type: "int64_t", dir: in, role: default} - name: awkward_UnionArray8_32_simplify8_32_to8_64 args: - {name: totags, type: "List[int8_t]", dir: out} @@ -3645,6 +3666,16 @@ kernels: - name: awkward_UnionArray_simplify_one specializations: + - name: awkward_UnionArray64_64_simplify_one_to8_64 + args: + - {name: totags, type: "List[int8_t]", dir: out} + - {name: toindex, type: "List[int64_t]", dir: out} + - {name: fromtags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags} + - {name: fromindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray-index} + - {name: towhich, type: "int64_t", dir: in, role: default} + - {name: fromwhich, type: "int64_t", dir: in, role: UnionArray-which} + - {name: length, type: "int64_t", dir: in, role: default} + - {name: base, type: "int64_t", dir: in, role: default} - name: awkward_UnionArray8_32_simplify_one_to8_64 args: - {name: totags, type: "List[int8_t]", dir: out} diff --git a/src/awkward/contents/unionarray.py b/src/awkward/contents/unionarray.py index bb3b6ca929..5bac5fda40 100644 --- a/src/awkward/contents/unionarray.py +++ b/src/awkward/contents/unionarray.py @@ -48,6 +48,7 @@ np = NumpyMetadata.instance() numpy = Numpy.instance() +MAX_UNION_CONTENTS = 2**7 # We use int8 tags, 0-127 @final @@ -230,6 +231,10 @@ def simplified( parameters=None, mergebool=False, ): + # Note: to help merge more than 128 arrays, tags *can* have type ak.index.Index64. + # This is only supported when index is also Index64, + # and all indexed contents are also Index64. + # We still require that this reduces to no more than 128 variants. self_index = index self_tags = tags self_contents = contents @@ -299,6 +304,10 @@ def simplified( # Did we fail to merge any of the final outer contents with this inner union content? if unmerged: + if len(contents) >= MAX_UNION_CONTENTS: + raise ValueError( + "UnionArray does not support more than 128 content types" + ) backend.maybe_kernel_error( backend[ "awkward_UnionArray_simplify", @@ -373,6 +382,10 @@ def simplified( break if unmerged: + if len(contents) >= MAX_UNION_CONTENTS: + raise ValueError( + "UnionArray does not support more than 128 content types" + ) backend.maybe_kernel_error( backend[ "awkward_UnionArray_simplify_one", @@ -393,11 +406,6 @@ def simplified( ) contents.append(self_cont) - if len(contents) > 2**7: - raise NotImplementedError( - "FIXME: handle UnionArray with more than 127 contents" - ) - # If any contents are non-categorical index types, we can merge them into the union # This is safe, because any remaining index types at this point in the routine are not considered # mergeable with the other contents. This means none of the other contents are option or index types, @@ -1107,8 +1115,8 @@ def _reverse_merge(self, other): ) ) - if len(contents) > 2**7: - raise AssertionError("FIXME: handle UnionArray with more than 127 contents") + if len(contents) > MAX_UNION_CONTENTS: + raise ValueError("UnionArray cannot have more than 128 content types") return ak.contents.UnionArray.simplified( tags, index, contents, parameters=self._parameters @@ -1236,8 +1244,8 @@ def _mergemany(self, others: Sequence[Content]) -> Content: nextcontents.append(array) - if len(nextcontents) > 127: - raise ValueError("FIXME: handle UnionArray with more than 127 contents") + if len(nextcontents) > MAX_UNION_CONTENTS: + raise ValueError("UnionArray cannot have more than 128 content types") next = ak.contents.UnionArray.simplified( nexttags, diff --git a/src/awkward/operations/ak_concatenate.py b/src/awkward/operations/ak_concatenate.py index f35f1baf22..fb8fcf94ae 100644 --- a/src/awkward/operations/ak_concatenate.py +++ b/src/awkward/operations/ak_concatenate.py @@ -246,7 +246,12 @@ def action(inputs, depth, backend, **kwargs): prototype[start : start + size] = tag start += size - tags = ak.index.Index8( + if len(regulararrays) < 2**7: + tags_cls = ak.index.Index8 + else: + tags_cls = ak.index.Index64 + + tags = tags_cls( backend.index_nplike.reshape( backend.index_nplike.broadcast_to( prototype, (length, prototype.size) @@ -265,10 +270,9 @@ def action(inputs, depth, backend, **kwargs): return (ak.contents.RegularArray(inner, prototype.size),) elif all( - isinstance(x, ak.contents.Content) - and x.is_list - or (isinstance(x, ak.contents.NumpyArray) and x.data.ndim > 1) - or not isinstance(x, ak.contents.Content) + (isinstance(x, ak.contents.Content) and x.is_list) # Case 1 + or (isinstance(x, ak.contents.NumpyArray) and x.data.ndim > 1) # Case 2 + or not isinstance(x, ak.contents.Content) # Case 3: scalar value for x in inputs ): nextinputs = [] @@ -276,6 +280,8 @@ def action(inputs, depth, backend, **kwargs): if isinstance(x, ak.contents.Content): nextinputs.append(x) else: + # Treat non-content inputs as scalars. + # These become arrays of matching length. nextinputs.append( ak.contents.ListOffsetArray( ak.index.Index64( @@ -302,7 +308,7 @@ def action(inputs, depth, backend, **kwargs): all_flatten = [] for x in nextinputs: - o, f = x._offsets_and_flattened(1, 1) + o, f = x._offsets_and_flattened(axis=1, depth=1) c = o.data[1:] - o.data[:-1] backend.index_nplike.add(counts, c, maybe_out=counts) all_counts.append(c) @@ -316,10 +322,15 @@ def action(inputs, depth, backend, **kwargs): offsets = ak.index.Index64(offsets, nplike=backend.index_nplike) + if len(nextinputs) < 2**7: + tags_cls = ak.index.Index8 + else: + tags_cls = ak.index.Index64 tags, index = ak.contents.UnionArray.nested_tags_index( offsets, [ak.index.Index64(x) for x in all_counts], backend=backend, + tags_cls=tags_cls, ) inner = ak.contents.UnionArray.simplified( diff --git a/tests/test_2881_ak_concatenate_fails_on_too_many_nested_arrays.py b/tests/test_2881_ak_concatenate_fails_on_too_many_nested_arrays.py new file mode 100644 index 0000000000..0dcce21a5b --- /dev/null +++ b/tests/test_2881_ak_concatenate_fails_on_too_many_nested_arrays.py @@ -0,0 +1,34 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import awkward as ak + +# from awkward.operations import to_list + + +def test_concatenate_as_reported(): + a = ak.Array([[1]]) + a_concat_128 = ak.concatenate([a for i in range(128)], axis=1) + assert a_concat_128.to_list() == [[1] * 128] + + a_concat_129 = ak.concatenate([a for i in range(129)], axis=1) + assert a_concat_129.to_list() == [[1] * 129] + + +def test_concatenate_inner_union_simplify_one(): + a = ak.Array([[99]]) + astr = ak.Array(["a b c d".split()]) + aa = [a for i in range(129)] + [astr] + + cu = ak.concatenate(aa, axis=1) + assert cu.to_list() == [[99] * 129 + ["a", "b", "c", "d"]] + + +def test_concatenate_inner_union_simplify(): + a = ak.Array([[99]]) + amulti = ak.Array([[1, 2, "a", "b"]]) + aa = [a for i in range(129)] + [amulti] + + cu = ak.concatenate(aa, axis=1) + assert cu.to_list() == [[99] * 129 + [1, 2, "a", "b"]]