diff --git a/greed/project.py b/greed/project.py index a1d51cc..016e760 100644 --- a/greed/project.py +++ b/greed/project.py @@ -1,14 +1,16 @@ +from typing import List, Tuple import logging +import time +from collections import defaultdict import networkx as nx import web3 -from collections import defaultdict - -from greed.TAC.TAC_parser import TAC_parser -from greed.TAC.gigahorse_ops import TAC_Callprivateargs -from greed.factory import Factory from greed import options as opt +from greed.factory import Factory +from greed.function import TAC_Function +from greed.TAC.gigahorse_ops import TAC_Callprivate, TAC_Callprivateargs +from greed.TAC.TAC_parser import TAC_parser log = logging.getLogger(__name__) @@ -17,6 +19,11 @@ class Project(object): """ This is the main class for creating a greed Project! """ + code: str + factory: Factory + tac_parser: TAC_parser + + def __init__(self, target_dir: str): """ Args: @@ -63,7 +70,68 @@ def __init__(self, target_dir: str): self.w3 = None except Exception as e: self.w3 = None + + self.sanity_check() + + def sanity_check(self, raise_on_failure=False) -> bool: + """ + Perform standard sanity checks after loading a project. + + If any checks fail, this will return False. If raise_on_failure is True, + it will also raise an exception on first failure. + """ + t_start = time.time() + # Tracks whether all checks passed. In order to gather more information, + # we do not immediately return False when a check fails. Instead, we + # continue checking, and only return False (the variable `checks_passed`) at the end. + checks_passed = True + + # + # Check: All CALLPRIVATE statements have the correct number of arguments + # + # First, gather all CALLPRIVATE statements + callprivate_statements: List[TAC_Callprivate] = [] + for block in self.block_at.values(): + for statement in block.statements: + if isinstance(statement, TAC_Callprivate): + callprivate_statements.append(statement) + + # Find the target function for each CALLPRIVATE statement + callprivate_statements_with_target: List[Tuple[TAC_Callprivate, TAC_Function]] = [] + for statement in callprivate_statements: + if not hasattr(statement, "arg1_val") or statement.arg1_val is None: + log.debug(f"CALLPRIVATE statement {statement.id} has no known target function") + continue + + target_block = self.factory.block(hex(statement.arg1_val.value)) + if target_block is None: + log.debug(f"CALLPRIVATE statement {statement.id} has no known target function") + continue + + target_function = target_block.function + assert target_function is not None, f"Target block {target_block.id} of CALLPRIVATE statement {statement.id} has no function" + + callprivate_statements_with_target.append((statement, target_function)) + + # Check that the number of arguments is correct + for statement, target_function in callprivate_statements_with_target: + if len(statement.arg_vars) - 1 != len(target_function.arguments): # NOTE: -1 because the first argument is the target block + err_msg = f"CALLPRIVATE statement {statement.id} has {len(statement.arg_vars)} arguments, " \ + f"but target function {target_function.id} expects {len(target_function.arguments)}" + log.warning(err_msg) + if raise_on_failure: + raise ValueError(err_msg) + checks_passed = False + + elapsed = time.time() - t_start + if checks_passed: + log.debug(f"All sanity checks passed in {elapsed:.4f}s") + else: + log.warning(f"Some sanity checks failed in {elapsed:.4f}s") + + return checks_passed + def dump_callgraph(self, filename): """ Dump the callgraph in a dot file.