Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: namedtuple decoding for vvmcontract structs #356

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
45bb007
set base_path in loads_partial_vvm
charles-cooper Dec 26, 2024
c6389f2
add namedtuple parsing for structs
charles-cooper Dec 26, 2024
84fbf5a
small refactor for get_logs()
charles-cooper Dec 26, 2024
fe7cbdb
add decode_log for abi contracts
charles-cooper Dec 26, 2024
9da2a89
fix error handling in vvm deployer
charles-cooper Dec 26, 2024
655850d
fix lint
charles-cooper Dec 27, 2024
49dfd2c
add a note
charles-cooper Dec 27, 2024
3f241fe
rename a variable
charles-cooper Dec 27, 2024
6d53c63
fix: marshal output for tuple return
charles-cooper Dec 31, 2024
bf6c380
thread name= to VVMDeployer
charles-cooper Jan 2, 2025
d1d10b7
update _loads_partial_vvm to not trample VVMDeployer.name
charles-cooper Jan 2, 2025
c2ed9fa
handle tuples inside lists
charles-cooper Jan 3, 2025
594ed67
lint
charles-cooper Jan 3, 2025
970894f
handle namedtuples with `from` field
charles-cooper Jan 4, 2025
a67d641
fail more gracefully in decode_log when event abi not found
charles-cooper Jan 4, 2025
926c6c3
use namedtuple(rename=True)
charles-cooper Jan 4, 2025
2a3e564
add strict=True param to get_logs
charles-cooper Jan 6, 2025
f91d853
test VVMDeployer does not stomp cache
charles-cooper Jan 9, 2025
60dc4e4
add backwards compatibility
charles-cooper Jan 9, 2025
025c94f
add test for logs
charles-cooper Jan 9, 2025
56842d1
add address field to the log
charles-cooper Jan 9, 2025
a30932c
add tests for namedtuple structs
charles-cooper Jan 9, 2025
04c6113
add out-of-order indexed field
charles-cooper Jan 9, 2025
5273787
add tests for log address in subcall
charles-cooper Jan 9, 2025
417d0f0
fix lint
charles-cooper Jan 9, 2025
ef44c4d
add test for proper BoaError in VVMDeployer.deploy
charles-cooper Jan 9, 2025
06074bf
add tests, forward contract_name properly
charles-cooper Jan 9, 2025
16a8083
add a note
charles-cooper Jan 9, 2025
bf6a75f
fix lint
charles-cooper Jan 9, 2025
fc31b21
update existing tests
charles-cooper Jan 9, 2025
83459ba
update another test
charles-cooper Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 116 additions & 12 deletions boa/contracts/abi/abi_contract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections import defaultdict
from collections import defaultdict, namedtuple
from functools import cached_property
from typing import Any, Optional, Union
from warnings import warn

from eth.abc import ComputationAPI
from vyper.semantics.analysis.base import FunctionVisibility, StateMutability
from vyper.utils import method_id
from vyper.utils import keccak256, method_id

from boa.contracts.base_evm_contract import (
BoaError,
Expand Down Expand Up @@ -47,7 +47,7 @@ def argument_count(self) -> int:

@property
def signature(self) -> str:
return f"({_format_abi_type(self.argument_types)})"
return _format_abi_type(self.argument_types)

@cached_property
def return_type(self) -> list:
Expand Down Expand Up @@ -138,9 +138,9 @@ def __call__(self, *args, value=0, gas=None, sender=None, **kwargs):
case ():
return None
case (single,):
return single
return _parse_complex(self._abi["outputs"][0], single, name=self.name)
case multiple:
return tuple(multiple)
return _parse_complex(self._abi["outputs"], multiple, name=self.name)


class ABIOverload:
Expand Down Expand Up @@ -234,13 +234,15 @@ def __init__(
name: str,
abi: list[dict],
functions: list[ABIFunction],
events: list[dict],
address: Address,
filename: Optional[str] = None,
env=None,
):
super().__init__(name, env, filename=filename, address=address)
self._abi = abi
self._functions = functions
self._events = events

self._bytecode = self.env.get_code(address)
if not self._bytecode:
Expand Down Expand Up @@ -276,6 +278,76 @@ def method_id_map(self):
if not function.is_constructor
}

@cached_property
def event_for(self):
# [{"name": "Bar", "inputs":
# [{"name": "x", "type": "uint256", "indexed": false},
# {"name": "y", "type": "tuple", "components": [{"name": "x", "type": "uint256"}], "indexed": false}],
# "anonymous": false, "type": "event"},
# }]
ret = {}
for event_abi in self._events:
event_signature = ",".join(
_abi_from_json(item) for item in event_abi["inputs"]
)
event_name = event_abi["name"]
event_signature = f"{event_name}({event_signature})"
event_id = int(keccak256(event_signature.encode()).hex(), 16)
ret[event_id] = event_abi
return ret

def decode_log(self, log_entry):
# low level log id
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
_log_id, address, topics, data = log_entry
assert self._address.canonical_address == address
event_hash = topics[0]
event_abi = self.event_for[event_hash]

topic_abis = []
arg_abis = []
tuple_names = []
for item_abi in event_abi["inputs"]:
is_topic = item_abi["indexed"]
assert isinstance(is_topic, bool)
if not is_topic:
arg_abis.append(item_abi)
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
else:
topic_abis.append(item_abi)

tuple_names.append(item_abi["name"])

tuple_typ = namedtuple(event_abi["name"], tuple_names)

decoded_topics = []
for topic_abi, t in zip(topic_abis, topics[1:]):
# convert to bytes for abi decoder
encoded_topic = t.to_bytes(32, "big")
decoded_topics.append(abi_decode(_abi_from_json(topic_abi), encoded_topic))

args_selector = _format_abi_type(
[_abi_from_json(arg_abi) for arg_abi in arg_abis]
)

decoded_args = abi_decode(args_selector, data)

t_i = 0
a_i = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

topic_index and arg_index would be much easier to read

xs = []
# re-align the evm topic + args lists with the way they appear in the abi
# ex. Transfer(indexed address, address, indexed address)
for item_abi in event_abi["inputs"]:
is_topic = item_abi["indexed"]
if is_topic:
# topic abi is currently never complex, but use _parse_complex as
# future-proofing mechanism
xs.append(_parse_complex(topic_abis[t_i], decoded_topics[t_i]))
t_i += 1
else:
xs.append(_parse_complex(arg_abis[a_i], decoded_args[a_i]))
a_i += 1

return tuple_typ(*xs)

def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...]:
"""
Convert the output of a contract call to a Python object.
Expand All @@ -286,7 +358,7 @@ def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...]
if computation.is_error:
return self.handle_error(computation)

schema = f"({_format_abi_type(abi_type)})"
schema = _format_abi_type(abi_type)
try:
return abi_decode(schema, computation.output)
except ABIError as e:
Expand Down Expand Up @@ -360,6 +432,10 @@ def functions(self):
if item.get("type") == "function"
]

@property
def events(self):
return [item for item in self.abi if item.get("type") == "event"]

@classmethod
def from_abi_dict(cls, abi, name="<anonymous contract>", filename=None):
return cls(name, abi, filename)
Expand All @@ -370,7 +446,7 @@ def at(self, address: Address | str) -> ABIContract:
"""
address = Address(address)
contract = ABIContract(
self._name, self._abi, self.functions, address, self.filename
self._name, self._abi, self.functions, self.events, address, self.filename
)

contract.env.register_contract(address, contract)
Expand All @@ -390,15 +466,15 @@ def __repr__(self):

@cached_property
def args_abi_type(self):
return f"({_format_abi_type(self.function.argument_types)})"
return _format_abi_type(self.function.argument_types)

@cached_property
def _argument_names(self) -> list[str]:
return [arg["name"] for arg in self.function._abi["inputs"]]

@cached_property
def return_abi_type(self):
return f"({_format_abi_type(self.function.return_type)})"
return _format_abi_type(self.function.return_type)


def _abi_from_json(abi: dict) -> str:
Expand All @@ -407,6 +483,12 @@ def _abi_from_json(abi: dict) -> str:
:param abi: The ABI type to parse.
:return: The schema string for the given abi type.
"""
# {"stateMutability": "view", "type": "function", "name": "foo",
# "inputs": [],
# "outputs": [{"name": "", "type": "tuple",
# "components": [{"name": "x", "type": "uint256"}]}]
# }

if "components" in abi:
components = ",".join([_abi_from_json(item) for item in abi["components"]])
if abi["type"].startswith("tuple"):
Expand All @@ -416,11 +498,33 @@ def _abi_from_json(abi: dict) -> str:
return abi["type"]


def _parse_complex(abi: dict, value: Any, name=None) -> str:
"""
Parses an ABI type into its schema string.
:param abi: The ABI type to parse.
:return: The schema string for the given abi type.
"""
# simple case
if "components" not in abi:
return value

# complex case
# construct a namedtuple type on the fly
components = abi["components"]
typname = name or abi["name"] or "user_struct"
component_names = [item["name"] for item in components]
typ = namedtuple(typname, component_names)
components_parsed = [
_parse_complex(item_abi, item) for (item_abi, item) in zip(components, value)
]
return typ(*components_parsed)


def _format_abi_type(types: list) -> str:
"""
Converts a list of ABI types into a comma-separated string.
"""
return ",".join(
item if isinstance(item, str) else f"({_format_abi_type(item)})"
for item in types
ret = ",".join(
item if isinstance(item, str) else _format_abi_type(item) for item in types
)
return f"({ret})"
38 changes: 37 additions & 1 deletion boa/contracts/base_evm_contract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from eth.abc import ComputationAPI

Expand All @@ -13,6 +13,11 @@
from boa.vm.py_evm import titanoboa_computation


@dataclass
class RawEvent:
event_data: Any


class _BaseEVMContract:
"""
Base class for EVM (Ethereum Virtual Machine) contract:
Expand Down Expand Up @@ -57,6 +62,37 @@ def address(self) -> Address:
raise RuntimeError("Contract address is not set")
return self._address

# ## handling events
def _get_logs(self, computation, include_child_logs):
if computation is None:
return []

if include_child_logs:
return list(computation.get_raw_log_entries())

return computation._log_entries

def get_logs(self, computation=None, include_child_logs=True):
if computation is None:
computation = self._computation

entries = self._get_logs(computation, include_child_logs)

# py-evm log format is (log_id, topics, data)
# sort on log_id
entries = sorted(entries)

ret = []
for e in entries:
logger_address = e[1]
c = self.env.lookup_contract(logger_address)
if c is not None:
ret.append(c.decode_log(e))
else:
ret.append(RawEvent(e))

return ret


class StackTrace(list): # list[str|ErrorDetail]
def __str__(self):
Expand Down
18 changes: 15 additions & 3 deletions boa/contracts/vvm/vvm_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,19 @@ def deploy(self, *args, contract_name=None, env=None, **kwargs):
if env is None:
env = Env.get_singleton()

address, _ = env.deploy_code(bytecode=self.bytecode + encoded_args, **kwargs)
address, computation = env.deploy(
bytecode=self.bytecode + encoded_args, **kwargs
)

# TODO: pass thru contract_name
return self.at(address)
# NOTE: if computation.is_error, `self.at()` will raise a warning
# in the future we should refactor so that the warning is silenced
ret = self.at(address)

if computation.is_error:
ret.handle_error(computation)

return ret

@cached_property
def _blueprint_deployer(self):
Expand All @@ -75,10 +84,13 @@ def deploy_as_blueprint(self, env=None, blueprint_preamble=None, **kwargs):
blueprint_bytecode = generate_blueprint_bytecode(
self.bytecode, blueprint_preamble
)
address, _ = env.deploy_code(bytecode=blueprint_bytecode, **kwargs)
address, computation = env.deploy(bytecode=blueprint_bytecode, **kwargs)

ret = self._blueprint_deployer.at(address)

if computation.is_error:
ret.handle_error(computation)

env.register_blueprint(self.bytecode, ret)
return ret

Expand Down
5 changes: 0 additions & 5 deletions boa/contracts/vyper/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,3 @@ def __repr__(self):

args = ", ".join(f"{k}={v}" for k, v in b)
return f"{self.event_type.name}({args})"


@dataclass
class RawEvent:
event_data: Any
33 changes: 1 addition & 32 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
ByteAddressableStorage,
decode_vyper_object,
)
from boa.contracts.vyper.event import Event, RawEvent
from boa.contracts.vyper.event import Event
from boa.contracts.vyper.ir_executor import executor_from_ir
from boa.environment import Env
from boa.profiling import cache_gas_used_for_computation
Expand Down Expand Up @@ -711,37 +711,6 @@ def trace_source(self, computation) -> Optional["VyperTraceSource"]:
return None
return VyperTraceSource(self, node, method_id=computation.msg.data[:4])

# ## handling events
def _get_logs(self, computation, include_child_logs):
if computation is None:
return []

if include_child_logs:
return list(computation.get_raw_log_entries())

return computation._log_entries

def get_logs(self, computation=None, include_child_logs=True):
if computation is None:
computation = self._computation

entries = self._get_logs(computation, include_child_logs)

# py-evm log format is (log_id, topics, data)
# sort on log_id
entries = sorted(entries)

ret = []
for e in entries:
logger_address = e[1]
c = self.env.lookup_contract(logger_address)
if c is not None:
ret.append(c.decode_log(e))
else:
ret.append(RawEvent(e))

return ret

@cached_property
def event_for(self):
module_t = self.compiler_data.global_ctx
Expand Down
11 changes: 9 additions & 2 deletions boa/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,21 @@ def load_partial(filename: str, compiler_args=None):
)


def _loads_partial_vvm(source_code: str, version: Version, filename: str):
def _loads_partial_vvm(
source_code: str, version: Version, filename: str, base_path=None
):
global _disk_cache

if base_path is None:
base_path = Path(".")

# install the requested version if not already installed
vvm.install_vyper(version=version)

def _compile():
compiled_src = vvm.compile_source(source_code, vyper_version=version)
compiled_src = vvm.compile_source(
source_code, vyper_version=version, base_path=base_path
)
compiler_output = compiled_src["<stdin>"]
return VVMDeployer.from_compiler_output(compiler_output, filename=filename)

Expand Down
Loading