Skip to content

Commit

Permalink
allow to mark a member as read_only, constant, or coerced using the m…
Browse files Browse the repository at this point in the history
…ember class
  • Loading branch information
MatthieuDartiailh committed Apr 16, 2024
1 parent e26e815 commit 228bee9
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 31 deletions.
35 changes: 30 additions & 5 deletions atom/meta/annotation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@
from ..dict import DefaultDict, Dict as ADict
from ..instance import Instance
from ..list import List as AList
from ..scalars import Bool, Bytes, Callable as ACallable, Float, Int, Str, Value
from ..scalars import (
Bool,
Bytes,
Callable as ACallable,
Float,
Int,
ReadOnly,
Str,
Value,
Constant,
Range,
FloatRange,
)
from ..set import Set as ASet
from ..subclass import Subclass
from ..tuple import FixedTuple, Tuple as ATuple
Expand All @@ -38,6 +50,9 @@
collections.abc.Callable: ACallable,
}

# XXX handle member as annotation
# XXX handle str as annotation with default


def generate_member_from_type_or_generic(
type_generic: Any, default: Any, annotate_type_containers: int
Expand All @@ -56,10 +71,20 @@ def generate_member_from_type_or_generic(
"Member subclasses cannot be used as annotations without "
"specifying a default value for the attribute."
)
elif isinstance(default, member) and default.coercer is not None:
m_cls = Coerced
parameters = (types,)
m_kwargs["coercer"] = default.coercer
elif isinstance(default, member) and any(
(default._constant, default._coercer, default._read_only)
):
if default._coercer is not None:
m_cls = Coerced
parameters = (types,)
m_kwargs["coercer"] = default._coercer
if default._read_only:
m_cls = ReadOnly
parameters = (types,)
if default._constant:
m_cls = Constant
m_kwargs["kind"] = types

elif object in types or Any in types:
m_cls = Value
parameters = ()
Expand Down
39 changes: 32 additions & 7 deletions atom/meta/member_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class member(object):
"default_factory",
"default_args",
"default_kwargs",
"coercer",
"_coercer",
"_read_only",
"_constant",
)

#: Name of the member for which a new default value should be set. Used by
Expand All @@ -41,9 +43,6 @@ class member(object):
#: Keyword arguments to create a default value.
default_kwargs: Optional[dict]

#: Coercing function to use.
coercer: Optional[Callable[[Any], Any]]

#: Metadata to set on the member
metadata: dict[str, Any]

Expand All @@ -53,7 +52,6 @@ def __init__(
default_factory: Optional[Callable[[], Any]] = None,
default_args: Optional[tuple] = None,
default_kwargs: Optional[dict] = None,
coercer: Optional[Callable[[Any], Any]] = None,
) -> None:
if default_value is not _SENTINEL:
if (
Expand All @@ -72,7 +70,9 @@ def __init__(
self.default_factory = default_factory
self.default_args = default_args
self.default_kwargs = default_kwargs
self.coercer = coercer
self._coercer = None
self._read_only = False
self._constant = False
self.metadata = {}

def clone(self) -> Self:
Expand All @@ -82,16 +82,41 @@ def clone(self) -> Self:
self.default_factory,
self.default_args,
self.default_kwargs,
self.coercer,
)
new._coercer = self._coercer
new._read_only = self._read_only
new._constant = self._constant
new.metadata = self.metadata.copy()
return new

def coerce(self, coercer: Callable[[Any], Any]) -> Self:
self._coercer = coercer
return self

def read_only(self) -> Self:
self._read_only = True
return self

def constant(self) -> Self:
self._constant = True
return self

def tag(self, **meta: Any) -> Self:
"""Add new metadata to the member."""
self.metadata |= meta
return self

# --- Private API

#: Coercing function to use.
_coercer: Optional[Callable[[Any], Any]]

#: Should the member be read only.
_read_only: bool

#: Should the member be constant
_constant: bool


def set_default(value: Any) -> member:
return member(default_value=value)
2 changes: 2 additions & 0 deletions docs/source/substitutions.sub
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@

.. |observe| replace:: :py:class:`~atom.atom.observe`

.. |member| replace:: :py:class:`~atom.atom.member`

.. |set_default| replace:: :py:class:`~atom.atom.set_default`

.. |atomref| replace:: :py:class:`~atom.catom.atomref`
Expand Down
59 changes: 41 additions & 18 deletions tests/test_atom_from_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
Bool,
Bytes,
Callable,
Coerced,
Constant,
DefaultDict,
Dict,
FixedTuple,
Expand All @@ -41,14 +43,22 @@
Int,
List,
Member,
ReadOnly,
Set,
Str,
Subclass,
Tuple,
Typed,
Value,
member,
set_default,
)
from atom.atom import set_default


class Dummy:
def __init__(self, a=1, b=2) -> None:
self.a = a
self.b = b


def test_ignore_annotations():
Expand Down Expand Up @@ -129,16 +139,6 @@ class A(Atom, use_annotations=True):
a: TList[int] = List(int, default=[1, 2, 3])


def test_reject_non_member_annotated_set_default():
class A(Atom, use_annotations=True):
a = Value()

with pytest.raises(TypeError):

class B(A, use_annotations=True):
a: int = set_default(1)


@pytest.mark.parametrize(
"annotation, member",
[
Expand Down Expand Up @@ -275,7 +275,7 @@ class A(Atom, use_annotations=True, type_containers=depth):


@pytest.mark.parametrize(
"annotation, member, default",
"annotation, member_cls, default",
[
(bool, Bool, True),
(int, Int, 1),
Expand All @@ -293,6 +293,10 @@ class A(Atom, use_annotations=True, type_containers=depth):
(TDefaultDict, DefaultDict, defaultdict(int, {1: 2})),
(Optional[Iterable], Instance, None),
(Type[int], Subclass, int),
(Dummy, Typed, member(default_args=(5,), default_kwargs={"b": 1})),
(Dummy, Coerced, member().coerce(lambda v: Dummy(v, 5))),
(Dummy, ReadOnly, member(1).read_only()),
(Dummy, Constant, member(1).constant()),
]
+ (
[
Expand All @@ -306,15 +310,27 @@ class A(Atom, use_annotations=True, type_containers=depth):
else []
),
)
def test_annotations_with_default(annotation, member, default):
def test_annotations_with_default(annotation, member_cls, default):
class A(Atom, use_annotations=True):
a: annotation = default

assert isinstance(A.a, member)
if member is Subclass:
assert A.a.default_value_mode == member(int, default=default).default_value_mode
elif member is not Instance:
assert A.a.default_value_mode == member(default=default).default_value_mode
assert isinstance(A.a, member_cls)
if member_cls is Subclass:
assert (
A.a.default_value_mode
== member_cls(int, default=default).default_value_mode
)
elif member_cls not in (Instance, Typed, Coerced):
d = default.default_value if isinstance(default, member) else default
assert A.a.default_value_mode == member_cls(default=d).default_value_mode

if annotation is Dummy and default.default_args:
assert A().a.a == 5
assert A().a.b == 1
elif annotation is Dummy and default._coercer:
t = A()
t.a = 8
assert t.a.a == 8


def test_annotations_no_default_for_instance():
Expand All @@ -327,3 +343,10 @@ class A(Atom, use_annotations=True):

class B(Atom, use_annotations=True):
a: Optional[Iterable] = []


def test_setting_metadata():
class A(Atom, use_annotations=True):
a: Iterable = member().tag(a=1)

assert A.a.metadata == {"a": 1}
2 changes: 1 addition & 1 deletion tests/type_checking/test_annotations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
m: {{ annotation }} = {{ member_instance }}
reveal_type(A.m) # N: Revealed type is "{{ member_type }}"
reveal_type(A().m) # N: Revealed type is "{{ member_value_type }}"
reveal_type(A(m=[]).m) # N: Revealed type is "{{ member_value_type }}"

0 comments on commit 228bee9

Please sign in to comment.