diff --git a/src/ape/api/accounts.py b/src/ape/api/accounts.py index c5e3ba1dc0..6400a865b7 100644 --- a/src/ape/api/accounts.py +++ b/src/ape/api/accounts.py @@ -416,6 +416,28 @@ def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI: return txn + def get_deployment_address(self, nonce: Optional[int] = None) -> AddressType: + """ + Get a contract address before it is deployed. This is useful + when you need to pass the contract address to another contract + before deploying it. + + Args: + nonce (int | None): Optionally provide a nonce. Defaults + the account's current nonce. + + Returns: + AddressType: The contract address. + """ + # Use the connected network, if available. Else, default to Ethereum. + ecosystem = ( + self.network_manager.active_provider.network.ecosystem + if self.network_manager.active_provider + else self.network_manager.ethereum + ) + nonce = self.nonce if nonce is None else nonce + return ecosystem.get_deployment_address(self.address, nonce) + class AccountContainerAPI(BaseInterfaceModel): """ diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index e90604abd3..9f0cec1895 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -579,6 +579,18 @@ def decode_returndata(self, abi: "MethodABI", raw_data: bytes) -> Any: Any: All of the values returned from the contract function. """ + @raises_not_implemented + def get_deployment_address( # type: ignore[empty-body] + self, + address: AddressType, + nonce: int, + ) -> AddressType: + """ + Calculate the deployment address of a contract before it is deployed. + This is useful if the address is an argument to another contract's deployment + and you have not yet deployed the first contract yet. + """ + def get_network(self, network_name: str) -> "NetworkAPI": """ Get the network for the given name. diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index ac29bfaf4c..677942f019 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -4,6 +4,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast +import rlp # type: ignore from cchecksum import to_checksum_address from eth_abi import decode, encode from eth_abi.exceptions import InsufficientDataBytes, NonEmptyPaddingBytes @@ -1476,6 +1477,17 @@ def decode_custom_error( # error never found. return None + def get_deployment_address(self, address: AddressType, nonce: int) -> AddressType: + """ + Calculate the deployment address of a contract before it is deployed. + This is useful if the address is an argument to another contract's deployment + and you have not yet deployed the first contract yet. + """ + sender_bytes = to_bytes(hexstr=address) + encoded = rlp.encode([sender_bytes, nonce]) + address_bytes = keccak(encoded)[12:] + return self.decode_address(address_bytes) + def parse_type(type_: dict[str, Any]) -> Union[str, tuple, list]: if "tuple" not in type_["type"]: diff --git a/tests/functional/test_accounts.py b/tests/functional/test_accounts.py index e7f1422428..7fed67ca04 100644 --- a/tests/functional/test_accounts.py +++ b/tests/functional/test_accounts.py @@ -921,3 +921,12 @@ def test_import_account_from_private_key_insecure_passphrase(delete_account_afte def test_load(account_manager, keyfile_account): account = account_manager.load(keyfile_account.alias) assert account == keyfile_account + + +def test_get_deployment_address(owner, vyper_contract_container): + deployment_address_1 = owner.get_deployment_address() + deployment_address_2 = owner.get_deployment_address(nonce=owner.nonce + 1) + instance_1 = owner.deploy(vyper_contract_container, 490) + assert instance_1.address == deployment_address_1 + instance_2 = owner.deploy(vyper_contract_container, 490) + assert instance_2.address == deployment_address_2 diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index 00c31d59c9..19466242a3 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -1203,3 +1203,9 @@ def get_calltree(self) -> CallTreeNode: } ] assert events == expected + + +def test_get_deployment_address(ethereum, owner, vyper_contract_container): + actual = ethereum.get_deployment_address(owner.address, owner.nonce) + expected = owner.deploy(vyper_contract_container, 490) + assert actual == expected