Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: control attrs better as described in issue #3277 #3344

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/awkward/_attrs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from __future__ import annotations

import weakref
from collections.abc import Mapping
from types import MappingProxyType

from awkward._typing import Any, JSONMapping

Expand Down Expand Up @@ -42,3 +44,39 @@ def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping:

def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping:
return {k: v for k, v in attrs.items() if not k.startswith("@")}


class Attrs(Mapping):
def __init__(self, ref, data: Mapping[str, Any]):
self._ref = weakref.ref(ref)
self._data = _freeze_attrs(data)

def __getitem__(self, key: str):
return self._data[key]

def __setitem__(self, key: str, value: Any):
ref = self._ref()
if ref is None:
msg = "The reference array has been deleted. If you still need to set attributes, convert this 'Attrs' instance to a dict with '.to_dict()'."
raise ValueError(msg)
ref._attrs = _unfreeze_attrs(self._data) | {key: value}

def __iter__(self):
return iter(self._data)

def __len__(self):
return len(self._data)

def __repr__(self):
return f"Attrs({_unfreeze_attrs(self._data)!r})"

def to_dict(self):
return _unfreeze_attrs(self._data)


def _freeze_attrs(attrs: Mapping[str, Any]) -> Mapping[str, Any]:
return MappingProxyType(attrs)


def _unfreeze_attrs(attrs: Mapping[str, Any]) -> dict[str, Any]:
return dict(attrs)
28 changes: 14 additions & 14 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import awkward as ak
import awkward._connect.hist
from awkward._attrs import attrs_of, without_transient_attrs
from awkward._attrs import Attrs, attrs_of, without_transient_attrs
from awkward._backends.dispatch import register_backend_lookup_factory
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of, get_array_class, get_record_class
Expand All @@ -42,7 +42,7 @@
unpickle_record_schema_1,
)
from awkward._regularize import is_non_string_like_iterable
from awkward._typing import Any, MutableMapping, TypeVar
from awkward._typing import Any, TypeVar
from awkward._util import STDOUT
from awkward.prettyprint import Formatter
from awkward.prettyprint import valuestr as prettyprint_valuestr
Expand Down Expand Up @@ -337,7 +337,7 @@ def __init__(
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or a mapping")

if attrs is not None and not isinstance(attrs, MutableMapping):
if attrs is not None and not isinstance(attrs, Mapping):
raise TypeError("attrs must be None or a mapping")

if named_axis:
Expand Down Expand Up @@ -379,7 +379,7 @@ def _update_class(self):
self.__class__ = get_array_class(self._layout, self._behavior)

@property
def attrs(self) -> Mapping:
def attrs(self) -> Attrs:
"""
The mutable mapping containing top-level metadata, which is serialised
pfackeldey marked this conversation as resolved.
Show resolved Hide resolved
with the array during pickling.
Expand All @@ -390,14 +390,14 @@ def attrs(self) -> Mapping:
"""
if self._attrs is None:
self._attrs = {}
return self._attrs
return Attrs(self, self._attrs)

@attrs.setter
def attrs(self, value: Mapping[str, Any]):
if isinstance(value, Mapping):
self._attrs = value
self._attrs = dict(value)
else:
raise TypeError("attrs must be a mapping")
raise TypeError("attrs must be a 'Attrs' mapping")

@property
def layout(self):
Expand Down Expand Up @@ -1846,7 +1846,7 @@ def __init__(
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")

if attrs is not None and not isinstance(attrs, MutableMapping):
if attrs is not None and not isinstance(attrs, Mapping):
raise TypeError("attrs must be None or a mapping")

if named_axis:
Expand Down Expand Up @@ -1883,7 +1883,7 @@ def _update_class(self):
self.__class__ = get_record_class(self._layout, self._behavior)

@property
def attrs(self) -> Mapping[str, Any]:
def attrs(self) -> Attrs:
"""
The mapping containing top-level metadata, which is serialised
with the record during pickling.
Expand All @@ -1894,12 +1894,12 @@ def attrs(self) -> Mapping[str, Any]:
"""
if self._attrs is None:
self._attrs = {}
return self._attrs
return Attrs(self, self._attrs)

@attrs.setter
def attrs(self, value: Mapping[str, Any]):
if isinstance(value, Mapping):
self._attrs = value
self._attrs = dict(value)
else:
raise TypeError("attrs must be a mapping")

Expand Down Expand Up @@ -2672,7 +2672,7 @@ def _wrap(cls, layout, behavior=None, attrs=None):
return out

@property
def attrs(self) -> Mapping[str, Any]:
def attrs(self) -> Attrs:
"""
The mapping containing top-level metadata, which is serialised
with the array during pickling.
Expand All @@ -2683,12 +2683,12 @@ def attrs(self) -> Mapping[str, Any]:
"""
if self._attrs is None:
self._attrs = {}
return self._attrs
return Attrs(self, self._attrs)

@attrs.setter
def attrs(self, value: Mapping[str, Any]):
if isinstance(value, Mapping):
self._attrs = value
self._attrs = dict(value)
else:
raise TypeError("attrs must be a mapping")

Expand Down
52 changes: 26 additions & 26 deletions tests/test_2757_attrs_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_set_attrs():
assert array.attrs == {}

array.attrs = OTHER_ATTRS
assert array.attrs is OTHER_ATTRS
assert array.attrs == OTHER_ATTRS

with pytest.raises(TypeError):
array.attrs = "Hello world!"
Expand All @@ -52,7 +52,7 @@ def test_transient_metadata_persists():
attrs = {**SOME_ATTRS, "@transient_key": lambda: None}
array = ak.Array([[1, 2, 3]], attrs=attrs)
num = ak.num(array)
assert num.attrs is attrs
assert num.attrs == attrs


@pytest.mark.parametrize(
Expand All @@ -79,13 +79,13 @@ def test_single_arg_ops(func):
# Carry from argument
assert (
func([[1, 2, 3, 4], [5], [10]], axis=-1, highlevel=True, attrs=SOME_ATTRS).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from outer array
array = ak.Array([[1, 2, 3, 4], [5], [10]], attrs=SOME_ATTRS)
assert func(array, axis=-1, highlevel=True).attrs is SOME_ATTRS
assert func(array, axis=-1, highlevel=True).attrs == SOME_ATTRS
# Carry from argument exclusively
assert func(array, axis=-1, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
assert func(array, axis=-1, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS


@pytest.mark.parametrize(
Expand Down Expand Up @@ -134,15 +134,15 @@ def test_string_operations_unary(func):
highlevel=True,
attrs=SOME_ATTRS,
).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from outer array
array = ak.Array(
[["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS
)
assert func(array, highlevel=True).attrs is SOME_ATTRS
assert func(array, highlevel=True).attrs == SOME_ATTRS
# Carry from argument exclusively
assert func(array, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
assert func(array, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS


@pytest.mark.parametrize(
Expand Down Expand Up @@ -188,15 +188,15 @@ def test_string_operations_unary_with_arg(func, arg):
highlevel=True,
attrs=SOME_ATTRS,
).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from outer array
array = ak.Array(
[["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS
)
assert func(array, arg, highlevel=True).attrs is SOME_ATTRS
assert func(array, arg, highlevel=True).attrs == SOME_ATTRS
# Carry from argument exclusively
assert func(array, arg, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
assert func(array, arg, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS


def test_string_operations_unary_with_arg_slice():
Expand All @@ -220,16 +220,16 @@ def test_string_operations_unary_with_arg_slice():
highlevel=True,
attrs=SOME_ATTRS,
).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from outer array
array = ak.Array(
[["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS
)
assert ak.str.slice(array, 1, highlevel=True).attrs is SOME_ATTRS
assert ak.str.slice(array, 1, highlevel=True).attrs == SOME_ATTRS
# Carry from argument exclusively
assert (
ak.str.slice(array, 1, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
ak.str.slice(array, 1, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS
)


Expand Down Expand Up @@ -262,13 +262,13 @@ def test_string_operations_binary(func):
highlevel=True,
attrs=SOME_ATTRS,
).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from first array
array = ak.Array(
[["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS
)
assert func(array, ["hello"], highlevel=True).attrs is SOME_ATTRS
assert func(array, ["hello"], highlevel=True).attrs == SOME_ATTRS

# Carry from second array
value_array = ak.Array(["hello"], attrs=OTHER_ATTRS)
Expand All @@ -278,7 +278,7 @@ def test_string_operations_binary(func):
value_array,
highlevel=True,
).attrs
is OTHER_ATTRS
== OTHER_ATTRS
)
# Carry from both arrays
assert func(
Expand All @@ -289,7 +289,7 @@ def test_string_operations_binary(func):

# Carry from argument
assert (
func(array, value_array, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
func(array, value_array, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS
)


Expand All @@ -298,38 +298,38 @@ def test_broadcasting_arrays():
right = ak.Array([1], attrs=OTHER_ATTRS)

left_result, right_result = ak.broadcast_arrays(left, right)
assert left_result.attrs is SOME_ATTRS
assert right_result.attrs is OTHER_ATTRS
assert left_result.attrs == SOME_ATTRS
assert right_result.attrs == OTHER_ATTRS


def test_broadcasting_fields():
left = ak.Array([{"x": 1}, {"x": 2}], attrs=SOME_ATTRS)
right = ak.Array([{"y": 1}, {"y": 2}], attrs=OTHER_ATTRS)

left_result, right_result = ak.broadcast_fields(left, right)
assert left_result.attrs is SOME_ATTRS
assert right_result.attrs is OTHER_ATTRS
assert left_result.attrs == SOME_ATTRS
assert right_result.attrs == OTHER_ATTRS


def test_numba_arraybuilder():
numba = pytest.importorskip("numba")
builder = ak.ArrayBuilder(attrs=SOME_ATTRS)
assert builder.attrs is SOME_ATTRS
assert builder.attrs == SOME_ATTRS

@numba.njit
def func(array):
return array

assert func(builder).attrs is SOME_ATTRS
assert func(builder).attrs == SOME_ATTRS


def test_numba_array():
numba = pytest.importorskip("numba")
array = ak.Array([1, 2, 3], attrs=SOME_ATTRS)
assert array.attrs is SOME_ATTRS
assert array.attrs == SOME_ATTRS

@numba.njit
def func(array):
return array

assert func(array).attrs is SOME_ATTRS
assert func(array).attrs == SOME_ATTRS
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_ArrayBuilder_behavior():
SOME_ATTRS = {"FOO": "BAR"}
builder = ak.ArrayBuilder(behavior=SOME_ATTRS)

assert builder.behavior is SOME_ATTRS
assert builder.behavior == SOME_ATTRS
assert func(builder).behavior == SOME_ATTRS


Expand Down
4 changes: 2 additions & 2 deletions tests/test_2806_attrs_typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_typetracer_with_report():
form = layout.form_with_key("node{id}")

meta, report = typetracer_with_report(form, highlevel=True, attrs=SOME_ATTRS)
assert meta.attrs is SOME_ATTRS
assert meta.attrs == SOME_ATTRS

meta, report = typetracer_with_report(form, highlevel=True, attrs=None)
assert meta._attrs is None
Expand All @@ -44,5 +44,5 @@ def test_function(function):
"z": [[0.1, 0.1, 0.2], [3, 1, 2], [2, 1, 2]],
}
)
assert function(array, attrs=SOME_ATTRS).attrs is SOME_ATTRS
assert function(array, attrs=SOME_ATTRS).attrs == SOME_ATTRS
assert function(array)._attrs is None
6 changes: 3 additions & 3 deletions tests/test_2837_ufunc_attrs_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def test():
def test_unary():
x = ak.Array([1, 2, 3], behavior={"foo": "BAR"}, attrs={"hello": "world"})
y = -x
assert y.attrs is x.attrs
assert y.attrs == x.attrs
assert x.behavior is y.behavior


def test_two_return():
x = ak.Array([1, 2, 3], behavior={"foo": "BAR"}, attrs={"hello": "world"})
y, y_ret = divmod(x, 2)
assert y.attrs is y_ret.attrs
assert y.attrs is x.attrs
assert y.attrs == y_ret.attrs
assert y.attrs == x.attrs

assert y.behavior is y_ret.behavior
assert y.behavior is x.behavior
Loading
Loading