diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 8da51e1..e17c77e 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -54,3 +54,8 @@ jobs: - name: Run detectors tests run: | pytest tests/test_detectors.py + + - name: Run dataflow analysis tests + run: | + pytest tests/transaction_context/test_group_sizes.py + pytest tests/transaction_context/test_group_indices.py diff --git a/tealer/analyses/__init__.py b/tealer/analyses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tealer/analyses/dataflow/__init__.py b/tealer/analyses/dataflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tealer/analyses/dataflow/all_constraints.py b/tealer/analyses/dataflow/all_constraints.py new file mode 100644 index 0000000..f8b0fd7 --- /dev/null +++ b/tealer/analyses/dataflow/all_constraints.py @@ -0,0 +1,2 @@ +# pylint: disable=unused-import +from tealer.analyses.dataflow.int_fields import GroupIndices diff --git a/tealer/analyses/dataflow/generic.py b/tealer/analyses/dataflow/generic.py new file mode 100644 index 0000000..98b2471 --- /dev/null +++ b/tealer/analyses/dataflow/generic.py @@ -0,0 +1,536 @@ +"""Defines generic class for dataflow analysis to find constraints on transaction fields. + +Possible values for a field are considered to be a set, referred to as universal set `U` for that field. +if U is finite and small, values are enumerated and are stored in the context. However, in case U is large, +such as address type fields, enum type values are used to represent UniversalSet and NullSet. + +For a given `key` and a `basic_block`, block_contexts[key][basic_block] are values(V), such that, if the transaction field represented by the `key` +is set to one of the values present in V, then: + - The execution might reach the `basic_block` + - The execution might successfully reach the end of the `basic_block` + - The execution might reach a leaf basic block which results in successful completion of execution. + +block_contexts is computed in three steps, each step making the information more precise. +1: In the first step, local information is considered. For each block, information inferred from the instructions present in the + block is computed. + - if the basic block contains instructions `assert(txn OnCompletion == int UpdateApplication)`, then block context of this block + for transaction types will be equal to `{ApplUpdateApplication}`. + - if basic block errors, contains `err` instruction or `return 0`, then block context will be NullSet. + - if instructions in the block does not provide any information related to the field, then block context will be equal to + UniversalSet (all possible values) for that key. +2: In the second step, information from the predecessors is considered. For this, forward dataflow analysis is used. + This problem is analogous to reaching definitions problem: + I Each possible value is a definition that is defined at the start of execution i.e defined at the start of entry block. + II The definition(value) will reach start of the basic block, if it reaches the end of one of it's predecessors. + III The definition(value) will reach the end of the block, if it is preserved by the basic block or it is defined in the basic block. + IV No definition(value) is defined in a basic block + V The definition(value) will reach start of the basic block, if the condition specific to reach this block is satisfied. Condition used + to determine the branch taken can contain constraints related to the analysis. Developers can branch to error block or success block + based on the transaction field value. Path based context is used to combine this information in forward analysis. + + Equations: + initialization: + RCHin(entry) = UniversalSet() - from (I) + RCHout(b) = NullSet() for all basic blocks b. + fixed point iteration: + RCHin(b) = union(intersection(RCHout(prev_b), path_context[b][prev_b]) for prev_b in predecessors) - from (II), (V) + RCHout(b) = union(GEN(b), intersection(RCHin(b), PRSV(b)) - from (III) + GEN(b) = NullSet() for all basic blocks b. + RCHout(b) = intersection(RCHin(b), PRSV(b)) + + `PRSV(b)` is `block_contexts` computed in the first step. + Reverse postorder ordering is used for the iteration algorithm. + +3: Finally, information from the successors is combined using backward dataflow analysis similar to Live-variable analysis. + - `block_contexts` is equal to reach out information computed in the second step. + - For leaf blocks, value is live if the value is preserved by the block. + - For other blocks, the value is live, if the value is used(preserved) by one of the successors. + equations: + initialization: + LIVEout(b) = NullSet() for all non-leaf blocks. + LIVEout(b) = PRSV(b) for all leaf blocks. + fixed point iteration: + LIVEin(b) = union(LIVEou(succ_b) for succ_b in successors) + LIVEout(b) = intersection(LIVEin(b), PRSV(b)) + + `PRSV(b)` is `block_contexts` after the second step. + Postorder ordering is used for the iteration algorithm. + +Blocks containing `callsub` instruction and blocks which are right after the `callsub` instruction are +treated differently. +e.g + main: // Basic Block B1 + int 2 + retsub + + path_1: // Basic Block B2 + txn OnCompletion + int UpdateApplication + == + assert + callsub main + + byte "PATH_1" // Basic Block B3 + int 1 + return + + path_2: // Basic Block B4 + txn OnCompletion + int DeleteApplication + == + assert + callsub main + + byte "PATH_2" // Basic Block B5 + int 1 + return + +CFG: + B2 + +--------------------------+ + | 4: path_1: | + | 5: txn OnCompletion | + | 6: int UpdateApplication | + | 7: == | + | 8: assert | + | 9: callsub main | + +--------------------------+ + | + | +B4 v B1 B5 ++---------------------------+ +--------------------------+ +-------------------+ +| 13: path_2: | | | | | +| 14: txn OnCompletion | | 1: main: | | 19: byte "PATH_2" | +| 15: int DeleteApplication | | 2: int 2 | | 20: int 1 | +| 16: == | | 3: retsub | | 21: return | +| 17: assert | | | | | +| 18: callsub main | --> | | --> | | ++---------------------------+ +--------------------------+ +-------------------+ + | + | + v B3 + +--------------------------+ + | 10: byte "PATH_1" | + | 11: int 1 | + | 12: return | + +--------------------------+ + +CFG will have edges: + - B2 -> B1 # callsub instruction transfers the execution to called subroutine + - B4 -> B1 + - B1 -> B3 # execution returns to the next instruction present after callsub instruction + - B1 -> B5 + +B3 and B5 are return points of the subroutine. +B3 is only executed if execution reaches B2 => block_context for txn type is `{DeleteApplication}`. +similarly, B5 is only executed if execution reaches B4 => block_context for txn type is `{UpdateApplication}`. + +block_contexts["TransactionType"][B1] = `{UpdateApplication, DeleteApplication}` # from B2 -> B1 and B4 -> B1. + +B3 and B5 are return points and have only one predecessor(B1) in CFG. As a result, block_contexts will be +block_contexts["TransactionType"][B4] = `{UpdateApplication, DeleteApplication}` +block_contexts["TransactionType"][B5] = `{UpdateApplication, DeleteApplication}`. + +This is because, when traversing the CFG without differentiating subroutine blocks and others, possible execution paths will be: +1. B2 -> B1 -> B3 +2. B4 -> B1 -> B5 +3. B2 -> B1 -> B5 # won't be possible at runtime. +4. B4 -> B1 -> B3 # won't be possible at runtime. + +However, At runtime, execution will reach B3 if and only if it reaches B2, same for B5 and B3. Using this reasoning while +combining information from predecessors and successors will give more accurate results. +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Set + +from tealer.teal.instructions.instructions import ( + Assert, + Return, + BZ, + BNZ, + Err, +) + +from tealer.utils.analyses import is_int_push_ins +from tealer.utils.algorand_constants import MAX_GROUP_SIZE + +if TYPE_CHECKING: + from tealer.teal.teal import Teal + from tealer.teal.basic_blocks import BasicBlock + from tealer.teal.instructions.instructions import Instruction + + +class IncorrectDataflowTransactionContextInitialization(Exception): + pass + + +class DataflowTransactionContext(ABC): # pylint: disable=too-few-public-methods + + # List of keys, unique and separate context is stored for each key. + # each key represents a transaction field. + BASE_KEYS: List[str] = [] + # BASE_KEYS for which transaction context information from `gtxn {i} {field}` is also stored. + KEYS_WITH_GTXN: List[str] = [] # every key in this list should also present in BASE_KEYS. + + def __init__(self, teal: "Teal"): + self._teal: "Teal" = teal + # entry block of CFG + self._entry_block: "BasicBlock" = teal.bbs[ + 0 + ] # blocks are ordered by entry instruction in parsing stage. + # self._block_contexts[KEY][B] -> block_context of KEY for block B + self._block_contexts: Dict[str, Dict["BasicBlock", Any]] = {} + # self._path_contexts[KEY][Bi][Bj] -> path_context of KEY for path Bj -> Bi + self._path_contexts: Dict[str, Dict["BasicBlock", Dict["BasicBlock", Any]]] = {} + if not self.BASE_KEYS: + raise IncorrectDataflowTransactionContextInitialization( + f"BASE_KEYS are not initialized {self.__class__.__name__}" + ) + + @staticmethod + def gtx_key(idx: int, key: str) -> str: + """return key used for tracking context of gtxn {idx} {field represented by key}""" + return f"GTXN_{idx:02d}_{key}" + + @abstractmethod + def _universal_set(self, key: str) -> Any: + """Return universal set for the field corresponding to given key""" + + @abstractmethod + def _null_set(self, key: str) -> Any: + """Return null set for the field corresponding to given key""" + + @abstractmethod + def _union(self, key: str, a: Any, b: Any) -> Any: + """return union of a and b, where a, b represent values for the given key""" + + @abstractmethod + def _intersection(self, key: str, a: Any, b: Any) -> Any: + """return intersection of a and b, where a, b represent values for the given key""" + + @abstractmethod + def _get_asserted(self, key: str, ins_stack: List["Instruction"]) -> Tuple[Any, Any]: + """For the given key and ins_stack, return true_values and false_values + + true_values for a key are considered to be values which result in non-zero value on + top of the stack. + false_values for a key are considered to be values which result in zero value on top + of the stack. + """ + + @abstractmethod + def _store_results(self) -> None: + """Store the collected information in the context object of each block""" + + def _block_level_constraints(self, analysis_keys: List[str], block: "BasicBlock") -> None: + """Calculate and store constraints on keys applied within the block. + + By default, no constraints are considered i.e values are assumed to be universal_set + if block contains `Err` or `Return 0`, values are set to null set. + + if block contains assert instruction, values are further constrained using the comparison being + asserted. Values are stored in self._block_contexts + self._block_contexts[KEY][B] -> block_context of KEY for block B + """ + for key in analysis_keys: + if key not in self._block_contexts: + self._block_contexts[key] = {} + self._block_contexts[key][block] = self._universal_set(key) + + stack: List["Instruction"] = [] + for ins in block.instructions: + if isinstance(ins, Assert): + for key in analysis_keys: + asserted_values, _ = self._get_asserted(key, stack) + present_values = self._block_contexts[key][block] + self._block_contexts[key][block] = self._intersection( + key, present_values, asserted_values + ) + + # if return 0, set possible values to NullSet() + if isinstance(ins, Return): + if len(ins.prev) == 1: + is_int, value = is_int_push_ins(ins.prev[0]) + if is_int and value == 0: + for key in analysis_keys: + self._block_contexts[key][block] = self._null_set(key) + + if isinstance(ins, Err): + for key in analysis_keys: + self._block_contexts[key][block] = self._null_set(key) + + stack.append(ins) + + def _path_level_constraints(self, analysis_keys: List[str], block: "BasicBlock") -> None: + """Calculate and store constraints on keys applied along each path. + + By default, no constraints are considered i.e values are assumed to be universal_set + + if block contains bz/bnz instruction, possible values are calculated for each branch and + are stored in self._path_contexts + self._path_contexts[KEY][Bi][Bj] -> path_context of KEY for path Bj -> Bi + """ + + for key in analysis_keys: + if key not in self._path_contexts: + self._path_contexts[key] = {} + path_context = self._path_contexts[key] + for b in block.next: + # path_context[bi][bj]: path context of path bj -> bi, bi is the successor + if b not in path_context: + path_context[b] = {} + # if there are no constraints, set the possible values to universal set + path_context[b][block] = self._universal_set(key) + + if isinstance(block.exit_instr, (BZ, BNZ)): + for key in analysis_keys: + # true_values: possible values for {key} which result in non-zero value on top of the stack + # false_values: possible values for {key} which result in zero value on top of the stack + # if the check is not related to the field, true_values and false_values will be universal sets + true_values, false_values = self._get_asserted(key, block.instructions[:-1]) + + if len(block.next) == 1: + # happens when bz/bnz is the last instruction in the contract and there is no default branch + default_branch = None + jump_branch = block.next[0] + else: + default_branch = block.next[0] + jump_branch = block.next[1] + + if isinstance(block.exit_instr, BZ): + # jump branch is taken if the comparison is false i.e not in asserted values + self._path_contexts[key][jump_branch][block] = false_values + # default branch is taken if the comparison is true i.e in asserted values + if default_branch is not None: + self._path_contexts[key][default_branch][block] = true_values + elif isinstance(block.exit_instr, BNZ): + # jump branch is taken if the comparison is true i.e in asserted values + self._path_contexts[key][jump_branch][block] = true_values + # default branch is taken if the comparison is false i.e not in asserted values + if default_branch is not None: + self._path_contexts[key][default_branch][block] = false_values + + def _update_gtxn_constraints(self, keys_with_gtxn: List[str], block: "BasicBlock") -> None: + """Use information from txn constraints on gtxn_ + + `block.transaction_context.group_indices` will contain indices the `txn` can have. + + The values of each key represent possible values for that field. Values of `GTXN_0_RekeyTo` are + possible values for `gtxn 0 RekeyTo` i.e possible `RekeyTo` field values of transaction present at index 0. + + self._block_contexts[f"GTXN_{idx}_{key}"] stores the information collected from + instructions `gtxn {idx} {field}`. This information represents validations performed + on the `txn {field}` by accessing it through `gtxn {idx} {field}`. + + e.g if index of the `txn` should be 0 then `txn RekeyTo` is same as `gtxn 0 RekeyTo`. + similary, if index of `txn` can be `0` or `1` then checking `txn RekeyTo` is equaivalent to + checking both `gtxn 0 RekeyTo` and `gtxn 1 RekeyTo`. + + if `i` is not a possible index of `txn` for basic block `B`, then possible values for `txn {field}` when + accessed through `gtxn {i} {field}` is Null. Because, `txn` can never have index `i` and `gtxn {i} {field}` is field + of `txn` when index of `txn` is `i`. + + This requires that group_indices analysis is done before any other analysis. + """ + for key in keys_with_gtxn: + for ind in range(MAX_GROUP_SIZE): + gtx_key = self.gtx_key(ind, key) + if ind in block.transaction_context.group_indices: + # txn can have index {ind} + # gtxn {ind} {field} can have a value if and only if {txn} {field} can also have that value + self._block_contexts[gtx_key][block] = self._intersection( + gtx_key, + self._block_contexts[gtx_key][block], + self._block_contexts[key][block], + ) + else: + # txn cannot have index {ind} + self._block_contexts[gtx_key][block] = self._null_set(gtx_key) + + def _calculate_reachin( + self, key: str, block: "BasicBlock", reachout: Dict["BasicBlock", Any] + ) -> Any: + if block == self._entry_block: + # We are considering each possible value as a definition defined at the start of entry block. + reachin_information = self._universal_set(key) + else: + reachin_information = self._null_set(key) + + path_context = self._path_contexts[key] + for prev_b in block.prev: + reachin_information = self._union( + key, + reachin_information, + self._intersection(key, reachout[prev_b], path_context[block][prev_b]), + ) + + if block.callsub_block is not None: + # this block is the return point for callsub instruction present in `block.callsub_block` + # execution will only reach this block, if it reaches `block.callsub_block` + reachin_information = self._intersection( + key, reachin_information, reachout[block.callsub_block] + ) + + return reachin_information + + def _merge_information_forward( + self, + analysis_keys: List[str], + block: "BasicBlock", + global_reachout: Dict[str, Dict["BasicBlock", Any]], + ) -> bool: + updated = False + for key in analysis_keys: + # RCHout(b) = intersection(RCHin(b), PRSV(b)) + new_reachout = self._intersection( + key, + self._calculate_reachin(key, block, global_reachout[key]), + self._block_contexts[key][block], + ) + if new_reachout != global_reachout[key][block]: + global_reachout[key][block] = new_reachout + updated = True + return updated + + def forward_analyis(self, analysis_keys: List[str], worklist: List["BasicBlock"]) -> None: + """Perform forward analysis for analysis_keys and update self._block_contexts""" + # reachout for all analysis keys. global_reachout[key] -> reachout of key. + # global_reachout[key][block] -> reachout of block for key. + global_reachout: Dict[str, Dict["BasicBlock", Any]] = {} + for key in analysis_keys: + global_reachout[key] = {} + for b in self._teal.bbs: + global_reachout[key][b] = self._null_set(key) + + while worklist: + b = worklist[0] + worklist = worklist[1:] + updated = self._merge_information_forward(analysis_keys, b, global_reachout) + + if updated: + return_point_block = [b.sub_return_point] if b.sub_return_point is not None else [] + for bi in b.next + return_point_block: + if bi not in worklist: + worklist.append(bi) + + for key in analysis_keys: + self._block_contexts[key] = global_reachout[key] + + def _calculate_livein( + self, key: str, block: "BasicBlock", liveout: Dict["BasicBlock", Any] + ) -> Any: + livein_information = self._null_set(key) + + for next_b in block.next: + livein_information = self._union(key, livein_information, liveout[next_b]) + + if block.sub_return_point is not None: + # this block is the `callsub block` and `block.sub_return_point` is the block that will be executed after subroutine. + livein_information = self._intersection( + key, livein_information, liveout[block.sub_return_point] + ) + return livein_information + + def _merge_information_backward( + self, + analysis_keys: List[str], + block: "BasicBlock", + global_liveout: Dict[str, Dict["BasicBlock", Any]], + ) -> bool: + if len(block.next) == 0: # leaf block + return False + + updated = False + for key in analysis_keys: + new_liveout = self._intersection( + key, + self._calculate_livein(key, block, global_liveout[key]), + self._block_contexts[key][block], + ) + if new_liveout != global_liveout[key][block]: + global_liveout[key][block] = new_liveout + updated = True + return updated + + def backward_analysis(self, analysis_keys: List[str], worklist: List["BasicBlock"]) -> None: + """Perform backward analysis for analysis_keys and update self._block_contexts""" + global_liveout: Dict[str, Dict["BasicBlock", Any]] = {} + for key in analysis_keys: + global_liveout[key] = {} + for b in self._teal.bbs: + if len(b.next) == 0: # leaf block + global_liveout[key][b] = self._block_contexts[key][b] + else: + global_liveout[key][b] = self._null_set(key) + + while worklist: + b = worklist[0] + worklist = worklist[1:] + updated = self._merge_information_backward(analysis_keys, b, global_liveout) + + if updated: + callsub_block = [b.callsub_block] if b.callsub_block is not None else [] + for bi in b.prev + callsub_block: + if bi not in worklist: + worklist.append(bi) + + for key in analysis_keys: + self._block_contexts[key] = global_liveout[key] + + @staticmethod + def _postorder(entry: "BasicBlock") -> List["BasicBlock"]: + visited: Set["BasicBlock"] = set() + order: List["BasicBlock"] = [] + + def dfs(block: "BasicBlock") -> None: + visited.add(block) + for successor in block.next: + if not successor in visited: + dfs(successor) + order.append(block) + + dfs(entry) + return order + + def run_analysis(self) -> None: + """Run analysis""" + + gtx_keys = [] + for key in self.KEYS_WITH_GTXN: + for ind in range(MAX_GROUP_SIZE): + gtx_keys.append(self.gtx_key(ind, key)) + + all_keys = self.BASE_KEYS + gtx_keys + + # step 1: initialise information + for block in self._teal.bbs: + self._block_level_constraints(all_keys, block) # initialise information for all keys + self._path_level_constraints(all_keys, block) + + postorder = self._postorder(self._entry_block) + + # perform analysis of base keys first. Information of these base keys will be used for + # gtxn keys. see `self._update_gtxn_constraints` + analysis_keys = list(self.BASE_KEYS) + + worklist = postorder[::-1] # Reverse postorder + self.forward_analyis(analysis_keys, worklist) + + worklist = [b for b in postorder if len(b.next) != 0] # postorder, exclude leaf blocks + self.backward_analysis(analysis_keys, worklist) + + # update gtxn constraints using possible group indices and txn constraints. + for block in self._teal.bbs: + self._update_gtxn_constraints(self.KEYS_WITH_GTXN, block) + + # Now perform analysis for gtx_keys + analysis_keys = gtx_keys + + worklist = postorder[::-1] # Reverse postorder + self.forward_analyis(analysis_keys, worklist) + + worklist = [b for b in postorder if len(b.next) != 0] # postorder, exclude leaf blocks + self.backward_analysis(analysis_keys, worklist) + + self._store_results() diff --git a/tealer/analyses/dataflow/int_fields.py b/tealer/analyses/dataflow/int_fields.py new file mode 100644 index 0000000..13d175b --- /dev/null +++ b/tealer/analyses/dataflow/int_fields.py @@ -0,0 +1,195 @@ +from typing import TYPE_CHECKING, List, Set, Tuple, Dict + +from tealer.analyses.dataflow.generic import DataflowTransactionContext +from tealer.teal.instructions.instructions import ( + Global, + Eq, + Neq, + Greater, + GreaterE, + Less, + LessE, + Txn, +) +from tealer.teal.global_field import GroupSize +from tealer.teal.instructions.transaction_field import GroupIndex +from tealer.utils.analyses import is_int_push_ins +from tealer.utils.algorand_constants import MAX_GROUP_SIZE + +if TYPE_CHECKING: + from tealer.teal.instructions.instructions import Instruction + +group_size_key = "GroupSize" +group_index_key = "GroupIndex" +analysis_keys = [group_size_key, group_index_key] +universal_sets = {} +universal_sets[group_size_key] = list(range(1, MAX_GROUP_SIZE + 1)) +universal_sets[group_index_key] = list(range(0, MAX_GROUP_SIZE)) + + +class GroupIndices(DataflowTransactionContext): # pylint: disable=too-few-public-methods + + GROUP_SIZE_KEY = group_size_key + GROUP_INDEX_KEY = group_index_key + BASE_KEYS: List[str] = analysis_keys + KEYS_WITH_GTXN: List[str] = [] # gtxn information is not collected for any of the keys + UNIVERSAL_SETS: Dict[str, List] = universal_sets + + def _universal_set(self, key: str) -> Set: + return set(self.UNIVERSAL_SETS[key]) + + def _null_set(self, key: str) -> Set: + return set() + + def _union(self, key: str, a: Set, b: Set) -> Set: + return a | b + + def _intersection(self, key: str, a: Set, b: Set) -> Set: + return a & b + + @staticmethod + def _get_asserted_int_values( + comparison_ins: "Instruction", compared_int: int, universal_set: List[int] + ) -> List[int]: + """return list of ints from universal set(U) that will satisfy the comparison. + + if the given condition uses ==, return compared int list. + if the condition uses != then return {U - compared_int} + if the given condition uses <, return U[ : U.index(compared_int)] + if the given condition uses <=, return U[ : U.index(compared_int) + 1] + if the given condition uses >, return U[U.index(compared_int) + 1:] + if the given condition uses >=, return U[U.index(compared_int): ] + + Args: + comparison_ins: comparison instruction used. can be [==, !=, <, <=, >, >=] + compared_int: integer value compared. + universal_set: list of all possible integer values for the field. + + Returns: + list of ints that will satisfy the comparison + """ + U = list(universal_set) + + if isinstance(comparison_ins, Eq): # pylint: disable=no-else-return + return [compared_int] + elif isinstance(comparison_ins, Neq): + if compared_int in U: + U.remove(compared_int) + return U + elif isinstance(comparison_ins, Less): + return [i for i in U if i < compared_int] + elif isinstance(comparison_ins, LessE): + return [i for i in U if i <= compared_int] + elif isinstance(comparison_ins, Greater): + return [i for i in U if i > compared_int] + elif isinstance(comparison_ins, GreaterE): + return [i for i in U if i >= compared_int] + else: + return U + + def _get_asserted_groupsizes(self, ins_stack: List["Instruction"]) -> Tuple[Set[int], Set[int]]: + """return set of values for groupsize that will make the comparison true and false + + checks for instruction sequence and returns group size values that will make the comparison true. + + [ Global GroupSize | (int | pushint)] + [ (int | pushint) | Global GroupSize] + [ == | != | < | <= | > | >=] + + Args: + ins_stack: list of instructions that are executed up until the comparison instruction (including the comparison instruction). + + Returns: + set of groupsize values that will make the comparison true, set of groupsize values that will make the comparison false. + """ + U = list(self.UNIVERSAL_SETS[self.GROUP_SIZE_KEY]) + if len(ins_stack) < 3: + return set(U), set(U) + + if isinstance(ins_stack[-1], (Eq, Neq, Less, LessE, Greater, GreaterE)): + ins1 = ins_stack[-2] + ins2 = ins_stack[-3] + compared_value = None + + if isinstance(ins1, Global) and isinstance(ins1.field, GroupSize): + is_int, value = is_int_push_ins(ins2) + if is_int: + compared_value = value + elif isinstance(ins2, Global) and isinstance(ins2.field, GroupSize): + is_int, value = is_int_push_ins(ins1) + if is_int: + compared_value = value + + if compared_value is None or not isinstance(compared_value, int): + # if the comparison does not check groupsize, return U as values that make the comparison false + return set(U), set(U) + + ins = ins_stack[-1] + asserted_values = self._get_asserted_int_values(ins, compared_value, U) + return set(asserted_values), set(U) - set(asserted_values) + return set(U), set(U) + + def _get_asserted_groupindices( + self, ins_stack: List["Instruction"] + ) -> Tuple[Set[int], Set[int]]: + """return list of values for group index that will make the comparison true and false + + checks for instruction sequence and returns group index values that will make the comparison true. + + [ txn GroupIndex | (int | pushint)] + [ (int | pushint) | txn GroupIndex] + [ == | != | < | <= | > | >=] + + Args: + ins_stack: list of instructions that are executed up until the comparison instruction (including the comparison instruction). + + Returns: + List of groupindex values that will make the comparison true. + """ + U = list(self.UNIVERSAL_SETS[self.GROUP_INDEX_KEY]) + if len(ins_stack) < 3: + return set(U), set(U) + + if isinstance(ins_stack[-1], (Eq, Neq, Less, LessE, Greater, GreaterE)): + ins1 = ins_stack[-2] + ins2 = ins_stack[-3] + compared_value = None + + if isinstance(ins1, Txn) and isinstance(ins1.field, GroupIndex): + is_int, value = is_int_push_ins(ins2) + if is_int: + compared_value = value + elif isinstance(ins2, Txn) and isinstance(ins2.field, GroupIndex): + is_int, value = is_int_push_ins(ins1) + if is_int: + compared_value = value + + if compared_value is None or not isinstance(compared_value, int): + return set(U), set(U) + + ins = ins_stack[-1] + asserted_values = self._get_asserted_int_values(ins, compared_value, U) + return set(asserted_values), set(U) - set(asserted_values) + return set(U), set(U) + + def _get_asserted(self, key: str, ins_stack: List["Instruction"]) -> Tuple[Set, Set]: + if key == self.GROUP_SIZE_KEY: + return self._get_asserted_groupsizes(ins_stack) + return self._get_asserted_groupindices(ins_stack) + + def _store_results(self) -> None: + # use group_sizes to update group_indices + group_sizes_context = self._block_contexts[self.GROUP_SIZE_KEY] + group_indices_context = self._block_contexts[self.GROUP_INDEX_KEY] + for bi in self._teal.bbs: + group_indices_context[bi] = group_indices_context[bi] & set( + range(0, max(group_sizes_context[bi], default=0)) + ) + + group_size_block_context = self._block_contexts[self.GROUP_SIZE_KEY] + for block in self._teal.bbs: + block.transaction_context.group_sizes = list(group_size_block_context[block]) + + group_index_block_context = self._block_contexts[self.GROUP_INDEX_KEY] + for block in self._teal.bbs: + block.transaction_context.group_indices = list(group_index_block_context[block]) diff --git a/tealer/teal/basic_blocks.py b/tealer/teal/basic_blocks.py index d174d91..991c760 100644 --- a/tealer/teal/basic_blocks.py +++ b/tealer/teal/basic_blocks.py @@ -16,12 +16,14 @@ from typing import List, Optional, TYPE_CHECKING from tealer.teal.instructions.instructions import Instruction +from tealer.teal.context.block_transaction_context import BlockTransactionContext + if TYPE_CHECKING: from tealer.teal.teal import Teal -class BasicBlock: +class BasicBlock: # pylint: disable=too-many-instance-attributes """Class to represent basic blocks of the teal contract. A basic block is a sequence of instructions with a single entry @@ -37,6 +39,9 @@ def __init__(self) -> None: self._next: List[BasicBlock] = [] self._idx: int = 0 self._teal: Optional["Teal"] = None + self._transaction_context = BlockTransactionContext() + self._callsub_block: Optional[BasicBlock] = None + self._sub_return_point: Optional[BasicBlock] = None def add_instruction(self, instruction: Instruction) -> None: """Append instruction to this basic block. @@ -113,6 +118,28 @@ def idx(self) -> int: def idx(self, i: int) -> None: self._idx = i + @property + def callsub_block(self) -> Optional["BasicBlock"]: + """If this block is the return point of a subroutine, `callsub_block` is the block + that called the subroutine. + """ + return self._callsub_block + + @callsub_block.setter + def callsub_block(self, b: "BasicBlock") -> None: + self._callsub_block = b + + @property + def sub_return_point(self) -> Optional["BasicBlock"]: + """If a subroutine is executed after this block i.e exit instruction is callsub. + then, sub_return_point will be basic block that will be executed after the subroutine. + """ + return self._sub_return_point + + @sub_return_point.setter + def sub_return_point(self, b: "BasicBlock") -> None: + self._sub_return_point = b + @property def cost(self) -> int: """cost of executing all instructions in this basic block""" @@ -127,6 +154,10 @@ def teal(self) -> Optional["Teal"]: def teal(self, teal_instance: "Teal") -> None: self._teal = teal_instance + @property + def transaction_context(self) -> "BlockTransactionContext": + return self._transaction_context + def __str__(self) -> str: ret = "" for ins in self._instructions: diff --git a/tealer/teal/context/__init__.py b/tealer/teal/context/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tealer/teal/context/block_transaction_context.py b/tealer/teal/context/block_transaction_context.py new file mode 100644 index 0000000..313f1a3 --- /dev/null +++ b/tealer/teal/context/block_transaction_context.py @@ -0,0 +1,30 @@ +from typing import List, Optional + +from tealer.exceptions import TealerException +from tealer.utils.algorand_constants import MAX_GROUP_SIZE + + +class BlockTransactionContext: # pylint: disable=too-few-public-methods + + _group_transactions_context: Optional[List["BlockTransactionContext"]] = None + + def __init__(self, tail: bool = False) -> None: + if not tail: + self._group_transactions_context = [BlockTransactionContext(True) for _ in range(16)] + + # set default values + if tail: + # information from gtxn {i} instructions. + self.group_indices = [] + self.group_sizes = [] + else: + self.group_sizes = list(range(1, MAX_GROUP_SIZE + 1)) + self.group_indices = list(range(0, MAX_GROUP_SIZE)) + + def gtxn_context(self, txn_index: int) -> "BlockTransactionContext": + """context information collected from gtxn {txn_index} field instructions""" + if self._group_transactions_context is None: + raise TealerException() + if txn_index >= MAX_GROUP_SIZE: + raise TealerException() + return self._group_transactions_context[txn_index] diff --git a/tealer/teal/instructions/instructions.py b/tealer/teal/instructions/instructions.py index 4b40625..313f651 100644 --- a/tealer/teal/instructions/instructions.py +++ b/tealer/teal/instructions/instructions.py @@ -49,7 +49,7 @@ class ContractType(ComparableEnum): } -class Instruction: +class Instruction: # pylint: disable=too-many-instance-attributes """Base class for Teal instructions. Any class that represents a teal instruction must inherit from @@ -66,6 +66,7 @@ def __init__(self) -> None: self._bb: Optional["BasicBlock"] = None self._version: int = 1 self._mode: ContractType = ContractType.ANY + self._callsub_ins: Optional["Instruction"] = None def add_prev(self, prev_ins: "Instruction") -> None: """Add instruction that may execute just before this instruction. @@ -137,6 +138,25 @@ def bb(self) -> Optional["BasicBlock"]: def bb(self, b: "BasicBlock") -> None: self._bb = b + @property + def callsub_ins(self) -> Optional["Instruction"]: + """if this instruction is a return point to a callsub instruction i.e callsub instruction is + present right before this instruction, then callsub_ins returns a reference to that callsub + instruction object. + + e.g + callsub main + int 1 + return + + callsub_ins of `int 1` will be instruction obj of `callsub main`. + """ + return self._callsub_ins + + @callsub_ins.setter + def callsub_ins(self, ins: "Instruction") -> None: + self._callsub_ins = ins + @property def version(self) -> int: """Teal version this instruction is introduced in and supported from.""" diff --git a/tealer/teal/parse_teal.py b/tealer/teal/parse_teal.py index edf99ba..8f1b9e0 100644 --- a/tealer/teal/parse_teal.py +++ b/tealer/teal/parse_teal.py @@ -25,6 +25,7 @@ """ +import inspect import sys from typing import Optional, Dict, List @@ -48,6 +49,8 @@ from tealer.teal.instructions.asset_params_field import AssetParamsField from tealer.teal.instructions.app_params_field import AppParamsField from tealer.teal.teal import Teal +from tealer.analyses.dataflow import all_constraints +from tealer.analyses.dataflow.generic import DataflowTransactionContext def _detect_contract_type(instructions: List[Instruction]) -> ContractType: @@ -124,6 +127,14 @@ def create_bb(instructions: List[Instruction], all_bbs: List[BasicBlock]) -> Non bb.add_instruction(ins) ins.bb = bb + if ins.callsub_ins is not None and ins.bb is not None: + # callsub is before this instruction in the code. so, bb should have been assigned + # already + callsub_basic_block = ins.callsub_ins.bb + if callsub_basic_block is not None: + ins.bb.callsub_block = callsub_basic_block + callsub_basic_block.sub_return_point = ins.bb + if len(ins.next) > 1 and not isinstance(ins, Retsub): if not isinstance(ins.next[0], Label): next_bb = BasicBlock() @@ -210,6 +221,7 @@ def _first_pass( rets[call.label].append(ins) else: rets[call.label] = [ins] + ins.callsub_ins = call # ins is the return point when call is executed. # Now prepare for the next-line instruction # A flag that says that this was an unconditional jump @@ -455,6 +467,22 @@ def _verify_version(ins_list: List[Instruction], program_version: int) -> bool: return error +def _apply_transaction_context_analysis(teal: "Teal") -> None: + group_indices_cls = all_constraints.GroupIndices + analyses_classes = [getattr(all_constraints, name) for name in dir(all_constraints)] + analyses_classes = [ + c + for c in analyses_classes + if inspect.isclass(c) + and issubclass(c, DataflowTransactionContext) + and c != group_indices_cls + ] + # Run group indices analysis first as other analysis use them. + group_indices_cls(teal).run_analysis() + for cl in analyses_classes: + cl(teal).run_analysis() + + def parse_teal(source_code: str) -> Teal: """Parse algorand smart contracts written in teal. @@ -515,4 +543,6 @@ def parse_teal(source_code: str) -> Teal: for bb in teal.bbs: bb.teal = teal + _apply_transaction_context_analysis(teal) + return teal diff --git a/tealer/utils/algorand_constants.py b/tealer/utils/algorand_constants.py new file mode 100644 index 0000000..78843b3 --- /dev/null +++ b/tealer/utils/algorand_constants.py @@ -0,0 +1,2 @@ +MAX_GROUP_SIZE = 16 +MIN_ALGORAND_FEE = 1000 # in micro algos diff --git a/tealer/utils/analyses.py b/tealer/utils/analyses.py index 3f17933..62043aa 100644 --- a/tealer/utils/analyses.py +++ b/tealer/utils/analyses.py @@ -19,7 +19,7 @@ from tealer.teal.instructions.transaction_field import TransactionField, OnCompletion, ApplicationID -def _is_int_push_ins(ins: Instruction) -> Tuple[bool, Optional[Union[int, str]]]: +def is_int_push_ins(ins: Instruction) -> Tuple[bool, Optional[Union[int, str]]]: if isinstance(ins, Int) or isinstance( # pylint: disable=consider-merging-isinstance ins, PushInt ): @@ -123,7 +123,7 @@ def detect_missing_txn_check( if isinstance(ins, Return): if len(ins.prev) == 1: prev = ins.prev[0] - is_int_push, value = _is_int_push_ins(prev) + is_int_push, value = is_int_push_ins(prev) if is_int_push and value == 0: return @@ -172,7 +172,7 @@ def is_oncompletion_check(ins1: Instruction, ins2: Instruction, checked_values: integer_checked_values.append(ENUM_NAMES_TO_INT[named_constant]) if isinstance(ins1, Txn) and isinstance(ins1.field, OnCompletion): - is_int_push, value = _is_int_push_ins(ins2) + is_int_push, value = is_int_push_ins(ins2) return is_int_push and (value in checked_values or value in integer_checked_values) return False @@ -200,7 +200,7 @@ def is_application_creation_check(ins1: Instruction, ins2: Instruction) -> bool: """ if isinstance(ins1, Txn) and isinstance(ins1.field, ApplicationID): - is_int_push, value = _is_int_push_ins(ins2) + is_int_push, value = is_int_push_ins(ins2) return is_int_push and value == 0 return False @@ -256,7 +256,7 @@ def detect_missing_on_completion( # pylint: disable=too-many-branches, too-many if isinstance(ins, Return): if len(ins.prev) == 1: prev = ins.prev[0] - is_int_push, value = _is_int_push_ins(prev) + is_int_push, value = is_int_push_ins(prev) if is_int_push and value == 0: return diff --git a/tests/transaction_context/__init__.py b/tests/transaction_context/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/transaction_context/test_group_indices.py b/tests/transaction_context/test_group_indices.py new file mode 100644 index 0000000..dbc5ac1 --- /dev/null +++ b/tests/transaction_context/test_group_indices.py @@ -0,0 +1,182 @@ +from typing import List, Tuple +import pytest + +from tealer.teal.parse_teal import parse_teal +from tests.utils import order_basic_blocks + + +MULTIPLE_RETSUB = """ +#pragma version 5 +b main +is_even: + int 2 + % + bz return_1 + int 0 + txn GroupIndex + int 2 + == + assert + retsub +return_1: + txn GroupIndex + int 3 + < + assert + int 1 + retsub +main: + txn GroupIndex + int 1 + != + assert + int 4 + callsub is_even + return +""" + +MULTIPLE_RETSUB_GROUP_INDICES = [[0, 2], [0, 2], [2], [0, 2], [0, 2], [0, 2]] + +SUBROUTINE_BACK_JUMP = """ +#pragma version 5 +b main +getmod: + % + retsub +is_odd: + txn GroupIndex + int 4 + < + assert + txn GroupIndex + int 2 + != + assert + int 2 + b getmod +main: + int 5 + callsub is_odd + return +""" + +SUBROUTINE_BACK_JUMP_GROUP_INDICES = [[0, 1, 3], [0, 1, 3], [0, 1, 3], [0, 1, 3], [0, 1, 3], [0, 1, 3]] + +BRANCHING = """ +#pragma version 4 +txn GroupIndex +int 2 +>= +assert +txn GroupIndex +int 4 +> +bz fin +txn GroupIndex +int 1 +== +bnz check_second_arg +int 0 +return +check_second_arg: +txn ApplicationArgs 1 +btoi +int 100 +> +bnz fin +int 0 +return +fin: +int 1 +return +""" + +BRANCHING_GROUP_INDICES = [[2, 3, 4], [], [], [], [], [2, 3, 4]] + +LOOPS = """ +#pragma version 5 +txn GroupIndex +int 4 +!= +assert +int 0 +loop: + dup + txn GroupIndex + int 3 + >= + bz end + int 1 + + + txn GroupIndex + int 3 + < + assert + b loop +end: + int 2 + txn GroupIndex + == + assert + int 1 + return +""" + +LOOPS_GROUP_INDICES = [[2], [2], [], [2]] + +LOOPS_GROUP_SIZES = """ +#pragma version 5 +txn GroupIndex +int 4 +!= +assert +global GroupSize +int 6 +<= +int 0 +loop: + dup + txn GroupIndex + int 3 + > + bz end + int 1 + + + txn GroupIndex + int 6 + < + assert + b loop +end: + int 2 + txn GroupIndex + > + assert + int 5 + global GroupSize + <= + assert + int 1 + return +""" + +LOOPS_GROUP_SIZES_GROUP_INDICES = [[3], [3], [], [3]] + +ALL_TESTS = [ + (MULTIPLE_RETSUB, MULTIPLE_RETSUB_GROUP_INDICES), + (SUBROUTINE_BACK_JUMP, SUBROUTINE_BACK_JUMP_GROUP_INDICES), + (BRANCHING, BRANCHING_GROUP_INDICES), + (LOOPS, LOOPS_GROUP_INDICES), + (LOOPS_GROUP_SIZES, LOOPS_GROUP_SIZES_GROUP_INDICES), +] + + +@pytest.mark.parametrize("test", ALL_TESTS) # type: ignore +def test_group_indices(test: Tuple[str, List[List[int]]]) -> None: + code, group_indices = test + teal = parse_teal(code.strip()) + + bbs = order_basic_blocks(teal.bbs) + for b, indices in zip(bbs, group_indices): + assert b.transaction_context.group_indices == indices + diff --git a/tests/transaction_context/test_group_sizes.py b/tests/transaction_context/test_group_sizes.py new file mode 100644 index 0000000..cdc3eaa --- /dev/null +++ b/tests/transaction_context/test_group_sizes.py @@ -0,0 +1,319 @@ +from typing import List, Tuple +import pytest + + +from tealer.teal.basic_blocks import BasicBlock +from tealer.teal.instructions import instructions +from tealer.teal.instructions import transaction_field +from tealer.teal import global_field +from tealer.teal.parse_teal import parse_teal + +from tests.utils import cmp_cfg, construct_cfg, order_basic_blocks + + +MULTIPLE_RETSUB = """ +#pragma version 5 +b main +is_even: + int 2 + % + bz return_1 + int 0 + global GroupSize + int 2 + == + assert + retsub +return_1: + global GroupSize + int 3 + < + assert + int 1 + retsub +main: + global GroupSize + int 1 + != + assert + int 4 + callsub is_even + return +""" + +ins_list = [ + instructions.Pragma(5), + instructions.B("main"), + instructions.Label("is_even"), + instructions.Int(2), + instructions.Modulo(), + instructions.BZ("return_1"), + instructions.Int(0), + instructions.Global(global_field.GroupSize()), + instructions.Int(2), + instructions.Eq(), + instructions.Assert(), + instructions.Retsub(), + instructions.Label("return_1"), + instructions.Global(global_field.GroupSize()), + instructions.Int(3), + instructions.Less(), + instructions.Assert(), + instructions.Int(1), + instructions.Retsub(), + instructions.Label("main"), + instructions.Global(global_field.GroupSize()), + instructions.Int(1), + instructions.Neq(), + instructions.Assert(), + instructions.Int(4), + instructions.Callsub("is_even"), + instructions.Return(), +] + +ins_partitions = [(0, 2), (2, 6), (6, 12), (12, 19), (19, 26), (26, 27)] +bbs_links = [(0, 4), (4, 1), (1, 2), (1, 3), (2, 5), (3, 5)] + +MULTIPLE_RETSUB_CFG_GROUP_SIZES = [[2], [2], [2], [2], [2], [2]] + +MULTIPLE_RETSUB_CFG = construct_cfg(ins_list, ins_partitions, bbs_links) + + +SUBROUTINE_BACK_JUMP = """ +#pragma version 5 +b main +getmod: + % + retsub +is_odd: + global GroupSize + int 4 + < + assert + global GroupSize + int 2 + != + assert + int 2 + b getmod +main: + int 5 + callsub is_odd + return +""" + +ins_list = [ + instructions.Pragma(5), + instructions.B("main"), + instructions.Label("getmod"), + instructions.Modulo(), + instructions.Retsub(), + instructions.Label("is_odd"), + instructions.Global(global_field.GroupSize()), + instructions.Int(4), + instructions.Less(), + instructions.Assert(), + instructions.Global(global_field.GroupSize()), + instructions.Int(2), + instructions.Neq(), + instructions.Assert(), + instructions.Int(2), + instructions.B("getmod"), + instructions.Label("main"), + instructions.Int(5), + instructions.Callsub("is_odd"), + instructions.Return(), +] + +ins_partitions = [(0, 2), (2, 5), (5, 16), (16, 19), (19, 20)] +bbs_links = [(0, 3), (3, 2), (2, 1), (1, 4)] + +SUBROUTINE_BACK_JUMP_CFG_GROUP_SIZES = [[1, 3], [1, 3], [1, 3], [1, 3], [1, 3], [1, 3]] +SUBROUTINE_BACK_JUMP_CFG = construct_cfg(ins_list, ins_partitions, bbs_links) + +BRANCHING = """ +#pragma version 4 +global GroupSize +int 2 +>= +assert +global GroupSize +int 4 +> +bz fin +global GroupSize +int 1 +== +bnz check_second_arg +int 0 +return +check_second_arg: +txn ApplicationArgs 1 +btoi +int 100 +> +bnz fin +int 0 +return +fin: +int 1 +return +""" + +ins_list = [ + instructions.Pragma(4), + instructions.Global(global_field.GroupSize()), + instructions.Int(2), + instructions.GreaterE(), + instructions.Assert(), + instructions.Global(global_field.GroupSize()), + instructions.Int(4), + instructions.Greater(), + instructions.BZ("fin"), + instructions.Global(global_field.GroupSize()), + instructions.Int(1), + instructions.Eq(), + instructions.BNZ("check_second_arg"), + instructions.Int(0), + instructions.Return(), + instructions.Label("check_second_arg"), + instructions.Txn(transaction_field.ApplicationArgs(1)), + instructions.Btoi(), + instructions.Int(100), + instructions.Greater(), + instructions.BNZ("fin"), + instructions.Int(0), + instructions.Return(), + instructions.Label("fin"), + instructions.Int(1), + instructions.Return(), +] + +ins_partitions = [(0, 9), (9, 13), (13, 15), (15, 21), (21, 23), (23, 26)] +bbs_links = [(0, 1), (0, 5), (1, 2), (1, 3), (3, 4), (3, 5)] + +BRANCHING_CFG_GROUP_SIZES = [[2, 3, 4], [], [], [], [], [2, 3, 4]] +BRANCHING_CFG = construct_cfg(ins_list, ins_partitions, bbs_links) + +LOOPS = """ +#pragma version 5 +global GroupSize +int 4 +!= +assert +int 0 +loop: + dup + global GroupSize + int 3 + >= + bz end + int 1 + + + global GroupSize + int 3 + < + assert + b loop +end: + int 2 + global GroupSize + == + assert + int 1 + return +""" + +ins_list = [ + instructions.Pragma(5), + instructions.Global(global_field.GroupSize()), + instructions.Int(4), + instructions.Neq(), + instructions.Assert(), + instructions.Int(0), + instructions.Label("loop"), + instructions.Dup(), + instructions.Global(global_field.GroupSize()), + instructions.Int(3), + instructions.GreaterE(), + instructions.BZ("end"), + instructions.Int(1), + instructions.Add(), + instructions.Global(global_field.GroupSize()), + instructions.Int(3), + instructions.Less(), + instructions.Assert(), + instructions.B("loop"), + instructions.Label("end"), + instructions.Int(2), + instructions.Global(global_field.GroupSize()), + instructions.Eq(), + instructions.Assert(), + instructions.Int(1), + instructions.Return(), +] + +ins_partitions = [(0, 6), (6, 12), (12, 19), (19, 26)] +bbs_links = [(0, 1), (1, 2), (1, 3), (2, 1)] + +LOOPS_CFG_GROUP_SIZES = [[2], [2], [], [2]] +LOOPS_CFG = construct_cfg(ins_list, ins_partitions, bbs_links) + +cfg_group_sizes = [ + (MULTIPLE_RETSUB_CFG, MULTIPLE_RETSUB_CFG_GROUP_SIZES), + (SUBROUTINE_BACK_JUMP_CFG, SUBROUTINE_BACK_JUMP_CFG_GROUP_SIZES), + (BRANCHING_CFG, BRANCHING_CFG_GROUP_SIZES), + (LOOPS_CFG, LOOPS_CFG_GROUP_SIZES), +] + +for cfg, sizes in cfg_group_sizes: + bb = order_basic_blocks(cfg) + for b, group_sizes in zip(bb, sizes): + b.transaction_context.group_sizes = group_sizes + + +ALL_TESTS = [ + (MULTIPLE_RETSUB, MULTIPLE_RETSUB_CFG), + (SUBROUTINE_BACK_JUMP, SUBROUTINE_BACK_JUMP_CFG), + (BRANCHING, BRANCHING_CFG), + (LOOPS, LOOPS_CFG), +] + + +@pytest.mark.parametrize("test", ALL_TESTS) # type: ignore +def test_group_sizes(test: Tuple[str, List[BasicBlock]]) -> None: + code, cfg = test + teal = parse_teal(code.strip()) + for bb in cfg: + print(bb) + print("*" * 20) + assert cmp_cfg(teal.bbs, cfg) + + bbs = order_basic_blocks(teal.bbs) + cfg = order_basic_blocks(cfg) + for b1, b2 in zip(bbs, cfg): + print(b1.transaction_context.group_sizes, b2.transaction_context.group_sizes) + assert b1.transaction_context.group_sizes == b2.transaction_context.group_sizes + + +MULTIPLE_RETSUB_CFG_GROUP_INDICES = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]] +SUBROUTINE_BACK_JUMP_GROUP_INDICES = [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]] +BRANCHING_GROUP_INDICES = [[0, 1, 2, 3], [], [], [], [], [0, 1, 2, 3]] +LOOPS_GROUP_INDICES = [[0, 1], [0, 1], [], [0, 1]] + +GROUP_INDICES_TESTS = [ + (MULTIPLE_RETSUB, MULTIPLE_RETSUB_CFG_GROUP_INDICES), + (SUBROUTINE_BACK_JUMP, SUBROUTINE_BACK_JUMP_GROUP_INDICES), + (BRANCHING, BRANCHING_GROUP_INDICES), + (LOOPS, LOOPS_GROUP_INDICES), +] + +@pytest.mark.parametrize("test", GROUP_INDICES_TESTS) # type: ignore +def test_group_indices(test: Tuple[str, List[List[int]]]) -> None: + code, group_indices_list = test + teal = parse_teal(code.strip()) + + bbs = order_basic_blocks(teal.bbs) + for b, group_indices in zip(bbs, group_indices_list): + print(b.transaction_context.group_indices, group_indices) + assert b.transaction_context.group_indices == group_indices