Skip to content

Commit

Permalink
fix: update __class__ for both layout and behavior consistently (
Browse files Browse the repository at this point in the history
…#2759)

* fix: set `__class__` if `layout` OR `behavior` change

* fix: change `__class__` under various circumstances

* fix: set private attributes first

* refactor: just use `._update_class`

* test: add trivial test

* fix: correct pickling
  • Loading branch information
agoose77 authored Oct 16, 2023
1 parent c10b381 commit 43adb65
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 16 deletions.
55 changes: 39 additions & 16 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,18 @@ def __init__(
if backend is not None and backend != ak.operations.backend(layout):
layout = ak.operations.to_backend(layout, backend, highlevel=False)

self.layout = layout
self.behavior = behavior
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")

self._layout = layout
self._behavior = behavior

docstr = layout.purelist_parameter("__doc__")
if isinstance(docstr, str):
self.__doc__ = docstr

self._update_class()

if check_valid:
ak.operations.validity_error(self, exception=True)

Expand All @@ -330,6 +335,10 @@ def __init_subclass__(cls, **kwargs):

_histogram_module_ = awkward._connect.hist

def _update_class(self):
self._numbaview = None
self.__class__ = get_array_class(self._layout, self._behavior)

@property
def layout(self):
"""
Expand Down Expand Up @@ -377,7 +386,7 @@ def layout(self):
def layout(self, layout):
if isinstance(layout, ak.contents.Content):
self._layout = layout
self._numbaview = None
self._update_class()
else:
raise TypeError("layout must be a subclass of ak.contents.Content")

Expand All @@ -403,8 +412,8 @@ def behavior(self):
@behavior.setter
def behavior(self, behavior):
if behavior is None or isinstance(behavior, Mapping):
self.__class__ = get_array_class(self._layout, behavior)
self._behavior = behavior
self._update_class()
else:
raise TypeError("behavior must be None or a dict")

Expand Down Expand Up @@ -1516,8 +1525,9 @@ def __setstate__(self, state):
buffer_key="{form_key}-{attribute}",
byteorder="<",
)
self.layout = layout
self.behavior = behavior
self._layout = layout
self._behavior = behavior
self._update_class()

def __copy__(self):
return Array(self._layout, behavior=self._behavior)
Expand Down Expand Up @@ -1556,9 +1566,9 @@ def cpp_type(self):

if self._cpp_type is None:
self._generator = ak._connect.cling.togenerator(
self.layout.form, flatlist_as_rvec=False
self._layout.form, flatlist_as_rvec=False
)
self._lookup = ak._lookup.Lookup(self.layout)
self._lookup = ak._lookup.Lookup(self._layout)
self._generator.generate(cppyy.cppdef)
self._cpp_type = f"awkward::{self._generator.class_type()}"

Expand Down Expand Up @@ -1659,13 +1669,18 @@ def __init__(
if library is not None and library != ak.operations.library(layout):
layout = ak.operations.to_library(layout, library, highlevel=False)

self.layout = layout
self.behavior = behavior
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")

self._layout = layout
self._behavior = behavior

docstr = layout.purelist_parameter("__doc__")
if isinstance(docstr, str):
self.__doc__ = docstr

self._update_class()

if check_valid:
ak.operations.validity_error(self, exception=True)

Expand All @@ -1674,6 +1689,10 @@ def __init_subclass__(cls, **kwargs):

ak.jax.register_behavior_class(cls)

def _update_class(self):
self._numbaview = None
self.__class__ = get_record_class(self._layout, self._behavior)

@property
def layout(self):
"""
Expand Down Expand Up @@ -1715,7 +1734,7 @@ def layout(self):
def layout(self, layout):
if isinstance(layout, ak.record.Record):
self._layout = layout
self._numbaview = None
self._update_class()
else:
raise TypeError("layout must be a subclass of ak.record.Record")

Expand All @@ -1741,8 +1760,8 @@ def behavior(self):
@behavior.setter
def behavior(self, behavior):
if behavior is None or isinstance(behavior, Mapping):
self.__class__ = get_record_class(self._layout, behavior)
self._behavior = behavior
self._update_class()
else:
raise TypeError("behavior must be None or a dict")

Expand Down Expand Up @@ -2177,8 +2196,9 @@ def __setstate__(self, state):
byteorder="<",
)
layout = ak.record.Record(layout, at)
self.layout = layout
self.behavior = behavior
self._layout = layout
self._behavior = behavior
self._update_class()

def __copy__(self):
return Record(self._layout, behavior=self._behavior)
Expand Down Expand Up @@ -2329,8 +2349,11 @@ class ArrayBuilder(Sized):
"""

def __init__(self, *, behavior=None, initial=1024, resize=8):
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")

self._layout = _ext.ArrayBuilder(initial=initial, resize=resize)
self.behavior = behavior
self._behavior = behavior

@classmethod
def _wrap(cls, layout, behavior=None):
Expand All @@ -2350,7 +2373,7 @@ def _wrap(cls, layout, behavior=None):
assert isinstance(layout, _ext.ArrayBuilder)
out = cls.__new__(cls)
out._layout = layout
out.behavior = behavior
out._behavior = behavior
return out

@property
Expand Down
53 changes: 53 additions & 0 deletions tests/test_2759_update_class_consistently.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE


import awkward as ak


class ArrayBehavior(ak.Array):
def impl(self):
return True


class RecordBehavior(ak.Record):
def impl(self):
return True


BEHAVIOR = {("*", "impl"): ArrayBehavior, "impl": RecordBehavior}


def test_array_layout():
array = ak.Array([{"x": 1}, {"y": 3}], behavior=BEHAVIOR)
assert not isinstance(array, ArrayBehavior)

array.layout = ak.with_name([{"x": 1}, {"y": 3}], "impl", highlevel=False)
assert isinstance(array, ArrayBehavior)
assert array.impl()


def test_array_behavior():
array = ak.Array([{"x": 1}, {"y": 3}], with_name="impl")
assert not isinstance(array, ArrayBehavior)

array.behavior = BEHAVIOR
assert isinstance(array, ArrayBehavior)
assert array.impl()


def test_record_layout():
record = ak.Record({"x": 1}, behavior=BEHAVIOR)
assert not isinstance(record, RecordBehavior)

record.layout = ak.with_name({"x": 1}, "impl", highlevel=False)
assert isinstance(record, RecordBehavior)
assert record.impl()


def test_record_behavior():
record = ak.Record({"x": 1}, with_name="impl")
assert not isinstance(record, RecordBehavior)

record.behavior = BEHAVIOR
assert isinstance(record, RecordBehavior)
assert record.impl()

0 comments on commit 43adb65

Please sign in to comment.