Skip to content

Commit

Permalink
Merge pull request #14 from ucsb-seclab/sanity_check_callprivate
Browse files Browse the repository at this point in the history
Add a sanity check for CALLPRIVATE argument count match
  • Loading branch information
ruaronicola authored Jan 24, 2024
2 parents 91e3cdc + 1f17d01 commit 8dab556
Showing 1 changed file with 73 additions and 5 deletions.
78 changes: 73 additions & 5 deletions greed/project.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import List, Tuple
import logging
import time
from collections import defaultdict

import networkx as nx
import web3

from collections import defaultdict

from greed.TAC.TAC_parser import TAC_parser
from greed.TAC.gigahorse_ops import TAC_Callprivateargs
from greed.factory import Factory
from greed import options as opt
from greed.factory import Factory
from greed.function import TAC_Function
from greed.TAC.gigahorse_ops import TAC_Callprivate, TAC_Callprivateargs
from greed.TAC.TAC_parser import TAC_parser

log = logging.getLogger(__name__)

Expand All @@ -17,6 +19,11 @@ class Project(object):
"""
This is the main class for creating a greed Project!
"""
code: str
factory: Factory
tac_parser: TAC_parser


def __init__(self, target_dir: str):
"""
Args:
Expand Down Expand Up @@ -63,7 +70,68 @@ def __init__(self, target_dir: str):
self.w3 = None
except Exception as e:
self.w3 = None

self.sanity_check()

def sanity_check(self, raise_on_failure=False) -> bool:
"""
Perform standard sanity checks after loading a project.
If any checks fail, this will return False. If raise_on_failure is True,
it will also raise an exception on first failure.
"""
t_start = time.time()
# Tracks whether all checks passed. In order to gather more information,
# we do not immediately return False when a check fails. Instead, we
# continue checking, and only return False (the variable `checks_passed`) at the end.
checks_passed = True

#
# Check: All CALLPRIVATE statements have the correct number of arguments
#

# First, gather all CALLPRIVATE statements
callprivate_statements: List[TAC_Callprivate] = []
for block in self.block_at.values():
for statement in block.statements:
if isinstance(statement, TAC_Callprivate):
callprivate_statements.append(statement)

# Find the target function for each CALLPRIVATE statement
callprivate_statements_with_target: List[Tuple[TAC_Callprivate, TAC_Function]] = []
for statement in callprivate_statements:
if not hasattr(statement, "arg1_val") or statement.arg1_val is None:
log.debug(f"CALLPRIVATE statement {statement.id} has no known target function")
continue

target_block = self.factory.block(hex(statement.arg1_val.value))
if target_block is None:
log.debug(f"CALLPRIVATE statement {statement.id} has no known target function")
continue

target_function = target_block.function
assert target_function is not None, f"Target block {target_block.id} of CALLPRIVATE statement {statement.id} has no function"

callprivate_statements_with_target.append((statement, target_function))

# Check that the number of arguments is correct
for statement, target_function in callprivate_statements_with_target:
if len(statement.arg_vars) - 1 != len(target_function.arguments): # NOTE: -1 because the first argument is the target block
err_msg = f"CALLPRIVATE statement {statement.id} has {len(statement.arg_vars)} arguments, " \
f"but target function {target_function.id} expects {len(target_function.arguments)}"
log.warning(err_msg)
if raise_on_failure:
raise ValueError(err_msg)
checks_passed = False

elapsed = time.time() - t_start
if checks_passed:
log.debug(f"All sanity checks passed in {elapsed:.4f}s")
else:
log.warning(f"Some sanity checks failed in {elapsed:.4f}s")

return checks_passed

def dump_callgraph(self, filename):
"""
Dump the callgraph in a dot file.
Expand Down

0 comments on commit 8dab556

Please sign in to comment.