Skip to content

Commit

Permalink
fix: don't preserve unexpected option for is_in (#2792)
Browse files Browse the repository at this point in the history
* fix: don't preserve unexpected option for is_in

PyArrow 14 changed this.

* feat: don't check masks by default
  • Loading branch information
agoose77 authored Nov 1, 2023
1 parent 2fbaa2c commit 9dc10a0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 53 deletions.
8 changes: 5 additions & 3 deletions src/awkward/operations/str/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from awkward.operations.str.akstr_upper import *


def _drop_option_preserving_form(layout):
def _drop_option_preserving_form(layout, ensure_empty_mask: bool = False):
from awkward._do import recursively_apply
from awkward.contents import UnmaskedArray, IndexedOptionArray, IndexedArray

Expand All @@ -98,7 +98,8 @@ def action(_, continuation, **kwargs):
else:
index_nplike = this.backend.index_nplike
assert not (
index_nplike.known_data
ensure_empty_mask
and index_nplike.known_data
and index_nplike.any(this.mask_as_bool(valid_when=False))
), "did not expect option type, but arrow returned a non-erasable option"
# Re-write indexed options as indexed
Expand All @@ -120,6 +121,7 @@ def _apply_through_arrow(
expect_option_type=False,
string_to32=False,
bytestring_to32=False,
ensure_empty_mask=False,
**kwargs,
):
from awkward._backends.dispatch import backend_of
Expand Down Expand Up @@ -170,7 +172,7 @@ def _apply_through_arrow(
if expect_option_type:
return out
else:
return _drop_option_preserving_form(out)
return _drop_option_preserving_form(out, ensure_empty_mask=ensure_empty_mask)


def _get_ufunc_action(
Expand Down
35 changes: 9 additions & 26 deletions src/awkward/operations/str/akstr_index_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _is_maybe_optional_list_of_string(layout):

def _impl(array, value_set, skip_nones, highlevel, behavior):
from awkward._connect.pyarrow import import_pyarrow_compute
from awkward.operations.str import _apply_through_arrow

pc = import_pyarrow_compute("ak.str.index_in")

Expand All @@ -71,32 +72,14 @@ def _impl(array, value_set, skip_nones, highlevel, behavior):

def apply(layout, **kwargs):
if _is_maybe_optional_list_of_string(layout):
if layout.backend is typetracer:
return ak.from_arrow(
pc.index_in(
ak.to_arrow(
layout.form.length_zero_array(highlevel=False),
extensionarray=False,
),
ak.to_arrow(
value_set_layout.form.length_zero_array(highlevel=False),
extensionarray=False,
),
skip_nulls=skip_nones,
),
highlevel=False,
generate_bitmasks=True,
).to_typetracer(forget_length=True)
else:
return ak.from_arrow(
pc.index_in(
ak.to_arrow(layout, extensionarray=False),
ak.to_arrow(value_set_layout, extensionarray=False),
skip_nulls=skip_nones,
),
highlevel=False,
generate_bitmasks=True,
)
return _apply_through_arrow(
pc.index_in,
layout,
value_set_layout,
skip_nulls=skip_nones,
expect_option_type=True,
generate_bitmasks=True,
)

out = ak._do.recursively_apply(layout, apply, behavior=behavior)

Expand Down
28 changes: 4 additions & 24 deletions src/awkward/operations/str/akstr_is_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _is_maybe_optional_list_of_string(layout):

def _impl(array, value_set, skip_nones, highlevel, behavior):
from awkward._connect.pyarrow import import_pyarrow_compute
from awkward.operations.str import _apply_through_arrow

pc = import_pyarrow_compute("ak.str.is_in")

Expand All @@ -70,30 +71,9 @@ def _impl(array, value_set, skip_nones, highlevel, behavior):

def apply(layout, **kwargs):
if _is_maybe_optional_list_of_string(layout):
if layout.backend is typetracer:
return ak.from_arrow(
pc.is_in(
ak.to_arrow(
layout.form.length_zero_array(highlevel=False),
extensionarray=False,
),
ak.to_arrow(
value_set_layout.form.length_zero_array(highlevel=False),
extensionarray=False,
),
skip_nulls=skip_nones,
),
highlevel=False,
).to_typetracer(forget_length=True)
else:
return ak.from_arrow(
pc.is_in(
ak.to_arrow(layout, extensionarray=False),
ak.to_arrow(value_set_layout, extensionarray=False),
skip_nulls=skip_nones,
),
highlevel=False,
)
return _apply_through_arrow(
pc.is_in, layout, value_set_layout, skip_nulls=skip_nones
)

out = ak._do.recursively_apply(layout, apply, behavior=behavior)

Expand Down

0 comments on commit 9dc10a0

Please sign in to comment.