Skip to content

Commit

Permalink
refactor: drop use of behavior in recursively_apply (#2805)
Browse files Browse the repository at this point in the history
* refactor: drop use of `behavior` in `recursively_apply`

* refactor: remove remaining `behavior` arguments

* refactor: imports
  • Loading branch information
agoose77 authored Nov 7, 2023
1 parent aafd2b4 commit a4ebc3b
Show file tree
Hide file tree
Showing 83 changed files with 238 additions and 212 deletions.
1 change: 0 additions & 1 deletion src/awkward/_do.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def recursively_apply(
if isinstance(layout, Content):
return layout._recursively_apply(
action,
behavior,
1,
copy.copy(depth_context),
lateral_context,
Expand Down
24 changes: 14 additions & 10 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import json
import math
from collections.abc import MutableMapping, Sequence
from collections.abc import Mapping, MutableMapping, Sequence

import awkward as ak
from awkward._backends.backend import Backend
Expand All @@ -31,9 +31,11 @@
from awkward._util import UNSET
from awkward.contents.bytemaskedarray import ByteMaskedArray
from awkward.contents.content import (
ApplyActionOptions,
Content,
RemoveStructureOptionsType,
ToArrowOptionsType,
ImplementsApplyAction,
RemoveStructureOptions,
ToArrowOptions,
)
from awkward.forms.bitmaskedform import BitMaskedForm
from awkward.forms.form import Form
Expand Down Expand Up @@ -695,7 +697,7 @@ def _to_arrow(
mask_node: Content | None,
validbytes: Content | None,
length: int,
options: ToArrowOptionsType,
options: ToArrowOptions,
):
return self.to_ByteMaskedArray()._to_arrow(
pyarrow, mask_node, validbytes, length, options
Expand All @@ -705,7 +707,7 @@ def _to_backend_array(self, allow_missing, backend):
return self.to_ByteMaskedArray()._to_backend_array(allow_missing, backend)

def _remove_structure(
self, backend: Backend, options: RemoveStructureOptionsType
self, backend: Backend, options: RemoveStructureOptions
) -> list[Content]:
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
Expand All @@ -717,8 +719,13 @@ def _drop_none(self) -> Content:
return self.to_ByteMaskedArray()._drop_none()

def _recursively_apply(
self, action, behavior, depth, depth_context, lateral_context, options
):
self,
action: ImplementsApplyAction,
depth: int,
depth_context: Mapping[str, Any] | None,
lateral_context: Mapping[str, Any] | None,
options: ApplyActionOptions,
) -> Content | None:
if self._backend.nplike.known_data:
content = self._content[0 : self._length]
else:
Expand All @@ -735,7 +742,6 @@ def continuation():
self._mask,
content._recursively_apply(
action,
behavior,
depth,
copy.copy(depth_context),
lateral_context,
Expand All @@ -752,7 +758,6 @@ def continuation():
def continuation():
content._recursively_apply(
action,
behavior,
depth,
copy.copy(depth_context),
lateral_context,
Expand All @@ -765,7 +770,6 @@ def continuation():
depth_context=depth_context,
lateral_context=lateral_context,
continuation=continuation,
behavior=behavior,
backend=self._backend,
options=options,
)
Expand Down
24 changes: 14 additions & 10 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import json
import math
from collections.abc import MutableMapping, Sequence
from collections.abc import Mapping, MutableMapping, Sequence

import awkward as ak
from awkward._backends.backend import Backend
Expand All @@ -32,9 +32,11 @@
)
from awkward._util import UNSET
from awkward.contents.content import (
ApplyActionOptions,
Content,
RemoveStructureOptionsType,
ToArrowOptionsType,
ImplementsApplyAction,
RemoveStructureOptions,
ToArrowOptions,
)
from awkward.errors import AxisError
from awkward.forms.bytemaskedform import ByteMaskedForm
Expand Down Expand Up @@ -1051,7 +1053,7 @@ def _to_arrow(
mask_node: Content | None,
validbytes: Content | None,
length: int,
options: ToArrowOptionsType,
options: ToArrowOptions,
):
this_validbytes = self.mask_as_bool(valid_when=True)

Expand All @@ -1067,7 +1069,7 @@ def _to_backend_array(self, allow_missing, backend):
return self.to_IndexedOptionArray64()._to_backend_array(allow_missing, backend)

def _remove_structure(
self, backend: Backend, options: RemoveStructureOptionsType
self, backend: Backend, options: RemoveStructureOptions
) -> list[Content]:
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
Expand All @@ -1079,8 +1081,13 @@ def _drop_none(self) -> Content:
return self.project()

def _recursively_apply(
self, action, behavior, depth, depth_context, lateral_context, options
):
self,
action: ImplementsApplyAction,
depth: int,
depth_context: Mapping[str, Any] | None,
lateral_context: Mapping[str, Any] | None,
options: ApplyActionOptions,
) -> Content | None:
if self._backend.nplike.known_data:
content = self._content[0 : self._mask.length]
else:
Expand All @@ -1097,7 +1104,6 @@ def continuation():
self._mask,
content._recursively_apply(
action,
behavior,
depth,
copy.copy(depth_context),
lateral_context,
Expand All @@ -1112,7 +1118,6 @@ def continuation():
def continuation():
content._recursively_apply(
action,
behavior,
depth,
copy.copy(depth_context),
lateral_context,
Expand All @@ -1125,7 +1130,6 @@ def continuation():
depth_context=depth_context,
lateral_context=lateral_context,
continuation=continuation,
behavior=behavior,
backend=self._backend,
options=options,
)
Expand Down
36 changes: 26 additions & 10 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
AxisMaybeNone,
JSONMapping,
Literal,
Protocol,
Self,
SupportsIndex,
TypeAlias,
Expand Down Expand Up @@ -72,7 +73,23 @@
"""


class RecursivelyApplyOptionsType(TypedDict):
class ImplementsApplyAction(Protocol):
def __call__(
self,
layout: Content,
*,
depth: int,
depth_context: Mapping[str, Any] | None,
lateral_context: Mapping[str, Any] | None,
continuation: Callable[[], Content],
behavior: Mapping | None,
backend: Backend,
options: ApplyActionOptions,
) -> Content | None:
...


class ApplyActionOptions(TypedDict):
allow_records: bool
keep_parameters: bool
numpy_to_regular: bool
Expand All @@ -82,7 +99,7 @@ class RecursivelyApplyOptionsType(TypedDict):
function_name: str | None


class RemoveStructureOptionsType(TypedDict):
class RemoveStructureOptions(TypedDict):
flatten_records: bool
function_name: str
drop_nones: bool
Expand All @@ -91,7 +108,7 @@ class RemoveStructureOptionsType(TypedDict):
list_to_regular: bool


class ToArrowOptionsType(TypedDict):
class ToArrowOptions(TypedDict):
list_to32: bool
string_to32: bool
bytestring_to32: bool
Expand Down Expand Up @@ -1090,7 +1107,7 @@ def _to_arrow(
mask_node: Content | None,
validbytes: Content | None,
length: int,
options: ToArrowOptionsType,
options: ToArrowOptions,
):
raise NotImplementedError

Expand All @@ -1113,18 +1130,17 @@ def _drop_none(self) -> Content:
raise NotImplementedError

def _remove_structure(
self, backend: Backend, options: RemoveStructureOptionsType
self, backend: Backend, options: RemoveStructureOptions
) -> list[Content]:
raise NotImplementedError

def _recursively_apply(
self,
action: ActionType,
behavior: dict | None,
action: ImplementsApplyAction,
depth: int,
depth_context: dict[str, Any] | None,
lateral_context: dict[str, Any] | None,
options: RecursivelyApplyOptionsType,
depth_context: Mapping[str, Any] | None,
lateral_context: Mapping[str, Any] | None,
options: ApplyActionOptions,
) -> Content | None:
raise NotImplementedError

Expand Down
22 changes: 14 additions & 8 deletions src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import MutableMapping, Sequence
from collections.abc import Mapping, MutableMapping, Sequence

import awkward as ak
from awkward._backends.backend import Backend
Expand All @@ -27,9 +27,11 @@
)
from awkward._util import UNSET
from awkward.contents.content import (
ApplyActionOptions,
Content,
RemoveStructureOptionsType,
ToArrowOptionsType,
ImplementsApplyAction,
RemoveStructureOptions,
ToArrowOptions,
)
from awkward.errors import AxisError
from awkward.forms.emptyform import EmptyForm
Expand Down Expand Up @@ -364,7 +366,7 @@ def _to_arrow(
mask_node: Content | None,
validbytes: Content | None,
length: int,
options: ToArrowOptionsType,
options: ToArrowOptions,
):
if options["emptyarray_to"] is None:
return pyarrow.Array.from_buffers(
Expand Down Expand Up @@ -395,13 +397,18 @@ def _to_backend_array(self, allow_missing, backend):
return backend.nplike.empty(0, dtype=np.float64)

def _remove_structure(
self, backend: Backend, options: RemoveStructureOptionsType
self, backend: Backend, options: RemoveStructureOptions
) -> list[Content]:
return [self]

def _recursively_apply(
self, action, behavior, depth, depth_context, lateral_context, options
):
self,
action: ImplementsApplyAction,
depth: int,
depth_context: Mapping[str, Any] | None,
lateral_context: Mapping[str, Any] | None,
options: ApplyActionOptions,
) -> Content | None:
if options["return_array"]:

def continuation():
Expand All @@ -421,7 +428,6 @@ def continuation():
depth_context=depth_context,
lateral_context=lateral_context,
continuation=continuation,
behavior=behavior,
backend=self._backend,
options=options,
)
Expand Down
24 changes: 14 additions & 10 deletions src/awkward/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import copy
from collections.abc import MutableMapping, Sequence
from collections.abc import Mapping, MutableMapping, Sequence

import awkward as ak
from awkward._backends.backend import Backend
Expand Down Expand Up @@ -31,9 +31,11 @@
)
from awkward._util import UNSET
from awkward.contents.content import (
ApplyActionOptions,
Content,
RemoveStructureOptionsType,
ToArrowOptionsType,
ImplementsApplyAction,
RemoveStructureOptions,
ToArrowOptions,
)
from awkward.errors import AxisError
from awkward.forms.form import Form
Expand Down Expand Up @@ -1004,7 +1006,7 @@ def _to_arrow(
mask_node: Content | None,
validbytes: Content | None,
length: int,
options: ToArrowOptionsType,
options: ToArrowOptions,
):
if (
not options["categorical_as_dictionary"]
Expand Down Expand Up @@ -1058,13 +1060,18 @@ def _to_backend_array(self, allow_missing, backend):
return self.project()._to_backend_array(allow_missing, backend)

def _remove_structure(
self, backend: Backend, options: RemoveStructureOptionsType
self, backend: Backend, options: RemoveStructureOptions
) -> list[Content]:
return self.project()._remove_structure(backend, options)

def _recursively_apply(
self, action, behavior, depth, depth_context, lateral_context, options
):
self,
action: ImplementsApplyAction,
depth: int,
depth_context: Mapping[str, Any] | None,
lateral_context: Mapping[str, Any] | None,
options: ApplyActionOptions,
) -> Content | None:
if (
self._backend.nplike.known_data
and self._backend.nplike.known_data
Expand Down Expand Up @@ -1092,7 +1099,6 @@ def continuation():
index,
content._recursively_apply(
action,
behavior,
depth,
copy.copy(depth_context),
lateral_context,
Expand All @@ -1106,7 +1112,6 @@ def continuation():
def continuation():
content._recursively_apply(
action,
behavior,
depth,
copy.copy(depth_context),
lateral_context,
Expand All @@ -1119,7 +1124,6 @@ def continuation():
depth_context=depth_context,
lateral_context=lateral_context,
continuation=continuation,
behavior=behavior,
backend=self._backend,
options=options,
)
Expand Down
Loading

0 comments on commit a4ebc3b

Please sign in to comment.