Skip to content

Commit

Permalink
feat: implement numeric operators for arc4.UIntN (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Mar 8, 2024
1 parent 52c1666 commit 7c27c81
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 49 deletions.
49 changes: 48 additions & 1 deletion src/puya/awst_build/eb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@
import structlog

from puya.awst import wtypes
from puya.awst.nodes import BoolConstant, Expression, IntrinsicCall, Literal
from puya.awst.nodes import (
BigUIntBinaryOperator,
BoolConstant,
Expression,
IntrinsicCall,
Literal,
UInt64BinaryOperator,
)
from puya.awst_build.eb.base import BuilderBinaryOp
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import expect_operand_wtype
from puya.errors import CodeError

if TYPE_CHECKING:
from puya.awst_build.eb.base import ExpressionBuilder
Expand Down Expand Up @@ -37,3 +46,41 @@ def uint64_to_biguint(
stack_args=[arg],
)
return itob_call


def translate_uint64_math_operator(
operator: BuilderBinaryOp, loc: SourceLocation
) -> UInt64BinaryOperator:
if operator is BuilderBinaryOp.div:
logger.error(
(
"To maintain semantic compatibility with Python, "
"only the truncating division operator (//) is supported "
),
location=loc,
)
# continue traversing code to generate any further errors
operator = BuilderBinaryOp.floor_div
try:
return UInt64BinaryOperator(operator.value)
except ValueError as ex:
raise CodeError(f"Unsupported UInt64 math operator {operator.value}", loc) from ex


def translate_biguint_math_operator(
operator: BuilderBinaryOp, loc: SourceLocation
) -> BigUIntBinaryOperator:
if operator is BuilderBinaryOp.div:
logger.error(
(
"To maintain semantic compatibility with Python, "
"only the truncating division operator (//) is supported "
),
location=loc,
)
# continue traversing code to generate any further errors
operator = BuilderBinaryOp.floor_div
try:
return BigUIntBinaryOperator(operator.value)
except ValueError as ex:
raise CodeError(f"Unsupported BigUInt math operator {operator.value}", loc) from ex
110 changes: 108 additions & 2 deletions src/puya/awst_build/eb/arc4/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,31 @@

from puya.awst import wtypes
from puya.awst.nodes import (
ARC4Decode,
ARC4Encode,
BigUIntBinaryOperation,
DecimalConstant,
Expression,
IntegerConstant,
Literal,
NumericComparison,
NumericComparisonExpression,
ReinterpretCast,
Statement,
UInt64BinaryOperation,
)
from puya.awst_build.eb._utils import (
translate_biguint_math_operator,
translate_uint64_math_operator,
uint64_to_biguint,
)
from puya.awst_build.eb._utils import uint64_to_biguint
from puya.awst_build.eb.arc4.base import (
ARC4ClassExpressionBuilder,
ARC4EncodedExpressionBuilder,
arc4_bool_bytes,
get_integer_literal_value,
)
from puya.awst_build.eb.base import BuilderComparisonOp, ExpressionBuilder
from puya.awst_build.eb.base import BuilderBinaryOp, BuilderComparisonOp, ExpressionBuilder
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import convert_literal_to_expr
from puya.errors import CodeError, InternalError, TodoError
Expand Down Expand Up @@ -185,6 +193,14 @@ def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> Expres
negate=negate,
)

def unary_plus(self, location: SourceLocation) -> ExpressionBuilder:
# unary + is allowed, but for the current types it has no real impact
# so just expand the existing expression to include the unary operator
raise TodoError(location)

def bitwise_invert(self, location: SourceLocation) -> ExpressionBuilder:
raise TodoError(location)

def compare(
self, other: ExpressionBuilder | Literal, op: BuilderComparisonOp, location: SourceLocation
) -> ExpressionBuilder:
Expand Down Expand Up @@ -216,6 +232,96 @@ def compare(
)
return var_expression(cmp_expr)

def binary_op(
self,
other: ExpressionBuilder | Literal,
op: BuilderBinaryOp,
location: SourceLocation,
*,
reverse: bool,
) -> ExpressionBuilder:
other_expr = convert_literal_to_expr(other, self.wtype)
if self.wtype.n <= 64:
result_expr = self._uint64_binary_op(other_expr, op, location, reverse=reverse)
else:
result_expr = self._biguint_binary_op(other_expr, op, location, reverse=reverse)
encoded_result = ARC4Encode(value=result_expr, source_location=location, wtype=self.wtype)
return var_expression(encoded_result)

def _uint64_binary_op(
self, other: Expression, op: BuilderBinaryOp, location: SourceLocation, *, reverse: bool
) -> Expression:
if other.wtype == self.wtype:
other = ARC4Decode(
value=other,
wtype=wtypes.uint64_wtype,
source_location=location,
)
elif isinstance(other.wtype, wtypes.ARC4UIntN):
raise TodoError(location, "TODO: support mixed size operators with arc4 numerics")
elif other.wtype == wtypes.uint64_wtype:
pass
elif other.wtype == wtypes.bool_wtype:
raise TodoError(location, "TODO: support upcast from bool to arc4.UIntN")
else:
return NotImplemented
lhs: Expression = ARC4Decode(
value=self.expr,
wtype=wtypes.uint64_wtype,
source_location=self.source_location,
)
rhs = other
if reverse:
(lhs, rhs) = (rhs, lhs)
uint64_op = translate_uint64_math_operator(op, location)
bin_op_expr = UInt64BinaryOperation(
source_location=location, left=lhs, op=uint64_op, right=rhs
)
return bin_op_expr

def _biguint_binary_op(
self, other: Expression, op: BuilderBinaryOp, location: SourceLocation, *, reverse: bool
) -> Expression:
if other.wtype == self.wtype:
other = ReinterpretCast(
expr=other,
wtype=wtypes.biguint_wtype,
source_location=other.source_location,
)
elif isinstance(other.wtype, wtypes.ARC4UIntN):
raise TodoError(location, "TODO: support mixed size operators with arc4 numerics")
elif other.wtype == wtypes.uint64_wtype:
other = uint64_to_biguint(other, location)
elif other.wtype == wtypes.biguint_wtype:
pass
elif other.wtype == wtypes.bool_wtype:
raise TodoError(location, "TODO: support upcast from bool to arc4.UIntN")
else:
return NotImplemented
lhs: Expression = ReinterpretCast(
expr=self.expr,
wtype=wtypes.biguint_wtype,
source_location=self.source_location,
)
rhs = other
if reverse:
(lhs, rhs) = (rhs, lhs)
biguint_op = translate_biguint_math_operator(op, location)
bin_op_expr = BigUIntBinaryOperation(
source_location=location, left=lhs, op=biguint_op, right=rhs
)
return bin_op_expr

def augmented_assignment(
self, op: BuilderBinaryOp, rhs: ExpressionBuilder | Literal, location: SourceLocation
) -> Statement:
raise TodoError(location)
# rhs_expr = convert_literal_to_expr(rhs, self.wtype)
# if self.wtype.n <= 64:
# return self._uint64_augmented_assignment(rhs_expr, op, location)
# else:
# return self._biguint_augmented_assignment(rhs_expr, op, location)


class UFixedNxMExpressionBuilder(ARC4EncodedExpressionBuilder):
def __init__(self, expr: Expression):
Expand Down
28 changes: 4 additions & 24 deletions src/puya/awst_build/eb/biguint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
from puya.awst.nodes import (
BigUIntAugmentedAssignment,
BigUIntBinaryOperation,
BigUIntBinaryOperator,
BigUIntConstant,
Literal,
NumericComparison,
NumericComparisonExpression,
ReinterpretCast,
Statement,
)
from puya.awst_build.eb._utils import uint64_to_biguint
from puya.awst_build.eb._utils import translate_biguint_math_operator, uint64_to_biguint
from puya.awst_build.eb.base import (
BuilderBinaryOp,
BuilderComparisonOp,
Expand Down Expand Up @@ -124,7 +123,7 @@ def binary_op(
if other_expr.wtype == self.wtype:
pass
elif other_expr.wtype == wtypes.uint64_wtype:
other_expr = uint64_to_biguint(other, location)
other_expr = uint64_to_biguint(other_expr, location)
elif other_expr.wtype == wtypes.bool_wtype:
raise TodoError(location, "TODO: support upcast from bool to biguint")
else:
Expand All @@ -133,7 +132,7 @@ def binary_op(
rhs = other_expr
if reverse:
(lhs, rhs) = (rhs, lhs)
biguint_op = _translate_biguint_math_operator(op, location)
biguint_op = translate_biguint_math_operator(op, location)
bin_op_expr = BigUIntBinaryOperation(
source_location=location, left=lhs, op=biguint_op, right=rhs
)
Expand All @@ -154,29 +153,10 @@ def augmented_assignment(
f"Invalid operand type {value.wtype} for {op.value}= with {self.wtype}", location
)
target = self.lvalue()
biguint_op = _translate_biguint_math_operator(op, location)
biguint_op = translate_biguint_math_operator(op, location)
return BigUIntAugmentedAssignment(
source_location=location,
target=target,
value=value,
op=biguint_op,
)


def _translate_biguint_math_operator(
operator: BuilderBinaryOp, loc: SourceLocation
) -> BigUIntBinaryOperator:
if operator is BuilderBinaryOp.div:
logger.error(
(
"To maintain semantic compatibility with Python, "
"only the truncating division operator (//) is supported "
),
location=loc,
)
# continue traversing code to generate any further errors
operator = BuilderBinaryOp.floor_div
try:
return BigUIntBinaryOperator(operator.value)
except ValueError as ex:
raise CodeError(f"Unsupported BigUInt math operator {operator.value}", loc) from ex
25 changes: 3 additions & 22 deletions src/puya/awst_build/eb/uint64.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
Statement,
UInt64AugmentedAssignment,
UInt64BinaryOperation,
UInt64BinaryOperator,
UInt64Constant,
UInt64UnaryOperation,
UInt64UnaryOperator,
)
from puya.awst_build.eb._utils import translate_uint64_math_operator
from puya.awst_build.eb.base import (
BuilderBinaryOp,
BuilderComparisonOp,
Expand Down Expand Up @@ -130,7 +130,7 @@ def binary_op(
rhs = other_expr
if reverse:
(lhs, rhs) = (rhs, lhs)
uint64_op = _translate_uint64_math_operator(op, location)
uint64_op = translate_uint64_math_operator(op, location)
bin_op_expr = UInt64BinaryOperation(
source_location=location, left=lhs, op=uint64_op, right=rhs
)
Expand All @@ -149,26 +149,7 @@ def augmented_assignment(
f"Invalid operand type {value.wtype} for {op.value}= with {self.wtype}", location
)
target = self.lvalue()
uint64_op = _translate_uint64_math_operator(op, location)
uint64_op = translate_uint64_math_operator(op, location)
return UInt64AugmentedAssignment(
source_location=location, target=target, value=value, op=uint64_op
)


def _translate_uint64_math_operator(
operator: BuilderBinaryOp, loc: SourceLocation
) -> UInt64BinaryOperator:
if operator is BuilderBinaryOp.div:
logger.error(
(
"To maintain semantic compatibility with Python, "
"only the truncating division operator (//) is supported "
),
location=loc,
)
# continue traversing code to generate any further errors
operator = BuilderBinaryOp.floor_div
try:
return UInt64BinaryOperator(operator.value)
except ValueError as ex:
raise CodeError(f"Unsupported UInt64 math operator {operator.value}", loc) from ex
Loading

0 comments on commit 7c27c81

Please sign in to comment.