diff --git a/src/awkward/_attrs.py b/src/awkward/_attrs.py index 14a42549d2..cf7a84f9fc 100644 --- a/src/awkward/_attrs.py +++ b/src/awkward/_attrs.py @@ -1,7 +1,9 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE from __future__ import annotations +import weakref from collections.abc import Mapping +from types import MappingProxyType from awkward._typing import Any, JSONMapping @@ -41,4 +43,42 @@ def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping: def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping: - return {k: v for k, v in attrs.items() if not k.startswith("@")} + return { + k: v for k, v in attrs.items() if not (isinstance(k, str) and k.startswith("@")) + } + + +class Attrs(Mapping): + def __init__(self, ref, data: Mapping[str, Any]): + self._ref = weakref.ref(ref) + self._data = _freeze_attrs(data) + + def __getitem__(self, key: str): + return self._data[key] + + def __setitem__(self, key: str, value: Any): + ref = self._ref() + if ref is None: + msg = "The reference array has been deleted. If you still need to set attributes, convert this 'Attrs' instance to a dict with '.to_dict()'." + raise ValueError(msg) + ref._attrs = _unfreeze_attrs(self._data) | {key: value} + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __repr__(self): + return f"Attrs({_unfreeze_attrs(self._data)!r})" + + def to_dict(self): + return _unfreeze_attrs(self._data) + + +def _freeze_attrs(attrs: Mapping[str, Any]) -> Mapping[str, Any]: + return MappingProxyType(attrs) + + +def _unfreeze_attrs(attrs: Mapping[str, Any]) -> dict[str, Any]: + return dict(attrs) diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index d878d9929a..513d44206d 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -19,7 +19,7 @@ import awkward as ak import awkward._connect.hist -from awkward._attrs import attrs_of, without_transient_attrs +from awkward._attrs import Attrs, attrs_of, without_transient_attrs from awkward._backends.dispatch import register_backend_lookup_factory from awkward._backends.numpy import NumpyBackend from awkward._behavior import behavior_of, get_array_class, get_record_class @@ -43,7 +43,7 @@ unpickle_record_schema_1, ) from awkward._regularize import is_non_string_like_iterable -from awkward._typing import Any, MutableMapping, TypeVar +from awkward._typing import Any, TypeVar from awkward._util import STDOUT from awkward.prettyprint import Formatter from awkward.prettyprint import valuestr as prettyprint_valuestr @@ -338,7 +338,7 @@ def __init__( if behavior is not None and not isinstance(behavior, Mapping): raise TypeError("behavior must be None or a mapping") - if attrs is not None and not isinstance(attrs, MutableMapping): + if attrs is not None and not isinstance(attrs, Mapping): raise TypeError("attrs must be None or a mapping") if named_axis: @@ -380,9 +380,9 @@ def _update_class(self): self.__class__ = get_array_class(self._layout, self._behavior) @property - def attrs(self) -> Mapping: + def attrs(self) -> Attrs: """ - The mutable mapping containing top-level metadata, which is serialised + The mapping containing top-level metadata, which is serialised with the array during pickling. Keys prefixed with `@` are identified as "transient" attributes @@ -391,14 +391,14 @@ def attrs(self) -> Mapping: """ if self._attrs is None: self._attrs = {} - return self._attrs + return Attrs(self, self._attrs) @attrs.setter def attrs(self, value: Mapping[str, Any]): if isinstance(value, Mapping): - self._attrs = value + self._attrs = dict(value) else: - raise TypeError("attrs must be a mapping") + raise TypeError("attrs must be a 'Attrs' mapping") @property def layout(self): @@ -1853,7 +1853,7 @@ def __init__( if behavior is not None and not isinstance(behavior, Mapping): raise TypeError("behavior must be None or mapping") - if attrs is not None and not isinstance(attrs, MutableMapping): + if attrs is not None and not isinstance(attrs, Mapping): raise TypeError("attrs must be None or a mapping") if named_axis: @@ -1890,7 +1890,7 @@ def _update_class(self): self.__class__ = get_record_class(self._layout, self._behavior) @property - def attrs(self) -> Mapping[str, Any]: + def attrs(self) -> Attrs: """ The mapping containing top-level metadata, which is serialised with the record during pickling. @@ -1901,12 +1901,12 @@ def attrs(self) -> Mapping[str, Any]: """ if self._attrs is None: self._attrs = {} - return self._attrs + return Attrs(self, self._attrs) @attrs.setter def attrs(self, value: Mapping[str, Any]): if isinstance(value, Mapping): - self._attrs = value + self._attrs = dict(value) else: raise TypeError("attrs must be a mapping") @@ -2679,7 +2679,7 @@ def _wrap(cls, layout, behavior=None, attrs=None): return out @property - def attrs(self) -> Mapping[str, Any]: + def attrs(self) -> Attrs: """ The mapping containing top-level metadata, which is serialised with the array during pickling. @@ -2690,12 +2690,12 @@ def attrs(self) -> Mapping[str, Any]: """ if self._attrs is None: self._attrs = {} - return self._attrs + return Attrs(self, self._attrs) @attrs.setter def attrs(self, value: Mapping[str, Any]): if isinstance(value, Mapping): - self._attrs = value + self._attrs = dict(value) else: raise TypeError("attrs must be a mapping") diff --git a/tests/test_2757_attrs_metadata.py b/tests/test_2757_attrs_metadata.py index de04074860..e4356d659f 100644 --- a/tests/test_2757_attrs_metadata.py +++ b/tests/test_2757_attrs_metadata.py @@ -25,7 +25,7 @@ def test_set_attrs(): assert array.attrs == {} array.attrs = OTHER_ATTRS - assert array.attrs is OTHER_ATTRS + assert array.attrs == OTHER_ATTRS with pytest.raises(TypeError): array.attrs = "Hello world!" @@ -52,7 +52,7 @@ def test_transient_metadata_persists(): attrs = {**SOME_ATTRS, "@transient_key": lambda: None} array = ak.Array([[1, 2, 3]], attrs=attrs) num = ak.num(array) - assert num.attrs is attrs + assert num.attrs == attrs @pytest.mark.parametrize( @@ -79,13 +79,13 @@ def test_single_arg_ops(func): # Carry from argument assert ( func([[1, 2, 3, 4], [5], [10]], axis=-1, highlevel=True, attrs=SOME_ATTRS).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from outer array array = ak.Array([[1, 2, 3, 4], [5], [10]], attrs=SOME_ATTRS) - assert func(array, axis=-1, highlevel=True).attrs is SOME_ATTRS + assert func(array, axis=-1, highlevel=True).attrs == SOME_ATTRS # Carry from argument exclusively - assert func(array, axis=-1, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + assert func(array, axis=-1, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS @pytest.mark.parametrize( @@ -134,15 +134,15 @@ def test_string_operations_unary(func): highlevel=True, attrs=SOME_ATTRS, ).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from outer array array = ak.Array( [["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS ) - assert func(array, highlevel=True).attrs is SOME_ATTRS + assert func(array, highlevel=True).attrs == SOME_ATTRS # Carry from argument exclusively - assert func(array, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + assert func(array, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS @pytest.mark.parametrize( @@ -188,15 +188,15 @@ def test_string_operations_unary_with_arg(func, arg): highlevel=True, attrs=SOME_ATTRS, ).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from outer array array = ak.Array( [["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS ) - assert func(array, arg, highlevel=True).attrs is SOME_ATTRS + assert func(array, arg, highlevel=True).attrs == SOME_ATTRS # Carry from argument exclusively - assert func(array, arg, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + assert func(array, arg, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS def test_string_operations_unary_with_arg_slice(): @@ -220,16 +220,16 @@ def test_string_operations_unary_with_arg_slice(): highlevel=True, attrs=SOME_ATTRS, ).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from outer array array = ak.Array( [["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS ) - assert ak.str.slice(array, 1, highlevel=True).attrs is SOME_ATTRS + assert ak.str.slice(array, 1, highlevel=True).attrs == SOME_ATTRS # Carry from argument exclusively assert ( - ak.str.slice(array, 1, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + ak.str.slice(array, 1, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS ) @@ -262,13 +262,13 @@ def test_string_operations_binary(func): highlevel=True, attrs=SOME_ATTRS, ).attrs - is SOME_ATTRS + == SOME_ATTRS ) # Carry from first array array = ak.Array( [["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS ) - assert func(array, ["hello"], highlevel=True).attrs is SOME_ATTRS + assert func(array, ["hello"], highlevel=True).attrs == SOME_ATTRS # Carry from second array value_array = ak.Array(["hello"], attrs=OTHER_ATTRS) @@ -278,7 +278,7 @@ def test_string_operations_binary(func): value_array, highlevel=True, ).attrs - is OTHER_ATTRS + == OTHER_ATTRS ) # Carry from both arrays assert func( @@ -289,7 +289,7 @@ def test_string_operations_binary(func): # Carry from argument assert ( - func(array, value_array, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS + func(array, value_array, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS ) @@ -298,8 +298,8 @@ def test_broadcasting_arrays(): right = ak.Array([1], attrs=OTHER_ATTRS) left_result, right_result = ak.broadcast_arrays(left, right) - assert left_result.attrs is SOME_ATTRS - assert right_result.attrs is OTHER_ATTRS + assert left_result.attrs == SOME_ATTRS + assert right_result.attrs == OTHER_ATTRS def test_broadcasting_fields(): @@ -307,29 +307,29 @@ def test_broadcasting_fields(): right = ak.Array([{"y": 1}, {"y": 2}], attrs=OTHER_ATTRS) left_result, right_result = ak.broadcast_fields(left, right) - assert left_result.attrs is SOME_ATTRS - assert right_result.attrs is OTHER_ATTRS + assert left_result.attrs == SOME_ATTRS + assert right_result.attrs == OTHER_ATTRS def test_numba_arraybuilder(): numba = pytest.importorskip("numba") builder = ak.ArrayBuilder(attrs=SOME_ATTRS) - assert builder.attrs is SOME_ATTRS + assert builder.attrs == SOME_ATTRS @numba.njit def func(array): return array - assert func(builder).attrs is SOME_ATTRS + assert func(builder).attrs == SOME_ATTRS def test_numba_array(): numba = pytest.importorskip("numba") array = ak.Array([1, 2, 3], attrs=SOME_ATTRS) - assert array.attrs is SOME_ATTRS + assert array.attrs == SOME_ATTRS @numba.njit def func(array): return array - assert func(array).attrs is SOME_ATTRS + assert func(array).attrs == SOME_ATTRS diff --git a/tests/test_2770_serialize_and_deserialize_behaviour_for_numba.py b/tests/test_2770_serialize_and_deserialize_behaviour_for_numba.py index f44c513cc2..028b23cb90 100644 --- a/tests/test_2770_serialize_and_deserialize_behaviour_for_numba.py +++ b/tests/test_2770_serialize_and_deserialize_behaviour_for_numba.py @@ -19,7 +19,7 @@ def test_ArrayBuilder_behavior(): SOME_ATTRS = {"FOO": "BAR"} builder = ak.ArrayBuilder(behavior=SOME_ATTRS) - assert builder.behavior is SOME_ATTRS + assert builder.behavior == SOME_ATTRS assert func(builder).behavior == SOME_ATTRS diff --git a/tests/test_2806_attrs_typetracer.py b/tests/test_2806_attrs_typetracer.py index ea466eff04..980728c076 100644 --- a/tests/test_2806_attrs_typetracer.py +++ b/tests/test_2806_attrs_typetracer.py @@ -22,7 +22,7 @@ def test_typetracer_with_report(): form = layout.form_with_key("node{id}") meta, report = typetracer_with_report(form, highlevel=True, attrs=SOME_ATTRS) - assert meta.attrs is SOME_ATTRS + assert meta.attrs == SOME_ATTRS meta, report = typetracer_with_report(form, highlevel=True, attrs=None) assert meta._attrs is None @@ -44,5 +44,5 @@ def test_function(function): "z": [[0.1, 0.1, 0.2], [3, 1, 2], [2, 1, 2]], } ) - assert function(array, attrs=SOME_ATTRS).attrs is SOME_ATTRS + assert function(array, attrs=SOME_ATTRS).attrs == SOME_ATTRS assert function(array)._attrs is None diff --git a/tests/test_2837_ufunc_attrs_behavior.py b/tests/test_2837_ufunc_attrs_behavior.py index 86f2dcedec..0be740fbf1 100644 --- a/tests/test_2837_ufunc_attrs_behavior.py +++ b/tests/test_2837_ufunc_attrs_behavior.py @@ -14,15 +14,15 @@ def test(): def test_unary(): x = ak.Array([1, 2, 3], behavior={"foo": "BAR"}, attrs={"hello": "world"}) y = -x - assert y.attrs is x.attrs + assert y.attrs == x.attrs assert x.behavior is y.behavior def test_two_return(): x = ak.Array([1, 2, 3], behavior={"foo": "BAR"}, attrs={"hello": "world"}) y, y_ret = divmod(x, 2) - assert y.attrs is y_ret.attrs - assert y.attrs is x.attrs + assert y.attrs == y_ret.attrs + assert y.attrs == x.attrs assert y.behavior is y_ret.behavior assert y.behavior is x.behavior diff --git a/tests/test_2866_getitem_attrs.py b/tests/test_2866_getitem_attrs.py index 727edfe214..0524d6dc0e 100644 --- a/tests/test_2866_getitem_attrs.py +++ b/tests/test_2866_getitem_attrs.py @@ -11,24 +11,24 @@ def test_array_slice(): array = ak.Array([[0, 1, 2], [4]], attrs=ATTRS) - assert array.attrs is ATTRS + assert array.attrs == ATTRS - assert array[0].attrs is ATTRS - assert array[1:].attrs is ATTRS + assert array[0].attrs == ATTRS + assert array[1:].attrs == ATTRS def test_array_field(): array = ak.Array([[{"x": 1}, {"x": 2}], [{"x": 10}]], attrs=ATTRS) - assert array.attrs is ATTRS + assert array.attrs == ATTRS - assert array.x.attrs is ATTRS - assert array.x[1:].attrs is ATTRS + assert array.x.attrs == ATTRS + assert array.x[1:].attrs == ATTRS def test_record_field(): array = ak.Array([{"x": [1, 2, 3]}], attrs=ATTRS) - assert array.attrs is ATTRS + assert array.attrs == ATTRS record = array[0] - assert record.attrs is ATTRS - assert record.x.attrs is ATTRS + assert record.attrs == ATTRS + assert record.x.attrs == ATTRS diff --git a/tests/test_3277_attrs_behavior_on_array_copies.py b/tests/test_3277_attrs_behavior_on_array_copies.py new file mode 100644 index 0000000000..988f976adb --- /dev/null +++ b/tests/test_3277_attrs_behavior_on_array_copies.py @@ -0,0 +1,17 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE +# ruff: noqa: E402 + +from __future__ import annotations + +import awkward as ak + + +def test(): + arr = ak.Array([1]) + arr.attrs["foo"] = "bar" + + arr2 = ak.copy(arr) + assert arr2.attrs == arr.attrs + + arr2.attrs["foo"] = "baz" + assert arr2.attrs != arr.attrs