Skip to content

Commit

Permalink
fix[codegen]: fix some hardcoded references to STORAGE location (vy…
Browse files Browse the repository at this point in the history
…perlang#4015)

fix some hardcoded references to `STORAGE` when `TRANSIENT` is also
possible. these are not breaking changes semantically, but should enable
the correct optimization codepath for transient storage. this commit
also results in better codepaths taken for the IMMUTABLES location
(`iload`/`istore`), resulting in smaller bytecode sometimes for
constructors.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
cyberthirst and charles-cooper authored May 21, 2024
1 parent 0ba1c62 commit a0d9b1f
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 30 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ jobs:
- evm-version: paris
- evm-version: shanghai

# test pre-cancun with opt-codesize and opt-none
- evm-version: shanghai
opt-mode: none
- evm-version: shanghai
opt-mode: codesize

# test py-evm
- evm-backend: py-evm
evm-version: shanghai
Expand Down
1 change: 1 addition & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class VariableDecl(VyperNode):
is_constant: bool = ...
is_public: bool = ...
is_immutable: bool = ...
is_transient: bool = ...
_expanded_getter: FunctionDef = ...

class AugAssign(VyperNode):
Expand Down
33 changes: 25 additions & 8 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
STORAGE,
TRANSIENT,
AddrSpace,
legal_in_staticcall,
)
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch
Expand Down Expand Up @@ -136,6 +137,14 @@ def address_space_to_data_location(s: AddrSpace) -> DataLocation:
raise CompilerPanic("unreachable!") # pragma: nocover


def writeable(context, ir_node):
assert ir_node.is_pointer # sanity check

if context.is_constant() and not legal_in_staticcall(ir_node.location):
return False
return ir_node.mutable


# Copy byte array word-for-word (including layout)
# TODO make this a private function
def make_byte_array_copier(dst, src):
Expand All @@ -150,12 +159,9 @@ def make_byte_array_copier(dst, src):
return STORE(dst, 0)

with src.cache_when_complex("src") as (b1, src):
has_storage = STORAGE in (src.location, dst.location)
is_memory_copy = dst.location == src.location == MEMORY
batch_uses_identity = is_memory_copy and not version_check(begin="cancun")
if src.typ.maxlen <= 32 and (has_storage or batch_uses_identity):
if src.typ.maxlen <= 32 and not copy_opcode_available(dst, src):
# if there is no batch copy opcode available,
# it's cheaper to run two load/stores instead of copy_bytes

ret = ["seq"]
# store length word
len_ = get_bytearray_length(src)
Expand Down Expand Up @@ -914,6 +920,15 @@ def make_setter(left, right):
return _complex_make_setter(left, right)


# locations with no dedicated copy opcode
# (i.e. storage and transient storage)
def copy_opcode_available(left, right):
if left.location == MEMORY and right.location == MEMORY:
return version_check(begin="cancun")

return left.location == MEMORY and right.location.has_copy_opcode


def _complex_make_setter(left, right):
if right.value == "~empty" and left.location == MEMORY:
# optimized memzero
Expand All @@ -935,8 +950,10 @@ def _complex_make_setter(left, right):
assert left.encoding == Encoding.VYPER
len_ = left.typ.memory_bytes_required

has_storage = STORAGE in (left.location, right.location)
if has_storage:
# special logic for identity precompile (pre-cancun) in the else branch
mem2mem = left.location == right.location == MEMORY

if not copy_opcode_available(left, right) and not mem2mem:
if _opt_codesize():
# assuming PUSH2, a single sstore(dst (sload src)) is 8 bytes,
# sstore(add (dst ofst), (sload (add (src ofst)))) is 16 bytes,
Expand Down Expand Up @@ -983,7 +1000,7 @@ def _complex_make_setter(left, right):
base_unroll_cost + (nth_word_cost * (n_words - 1)) >= identity_base_cost
)

# calldata to memory, code to memory, cancun, or codesize -
# calldata to memory, code to memory, cancun, or opt-codesize -
# batch copy is always better.
else:
should_batch_copy = True
Expand Down
13 changes: 7 additions & 6 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
get_type_for_exact_size,
make_setter,
wrap_value_for_external_return,
writeable,
)
from vyper.codegen.expr import Expr
from vyper.codegen.return_ import make_return_stmt
from vyper.evm.address_space import MEMORY, STORAGE
from vyper.evm.address_space import MEMORY
from vyper.exceptions import CodegenPanic, StructureException, TypeCheckFailure, tag_exceptions
from vyper.semantics.types import DArrayT
from vyper.semantics.types.shortcuts import UINT256_T
Expand Down Expand Up @@ -312,18 +313,18 @@ def parse_Return(self):
def _get_target(self, target):
_dbg_expr = target

if isinstance(target, vy_ast.Name) and target.id in self.context.forvars:
if isinstance(target, vy_ast.Name) and target.id in self.context.forvars: # pragma: nocover
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")

if isinstance(target, vy_ast.Tuple):
target = Expr(target, self.context).ir_node
for node in target.args:
if (node.location == STORAGE and self.context.is_constant()) or not node.mutable:
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
items = target.args
if any(not writeable(self.context, item) for item in items): # pragma: nocover
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

target = Expr.parse_pointer_expr(target, self.context)
if (target.location == STORAGE and self.context.is_constant()) or not target.mutable:
if not writeable(self.context, target): # pragma: nocover
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

Expand Down
17 changes: 14 additions & 3 deletions vyper/evm/address_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,25 @@ class AddrSpace:
load_op: the opcode for loading a word from this address space
store_op: the opcode for storing a word to this address space
(an address space is read-only if store_op is None)
copy_op: the opcode for batch-copying from this address space
to memory
"""

name: str
word_scale: int
load_op: str
# TODO maybe make positional instead of defaulting to None
store_op: Optional[str] = None
copy_op: Optional[str] = None

@property
def word_addressable(self) -> bool:
return self.word_scale == 1

@property
def has_copy_opcode(self):
return self.copy_op is not None


# alternative:
# class Memory(AddrSpace):
Expand All @@ -42,13 +49,17 @@ def word_addressable(self) -> bool:
#
# MEMORY = Memory()

MEMORY = AddrSpace("memory", 32, "mload", "mstore")
MEMORY = AddrSpace("memory", 32, "mload", "mstore", "mcopy")
STORAGE = AddrSpace("storage", 1, "sload", "sstore")
TRANSIENT = AddrSpace("transient", 1, "tload", "tstore")
CALLDATA = AddrSpace("calldata", 32, "calldataload")
CALLDATA = AddrSpace("calldata", 32, "calldataload", None, "calldatacopy")
# immutables address space: "immutables" section of memory
# which is read-write in deploy code but then gets turned into
# the "data" section of the runtime code
IMMUTABLES = AddrSpace("immutables", 32, "iload", "istore")
# data addrspace: "data" section of runtime code, read-only.
DATA = AddrSpace("data", 32, "dload")
DATA = AddrSpace("data", 32, "dload", None, "dloadbytes")


def legal_in_staticcall(location: AddrSpace):
return location not in (STORAGE, TRANSIENT)
20 changes: 10 additions & 10 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,13 +621,6 @@ def visit_VariableDecl(self, node):
assert isinstance(node.target, vy_ast.Name)
name = node.target.id

if node.is_public:
# generate function type and add to metadata
# we need this when building the public getter
func_t = ContractFunctionT.getter_from_VariableDecl(node)
node._metadata["getter_type"] = func_t
self._add_exposed_function(func_t, node)

# TODO: move this check to local analysis
if node.is_immutable:
# mutability is checked automatically preventing assignment
Expand All @@ -648,7 +641,7 @@ def visit_VariableDecl(self, node):
)
raise ImmutableViolation(message, node)

data_loc = (
location = (
DataLocation.CODE
if node.is_immutable
else DataLocation.UNSET
Expand All @@ -666,21 +659,28 @@ def visit_VariableDecl(self, node):
else Modifiability.MODIFIABLE
)

type_ = type_from_annotation(node.annotation, data_loc)
type_ = type_from_annotation(node.annotation, location)

if node.is_transient and not version_check(begin="cancun"):
raise EvmVersionException("`transient` is not available pre-cancun", node.annotation)

var_info = VarInfo(
type_,
decl_node=node,
location=data_loc,
location=location,
modifiability=modifiability,
is_public=node.is_public,
)
node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace
node._metadata["type"] = type_

if node.is_public:
# generate function type and add to metadata
# we need this when building the public getter
func_t = ContractFunctionT.getter_from_VariableDecl(node)
node._metadata["getter_type"] = func_t
self._add_exposed_function(func_t, node)

def _finalize():
# add the variable name to `self` namespace if the variable is either
# 1. a public constant or immutable; or
Expand Down
5 changes: 4 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,10 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio
"""
if not node.is_public:
raise CompilerPanic("getter generated for non-public function")
type_ = type_from_annotation(node.annotation, DataLocation.STORAGE)

# calculated by caller (ModuleAnalyzer.visit_VariableDecl)
type_ = node.target._metadata["varinfo"].typ

arguments, return_type = type_.getter_signature
args = []
for i, item in enumerate(arguments):
Expand Down
5 changes: 3 additions & 2 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class HashMapT(_SubscriptableT):

_equality_attrs = ("key_type", "value_type")

# disallow everything but storage
# disallow everything but storage or transient
_invalid_locations = (
DataLocation.UNSET,
DataLocation.CALLDATA,
Expand Down Expand Up @@ -84,10 +84,11 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT":
)

k_ast, v_ast = node.slice.elements
key_type = type_from_annotation(k_ast, DataLocation.STORAGE)
key_type = type_from_annotation(k_ast)
if not key_type._as_hashmap_key:
raise InvalidType("can only use primitive types as HashMap key!", k_ast)

# TODO: thread through actual location - might also be TRANSIENT
value_type = type_from_annotation(v_ast, DataLocation.STORAGE)

return cls(key_type, value_type)
Expand Down

0 comments on commit a0d9b1f

Please sign in to comment.