Skip to content

Commit

Permalink
Self-review
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielSchiavini committed Jan 8, 2024
1 parent b99e305 commit 2ca7e30
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 52 deletions.
42 changes: 40 additions & 2 deletions boa/contracts/abi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,66 @@
from typing import Any, Union
from collections import defaultdict
from typing import Any, Callable, Iterable, TypeVar, Union

from boa.environment import Address

T = TypeVar("T")
K = TypeVar("K")


def _encode_addresses(values: list) -> list:
"""
Converts any object with an 'address' field into the address itself.
This is to allow `Address` objects to be used.
:param values: A list of values
:return: The same list of values, with addresses converted.
"""
return [getattr(arg, "address", arg) for arg in values]


def _decode_addresses(abi_type: Union[list, str], decoded: Any) -> Any:
"""
Converts addresses received from the EVM into `Address` objects, recursively.
:param abi_type: ABI type name. This should be a list if `decoded` is also a list.
:param decoded: The decoded value(s) from the EVM.
:return: The same value(s), with addresses converted.
"""
if abi_type == "address":
return Address(decoded)
if isinstance(abi_type, str) and abi_type.startswith("address["):
return [Address(i) for i in decoded]
return decoded


def _parse_abi_type(abi: dict) -> list:
def _parse_abi_type(abi: dict) -> Union[list, str]:
"""
Parses an ABI type into a list of types.
:param abi: The ABI type to parse.
:return: A list of types or a single type.
"""
if "components" in abi:
assert abi["type"] == "tuple" # sanity check
return [_parse_abi_type(item) for item in abi["components"]]
return abi["type"]


def _format_abi_type(types: list) -> str:
"""
Converts a list of ABI types into a comma-separated string.
"""
return ",".join(
item if isinstance(item, str) else f"({_format_abi_type(item)})"
for item in types
)


def group_by(sequence: Iterable[T], key: Callable[[T], K]) -> dict[K, list[T]]:
"""
Groups a sequence of items by a key function.
:param sequence: The sequence to group.
:param key: The key function.
:return: A dictionary mapping keys to a list of items with that key.
"""
result = defaultdict(list)
for item in sequence:
result[key(item)].append(item)
return result
40 changes: 19 additions & 21 deletions boa/contracts/abi/contract.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import cached_property
from itertools import groupby
from operator import attrgetter
from os.path import basename
from typing import Any, Optional
from typing import Any, Optional, Union
from warnings import warn

from _operator import attrgetter
from eth.abc import ComputationAPI

from boa.contracts.abi import _decode_addresses, _format_abi_type
from boa.contracts.abi import _decode_addresses, _format_abi_type, group_by
from boa.contracts.abi.function import ABIFunction, ABIOverload
from boa.contracts.evm_contract import BaseEVMContract
from boa.environment import Address
Expand All @@ -28,9 +28,15 @@ def __init__(
super().__init__(env, filename=filename, address=address)
self._name = name
self._functions = functions
self._bytecode = self.env.vm.state.get_code(address.canonical_address)
if not self._bytecode:
warn(
f"Requested {self} but there is no bytecode at that address!",
stacklevel=2,
)

for name, group in groupby(self._functions, key=attrgetter("name")):
setattr(self, name, ABIOverload.create(list(group), self))
for name, group in group_by(self._functions, attrgetter("name")).items():
setattr(self, name, ABIOverload.create(group, self))

self._address = Address(address)

Expand All @@ -48,7 +54,7 @@ def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...]
:param computation: the computation object returned by `execute_code`
:param abi_type: the ABI type of the return value.
"""
if computation.is_error:
if computation.is_error or (abi_type and not computation.output):
return self.handle_error(computation)

schema = f"({_format_abi_type(abi_type)})"
Expand Down Expand Up @@ -79,7 +85,8 @@ def deployer(self) -> "ABIContractFactory":

def __repr__(self):
file_str = f" (file {self.filename})" if self.filename else ""
return f"<{self._name} interface at {self.address}>{file_str}"
warn_str = "" if self._bytecode else " (WARNING: no bytecode at this address!)"
return f"<{self._name} interface at {self.address}{warn_str}>{file_str}"


class ABIContractFactory:
Expand All @@ -103,20 +110,11 @@ def from_abi_dict(cls, abi, name="<anonymous contract>"):
]
return cls(basename(name), functions, filename=name)

def at(self, address) -> ABIContract:
def at(self, address: Union[Address, str]) -> ABIContract:
"""
Create an ABI contract object for a deployed contract at `address`.
"""
address = Address(address)

ret = ABIContract(self._name, self._functions, address, self._filename)

bytecode = ret.env.vm.state.get_code(address.canonical_address)
if not bytecode:
raise ValueError(
f"Requested {ret} but there is no bytecode at that address!"
)

ret.env.register_contract(address, ret)

return ret
contract = ABIContract(self._name, self._functions, address, self._filename)
contract.env.register_contract(address, contract)
return contract
11 changes: 2 additions & 9 deletions boa/contracts/abi/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vyper.utils import method_id

from boa.contracts.abi import _encode_addresses, _format_abi_type, _parse_abi_type
from boa.util.abi import abi_decode, abi_encode, is_abi_encodable
from boa.util.abi import abi_encode, is_abi_encodable

if TYPE_CHECKING:
from boa.contracts.abi.contract import ABIContract
Expand Down Expand Up @@ -73,13 +73,6 @@ def is_encodable(self, *args, **kwargs) -> bool:
for abi_type, arg in zip(self.argument_types, parsed_args)
)

def matches(self, *args, **kwargs) -> bool:
"""Check whether this function matches the given arguments exactly."""
parsed_args = self._merge_kwargs(*args, **kwargs)
encoded_args = abi_encode(self.signature, args)
decoded_args = abi_decode(self.signature, encoded_args)
return map(type, parsed_args) == map(type, decoded_args)

def _merge_kwargs(self, *args, **kwargs) -> list:
"""Merge positional and keyword arguments into a single list."""
if len(kwargs) + len(args) != self.argument_count:
Expand Down Expand Up @@ -154,7 +147,7 @@ def name(self):
def __call__(self, *args, **kwargs):
"""
Call the function that matches the given arguments.
:raises Exception: if not a single function is found
:raises Exception: if a single function is not found
"""
match [f for f in self.functions if f.is_encodable(*args, **kwargs)]:
case [function]:
Expand Down
6 changes: 3 additions & 3 deletions boa/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def by_line(self):
return ret


# line profile. mergeable across contract
# line profile. mergeable across contracts
class LineProfile:
def __init__(self):
self.profile = {}
Expand Down Expand Up @@ -313,7 +313,7 @@ def get_call_profile_table(env: Env) -> Table:
(cache[profile].net_gas_stats.avg_gas, profile.address)
)

# arrange from most to least expensive contract:
# arrange from most to least expensive contracts:
sort_gas = sorted(contract_vs_mean_gas, key=lambda x: x[0], reverse=True)
sorted_addresses = list(set([x[1] for x in sort_gas]))

Expand All @@ -324,7 +324,7 @@ def get_call_profile_table(env: Env) -> Table:
for profile in cached_contracts[address]:
fn_vs_mean_gas.append((cache[profile].net_gas_stats.avg_gas, profile))

# arrange from most to least expensive contract:
# arrange from most to least expensive contracts:
fn_vs_mean_gas = sorted(fn_vs_mean_gas, key=lambda x: x[0], reverse=True)

for c, (_, profile) in enumerate(fn_vs_mean_gas):
Expand Down
2 changes: 1 addition & 1 deletion boa/test/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def pytest_addoption(parser):
parser.addoption(
"--gas-profile",
action="store_true",
help="Profile gas used by contract called in tests",
help="Profile gas used by contracts called in tests",
)


Expand Down
13 changes: 8 additions & 5 deletions boa/util/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# take an exception instance, and strip frames in the target module
# from the traceback
def strip_internal_frames(exc, module_name=None):
ei = sys.exc_info()
frame = ei[2].tb_frame
error_type, error, traceback = sys.exc_info()
frame = traceback.tb_frame

if module_name is None:
module_name = frame.f_globals["__name__"]
Expand All @@ -18,7 +18,7 @@ def strip_internal_frames(exc, module_name=None):
# kwargs incompatible with pypy here
# tb_next=None, tb_frame=frame, tb_lasti=frame.f_lasti, tb_lineno=frame.f_lineno
tb = types.TracebackType(None, frame, frame.f_lasti, frame.f_lineno)
return ei[1].with_traceback(tb)
return error.with_traceback(tb)


class StackTrace(list):
Expand Down Expand Up @@ -67,6 +67,9 @@ class BoaError(Exception):
# stack trace but does not require the actual stack trace itself.
def __str__(self):
frame = self.stack_trace.last_frame
err = frame.vm_error
err.args = (frame.pretty_vm_reason, *err.args[1:])
if hasattr(frame, "vm_error"):
err = frame.vm_error
err.args = (frame.pretty_vm_reason, *err.args[1:])
else:
err = frame
return f"{err}\n\n{self.stack_trace}"
21 changes: 13 additions & 8 deletions tests/integration/fork/test_abi_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from hypothesis import given

import boa
from boa.util.exceptions import BoaError

ZERO_ADDRESS = "0x0000000000000000000000000000000000000000"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -45,6 +48,14 @@ def test_tricrypto(tricrypto):
# TODO: test the overloaded functions


def test_no_bytecode(get_filepath):
abi_path = get_filepath("crvusd_abi.json")
crvusd = boa.load_abi(abi_path).at(ZERO_ADDRESS)
with pytest.raises(BoaError) as exc_info:
crvusd.decimals()
assert "no bytecode at this address" in str(exc_info.value)


def test_invariants(crvusd):
assert crvusd.decimals() == 18
assert crvusd.version() == "v1.0.0"
Expand All @@ -57,7 +68,7 @@ def test_metaregistry_overloading(metaregistry):
pool = metaregistry.pool_list(0)
coin1, coin2 = metaregistry.get_coins(pool)[:2]
pools_found = metaregistry.find_pools_for_coins(coin1, coin2)
first_pools = [pool for pool in pools_found if not pool.startswith("0x0000")][:10]
first_pools = [pool for pool in pools_found if not pool.startswith("0x0000")][:2]
assert first_pools[0] == metaregistry.find_pool_for_coins(coin1, coin2)
assert first_pools == [
metaregistry.find_pool_for_coins(coin1, coin2, i)
Expand All @@ -80,13 +91,7 @@ def test_stableswap_factory_ng(stableswap_factory_ng):
3,
[0, 0, 0],
)
assert stableswap_factory_ng.base_pool_data(pool) == (
"0x0000000000000000000000000000000000000000",
[],
0,
0,
[],
)
assert stableswap_factory_ng.base_pool_data(pool) == (ZERO_ADDRESS, [], 0, 0, [])


# randomly grabbed from:
Expand Down
4 changes: 1 addition & 3 deletions tests/unitary/test_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,8 @@ def test(a: uint128 = 0, b: uint128 = 0) -> uint128:


def test_bad_address():
with pytest.raises(ValueError) as exc_info:
with pytest.warns(UserWarning, match=r"there is no bytecode at that address!$"):
ABIContractFactory.from_abi_dict([]).at(boa.env.eoa)
(error,) = exc_info.value.args
assert "there is no bytecode at that address!" in error


def test_abi_reverts(load_via_abi):
Expand Down

0 comments on commit 2ca7e30

Please sign in to comment.