diff --git a/boa3/internal/analyser/moduleanalyser.py b/boa3/internal/analyser/moduleanalyser.py index 9eac98e62..242ec1d7a 100644 --- a/boa3/internal/analyser/moduleanalyser.py +++ b/boa3/internal/analyser/moduleanalyser.py @@ -209,7 +209,7 @@ def __include_variable(self, if isinstance(source_node, ast.Global): var = outer_symbol else: - if isinstance(var_type, SequenceType): + if isinstance(var_type, SequenceType) and not Type.tuple.is_type_of(var_type): var_type = var_type.build_collection(var_enumerate_type) var = Variable(var_type, origin_node=source_node) @@ -1208,8 +1208,7 @@ def visit_Subscript(self, subscript: ast.Subscript) -> str | IType: if (isinstance(symbol, (Collection, MetaType)) and isinstance(subscript.value, (ast.Name, ast.NameConstant, ast.Attribute))): # for evaluating names like list[str], dict[int, bool], etc - value = subscript.slice.value if isinstance(subscript.slice, ast.Index) else subscript.slice - values_type: Iterable[IType] = self.get_values_type(value) + values_type: Iterable[IType] = self.get_values_type(subscript.slice) if isinstance(symbol, Collection): return symbol.build_collection(*values_type) else: @@ -1222,11 +1221,10 @@ def visit_Subscript(self, subscript: ast.Subscript) -> str | IType: if isinstance(symbol, UnionType) or isinstance(symbol_type, UnionType): if not isinstance(symbol_type, UnionType): symbol_type = symbol - index = subscript.slice.value if isinstance(subscript.slice, ast.Index) else subscript.slice - if isinstance(index, ast.Tuple): - union_types = [self.get_type(value) for value in index.elts] + if isinstance(subscript.slice, ast.Tuple): + union_types = [self.get_type(value) for value in subscript.slice.elts] else: - union_types = self.get_type(index) + union_types = self.get_type(subscript.slice) return symbol_type.build(union_types) if isinstance(symbol_type, Collection): diff --git a/boa3/internal/analyser/typeanalyser.py b/boa3/internal/analyser/typeanalyser.py index 1a5074918..471cb66c1 100644 --- a/boa3/internal/analyser/typeanalyser.py +++ b/boa3/internal/analyser/typeanalyser.py @@ -583,7 +583,7 @@ def validate_get_or_set(self, subscript: ast.Subscript, index_node: ast.AST) -> type_id=symbol_type.identifier, operation_id=Operator.Subscript) ) - return symbol_type.item_type + return symbol_type.get_item_type(index) def validate_slice(self, subscript: ast.Subscript, slice_node: ast.Slice) -> IType: """ @@ -1932,7 +1932,7 @@ def visit_Starred(self, node: ast.Starred) -> ast.AST: actual_type_id=value_type.identifier) ) - return Type.tuple.build_collection(value_type.value_type) + return Type.tuple.build_any_length(value_type.value_type) def visit_Index(self, index: ast.Index) -> Any: """ diff --git a/boa3/internal/compiler/codegenerator/codegeneratorvisitor.py b/boa3/internal/compiler/codegenerator/codegeneratorvisitor.py index 6c9c68455..1f3b4287d 100644 --- a/boa3/internal/compiler/codegenerator/codegeneratorvisitor.py +++ b/boa3/internal/compiler/codegenerator/codegeneratorvisitor.py @@ -518,7 +518,7 @@ def visit_Subscript_Index(self, subscript: ast.Subscript) -> GeneratorData: if isinstance(subscript.ctx, ast.Load): # get item value_data = self.visit_to_generate(subscript.value) - slice = subscript.slice.value if isinstance(subscript.slice, ast.Index) else subscript.slice + slice = subscript.slice self.visit_to_generate(slice) index_is_constant_number = isinstance(slice, ast.Num) and isinstance(slice.n, int) @@ -530,7 +530,7 @@ def visit_Subscript_Index(self, subscript: ast.Subscript) -> GeneratorData: # set item var_data = self.visit(subscript.value) - index = subscript.slice.value if isinstance(subscript.slice, ast.Index) else subscript.slice + index = subscript.slice symbol_id = var_data.symbol_id value_type = var_data.type diff --git a/boa3/internal/model/callable.py b/boa3/internal/model/callable.py index 41e4bf3e0..69001f68a 100644 --- a/boa3/internal/model/callable.py +++ b/boa3/internal/model/callable.py @@ -53,7 +53,7 @@ def __init__(self, args: dict[str, Variable] = None, default_value = set_internal_call(ast.parse(default_code).body[0].value) - self.args[vararg_id] = Variable(Type.tuple.build_collection([vararg_var.type])) + self.args[vararg_id] = Variable(Type.tuple.build_any_length(vararg_var.type)) self.defaults.append(default_value) self._vararg = vararg diff --git a/boa3/internal/model/type/annotation/ellipsistype.py b/boa3/internal/model/type/annotation/ellipsistype.py new file mode 100644 index 000000000..c4f007b38 --- /dev/null +++ b/boa3/internal/model/type/annotation/ellipsistype.py @@ -0,0 +1,31 @@ +from typing import Any + +from boa3.internal.model.type.itype import IType + + +class EllipsisType(IType): + """ + A class used to represent Python Ellipsis (...) annotation + """ + + def __init__(self): + identifier = 'Ellipsis' + super().__init__(identifier) + + @classmethod + def build(cls, value: Any) -> IType: + return ellipsisType + + @classmethod + def _is_type_of(cls, value: Any): + return value is Ellipsis or value is ellipsisType + + def union_type(self, other_type: IType) -> IType: + return other_type + + def intersect_type(self, other_type: IType) -> IType: + from boa3.internal.model.type.type import Type + return Type.none + + +ellipsisType: IType = EllipsisType() diff --git a/boa3/internal/model/type/collection/icollection.py b/boa3/internal/model/type/collection/icollection.py index 5985db4c4..2112d8187 100644 --- a/boa3/internal/model/type/collection/icollection.py +++ b/boa3/internal/model/type/collection/icollection.py @@ -78,6 +78,9 @@ def get_types(cls, value: Any) -> set[IType]: types: set[IType] = {val if isinstance(val, IType) else Type.get_type(val) for val in value} return cls.filter_types(types) + def get_item_type(self, index: tuple): + return self.item_type + @classmethod def filter_types(cls, values_type) -> set[IType]: if values_type is None: @@ -93,6 +96,11 @@ def filter_types(cls, values_type) -> set[IType]: if any(t is Type.any or t is Type.none for t in values_type): return {Type.any} + if Type.ellipsis in values_type: + values_type.remove(Type.ellipsis) + if len(values_type) == 1: + return values_type + actual_types = list(values_type)[:1] for value in list(values_type)[1:]: other = next((x for x in actual_types diff --git a/boa3/internal/model/type/collection/sequence/tupletype.py b/boa3/internal/model/type/collection/sequence/tupletype.py index 712002c6c..1113cf25e 100644 --- a/boa3/internal/model/type/collection/sequence/tupletype.py +++ b/boa3/internal/model/type/collection/sequence/tupletype.py @@ -1,4 +1,6 @@ -from typing import Any +from __future__ import annotations + +from typing import Any, Iterable from boa3.internal.model.type.collection.sequence.sequencetype import SequenceType from boa3.internal.model.type.itype import IType @@ -9,11 +11,33 @@ class TupleType(SequenceType): A class used to represent Python tuple type """ - def __init__(self, values_type: set[IType] = None): + def __init__(self, values_type: list[IType] = None, any_length: bool = False): identifier = 'tuple' + if values_type is None: + values_type = [] + any_length = True + + self._tuple_types = values_type + self._is_any_length = any_length + values_type = self.filter_types(values_type) super().__init__(identifier, values_type) + @property + def identifier(self) -> str: + from boa3.internal.model.type.type import Type + if self.item_type == Type.any and self._is_any_length: + return self._identifier + + if len(self._tuple_types) == 0: + tuple_types = [self.item_type.identifier] + else: + tuple_types = [type_.identifier for type_ in self._tuple_types] + + if self._is_any_length: + tuple_types.append('...') + return f'{self._identifier}[{", ".join(tuple_types)}]' + @property def default_value(self) -> Any: return tuple() @@ -29,8 +53,85 @@ def valid_key(self) -> IType: @classmethod def build(cls, value: Any) -> IType: if cls._is_type_of(value): - values_types: set[IType] = cls.get_types(value) - return cls(values_types) + values_types: list[IType] = cls.get_types(value) + from boa3.internal.model.type.type import Type + if len(values_types) == 2 and values_types[-1] is Type.ellipsis: + has_ellipsis = True + values_types.pop() + else: + has_ellipsis = False + if Type.ellipsis in values_types: + # only tuple[, ...] is accepted as tuple of any size typed as + # all other cases where ... is used it has the same behavior as any + for index, value in enumerate(values_types): + if value is Type.ellipsis: + values_types[index] = Type.any + + return cls(values_types, any_length=has_ellipsis) + + def build_any_length(self, value: Any) -> IType: + result: TupleType = self.build((value,)) + if len(result._tuple_types) == 1: + result._is_any_length = True + return result + + @classmethod + def build_collection(cls, *value_type: IType | Iterable) -> IType: + params = [] + for arg in value_type: + if isinstance(arg, Iterable): + argument = list(arg) + else: + argument = [arg] + params.extend(argument) + return cls.build(tuple(params)) + + @classmethod + def get_types(cls, value: Any) -> list[IType]: + from boa3.internal.model.type.type import Type + return [val if isinstance(val, IType) else Type.get_type(val) for val in value] + + def get_item_type(self, index: tuple): + if len(index) > 0 and isinstance(index[0], int): + target_index = index[0] + if len(self._tuple_types) > target_index: + return self._tuple_types[target_index] + + return super().get_item_type(index) + + def is_type_of(self, value: Any) -> bool: + if self._is_type_of(value): + min_size = len(self._tuple_types) + if isinstance(value, TupleType): + types_to_check = value._tuple_types + any_length = value._is_any_length + else: + types_to_check = value + any_length = False + + len_types_to_check = len(types_to_check) + if self._is_any_length and len_types_to_check == 0 and not any_length: + # tuples of any length are always type of empty tuple + return True + if len_types_to_check < min_size: + return False + if not self._is_any_length: + if len_types_to_check > min_size: + return False + elif len_types_to_check == min_size and any_length: + return False + + for index in range(min_size): + if not self._tuple_types[index].is_type_of(types_to_check[index]): + return False + if len_types_to_check > min_size: + last_tuple_type = self._tuple_types[-1] if len(self._tuple_types) else self.value_type + for index in range(min_size, len_types_to_check): + if not last_tuple_type.is_type_of(types_to_check[index]): + return False + + return True + return False @classmethod def _is_type_of(cls, value: Any): diff --git a/boa3/internal/model/type/type.py b/boa3/internal/model/type/type.py index eeac27918..3b234eab0 100644 --- a/boa3/internal/model/type/type.py +++ b/boa3/internal/model/type/type.py @@ -1,5 +1,6 @@ from typing import Any +from boa3.internal.model.type.annotation.ellipsistype import ellipsisType from boa3.internal.model.type.annotation.optionaltype import OptionalType from boa3.internal.model.type.annotation.uniontype import UnionType from boa3.internal.model.type.anytype import anyType @@ -122,4 +123,5 @@ def get_generic_type(cls, *types: IType) -> IType: # Annotation types union = UnionType() optional = OptionalType() + ellipsis = ellipsisType any = anyType diff --git a/boa3_test/test_sc/any_test/AnyTuple.py b/boa3_test/test_sc/any_test/AnyTuple.py index 369ddcadf..caecba08a 100644 --- a/boa3_test/test_sc/any_test/AnyTuple.py +++ b/boa3_test/test_sc/any_test/AnyTuple.py @@ -5,4 +5,4 @@ @public def Main(): - a: Tuple[Any] = (True, 1, 'ok') + a: Tuple[Any, Any, Any] = (True, 1, 'ok') diff --git a/boa3_test/test_sc/class_test/NotificationSetVariables.py b/boa3_test/test_sc/class_test/NotificationSetVariables.py index 8132f90f9..771342292 100644 --- a/boa3_test/test_sc/class_test/NotificationSetVariables.py +++ b/boa3_test/test_sc/class_test/NotificationSetVariables.py @@ -20,7 +20,7 @@ def event_name(event: str) -> str: @public -def state(obj: Tuple[Any]) -> Any: +def state(obj: Tuple[Any, ...]) -> Any: x = Notification() x.state = obj return x.state diff --git a/boa3_test/test_sc/dict_test/MismatchedTypeKeysDict.py b/boa3_test/test_sc/dict_test/MismatchedTypeKeysDict.py index b239a0051..8831529d6 100644 --- a/boa3_test/test_sc/dict_test/MismatchedTypeKeysDict.py +++ b/boa3_test/test_sc/dict_test/MismatchedTypeKeysDict.py @@ -6,5 +6,5 @@ @public def Main() -> Sequence[str]: a: Dict[str, int] = {'one': 1, 'two': 2, 'three': 3} - b: Tuple[str] = a.keys() + b: Tuple[str, ...] = a.keys() return b diff --git a/boa3_test/test_sc/dict_test/MismatchedTypeValuesDict.py b/boa3_test/test_sc/dict_test/MismatchedTypeValuesDict.py index 9d4819748..b2f9ab5bb 100644 --- a/boa3_test/test_sc/dict_test/MismatchedTypeValuesDict.py +++ b/boa3_test/test_sc/dict_test/MismatchedTypeValuesDict.py @@ -6,5 +6,5 @@ @public def Main() -> Sequence[int]: a: Dict[str, int] = {'one': 1, 'two': 2, 'three': 3} - b: Tuple[int] = a.values() + b: Tuple[int, ...] = a.values() return b diff --git a/boa3_test/test_sc/for_test/ForElse.py b/boa3_test/test_sc/for_test/ForElse.py index 031d2fa99..96b5b20e3 100644 --- a/boa3_test/test_sc/for_test/ForElse.py +++ b/boa3_test/test_sc/for_test/ForElse.py @@ -6,7 +6,7 @@ @public def Main() -> int: a: int = 0 - sequence: Tuple[int] = (3, 5, 15) + sequence: Tuple[int, int, int] = (3, 5, 15) for x in sequence: a = a + x diff --git a/boa3_test/test_sc/for_test/NestedFor.py b/boa3_test/test_sc/for_test/NestedFor.py index 9056ebc66..c71bf6e8d 100644 --- a/boa3_test/test_sc/for_test/NestedFor.py +++ b/boa3_test/test_sc/for_test/NestedFor.py @@ -6,7 +6,7 @@ @public def Main() -> int: a: int = 0 - sequence: Tuple[int] = (3, 5, 15) + sequence: Tuple[int, int, int] = (3, 5, 15) for x in sequence: for y in sequence: diff --git a/boa3_test/test_sc/logical_test/LogicMismatchedOperandLogicOr.py b/boa3_test/test_sc/logical_test/LogicMismatchedOperandLogicOr.py index ed1823af4..20b96f418 100644 --- a/boa3_test/test_sc/logical_test/LogicMismatchedOperandLogicOr.py +++ b/boa3_test/test_sc/logical_test/LogicMismatchedOperandLogicOr.py @@ -1,5 +1,5 @@ from typing import Tuple -def Main(a: bool, b: Tuple[str]) -> bool: +def Main(a: bool, b: Tuple[str, ...]) -> bool: return a | b diff --git a/boa3_test/test_sc/native_test/neo/GetCandidates.py b/boa3_test/test_sc/native_test/neo/GetCandidates.py index 27defe83d..482105488 100644 --- a/boa3_test/test_sc/native_test/neo/GetCandidates.py +++ b/boa3_test/test_sc/native_test/neo/GetCandidates.py @@ -1,9 +1,10 @@ -from typing import Any, List, Tuple +from typing import List, Tuple from boa3.builtin.compile_time import public from boa3.builtin.nativecontract.neo import NEO +from boa3.builtin.type import ECPoint @public -def main() -> List[Tuple[Any, Any]]: +def main() -> List[Tuple[ECPoint, int]]: return NEO.get_candidates() diff --git a/boa3_test/test_sc/tuple_test/EmptyTupleAssignment.py b/boa3_test/test_sc/tuple_test/EmptyTupleAssignment.py index abc90992c..87e356c3e 100644 --- a/boa3_test/test_sc/tuple_test/EmptyTupleAssignment.py +++ b/boa3_test/test_sc/tuple_test/EmptyTupleAssignment.py @@ -1,8 +1,6 @@ -from typing import Tuple - from boa3.builtin.compile_time import public @public def Main(): - a: Tuple[int] = () + a: tuple = () diff --git a/boa3_test/test_sc/tuple_test/IndexTuple.py b/boa3_test/test_sc/tuple_test/IndexTuple.py index 92963b423..6fee5d79a 100644 --- a/boa3_test/test_sc/tuple_test/IndexTuple.py +++ b/boa3_test/test_sc/tuple_test/IndexTuple.py @@ -1,8 +1,8 @@ -from typing import Any, Tuple +from typing import Any from boa3.builtin.compile_time import public @public -def main(a: Tuple[Any], value: Any, start: int, end: int) -> int: +def main(a: tuple, value: Any, start: int, end: int) -> int: return a.index(value, start, end) diff --git a/boa3_test/test_sc/tuple_test/IndexTupleDefaults.py b/boa3_test/test_sc/tuple_test/IndexTupleDefaults.py index 684676b44..eb164fc89 100644 --- a/boa3_test/test_sc/tuple_test/IndexTupleDefaults.py +++ b/boa3_test/test_sc/tuple_test/IndexTupleDefaults.py @@ -1,8 +1,8 @@ -from typing import Any, Tuple +from typing import Any from boa3.builtin.compile_time import public @public -def main(a: Tuple[Any], value: Any) -> int: +def main(a: tuple, value: Any) -> int: return a.index(value) diff --git a/boa3_test/test_sc/tuple_test/IndexTupleEndDefault.py b/boa3_test/test_sc/tuple_test/IndexTupleEndDefault.py index 9edab7160..05e0f70ec 100644 --- a/boa3_test/test_sc/tuple_test/IndexTupleEndDefault.py +++ b/boa3_test/test_sc/tuple_test/IndexTupleEndDefault.py @@ -1,8 +1,8 @@ -from typing import Any, Tuple +from typing import Any from boa3.builtin.compile_time import public @public -def main(a: Tuple[Any], value: Any, start: int) -> int: +def main(a: tuple, value: Any, start: int) -> int: return a.index(value, start) diff --git a/boa3_test/test_sc/tuple_test/MultipleExpressionsInLine.py b/boa3_test/test_sc/tuple_test/MultipleExpressionsInLine.py index a640d8a9d..0be8422fc 100644 --- a/boa3_test/test_sc/tuple_test/MultipleExpressionsInLine.py +++ b/boa3_test/test_sc/tuple_test/MultipleExpressionsInLine.py @@ -4,6 +4,6 @@ @public -def Main(items1: Tuple[int]) -> int: +def Main(items1: Tuple[int, ...]) -> int: items2 = ('a', 'b', 'c', 'd'); value = items1[0]; count = value + len(items2) return count diff --git a/boa3_test/test_sc/tuple_test/TupleGetValue.py b/boa3_test/test_sc/tuple_test/TupleGetValue.py index 0ec5907dd..c369067f1 100644 --- a/boa3_test/test_sc/tuple_test/TupleGetValue.py +++ b/boa3_test/test_sc/tuple_test/TupleGetValue.py @@ -4,5 +4,5 @@ @public -def Main(a: Tuple[int]) -> int: +def Main(a: Tuple[int, ...]) -> int: return a[0] diff --git a/boa3_test/test_sc/tuple_test/TupleGetValueMismatchedType.py b/boa3_test/test_sc/tuple_test/TupleGetValueMismatchedType.py index a32c94143..854e65fc2 100644 --- a/boa3_test/test_sc/tuple_test/TupleGetValueMismatchedType.py +++ b/boa3_test/test_sc/tuple_test/TupleGetValueMismatchedType.py @@ -1,5 +1,5 @@ from typing import Tuple -def Main(a: Tuple[int]) -> int: +def Main(a: Tuple[int, ...]) -> int: return a[0][0] diff --git a/boa3_test/test_sc/tuple_test/TupleGetValueTypedTuple.py b/boa3_test/test_sc/tuple_test/TupleGetValueTypedTuple.py new file mode 100644 index 000000000..6c7be35c5 --- /dev/null +++ b/boa3_test/test_sc/tuple_test/TupleGetValueTypedTuple.py @@ -0,0 +1,10 @@ +from typing import Tuple + +from boa3.builtin.compile_time import public + + +@public +def Main() -> int: + x = (True, 1, 'ok') + return x[1] + diff --git a/boa3_test/test_sc/tuple_test/TupleIndexMismatchedType.py b/boa3_test/test_sc/tuple_test/TupleIndexMismatchedType.py index 9a24aa01a..adda8d033 100644 --- a/boa3_test/test_sc/tuple_test/TupleIndexMismatchedType.py +++ b/boa3_test/test_sc/tuple_test/TupleIndexMismatchedType.py @@ -1,5 +1,5 @@ from typing import Tuple -def Main(a: Tuple[int]) -> int: +def Main(a: Tuple[int, ...]) -> int: return a['0'] diff --git a/boa3_test/test_sc/tuple_test/TupleOfTuple.py b/boa3_test/test_sc/tuple_test/TupleOfTuple.py index 35ed30847..091604373 100644 --- a/boa3_test/test_sc/tuple_test/TupleOfTuple.py +++ b/boa3_test/test_sc/tuple_test/TupleOfTuple.py @@ -4,5 +4,5 @@ @public -def Main(a: Tuple[Tuple[int]]) -> int: +def Main(a: Tuple[Tuple[int, ...], ...]) -> int: return a[0][0] diff --git a/boa3_test/test_sc/tuple_test/TupleSetValue.py b/boa3_test/test_sc/tuple_test/TupleSetValue.py index ef5d7b559..e1c5e22bc 100644 --- a/boa3_test/test_sc/tuple_test/TupleSetValue.py +++ b/boa3_test/test_sc/tuple_test/TupleSetValue.py @@ -1,6 +1,6 @@ from typing import Tuple -def Main(a: Tuple[int]) -> int: +def Main(a: Tuple[int, ...]) -> int: a[0] = 1 return 1 diff --git a/boa3_test/tests/compiler_tests/test_tuple.py b/boa3_test/tests/compiler_tests/test_tuple.py index cc10b20fa..4706ff1a0 100644 --- a/boa3_test/tests/compiler_tests/test_tuple.py +++ b/boa3_test/tests/compiler_tests/test_tuple.py @@ -148,6 +148,34 @@ def test_non_sequence_set_value(self): path = self.get_contract_path('SetValueMismatchedType.py') self.assertCompilerLogs(CompilerError.UnresolvedOperation, path) + def test_tuple_get_value_typed_tuple_compile(self): + ok = String('ok').to_bytes() + expected_output = ( + Opcode.INITSLOT # function signature + + b'\x01' + + b'\x00' + + Opcode.PUSHDATA1 # x = [True, 1, 'ok'] + + Integer(len(ok)).to_byte_array() + ok + + Opcode.PUSH1 + + Opcode.PUSHT + + Opcode.PUSH3 + + Opcode.PACK + + Opcode.STLOC0 + + Opcode.LDLOC0 # x[1] + + Opcode.PUSH1 + + Opcode.PICKITEM + + Opcode.RET # return + ) + + output, _ = self.assertCompile('TupleGetValueTypedTuple.py') + self.assertEqual(expected_output, output) + + async def test_tuple_get_value_typed_tuple_run(self): + await self.set_up_contract('TupleGetValueTypedTuple.py') + + result, _ = await self.call('Main', [], return_type=int) + self.assertEqual(1, result) + def test_tuple_index_mismatched_type(self): path = self.get_contract_path('TupleIndexMismatchedType.py') self.assertCompilerLogs(CompilerError.MismatchedTypes, path) diff --git a/boa3_test/tests/compiler_tests/test_types.py b/boa3_test/tests/compiler_tests/test_types.py index 027c504cb..76320aa4e 100644 --- a/boa3_test/tests/compiler_tests/test_types.py +++ b/boa3_test/tests/compiler_tests/test_types.py @@ -183,6 +183,11 @@ def test_tuple_int_is_type_of_tuple_any(self): tuple_any_type = Type.tuple self.assertFalse(tuple_type.is_type_of(tuple_any_type)) + def test_typed_tuple_is_type_of_empty_tuple(self): + tuple_type = Type.tuple.build_any_length(Type.int) + empty_tuple = tuple() + self.assertTrue(tuple_type.is_type_of(empty_tuple)) + def test_list_any_is_type_of_sequence(self): list_type = Type.list sequence_type = Type.sequence