Skip to content

Commit

Permalink
fix: always return input for non-promoted scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 24, 2023
1 parent 409df30 commit aa20214
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/awkward/operations/ak_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,13 @@ def maybe_merge_mappings(primary, secondary):
return {**primary, **secondary}


def _handle_as_scalar(obj, layout, *, scalar_policy):
def _handle_as_scalar(obj, *, scalar_policy):
assert scalar_policy in ("allow", "promote", "error")

if scalar_policy == "allow":
return layout[0]
return obj
elif scalar_policy == "promote":
layout = ak.operations.from_iter([obj], highlevel=False)
return layout
else:
assert scalar_policy == "error"
Expand All @@ -110,8 +111,17 @@ def _handle_as_scalar(obj, layout, *, scalar_policy):


def _handle_array_like(obj, layout, *, scalar_policy):
assert scalar_policy in ("allow", "promote", "error")
if obj.ndim == 0:
return _handle_as_scalar(obj, layout, scalar_policy=scalar_policy)
if scalar_policy == "allow":
return obj
elif scalar_policy == "promote":
return layout
else:
assert scalar_policy == "error"
raise TypeError(
f"Encountered a scalar ({type(obj).__name__}), but scalars conversion/promotion is disabled"
)
else:
return layout

Expand Down Expand Up @@ -204,12 +214,10 @@ def _impl(
f"Encountered a scalar ({type(obj).__name__}), but scalars conversion/promotion is disabled"
)
elif isinstance(obj, (datetime, date, time, Number, bool)):
layout = ak.operations.from_iter([obj], highlevel=False)
return _handle_as_scalar(obj, layout, scalar_policy=scalar_policy)
return _handle_as_scalar(obj, scalar_policy=scalar_policy)
elif obj is None:
if allow_none:
layout = ak.operations.from_iter([obj], highlevel=False)
return _handle_as_scalar(obj, layout, scalar_policy=scalar_policy)
return _handle_as_scalar(obj, scalar_policy=scalar_policy)
else:
raise TypeError("Encountered None value, and `allow_none` is `False`")
# Iterables
Expand Down

0 comments on commit aa20214

Please sign in to comment.