diff --git a/src/awkward/operations/str/__init__.py b/src/awkward/operations/str/__init__.py index 4204afb854..b07d252347 100644 --- a/src/awkward/operations/str/__init__.py +++ b/src/awkward/operations/str/__init__.py @@ -84,7 +84,7 @@ from awkward.operations.str.akstr_upper import * -def _drop_option_preserving_form(layout): +def _drop_option_preserving_form(layout, ensure_empty_mask: bool = False): from awkward._do import recursively_apply from awkward.contents import UnmaskedArray, IndexedOptionArray, IndexedArray @@ -98,7 +98,8 @@ def action(_, continuation, **kwargs): else: index_nplike = this.backend.index_nplike assert not ( - index_nplike.known_data + ensure_empty_mask + and index_nplike.known_data and index_nplike.any(this.mask_as_bool(valid_when=False)) ), "did not expect option type, but arrow returned a non-erasable option" # Re-write indexed options as indexed @@ -120,6 +121,7 @@ def _apply_through_arrow( expect_option_type=False, string_to32=False, bytestring_to32=False, + ensure_empty_mask=False, **kwargs, ): from awkward._backends.dispatch import backend_of @@ -170,7 +172,7 @@ def _apply_through_arrow( if expect_option_type: return out else: - return _drop_option_preserving_form(out) + return _drop_option_preserving_form(out, ensure_empty_mask=ensure_empty_mask) def _get_ufunc_action( diff --git a/src/awkward/operations/str/akstr_index_in.py b/src/awkward/operations/str/akstr_index_in.py index 61805e6ea0..cf514fb727 100644 --- a/src/awkward/operations/str/akstr_index_in.py +++ b/src/awkward/operations/str/akstr_index_in.py @@ -57,6 +57,7 @@ def _is_maybe_optional_list_of_string(layout): def _impl(array, value_set, skip_nones, highlevel, behavior): from awkward._connect.pyarrow import import_pyarrow_compute + from awkward.operations.str import _apply_through_arrow pc = import_pyarrow_compute("ak.str.index_in") @@ -71,32 +72,14 @@ def _impl(array, value_set, skip_nones, highlevel, behavior): def apply(layout, **kwargs): if _is_maybe_optional_list_of_string(layout): - if layout.backend is typetracer: - return ak.from_arrow( - pc.index_in( - ak.to_arrow( - layout.form.length_zero_array(highlevel=False), - extensionarray=False, - ), - ak.to_arrow( - value_set_layout.form.length_zero_array(highlevel=False), - extensionarray=False, - ), - skip_nulls=skip_nones, - ), - highlevel=False, - generate_bitmasks=True, - ).to_typetracer(forget_length=True) - else: - return ak.from_arrow( - pc.index_in( - ak.to_arrow(layout, extensionarray=False), - ak.to_arrow(value_set_layout, extensionarray=False), - skip_nulls=skip_nones, - ), - highlevel=False, - generate_bitmasks=True, - ) + return _apply_through_arrow( + pc.index_in, + layout, + value_set_layout, + skip_nulls=skip_nones, + expect_option_type=True, + generate_bitmasks=True, + ) out = ak._do.recursively_apply(layout, apply, behavior=behavior) diff --git a/src/awkward/operations/str/akstr_is_in.py b/src/awkward/operations/str/akstr_is_in.py index 5d8f50a4cf..315fb03124 100644 --- a/src/awkward/operations/str/akstr_is_in.py +++ b/src/awkward/operations/str/akstr_is_in.py @@ -56,6 +56,7 @@ def _is_maybe_optional_list_of_string(layout): def _impl(array, value_set, skip_nones, highlevel, behavior): from awkward._connect.pyarrow import import_pyarrow_compute + from awkward.operations.str import _apply_through_arrow pc = import_pyarrow_compute("ak.str.is_in") @@ -70,30 +71,9 @@ def _impl(array, value_set, skip_nones, highlevel, behavior): def apply(layout, **kwargs): if _is_maybe_optional_list_of_string(layout): - if layout.backend is typetracer: - return ak.from_arrow( - pc.is_in( - ak.to_arrow( - layout.form.length_zero_array(highlevel=False), - extensionarray=False, - ), - ak.to_arrow( - value_set_layout.form.length_zero_array(highlevel=False), - extensionarray=False, - ), - skip_nulls=skip_nones, - ), - highlevel=False, - ).to_typetracer(forget_length=True) - else: - return ak.from_arrow( - pc.is_in( - ak.to_arrow(layout, extensionarray=False), - ak.to_arrow(value_set_layout, extensionarray=False), - skip_nulls=skip_nones, - ), - highlevel=False, - ) + return _apply_through_arrow( + pc.is_in, layout, value_set_layout, skip_nulls=skip_nones + ) out = ak._do.recursively_apply(layout, apply, behavior=behavior)