Skip to content

Commit

Permalink
Move fork variables to py-evm
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielSchiavini committed Apr 15, 2024
1 parent 59a78dd commit ba65161
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 37 deletions.
22 changes: 3 additions & 19 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ class Env:
_singleton = None
_random = random.Random("titanoboa") # something reproducible
_coverage_enabled = False
_fast_mode_enabled = False
_fork_try_prefetch_state = False

def __init__(self):
def __init__(self, fork_try_prefetch_state=False, fast_mode_enabled=False):
self._gas_price = None

self._aliases = {}
Expand All @@ -49,7 +47,7 @@ def __init__(self):

self._gas_tracker = 0

self.evm = PyEVM(self, self._fast_mode_enabled)
self.evm = PyEVM(self, fast_mode_enabled, fork_try_prefetch_state)

def set_random_seed(self, seed=None):
self._random = random.Random(seed)
Expand All @@ -58,7 +56,6 @@ def get_gas_price(self):
return self._gas_price or 0

def enable_fast_mode(self, flag: bool = True):
self._fast_mode_enabled = flag
self.evm.enable_fast_mode(flag)

def fork(self, url: str, reset_traces=True, block_identifier="safe", **kwargs):
Expand All @@ -78,20 +75,11 @@ def fork_rpc(self, rpc: RPC, reset_traces=True, block_identifier="safe", **kwarg
self.sha3_trace = {}
self.sstore_trace = {}

self.evm.fork_rpc(
rpc,
fast_mode_enabled=self._fast_mode_enabled,
block_identifier=block_identifier,
**kwargs,
)
self.evm.fork_rpc(rpc, block_identifier=block_identifier, **kwargs)

def get_gas_meter_class(self):
return self.evm.get_gas_meter_class()

@property
def _fork_mode(self):
return self.evm.is_forked

def set_gas_meter_class(self, cls: type) -> None:
self.evm.set_gas_meter_class(cls)

Expand Down Expand Up @@ -218,13 +206,11 @@ def deploy_code(
else:
target_address = Address(override_address)

prefetch_state = self._fork_mode and self._fork_try_prefetch_state
origin = sender # XXX: consider making this parameterizable
c = self.evm.deploy_code(
sender=sender,
origin=origin,
target_address=target_address,
prefetch_state=prefetch_state,
gas=gas,
gas_price=self.get_gas_price(),
value=value,
Expand Down Expand Up @@ -286,11 +272,9 @@ def execute_code(
bytecode = self.evm.get_code(to)

is_static = not is_modifying
prefetch_state = self._fork_mode and self._fork_try_prefetch_state
ret = self.evm.execute_code(
sender=sender,
to=to,
prefetch_state=prefetch_state,
gas=gas,
gas_price=self.get_gas_price(),
value=value,
Expand Down
4 changes: 2 additions & 2 deletions boa/integrations/jupyter/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ class BrowserEnv(NetworkEnv):
A NetworkEnv object that uses the BrowserSigner and BrowserRPC classes.
"""

def __init__(self, address=None):
super().__init__(rpc=BrowserRPC())
def __init__(self, address=None, **kwargs):
super().__init__(rpc=BrowserRPC(), **kwargs)
self.signer = BrowserSigner(address)
self.set_eoa(self.signer)

Expand Down
10 changes: 8 additions & 2 deletions boa/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ class NetworkEnv(Env):
# always prefetch state in network mode
_fork_try_prefetch_state = True

def __init__(self, rpc: str | RPC, accounts: dict[str, Account] = None):
super().__init__()
def __init__(
self,
rpc: str | RPC,
accounts: dict[str, Account] = None,
fork_try_prefetch_state=True,
**kwargs,
):
super().__init__(fork_try_prefetch_state, **kwargs)

if isinstance(rpc, str):
warnings.warn(
Expand Down
27 changes: 13 additions & 14 deletions boa/vm/py_evm.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def apply_create_message(cls, state, msg, tx_ctx, **kwargs):

if is_eip1167_contract(bytecode):
contract_address = extract_eip1167_address(bytecode)
bytecode = cls.env.vm.state.get_code(contract_address)
bytecode = cls.env.evm.vm.state.get_code(contract_address)

if bytecode in cls.env._code_registry:
target = cls.env._code_registry[bytecode].deployer.at(contract_address)
Expand All @@ -324,7 +324,7 @@ def finalize(c):
c._contract_repr_before_revert = repr(contract)
return c

if contract is None or not cls.env._fast_mode_enabled:
if contract is None or not cls.env.evm._fast_mode_enabled:
# print("SLOW MODE")
computation = super().apply_computation(state, msg, tx_ctx, **kwargs)
return finalize(computation)
Expand Down Expand Up @@ -369,12 +369,14 @@ def __init__(


class PyEVM:
def __init__(self, env, fast_mode_enabled: bool):
def __init__(self, env, fast_mode_enabled=False, fork_try_prefetch_state=False):
self.chain = _make_chain()
self.env = env
self._init_vm(env, AccountDB, fast_mode_enabled=fast_mode_enabled)
self._fast_mode_enabled = fast_mode_enabled
self._fork_try_prefetch_state = fork_try_prefetch_state
self._init_vm(env, AccountDB)

def _init_vm(self, env, account_db_class: Type[AccountDB], fast_mode_enabled: bool):
def _init_vm(self, env, account_db_class: Type[AccountDB]):
self.vm = self.chain.get_vm()
self.vm.__class__._state_class.account_db_class = account_db_class

Expand All @@ -386,7 +388,7 @@ def _init_vm(self, env, account_db_class: Type[AccountDB], fast_mode_enabled: bo
{"env": env},
)

if fast_mode_enabled:
if self._fast_mode_enabled:
patch_pyevm_state_object(self.vm.state)

self.vm.state.computation_class = c
Expand All @@ -408,11 +410,9 @@ def enable_fast_mode(self, flag: bool = True):
else:
unpatch_pyevm_state_object(self.vm.state)

def fork_rpc(
self, rpc: RPC, fast_mode_enabled: bool, block_identifier: str, **kwargs
):
def fork_rpc(self, rpc: RPC, block_identifier: str, **kwargs):
account_db_class = AccountDBFork.class_from_rpc(rpc, block_identifier, **kwargs)
self._init_vm(self.env, account_db_class, fast_mode_enabled)
self._init_vm(self.env, account_db_class)
block_info = self.vm.state._account_db._block_info

self.vm.patch.timestamp = int(block_info["timestamp"], 16)
Expand Down Expand Up @@ -474,7 +474,6 @@ def deploy_code(
sender: Address,
origin: Address,
target_address: Address,
prefetch_state: bool,
gas: Optional[int],
gas_price: int,
value: int,
Expand All @@ -493,7 +492,7 @@ def deploy_code(
data=b"",
)

if prefetch_state:
if self.is_forked and self._fork_try_prefetch_state:
self.vm.state._account_db.try_prefetch_state(msg)

tx_ctx = BaseTransactionContext(
Expand All @@ -507,7 +506,6 @@ def execute_code(
self,
sender: Address,
to: Address,
prefetch_state: bool,
gas: int,
gas_price: int,
value: int,
Expand All @@ -532,7 +530,8 @@ def execute_code(
ir_executor=ir_executor,
contract=contract,
)
if prefetch_state:

if self.is_forked and self._fork_try_prefetch_state:
self.vm.state._account_db.try_prefetch_state(msg)

origin = sender.canonical_address # XXX: consider making this parameterizable
Expand Down

0 comments on commit ba65161

Please sign in to comment.