From dc4a7d6718a0d7e84d60a6d6b91ec10e810ab4f0 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Mon, 16 Dec 2024 12:47:51 -0500 Subject: [PATCH 1/5] control attrs better as described in issue #3277 --- src/awkward/_attrs.py | 30 +++++++++++ src/awkward/highlevel.py | 28 +++++----- tests/test_2757_attrs_metadata.py | 52 +++++++++---------- ...ize_and_deserialize_behaviour_for_numba.py | 2 +- tests/test_2806_attrs_typetracer.py | 4 +- tests/test_2837_ufunc_attrs_behavior.py | 6 +-- tests/test_2866_getitem_attrs.py | 18 +++---- ...est_3277_attrs_behavior_on_array_copies.py | 17 ++++++ 8 files changed, 102 insertions(+), 55 deletions(-) create mode 100644 tests/test_3277_attrs_behavior_on_array_copies.py diff --git a/src/awkward/_attrs.py b/src/awkward/_attrs.py index 14a42549d2..ac4f72b7f8 100644 --- a/src/awkward/_attrs.py +++ b/src/awkward/_attrs.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping +from types import MappingProxyType from awkward._typing import Any, JSONMapping @@ -42,3 +43,32 @@ def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping: def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping: return {k: v for k, v in attrs.items() if not k.startswith("@")} + + +class Attrs(Mapping): + def __init__(self, ref, data: Mapping[str, Any]): + self._ref = ref + self._data = _freeze_attrs(data) + + def __getitem__(self, key: str): + return self._data[key] + + def __setitem__(self, key: str, value: Any): + self._ref._attrs = _unfreeze_attrs(self._data) | {key: value} + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __repr__(self): + return repr(_unfreeze_attrs(self._data)) + + +def _freeze_attrs(attrs: Mapping[str, Any]) -> Mapping[str, Any]: + return MappingProxyType(attrs) + + +def _unfreeze_attrs(attrs: Mapping[str, Any]) -> dict[str, Any]: + return dict(attrs) diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index 6d1d6649aa..6273bb0525 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -18,7 +18,7 @@ import awkward as ak import awkward._connect.hist -from awkward._attrs import attrs_of, without_transient_attrs +from awkward._attrs import Attrs, attrs_of, without_transient_attrs from awkward._backends.dispatch import register_backend_lookup_factory from awkward._backends.numpy import NumpyBackend from awkward._behavior import behavior_of, get_array_class, get_record_class @@ -42,7 +42,7 @@ unpickle_record_schema_1, ) from awkward._regularize import is_non_string_like_iterable -from awkward._typing import Any, MutableMapping, TypeVar +from awkward._typing import Any, TypeVar from awkward._util import STDOUT from awkward.prettyprint import Formatter from awkward.prettyprint import valuestr as prettyprint_valuestr @@ -337,7 +337,7 @@ def __init__( if behavior is not None and not isinstance(behavior, Mapping): raise TypeError("behavior must be None or a mapping") - if attrs is not None and not isinstance(attrs, MutableMapping): + if attrs is not None and not isinstance(attrs, Mapping): raise TypeError("attrs must be None or a mapping") if named_axis: @@ -379,7 +379,7 @@ def _update_class(self): self.__class__ = get_array_class(self._layout, self._behavior) @property - def attrs(self) -> Mapping: + def attrs(self) -> Attrs: """ The mutable mapping containing top-level metadata, which is serialised with the array during pickling. @@ -390,14 +390,14 @@ def attrs(self) -> Mapping: """ if self._attrs is None: self._attrs = {} - return self._attrs + return Attrs(self, self._attrs) @attrs.setter def attrs(self, value: Mapping[str, Any]): if isinstance(value, Mapping): - self._attrs = value + self._attrs = dict(value) else: - raise TypeError("attrs must be a mapping") + raise TypeError("attrs must be a 'Attrs' mapping") @property def layout(self): @@ -1846,7 +1846,7 @@ def __init__( if behavior is not None and not isinstance(behavior, Mapping): raise TypeError("behavior must be None or mapping") - if attrs is not None and not isinstance(attrs, MutableMapping): + if attrs is not None and not isinstance(attrs, Mapping): raise TypeError("attrs must be None or a mapping") if named_axis: @@ -1883,7 +1883,7 @@ def _update_class(self): self.__class__ = get_record_class(self._layout, self._behavior) @property - def attrs(self) -> Mapping[str, Any]: + def attrs(self) -> Attrs: """ The mapping containing top-level metadata, which is serialised with the record during pickling. @@ -1894,12 +1894,12 @@ def attrs(self) -> Mapping[str, Any]: """ if self._attrs is None: self._attrs = {} - return self._attrs + return Attrs(self, self._attrs) @attrs.setter def attrs(self, value: Mapping[str, Any]): if isinstance(value, Mapping): - self._attrs = value + self._attrs = dict(value) else: raise TypeError("attrs must be a mapping") @@ -2672,7 +2672,7 @@ def _wrap(cls, layout, behavior=None, attrs=None): return out @property - def attrs(self) -> Mapping[str, Any]: + def attrs(self) -> Attrs: """ The mapping containing top-level metadata, which is serialised with the array during pickling. @@ -2683,12 +2683,12 @@ def attrs(self) -> Mapping[str, Any]: """ if self._attrs is None: self._attrs = {} - return self._attrs + return Attrs(self, self._attrs) @attrs.setter def attrs(self, value: Mapping[str, Any]): if isinstance(value, Mapping): - self._attrs = value + self._attrs = dict(value) else: raise TypeError("attrs must be a mapping") diff --git a/tests/test_2757_attrs_metadata.py b/tests/test_2757_attrs_metadata.py index de04074860..e4356d659f 100644 --- a/tests/test_2757_attrs_metadata.py +++ b/tests/test_2757_attrs_metadata.py @@ -25,7 +25,7 @@ def test_set_attrs(): assert array.attrs == {} array.attrs = OTHER_ATTRS - assert array.attrs is OTHER_ATTRS + assert array.attrs == OTHER_ATTRS with pytest.raises(TypeError): array.attrs = "Hello world!" @@ -52,7 +52,7 @@ def test_transient_metadata_persists(): attrs = {**SOME_ATTRS, "@transient_key": lambda: None} array = ak.Array([[1, 2, 3]], attrs=attrs) num = ak.num(array) - assert num.attrs is attrs + assert num.attrs == attrs @pytest.mark.parametrize( @@ -79,13 +79,13 @@ def test_single_arg_ops(func): # Carry from argument assert ( func([[1, 2, 3, 4], [5], [10]], axis=-1, highlevel=True, attrs=SOME_ATTRS).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from outer array array = ak.Array([[1, 2, 3, 4], [5], [10]], attrs=SOME_ATTRS) - assert func(array, axis=-1, highlevel=True).attrs is SOME_ATTRS + assert func(array, axis=-1, highlevel=True).attrs == SOME_ATTRS # Carry from argument exclusively - assert func(array, axis=-1, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + assert func(array, axis=-1, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS @pytest.mark.parametrize( @@ -134,15 +134,15 @@ def test_string_operations_unary(func): highlevel=True, attrs=SOME_ATTRS, ).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from outer array array = ak.Array( [["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS ) - assert func(array, highlevel=True).attrs is SOME_ATTRS + assert func(array, highlevel=True).attrs == SOME_ATTRS # Carry from argument exclusively - assert func(array, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + assert func(array, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS @pytest.mark.parametrize( @@ -188,15 +188,15 @@ def test_string_operations_unary_with_arg(func, arg): highlevel=True, attrs=SOME_ATTRS, ).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from outer array array = ak.Array( [["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS ) - assert func(array, arg, highlevel=True).attrs is SOME_ATTRS + assert func(array, arg, highlevel=True).attrs == SOME_ATTRS # Carry from argument exclusively - assert func(array, arg, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + assert func(array, arg, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS def test_string_operations_unary_with_arg_slice(): @@ -220,16 +220,16 @@ def test_string_operations_unary_with_arg_slice(): highlevel=True, attrs=SOME_ATTRS, ).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from outer array array = ak.Array( [["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS ) - assert ak.str.slice(array, 1, highlevel=True).attrs is SOME_ATTRS + assert ak.str.slice(array, 1, highlevel=True).attrs == SOME_ATTRS # Carry from argument exclusively assert ( - ak.str.slice(array, 1, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + ak.str.slice(array, 1, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS ) @@ -262,13 +262,13 @@ def test_string_operations_binary(func): highlevel=True, attrs=SOME_ATTRS, ).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from first array array = ak.Array( [["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS ) - assert func(array, ["hello"], highlevel=True).attrs is SOME_ATTRS + assert func(array, ["hello"], highlevel=True).attrs == SOME_ATTRS # Carry from second array value_array = ak.Array(["hello"], attrs=OTHER_ATTRS) @@ -278,7 +278,7 @@ def test_string_operations_binary(func): value_array, highlevel=True, ).attrs - is OTHER_ATTRS + == OTHER_ATTRS ) # Carry from both arrays assert func( @@ -289,7 +289,7 @@ def test_string_operations_binary(func): # Carry from argument assert ( - func(array, value_array, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + func(array, value_array, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS ) @@ -298,8 +298,8 @@ def test_broadcasting_arrays(): right = ak.Array([1], attrs=OTHER_ATTRS) left_result, right_result = ak.broadcast_arrays(left, right) - assert left_result.attrs is SOME_ATTRS - assert right_result.attrs is OTHER_ATTRS + assert left_result.attrs == SOME_ATTRS + assert right_result.attrs == OTHER_ATTRS def test_broadcasting_fields(): @@ -307,29 +307,29 @@ def test_broadcasting_fields(): right = ak.Array([{"y": 1}, {"y": 2}], attrs=OTHER_ATTRS) left_result, right_result = ak.broadcast_fields(left, right) - assert left_result.attrs is SOME_ATTRS - assert right_result.attrs is OTHER_ATTRS + assert left_result.attrs == SOME_ATTRS + assert right_result.attrs == OTHER_ATTRS def test_numba_arraybuilder(): numba = pytest.importorskip("numba") builder = ak.ArrayBuilder(attrs=SOME_ATTRS) - assert builder.attrs is SOME_ATTRS + assert builder.attrs == SOME_ATTRS @numba.njit def func(array): return array - assert func(builder).attrs is SOME_ATTRS + assert func(builder).attrs == SOME_ATTRS def test_numba_array(): numba = pytest.importorskip("numba") array = ak.Array([1, 2, 3], attrs=SOME_ATTRS) - assert array.attrs is SOME_ATTRS + assert array.attrs == SOME_ATTRS @numba.njit def func(array): return array - assert func(array).attrs is SOME_ATTRS + assert func(array).attrs == SOME_ATTRS diff --git a/tests/test_2770_serialize_and_deserialize_behaviour_for_numba.py b/tests/test_2770_serialize_and_deserialize_behaviour_for_numba.py index f44c513cc2..028b23cb90 100644 --- a/tests/test_2770_serialize_and_deserialize_behaviour_for_numba.py +++ b/tests/test_2770_serialize_and_deserialize_behaviour_for_numba.py @@ -19,7 +19,7 @@ def test_ArrayBuilder_behavior(): SOME_ATTRS = {"FOO": "BAR"} builder = ak.ArrayBuilder(behavior=SOME_ATTRS) - assert builder.behavior is SOME_ATTRS + assert builder.behavior == SOME_ATTRS assert func(builder).behavior == SOME_ATTRS diff --git a/tests/test_2806_attrs_typetracer.py b/tests/test_2806_attrs_typetracer.py index ea466eff04..980728c076 100644 --- a/tests/test_2806_attrs_typetracer.py +++ b/tests/test_2806_attrs_typetracer.py @@ -22,7 +22,7 @@ def test_typetracer_with_report(): form = layout.form_with_key("node{id}") meta, report = typetracer_with_report(form, highlevel=True, attrs=SOME_ATTRS) - assert meta.attrs is SOME_ATTRS + assert meta.attrs == SOME_ATTRS meta, report = typetracer_with_report(form, highlevel=True, attrs=None) assert meta._attrs is None @@ -44,5 +44,5 @@ def test_function(function): "z": [[0.1, 0.1, 0.2], [3, 1, 2], [2, 1, 2]], } ) - assert function(array, attrs=SOME_ATTRS).attrs is SOME_ATTRS + assert function(array, attrs=SOME_ATTRS).attrs == SOME_ATTRS assert function(array)._attrs is None diff --git a/tests/test_2837_ufunc_attrs_behavior.py b/tests/test_2837_ufunc_attrs_behavior.py index 86f2dcedec..0be740fbf1 100644 --- a/tests/test_2837_ufunc_attrs_behavior.py +++ b/tests/test_2837_ufunc_attrs_behavior.py @@ -14,15 +14,15 @@ def test(): def test_unary(): x = ak.Array([1, 2, 3], behavior={"foo": "BAR"}, attrs={"hello": "world"}) y = -x - assert y.attrs is x.attrs + assert y.attrs == x.attrs assert x.behavior is y.behavior def test_two_return(): x = ak.Array([1, 2, 3], behavior={"foo": "BAR"}, attrs={"hello": "world"}) y, y_ret = divmod(x, 2) - assert y.attrs is y_ret.attrs - assert y.attrs is x.attrs + assert y.attrs == y_ret.attrs + assert y.attrs == x.attrs assert y.behavior is y_ret.behavior assert y.behavior is x.behavior diff --git a/tests/test_2866_getitem_attrs.py b/tests/test_2866_getitem_attrs.py index 727edfe214..0524d6dc0e 100644 --- a/tests/test_2866_getitem_attrs.py +++ b/tests/test_2866_getitem_attrs.py @@ -11,24 +11,24 @@ def test_array_slice(): array = ak.Array([[0, 1, 2], [4]], attrs=ATTRS) - assert array.attrs is ATTRS + assert array.attrs == ATTRS - assert array[0].attrs is ATTRS - assert array[1:].attrs is ATTRS + assert array[0].attrs == ATTRS + assert array[1:].attrs == ATTRS def test_array_field(): array = ak.Array([[{"x": 1}, {"x": 2}], [{"x": 10}]], attrs=ATTRS) - assert array.attrs is ATTRS + assert array.attrs == ATTRS - assert array.x.attrs is ATTRS - assert array.x[1:].attrs is ATTRS + assert array.x.attrs == ATTRS + assert array.x[1:].attrs == ATTRS def test_record_field(): array = ak.Array([{"x": [1, 2, 3]}], attrs=ATTRS) - assert array.attrs is ATTRS + assert array.attrs == ATTRS record = array[0] - assert record.attrs is ATTRS - assert record.x.attrs is ATTRS + assert record.attrs == ATTRS + assert record.x.attrs == ATTRS diff --git a/tests/test_3277_attrs_behavior_on_array_copies.py b/tests/test_3277_attrs_behavior_on_array_copies.py new file mode 100644 index 0000000000..988f976adb --- /dev/null +++ b/tests/test_3277_attrs_behavior_on_array_copies.py @@ -0,0 +1,17 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE +# ruff: noqa: E402 + +from __future__ import annotations + +import awkward as ak + + +def test(): + arr = ak.Array([1]) + arr.attrs["foo"] = "bar" + + arr2 = ak.copy(arr) + assert arr2.attrs == arr.attrs + + arr2.attrs["foo"] = "baz" + assert arr2.attrs != arr.attrs From d4893d3814b36e5caac4d8f67c7a4c9e0f0062be Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Mon, 16 Dec 2024 13:10:37 -0500 Subject: [PATCH 2/5] break cyclic ref with weakref --- src/awkward/_attrs.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/awkward/_attrs.py b/src/awkward/_attrs.py index ac4f72b7f8..03626c7f9c 100644 --- a/src/awkward/_attrs.py +++ b/src/awkward/_attrs.py @@ -1,6 +1,7 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE from __future__ import annotations +import weakref from collections.abc import Mapping from types import MappingProxyType @@ -47,14 +48,18 @@ def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping: class Attrs(Mapping): def __init__(self, ref, data: Mapping[str, Any]): - self._ref = ref + self._ref = weakref.ref(ref) self._data = _freeze_attrs(data) def __getitem__(self, key: str): return self._data[key] def __setitem__(self, key: str, value: Any): - self._ref._attrs = _unfreeze_attrs(self._data) | {key: value} + ref = self._ref() + if ref is None: + msg = "The reference array has been deleted. If you still need to set attributes, convert this 'Attrs' instance to a dict with '.to_dict()'." + raise ValueError(msg) + ref._attrs = _unfreeze_attrs(self._data) | {key: value} def __iter__(self): return iter(self._data) @@ -63,7 +68,10 @@ def __len__(self): return len(self._data) def __repr__(self): - return repr(_unfreeze_attrs(self._data)) + return f"Attrs({_unfreeze_attrs(self._data)!r})" + + def to_dict(self): + return _unfreeze_attrs(self._data) def _freeze_attrs(attrs: Mapping[str, Any]) -> Mapping[str, Any]: From 1b94a8fefef82b23cb02e2eb45eb9661b61c7c6a Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Mon, 16 Dec 2024 14:45:33 -0500 Subject: [PATCH 3/5] ensure transients are strings --- src/awkward/_attrs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/awkward/_attrs.py b/src/awkward/_attrs.py index 03626c7f9c..468e6345f8 100644 --- a/src/awkward/_attrs.py +++ b/src/awkward/_attrs.py @@ -43,7 +43,7 @@ def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping: def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping: - return {k: v for k, v in attrs.items() if not k.startswith("@")} + return {k: v for k, v in attrs.items() if not (isinstance(k, str) and k.startswith("@"))} class Attrs(Mapping): From eb3e3e9ccc01c2cbe12114fb385259adec737a71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 19:48:58 +0000 Subject: [PATCH 4/5] style: pre-commit fixes --- src/awkward/_attrs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/awkward/_attrs.py b/src/awkward/_attrs.py index 468e6345f8..cf7a84f9fc 100644 --- a/src/awkward/_attrs.py +++ b/src/awkward/_attrs.py @@ -43,7 +43,9 @@ def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping: def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping: - return {k: v for k, v in attrs.items() if not (isinstance(k, str) and k.startswith("@"))} + return { + k: v for k, v in attrs.items() if not (isinstance(k, str) and k.startswith("@")) + } class Attrs(Mapping): From 2db1252ccf4d5d6742a2f8009439f36bfc68e082 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 18 Dec 2024 08:50:51 -0500 Subject: [PATCH 5/5] fix doc string Co-authored-by: Angus Hollands --- src/awkward/highlevel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index 6273bb0525..f9473bdf71 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -381,7 +381,7 @@ def _update_class(self): @property def attrs(self) -> Attrs: """ - The mutable mapping containing top-level metadata, which is serialised + The mapping containing top-level metadata, which is serialised with the array during pickling. Keys prefixed with `@` are identified as "transient" attributes