Skip to content

Commit

Permalink
fix: Address unexpected Python behaviour in slicing, where end > star…
Browse files Browse the repository at this point in the history
…t would panic instead of returning an empty byte slice
  • Loading branch information
achidlow committed Mar 6, 2024
1 parent 46bf800 commit 52c1666
Show file tree
Hide file tree
Showing 102 changed files with 4,127 additions and 1,556 deletions.
2 changes: 1 addition & 1 deletion examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ ssa2 95 87 87
state_proxies/StateProxy 87 83 83
string_ops 157 152 152
stubs/BigUInt 172 112 112
stubs/Bytes 928 164 164
stubs/Bytes 1769 258 258
stubs/Uint64 371 8 8
too_many_permutations 108 107 107
transaction/Transaction 914 864 864
Expand Down
2 changes: 1 addition & 1 deletion examples/voting/out/VotingRoundApp.destructured.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out/VotingRoundApp.ssa.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out/VotingRoundApp.ssa.opt_pass_1.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out/VotingRoundApp.ssa.opt_pass_2.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out_O2/VotingRoundApp.destructured.ir

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion src/puya/awst/function_traverser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import puya.awst.visitors
from puya.awst import nodes as awst_nodes
from puya.awst.visitors import T


class FunctionTraverser(
Expand Down
32 changes: 29 additions & 3 deletions src/puya/awst_build/eb/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,36 @@ def slice_index(
# since we evaluate self both as base and to get its length,
# we need to create a temporary assignment in case it has side effects
base = SingleEvaluation(self.expr)
slice_expr = SliceExpression(
begin_index_expr = _eval_slice_component(base, begin_index, location)
end_index_expr = _eval_slice_component(base, end_index, location)
if begin_index_expr is not None and end_index_expr is not None:
# special handling for if begin > end, will devolve into begin == end,
# which already returns the correct result of an empty bytes
# TODO: maybe we could improve the generated code if the above conversions weren't
# isolated - ie, if we move this sort of checks to before the length
# truncating checks
end_index_expr = IntrinsicCall(
op_code="select",
stack_args=[
# false: end = end
end_index_expr,
# true: end = begin
begin_index_expr,
# condition: begin > end
NumericComparisonExpression(
lhs=begin_index_expr,
operator=NumericComparison.gt,
rhs=end_index_expr,
source_location=location,
),
],
wtype=wtypes.uint64_wtype,
source_location=end_index_expr.source_location,
)
slice_expr: Expression = SliceExpression(
base=base,
begin_index=_eval_slice_component(base, begin_index, location),
end_index=_eval_slice_component(base, end_index, location),
begin_index=begin_index_expr,
end_index=end_index_expr,
wtype=self.wtype,
source_location=location,
)
Expand Down
25 changes: 16 additions & 9 deletions src/puya/ir/arc4_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,24 @@ def call(
)


def load_abi_arg(
def app_arg(
index: int,
wtype: wtypes.WType,
location: SourceLocation,
) -> awst_nodes.IntrinsicCall:
return awst_nodes.IntrinsicCall(
) -> awst_nodes.Expression:
value = awst_nodes.IntrinsicCall(
source_location=location,
wtype=wtype,
wtype=wtypes.bytes_wtype,
op_code="txna",
immediates=["ApplicationArgs", index],
)
if wtype == wtypes.bytes_wtype:
return value
return awst_nodes.ReinterpretCast(
source_location=location,
expr=value,
wtype=wtype,
)


def btoi(bytes_arg: awst_nodes.Expression, location: SourceLocation) -> awst_nodes.Expression:
Expand Down Expand Up @@ -572,11 +579,11 @@ def map_param_wtype_to_arc4_tuple_type(wtype: wtypes.WType) -> wtypes.WType:
args_overflow_wtype = wtypes.ARC4Tuple.from_types(
[map_param_wtype_to_arc4_tuple_type(a.wtype) for a in non_transaction_args[14:]]
)
last_arg = load_abi_arg(15, args_overflow_wtype, location)
last_arg = app_arg(15, args_overflow_wtype, location)

def get_arg(index: int, arg_wtype: wtypes.WType) -> awst_nodes.Expression:
if index < 15:
return load_abi_arg(index, arg_wtype, location)
return app_arg(index, arg_wtype, location)
else:
if last_arg is None:
raise InternalError("last_arg should not be None if there are more than 15 args")
Expand All @@ -585,20 +592,20 @@ def get_arg(index: int, arg_wtype: wtypes.WType) -> awst_nodes.Expression:
for arg in args:
match arg.wtype:
case wtypes.asset_wtype:
bytes_arg = get_arg(abi_arg_index, arg.wtype)
bytes_arg = get_arg(abi_arg_index, wtypes.bytes_wtype)
asset_index = btoi(bytes_arg, location)
asset_id = asset_id_at(asset_index, location)
yield asset_id
abi_arg_index += 1

case wtypes.account_wtype:
bytes_arg = get_arg(abi_arg_index, arg.wtype)
bytes_arg = get_arg(abi_arg_index, wtypes.bytes_wtype)
account_index = btoi(bytes_arg, location)
account = account_at(account_index, location)
yield account
abi_arg_index += 1
case wtypes.application_wtype:
bytes_arg = get_arg(abi_arg_index, arg.wtype)
bytes_arg = get_arg(abi_arg_index, wtypes.bytes_wtype)
application_index = btoi(bytes_arg, location)
application = application_at(application_index, location)
yield application
Expand Down
2 changes: 2 additions & 0 deletions src/puya/ir/builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
AVMBytesEncoding,
bytes_enc_to_avm_bytes_enc,
wtype_to_avm_type,
wtype_to_avm_types,
)
from puya.ir.utils import format_tuple_index
from puya.parse import SourceLocation
Expand Down Expand Up @@ -306,6 +307,7 @@ def visit_intrinsic_call(self, call: awst_nodes.IntrinsicCall) -> TExpression:
source_location=call.source_location,
args=args,
immediates=list(call.immediates),
types=wtype_to_avm_types(call.wtype),
)

def visit_create_inner_transaction(self, call: awst_nodes.CreateInnerTransaction) -> None:
Expand Down
25 changes: 22 additions & 3 deletions src/puya/ir/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

def _check_stack_types(
error_format: str,
target_types: list[AVMType],
source_types: list[AVMType],
target_types: Sequence[AVMType],
source_types: Sequence[AVMType],
source_location: SourceLocation | None,
) -> None:
if len(target_types) != len(source_types) or not all(
Expand Down Expand Up @@ -276,6 +276,25 @@ class Intrinsic(Op, ValueProvider):
immediates: list[str | int] = attrs.field(factory=list)
args: list[Value] = attrs.field(factory=list)
comment: str | None = None # used e.g. for asserts
_types: Sequence[AVMType] = attrs.field(converter=tuple[AVMType, ...])

@_types.default
def _default_types(self) -> tuple[AVMType, ...]:
return tuple(map(stack_type_to_avm_type, self.op_signature.returns))

@_types.validator
def _validate_types(self, _attribute: object, types: Sequence[AVMType]) -> None:
expected_types = self._default_types()
received_types = tuple(types)
desc = f"({self.op} {' '.join(map(str, self.immediates))}): "
_check_stack_types(
"Incompatible return types on Intrinsic"
+ desc
+ " received = {source_types}, expected = {target_types}",
expected_types,
received_types,
self.source_location,
)

def _frozen_data(self) -> object:
return self.op, tuple(self.immediates), tuple(self.args), self.comment
Expand All @@ -285,7 +304,7 @@ def accept(self, visitor: IRVisitor[T]) -> T:

@property
def types(self) -> Sequence[AVMType]:
return tuple(map(stack_type_to_avm_type, self.op_signature.returns))
return self._types

@property
def op_signature(self) -> OpSignature:
Expand Down
16 changes: 16 additions & 0 deletions src/puya/ir/types_.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ def wtype_to_avm_type(
)


def wtype_to_avm_types(
expr_or_wtype: wtypes.WType | awst_nodes.Expression,
source_location: SourceLocation | None = None,
) -> list[AVMType]:
if isinstance(expr_or_wtype, awst_nodes.Expression):
wtype = expr_or_wtype.wtype
else:
wtype = expr_or_wtype
if wtype == wtypes.void_wtype:
return []
elif isinstance(wtype, wtypes.WTuple):
return [wtype_to_avm_type(t, source_location) for t in wtype.types]
else:
return [wtype_to_avm_type(wtype, source_location)]


def stack_type_to_avm_type(stack_type: StackType) -> AVMType:
match stack_type:
case StackType.uint64 | StackType.bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
let result#0: bytes = ((replace2 0) source#0 tmp%1#0)
let popped_location#0: uint64 = (+ length_minus_1#0 16u)
let tmp%2#0: uint64 = (getbit result#0 popped_location#0)
let popped#0: any = (setbit "\x00" 0u tmp%2#0)
let popped#0: bytes = (setbit "\x00" 0u tmp%2#0)
let result#0: bytes = (setbit result#0 popped_location#0 0u)
let tmp%3#0: uint64 = (+ length_minus_1#0 7u)
let tmp%4#0: uint64 = (/ tmp%3#0 8u)
Expand Down
6 changes: 3 additions & 3 deletions test_cases/arc4_types/out/Arc4MutationContract.ssa.ir
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,8 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
let result#0: bytes = ((replace2 0) source#0 tmp%1#0)
let popped_location#0: uint64 = (+ length_minus_1#0 16u)
let tmp%2#0: uint64 = (getbit result#0 popped_location#0)
let popped#0: any = (setbit "\x00" 0u tmp%2#0)
let result#1: any = (setbit result#0 popped_location#0 0u)
let popped#0: bytes = (setbit "\x00" 0u tmp%2#0)
let result#1: bytes = (setbit result#0 popped_location#0 0u)
let tmp%3#0: uint64 = (+ length_minus_1#0 7u)
let tmp%4#0: uint64 = (/ tmp%3#0 8u)
let tmp%5#0: uint64 = (+ 2u tmp%4#0)
Expand Down Expand Up @@ -878,7 +878,7 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
block@7: // for_body_L110
let i#0: uint64 = range_item%7#1
let tmp%9#0: uint64 = (getbit new_items_bytes#0 i#0)
let result#3: any = (setbit result#4 write_offset#3 tmp%9#0)
let result#3: bytes = (setbit result#4 write_offset#3 tmp%9#0)
let write_offset#2: uint64 = (+ write_offset#3 1u)
goto block@8
block@8: // for_footer_L110
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,8 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
let result#0: bytes = ((replace2 0) source#0 tmp%1#0)
let popped_location#0: uint64 = (+ length_minus_1#0 16u)
let tmp%2#0: uint64 = (getbit result#0 popped_location#0)
let popped#0: any = (setbit "\x00" 0u tmp%2#0)
let result#1: any = (setbit result#0 popped_location#0 0u)
let popped#0: bytes = (setbit "\x00" 0u tmp%2#0)
let result#1: bytes = (setbit result#0 popped_location#0 0u)
let tmp%3#0: uint64 = (+ length_minus_1#0 7u)
let tmp%4#0: uint64 = (/ tmp%3#0 8u)
let tmp%5#0: uint64 = (+ 2u tmp%4#0)
Expand Down Expand Up @@ -736,7 +736,7 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
goto continue_looping%8#0 ? block@7 : block@10
block@7: // for_body_L110
let tmp%9#0: uint64 = (getbit new_items_bytes#0 i#0)
let result#3: any = (setbit result#4 write_offset#3 tmp%9#0)
let result#3: bytes = (setbit result#4 write_offset#3 tmp%9#0)
let write_offset#2: uint64 = (+ write_offset#3 1u)
let range_item%7#3: uint64 = (+ i#0 ternary_result%6#2)
goto block@6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
let result#0: bytes = ((replace2 0) source#0 tmp%1#0)
let popped_location#0: uint64 = (+ length_minus_1#0 16u)
let tmp%2#0: uint64 = (getbit result#0 popped_location#0)
let popped#0: any = (setbit "\x00" 0u tmp%2#0)
let result#1: any = (setbit result#0 popped_location#0 0u)
let popped#0: bytes = (setbit "\x00" 0u tmp%2#0)
let result#1: bytes = (setbit result#0 popped_location#0 0u)
let tmp%3#0: uint64 = (+ length_minus_1#0 7u)
let tmp%4#0: uint64 = (/ tmp%3#0 8u)
let tmp%5#0: uint64 = (+ 2u tmp%4#0)
Expand Down Expand Up @@ -308,7 +308,7 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
goto continue_looping%8#0 ? block@7 : block@10
block@7: // for_body_L110
let tmp%9#0: uint64 = (getbit new_items_bytes#0 i#0)
let result#3: any = (setbit result#4 write_offset#3 tmp%9#0)
let result#3: bytes = (setbit result#4 write_offset#3 tmp%9#0)
let write_offset#2: uint64 = (+ write_offset#3 1u)
let range_item%7#3: uint64 = (+ i#0 ternary_result%6#2)
goto block@6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
let result#0: bytes = ((replace2 0) source#0 tmp%1#0)
let popped_location#0: uint64 = (+ length_minus_1#0 16u)
let tmp%2#0: uint64 = (getbit result#0 popped_location#0)
let popped#0: any = (setbit "\x00" 0u tmp%2#0)
let result#1: any = (setbit result#0 popped_location#0 0u)
let popped#0: bytes = (setbit "\x00" 0u tmp%2#0)
let result#1: bytes = (setbit result#0 popped_location#0 0u)
let tmp%3#0: uint64 = (+ length_minus_1#0 7u)
let tmp%4#0: uint64 = (/ tmp%3#0 8u)
let tmp%5#0: uint64 = (+ 2u tmp%4#0)
Expand Down Expand Up @@ -299,7 +299,7 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
goto continue_looping%8#0 ? block@7 : block@10
block@7: // for_body_L110
let tmp%9#0: uint64 = (getbit new_items_bytes#0 i#0)
let result#3: any = (setbit result#4 write_offset#3 tmp%9#0)
let result#3: bytes = (setbit result#4 write_offset#3 tmp%9#0)
let write_offset#2: uint64 = (+ write_offset#3 1u)
let range_item%7#3: uint64 = (+ i#0 ternary_result%6#2)
goto block@6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
let result#0: bytes = ((replace2 0) source#0 tmp%1#0)
let popped_location#0: uint64 = (+ length_minus_1#0 16u)
let tmp%2#0: uint64 = (getbit result#0 popped_location#0)
let popped#0: any = (setbit "\x00" 0u tmp%2#0)
let result#1: any = (setbit result#0 popped_location#0 0u)
let popped#0: bytes = (setbit "\x00" 0u tmp%2#0)
let result#1: bytes = (setbit result#0 popped_location#0 0u)
let tmp%3#0: uint64 = (+ length_minus_1#0 7u)
let tmp%4#0: uint64 = (/ tmp%3#0 8u)
let tmp%5#0: uint64 = (+ 2u tmp%4#0)
Expand Down Expand Up @@ -293,7 +293,7 @@ contract test_cases.arc4_types.mutation.Arc4MutationContract:
goto continue_looping%8#0 ? block@7 : block@10
block@7: // for_body_L110
let tmp%9#0: uint64 = (getbit new_items_bytes#0 i#0)
let result#3: any = (setbit result#4 write_offset#3 tmp%9#0)
let result#3: bytes = (setbit result#4 write_offset#3 tmp%9#0)
let write_offset#2: uint64 = (+ write_offset#3 1u)
let range_item%7#3: uint64 = (+ i#0 ternary_result%6#2)
goto block@6
Expand Down
Loading

0 comments on commit 52c1666

Please sign in to comment.