Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Named tuple support #331

Merged
merged 8 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
match/Counter 67 59 - | 31 28 -
merkle/MerkleTree 227 212 - | 114 105 -
module_consts 52 50 - | 12 11 -
named_tuples/NamedTuples 423 339 337 | 227 184 183
nested_loops/Nested 217 200 - | 133 119 -
regression_tests/Issue118 172 111 - | 93 55 -
regression_tests/Issue194 35 22 - | 21 10 -
Expand Down Expand Up @@ -118,8 +119,8 @@
tuple_support/NestedTuples 972 620 - | 563 358 -
tuple_support/TupleComparisons 136 68 - | 85 35 -
tuple_support/TupleSupport 696 409 - | 381 180 -
typed_abi_call/Greeter 4998 3806 - | 2599 1801 -
typed_abi_call/Logger 1344 1047 1046 | 754 575 574
typed_abi_call/Greeter 5239 3956 - | 2737 1866 -
typed_abi_call/Logger 1456 1139 1138 | 823 631 630
typed_abi_call_txn/Caller 468 406 - | 240 205 -
typed_abi_call_txn/Txn 312 248 - | 170 134 -
unary/Unary 130 67 - | 62 27 -
Expand Down
4 changes: 1 addition & 3 deletions scripts/generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,9 +844,7 @@ def pytype_repr(typ: pytypes.PyType) -> str:
except KeyError:
pass
match typ:
case pytypes.TupleType(generic=pytypes.GenericTupleType, items=tuple_items) if len(
tuple_items
) > 1:
case pytypes.TupleType(items=tuple_items) if len(tuple_items) > 1:
item_strs = [pytype_repr(item) for item in tuple_items]
return (
f"pytypes.GenericTupleType.parameterise("
Expand Down
11 changes: 5 additions & 6 deletions src/puya/awst/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,20 +720,19 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
@attrs.frozen
class FieldExpression(Expression):
base: Expression = attrs.field(
validator=expression_has_wtype(wtypes.WStructType, wtypes.ARC4Struct)
validator=expression_has_wtype(wtypes.WStructType, wtypes.ARC4Struct, wtypes.WTuple)
)
name: str
wtype: wtypes.WType = attrs.field(init=False)

@wtype.default
def _wtype_factory(self) -> wtypes.WType:
struct_wtype = self.base.wtype
if not isinstance(struct_wtype, wtypes.WStructType | wtypes.ARC4Struct):
raise InternalError("invalid struct wtype")
dataclass_type = self.base.wtype
assert isinstance(dataclass_type, wtypes.WStructType | wtypes.ARC4Struct | wtypes.WTuple)
try:
return struct_wtype.fields[self.name]
return dataclass_type.fields[self.name]
except KeyError:
raise CodeError(f"invalid field for {struct_wtype}", self.source_location) from None
raise CodeError(f"invalid field for {dataclass_type}", self.source_location) from None

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_field_expression(self)
Expand Down
82 changes: 61 additions & 21 deletions src/puya/awst/wtypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Mapping
from functools import cached_property

import attrs
Expand All @@ -10,6 +10,7 @@
from puya.errors import CodeError, InternalError
from puya.models import TransactionType
from puya.parse import SourceLocation
from puya.utils import unique

logger = log.get_logger(__name__)

Expand Down Expand Up @@ -160,12 +161,12 @@ def from_type(cls, transaction_type: TransactionType | None) -> "WInnerTransacti
@typing.final
@attrs.frozen
class WStructType(WType):
fields: Mapping[str, WType] = attrs.field(converter=immutabledict)
fields: immutabledict[str, WType] = attrs.field(converter=immutabledict)
scalar_type: None = attrs.field(default=None, init=False)
source_location: SourceLocation | None = attrs.field(eq=False)

@fields.validator
def _fields_validator(self, _: object, fields: Mapping[str, WType]) -> None:
def _fields_validator(self, _: object, fields: immutabledict[str, WType]) -> None:
if not fields:
raise CodeError("struct needs fields", self.source_location)
if void_wtype in fields.values():
Expand Down Expand Up @@ -194,14 +195,15 @@ def _name(self) -> str:
@typing.final
@attrs.frozen
class WTuple(WType):
types: Sequence[WType] = attrs.field(converter=tuple[WType, ...])
types: tuple[WType, ...] = attrs.field(converter=tuple[WType, ...])
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
scalar_type: None = attrs.field(default=None, init=False)
immutable: bool = attrs.field(default=True, init=False)
name: str = attrs.field(init=False)
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
name: str = attrs.field(eq=False, kw_only=True)
names: tuple[str, ...] | None = attrs.field(default=None)

@types.validator
def _types_validator(self, _attribute: object, types: Sequence[WType]) -> None:
def _types_validator(self, _attribute: object, types: tuple[WType, ...]) -> None:
if not types:
raise CodeError("empty tuples are not supported", self.source_location)
if void_wtype in types:
Expand All @@ -211,6 +213,32 @@ def _types_validator(self, _attribute: object, types: Sequence[WType]) -> None:
def _name(self) -> str:
return f"tuple<{','.join([t.name for t in self.types])}>"

@names.validator
def _names_validator(self, _attribute: object, names: tuple[str, ...] | None) -> None:
if names is None:
return
if len(names) != len(self.types):
raise InternalError("mismatch between tuple item names length and types")
if len(names) != len(unique(names)):
raise CodeError("tuple item names are not unique", self.source_location)

@cached_property
def fields(self) -> Mapping[str, WType]:
"""Mapping of item names to types if `names` is defined, otherwise empty."""
if self.names is None:
return {}
return dict(zip(self.names, self.types, strict=True))

def name_to_index(self, name: str, source_location: SourceLocation) -> int:
if self.names is None:
raise CodeError(
"cannot access tuple item by name of an unnamed tuple", source_location
)
try:
return self.names.index(name)
except ValueError:
raise CodeError(f"{name} is not a member of {self.name}") from None


@attrs.frozen(kw_only=True)
class ARC4Type(WType):
Expand Down Expand Up @@ -307,7 +335,7 @@ def _m_validator(self, _attribute: object, m: int) -> None:
raise CodeError("Precision must be between 1 and 160 inclusive", self.source_location)


def _required_arc4_wtypes(wtypes: Iterable[WType]) -> Sequence[ARC4Type]:
def _required_arc4_wtypes(wtypes: Iterable[WType]) -> tuple[ARC4Type, ...]:
result = []
for wtype in wtypes:
if not isinstance(wtype, ARC4Type):
Expand All @@ -320,7 +348,7 @@ def _required_arc4_wtypes(wtypes: Iterable[WType]) -> Sequence[ARC4Type]:
@attrs.frozen(kw_only=True)
class ARC4Tuple(ARC4Type):
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
types: Sequence[ARC4Type] = attrs.field(converter=_required_arc4_wtypes)
types: tuple[ARC4Type, ...] = attrs.field(converter=_required_arc4_wtypes)
name: str = attrs.field(init=False)
arc4_name: str = attrs.field(init=False, eq=False)
immutable: bool = attrs.field(init=False)
Expand All @@ -343,14 +371,20 @@ def _decode_type(self) -> WTuple:
return WTuple(self.types, self.source_location)

def can_encode_type(self, wtype: WType) -> bool:
if wtype == self.decode_type:
return True
elif not isinstance(wtype, WTuple) or len(wtype.types) != len(self.types):
return False
return all(
return super().can_encode_type(wtype) or _is_arc4_encodeable_tuple(wtype, self.types)


def _is_arc4_encodeable_tuple(
wtype: WType, target_types: tuple[ARC4Type, ...]
) -> typing.TypeGuard[WTuple]:
return (
isinstance(wtype, WTuple)
and len(wtype.types) == len(target_types)
and all(
arc4_wtype == encode_wtype or arc4_wtype.can_encode_type(encode_wtype)
for arc4_wtype, encode_wtype in zip(self.types, wtype.types, strict=True)
for arc4_wtype, encode_wtype in zip(target_types, wtype.types, strict=True)
)
)


def _expect_arc4_type(wtype: WType) -> ARC4Type:
Expand Down Expand Up @@ -409,7 +443,7 @@ def _require_arc4_fields(fields: Mapping[str, WType]) -> immutabledict[str, ARC4
]
if non_arc4_fields:
raise CodeError(
"Invalid ARC4 Struct declaration,"
"invalid ARC4 Struct declaration,"
f" the following fields are not ARC4 encoded types: {', '.join(non_arc4_fields)}",
)
return immutabledict(fields)
Expand All @@ -418,7 +452,7 @@ def _require_arc4_fields(fields: Mapping[str, WType]) -> immutabledict[str, ARC4
@typing.final
@attrs.frozen(kw_only=True)
class ARC4Struct(ARC4Type):
fields: Mapping[str, ARC4Type] = attrs.field(converter=_require_arc4_fields)
fields: immutabledict[str, ARC4Type] = attrs.field(converter=_require_arc4_fields)
immutable: bool = attrs.field()
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
arc4_name: str = attrs.field(init=False, eq=False)
Expand All @@ -433,12 +467,18 @@ def _arc4_name(self) -> str:
return f"({','.join(item.arc4_name for item in self.types)})"

@cached_property
def names(self) -> Sequence[str]:
return list(self.fields.keys())
def names(self) -> tuple[str, ...]:
return tuple(self.fields.keys())

@cached_property
def types(self) -> Sequence[ARC4Type]:
return list(self.fields.values())
def types(self) -> tuple[ARC4Type, ...]:
return tuple(self.fields.values())

def can_encode_type(self, wtype: WType) -> bool:
return super().can_encode_type(wtype) or (
_is_arc4_encodeable_tuple(wtype, self.types)
and (wtype.names is None or wtype.names == self.names)
)


arc4_byte_alias: typing.Final = ARC4UIntN(
Expand Down
16 changes: 16 additions & 0 deletions src/puya/ir/_puya_lib.awst.json
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,8 @@
}
],
"source_location": null,
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"body": {
Expand Down Expand Up @@ -2270,6 +2272,8 @@
"column": 11,
"end_column": 25
},
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"_type": "TupleExpression"
Expand Down Expand Up @@ -2356,6 +2360,8 @@
}
],
"source_location": null,
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"body": {
Expand Down Expand Up @@ -3091,6 +3097,8 @@
"column": 11,
"end_column": 25
},
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"_type": "TupleExpression"
Expand Down Expand Up @@ -3159,6 +3167,8 @@
}
],
"source_location": null,
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"body": {
Expand Down Expand Up @@ -4244,6 +4254,8 @@
"column": 11,
"end_column": 26
},
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"_type": "TupleExpression"
Expand Down Expand Up @@ -4312,6 +4324,8 @@
}
],
"source_location": null,
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"body": {
Expand Down Expand Up @@ -5652,6 +5666,8 @@
"column": 11,
"end_column": 26
},
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"_type": "TupleExpression"
Expand Down
2 changes: 1 addition & 1 deletion src/puya/ir/builder/_tuple_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ def build_tuple_item_names(
reg
for idx, item_type in enumerate(wtype.types)
for reg in build_tuple_item_names(
format_tuple_index(base_name, idx), item_type, source_location
format_tuple_index(wtype, base_name, idx), item_type, source_location
)
]
11 changes: 6 additions & 5 deletions src/puya/ir/builder/arc4.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,12 @@ def handle_arc4_assign(


def _get_tuple_var_name(expr: awst_nodes.TupleItemExpression) -> str:
if isinstance(expr.base, awst_nodes.TupleItemExpression):
return format_tuple_index(_get_tuple_var_name(expr.base), expr.index)
if isinstance(expr.base, awst_nodes.VarExpression):
return format_tuple_index(expr.base.name, expr.index)
raise CodeError("Invalid assignment target", expr.base.source_location)
if isinstance(expr.base.wtype, wtypes.WTuple):
if isinstance(expr.base, awst_nodes.TupleItemExpression):
return format_tuple_index(expr.base.wtype, _get_tuple_var_name(expr.base), expr.index)
if isinstance(expr.base, awst_nodes.VarExpression):
return format_tuple_index(expr.base.wtype, expr.base.name, expr.index)
raise CodeError("invalid assignment target", expr.base.source_location)


def concat_values(
Expand Down
25 changes: 19 additions & 6 deletions src/puya/ir/builder/itxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,10 @@ def _resolve_inner_txn_params_var_name(self, params: awst_nodes.Expression) -> s
case awst_nodes.VarExpression(name=var_name):
pass
case awst_nodes.TupleItemExpression(
base=awst_nodes.VarExpression(name=name), index=index
base=awst_nodes.VarExpression(name=name, wtype=wtypes.WTuple() as base_wtype),
index=index,
):
return format_tuple_index(name, index)
return format_tuple_index(base_wtype, name, index)
case awst_nodes.Copy(value=value):
return self._resolve_inner_txn_params_var_name(value)
case _:
Expand All @@ -570,9 +571,9 @@ def _get_assignment_target_local_names(
match target:
case awst_nodes.VarExpression(name=var_name) if expected_number == 1:
return [(var_name, target.source_location)]
case awst_nodes.VarExpression(name=var_name):
case awst_nodes.VarExpression(name=var_name, wtype=wtypes.WTuple() as var_wtype):
return [
(format_tuple_index(var_name, idx), target.source_location)
(format_tuple_index(var_wtype, var_name, idx), target.source_location)
for idx in range(expected_number)
]
case awst_nodes.TupleExpression(items=items) if expected_number == len(items) and all(
Expand All @@ -585,6 +586,14 @@ def _get_assignment_target_local_names(
):
tuple_names = _get_assignment_target_local_names(base, len(tuple_wtype.types))
return [tuple_names[index]]
case awst_nodes.FieldExpression(
base=awst_nodes.TupleExpression(wtype=tuple_wtype) as base,
name=name,
source_location=name_loc,
):
tuple_names = _get_assignment_target_local_names(base, len(tuple_wtype.types))
index = tuple_wtype.name_to_index(name, name_loc)
return [tuple_names[index]]
raise CodeError(
"Inner Transactions can only be assigned to local variables",
target.source_location,
Expand Down Expand Up @@ -677,15 +686,19 @@ def _get_uint64_const(expr: awst_nodes.Expression) -> int | None:

def _is_last_itxn(expr: awst_nodes.Expression) -> bool:
# is last itxn if expr is a submit expr of size 1 OR
if not isinstance(expr, awst_nodes.TupleItemExpression):
if not isinstance(expr, awst_nodes.TupleItemExpression | awst_nodes.FieldExpression):
return _is_submit_expr_of_size(expr, 1)

# if expr is a tuple item expression with an index into the last item of a submit expr
base = expr.base
if not isinstance(base.wtype, wtypes.WTuple):
return False

index = expr.index
index = (
expr.index
if isinstance(expr, awst_nodes.TupleItemExpression)
else base.wtype.name_to_index(expr.name, expr.source_location)
)
tuple_size = len(base.wtype.types)
if index == -1 or (index + 1) == tuple_size:
return _is_submit_expr_of_size(base, tuple_size)
Expand Down
Loading
Loading