diff --git a/docs/control-structures.rst b/docs/control-structures.rst index a0aa927261..4e18a21bd8 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -100,22 +100,24 @@ Functions marked with ``@pure`` cannot call non-``pure`` functions. Re-entrancy Locks ----------------- -The ``@nonreentrant()`` decorator places a lock on a function, and all functions with the same ```` value. An attempt by an external contract to call back into any of these functions causes the transaction to revert. +The ``@nonreentrant`` decorator places a global nonreentrancy lock on a function. An attempt by an external contract to call back into any other ``@nonreentrant`` function causes the transaction to revert. .. code-block:: vyper @external - @nonreentrant("lock") + @nonreentrant def make_a_call(_addr: address): # this function is protected from re-entrancy ... -You can put the ``@nonreentrant()`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way. +You can put the ``@nonreentrant`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way. Nonreentrancy locks work by setting a specially allocated storage slot to a ```` value on function entrance, and setting it to an ```` value on function exit. On function entrance, if the storage slot is detected to be the ```` value, execution reverts. You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can put it on a ``view`` function, but it only checks that the function is not in a callback (the storage slot is not in the ```` state), as ``view`` functions can only read the state, not change it. +You can view where the nonreentrant key is physically laid out in storage by using ``vyper`` with the ``-f layout`` option (e.g., ``vyper -f layout foo.vy``). Unless it is overriden, the compiler will allocate it at slot ``0``. + .. note:: A mutable function can protect a ``view`` function from being called back into (which is useful for instance, if a ``view`` function would return inconsistent state during a mutable function), but a ``view`` function cannot protect itself from being called back into. Note that mutable functions can never be called from a ``view`` function because all external calls out from a ``view`` function are protected by the use of the ``STATICCALL`` opcode. @@ -123,6 +125,8 @@ You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can p A nonreentrant lock has an ```` value of 3, and a ```` value of 2. Nonzero values are used to take advantage of net gas metering - as of the Berlin hard fork, the net cost for utilizing a nonreentrant lock is 2300 gas. Prior to v0.3.4, the ```` and ```` values were 0 and 1, respectively. +.. note:: + Prior to 0.4.0, nonreentrancy keys took a "key" argument for fine-grained nonreentrancy control. As of 0.4.0, only a global nonreentrancy lock is available. The ``__default__`` Function ---------------------------- @@ -194,7 +198,7 @@ Decorator Description ``@pure`` Function does not read contract state or environment variables ``@view`` Function does not alter contract state ``@payable`` Function is able to receive Ether -``@nonreentrant()`` Function cannot be called back into during an external call +``@nonreentrant`` Function cannot be called back into during an external call =============================== =========================================================== ``if`` statements diff --git a/tests/functional/codegen/features/decorators/test_nonreentrant.py b/tests/functional/codegen/features/decorators/test_nonreentrant.py index 9329605678..92a21cd302 100644 --- a/tests/functional/codegen/features/decorators/test_nonreentrant.py +++ b/tests/functional/codegen/features/decorators/test_nonreentrant.py @@ -2,30 +2,103 @@ from vyper.exceptions import FunctionDeclarationException - # TODO test functions in this module across all evm versions # once we have cancun support. + + def test_nonreentrant_decorator(get_contract, tx_failed): - calling_contract_code = """ -interface SpecialContract: + malicious_code = """ +interface ProtectedContract: + def protected_function(callback_address: address): nonpayable + +@external +def do_callback(): + ProtectedContract(msg.sender).protected_function(self) + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_nonreentrant_view_function(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: + def protected_function(): nonpayable + def protected_view_fn() -> uint256: view + +@external +def do_callback() -> uint256: + return ProtectedContract(msg.sender).protected_view_fn() + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +@external +@nonreentrant +@view +def protected_view_fn() -> uint256: + return 10 + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_multi_function_nonreentrant(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: def unprotected_function(val: String[100], do_callback: bool): nonpayable def protected_function(val: String[100], do_callback: bool): nonpayable def special_value() -> String[100]: nonpayable @external def updated(): - SpecialContract(msg.sender).unprotected_function('surprise!', False) + ProtectedContract(msg.sender).unprotected_function('surprise!', False) @external def updated_protected(): # This should fail. - SpecialContract(msg.sender).protected_function('surprise protected!', False) + ProtectedContract(msg.sender).protected_function('surprise protected!', False) """ - reentrant_code = """ + protected_code = """ interface Callback: def updated(): nonpayable def updated_protected(): nonpayable + interface Self: def protected_function(val: String[100], do_callback: bool) -> uint256: nonpayable def protected_function2(val: String[100], do_callback: bool) -> uint256: nonpayable @@ -39,7 +112,7 @@ def set_callback(c: address): self.callback = Callback(c) @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val @@ -50,7 +123,7 @@ def protected_function(val: String[100], do_callback: bool) -> uint256: return 2 @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function2(val: String[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: @@ -60,7 +133,7 @@ def protected_function2(val: String[100], do_callback: bool) -> uint256: return 2 @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function3(val: String[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: @@ -71,7 +144,8 @@ def protected_function3(val: String[100], do_callback: bool) -> uint256: @external -@nonreentrant('protect_special_value') +@nonreentrant +@view def protected_view_fn() -> String[100]: return self.special_value @@ -81,37 +155,42 @@ def unprotected_function(val: String[100], do_callback: bool): if do_callback: self.callback.updated() - """ - reentrant_contract = get_contract(reentrant_code) - calling_contract = get_contract(calling_contract_code) +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) - reentrant_contract.set_callback(calling_contract.address, transact={}) - assert reentrant_contract.callback() == calling_contract.address + contract.set_callback(malicious.address, transact={}) + assert contract.callback() == malicious.address # Test unprotected function. - reentrant_contract.unprotected_function("some value", True, transact={}) - assert reentrant_contract.special_value() == "surprise!" + contract.unprotected_function("some value", True, transact={}) + assert contract.special_value() == "surprise!" # Test protected function. - reentrant_contract.protected_function("some value", False, transact={}) - assert reentrant_contract.special_value() == "some value" - assert reentrant_contract.protected_view_fn() == "some value" + contract.protected_function("some value", False, transact={}) + assert contract.special_value() == "some value" + assert contract.protected_view_fn() == "some value" with tx_failed(): - reentrant_contract.protected_function("zzz value", True, transact={}) + contract.protected_function("zzz value", True, transact={}) - reentrant_contract.protected_function2("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function2("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function2("zzz value", True, transact={}) + contract.protected_function2("zzz value", True, transact={}) - reentrant_contract.protected_function3("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function3("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function3("zzz value", True, transact={}) + contract.protected_function3("zzz value", True, transact={}) def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed): @@ -145,7 +224,7 @@ def set_callback(c: address): @external @payable -@nonreentrant("lock") +@nonreentrant def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val _amount: uint256 = msg.value @@ -169,7 +248,7 @@ def unprotected_function(val: String[100], do_callback: bool): @external @payable -@nonreentrant("lock") +@nonreentrant def __default__(): pass """ @@ -209,7 +288,7 @@ def test_disallow_on_init_function(get_contract): code = """ @external -@nonreentrant("lock") +@nonreentrant def __init__(): foo: uint256 = 0 """ diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index afc7a35012..e530487fea 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -44,42 +44,11 @@ def foo() -> int128: return x.codesize() """, """ -@external -@nonreentrant("B") -@nonreentrant("C") -def double_nonreentrant(): - pass - """, - """ struct X: int128[5]: int128[7] """, """ @external -@nonreentrant(" ") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("123") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("!123abcd") -def invalid_nonreentrant_key(): - pass - """, - """ -@external def foo(): true: int128 = 3 """, diff --git a/tests/functional/syntax/signatures/test_invalid_function_decorators.py b/tests/functional/syntax/signatures/test_invalid_function_decorators.py index b3d4219a2d..a7a500efc7 100644 --- a/tests/functional/syntax/signatures/test_invalid_function_decorators.py +++ b/tests/functional/syntax/signatures/test_invalid_function_decorators.py @@ -7,10 +7,23 @@ """ @external @pure -@nonreentrant('lock') +@nonreentrant def nonreentrant_foo() -> uint256: return 1 + """, """ +@external +@nonreentrant +@nonreentrant +def nonreentrant_foo() -> uint256: + return 1 + """, + """ +@external +@nonreentrant("foo") +def nonreentrant_foo() -> uint256: + return 1 + """, ] diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index f0ee25f747..9724dd723c 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -6,18 +6,18 @@ def test_storage_layout(): foo: HashMap[address, uint256] @external -@nonreentrant("foo") +@nonreentrant def public_foo1(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo2(): pass @internal -@nonreentrant("bar") +@nonreentrant def _bar(): pass @@ -28,12 +28,12 @@ def _bar(): bar: uint256 @external -@nonreentrant("bar") +@nonreentrant def public_bar(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo3(): pass """ @@ -41,12 +41,11 @@ def public_foo3(): out = compile_code(code, output_formats=["layout"]) assert out["layout"]["storage_layout"] == { - "nonreentrant.foo": {"type": "nonreentrant lock", "slot": 0}, - "nonreentrant.bar": {"type": "nonreentrant lock", "slot": 1}, - "foo": {"type": "HashMap[address, uint256]", "slot": 2}, - "arr": {"type": "DynArray[uint256, 3]", "slot": 3}, - "baz": {"type": "Bytes[65]", "slot": 7}, - "bar": {"type": "uint256", "slot": 11}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "foo": {"slot": 1, "type": "HashMap[address, uint256]"}, + "arr": {"slot": 2, "type": "DynArray[uint256, 3]"}, + "baz": {"slot": 6, "type": "Bytes[65]"}, + "bar": {"slot": 10, "type": "uint256"}, } @@ -64,10 +63,13 @@ def __init__(): expected_layout = { "code_layout": { - "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, "SYMBOL": {"length": 64, "offset": 0, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, + }, + "storage_layout": { + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "name": {"slot": 1, "type": "String[32]"}, }, - "storage_layout": {"name": {"slot": 0, "type": "String[32]"}}, } out = compile_code(code, output_formats=["layout"]) @@ -107,14 +109,15 @@ def __init__(): "code_layout": { "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, "a_library": { - "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "counter2": {"slot": 1, "type": "uint256"}, - "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "counter2": {"slot": 2, "type": "uint256"}, + "a_library": {"supply": {"slot": 3, "type": "uint256"}}, }, } @@ -160,9 +163,10 @@ def __init__(): }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "a_library": {"supply": {"slot": 1, "type": "uint256"}}, - "counter2": {"slot": 2, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + "counter2": {"slot": 3, "type": "uint256"}, }, } @@ -171,7 +175,8 @@ def __init__(): def test_storage_layout_module_uses(make_input_bundle): - # test module storage layout, with initializes/uses + # test module storage layout, with initializes/uses and a nonreentrant + # lock lib1 = """ supply: uint256 SYMBOL: immutable(String[32]) @@ -197,6 +202,11 @@ def __init__(s: uint256): @internal def decimals() -> uint8: return lib1.DECIMALS + +@external +@nonreentrant +def foo(): + pass """ code = """ import lib1 as a_library @@ -218,6 +228,11 @@ def __init__(): some_immutable = [1, 2, 3] lib2.__init__(17) + +@external +@nonreentrant +def bar(): + pass """ input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) @@ -231,10 +246,11 @@ def __init__(): }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "lib2": {"storage_variable": {"slot": 1, "type": "uint256"}}, - "counter2": {"slot": 2, "type": "uint256"}, - "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "lib2": {"storage_variable": {"slot": 2, "type": "uint256"}}, + "counter2": {"slot": 3, "type": "uint256"}, + "a_library": {"supply": {"slot": 4, "type": "uint256"}}, }, } @@ -309,12 +325,13 @@ def foo() -> uint256: }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, "lib2": { - "lib1": {"supply": {"slot": 1, "type": "uint256"}}, - "storage_variable": {"slot": 2, "type": "uint256"}, + "lib1": {"supply": {"slot": 2, "type": "uint256"}}, + "storage_variable": {"slot": 3, "type": "uint256"}, }, - "counter2": {"slot": 3, "type": "uint256"}, + "counter2": {"slot": 4, "type": "uint256"}, }, } diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index f4c11b7ae6..707c94c3fc 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -1,3 +1,5 @@ +import re + import pytest from vyper.compiler import compile_code @@ -28,18 +30,18 @@ def test_storage_layout_for_more_complex(): foo: HashMap[address, uint256] @external -@nonreentrant("foo") +@nonreentrant def public_foo1(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo2(): pass @internal -@nonreentrant("bar") +@nonreentrant def _bar(): pass @@ -48,19 +50,18 @@ def _bar(): bar: uint256 @external -@nonreentrant("bar") +@nonreentrant def public_bar(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo3(): pass """ storage_layout_override = { - "nonreentrant.foo": {"type": "nonreentrant lock", "slot": 8}, - "nonreentrant.bar": {"type": "nonreentrant lock", "slot": 7}, + "$.nonreentrant_key": {"type": "nonreentrant lock", "slot": 8}, "foo": {"type": "HashMap[address, uint256]", "slot": 1}, "baz": {"type": "Bytes[65]", "slot": 2}, "bar": {"type": "uint256", "slot": 6}, @@ -110,6 +111,25 @@ def test_overflow(): ) +def test_override_nonreentrant_slot(): + code = """ +@nonreentrant +@external +def foo(): + pass + """ + + storage_layout_override = {"$.nonreentrant_key": {"slot": 2**256, "type": "nonreentrant key"}} + + exception_regex = re.escape( + f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}" + ) + with pytest.raises(StorageLayoutException, match=exception_regex): + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + + def test_incomplete_overrides(): code = """ name: public(String[64]) diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index 3620ef64b9..1dc70fd1ba 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -47,15 +47,9 @@ def __init__(): self.foo[1] = [123, 456, 789] @external -@nonreentrant('lock') +@nonreentrant def with_lock(): pass - - -@external -@nonreentrant('otherlock') -def with_other_lock(): - pass """ @@ -84,7 +78,6 @@ def test_reentrancy_lock(get_contract): # if re-entrancy locks are incorrectly placed within storage, these # calls will either revert or correupt the data that we read later c.with_lock() - c.with_other_lock() assert c.a() == ("ok", [4, 5, 6]) assert [c.b(i) for i in range(2)] == [7, 8] @@ -105,7 +98,7 @@ def test_reentrancy_lock(get_contract): def test_allocator_overflow(get_contract): code = """ -x: uint256 +# --> global nonreentrancy slot allocated here <-- y: uint256[max_value(uint256)] """ with pytest.raises( diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 604bc6b594..bb4322c7b2 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -43,10 +43,15 @@ def __setitem__(self, k, v): super().__setitem__(k, v) +# some name that the user cannot assign to a variable +GLOBAL_NONREENTRANT_KEY = "$.nonreentrant_key" + + class SimpleAllocator: def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): # Allocate storage slots from 0 # note storage is word-addressable, not byte-addressable + self._starting_slot = starting_slot self._slot = starting_slot self._max_slot = max_slot @@ -61,12 +66,19 @@ def allocate_slot(self, n, var_name, node=None): self._slot += n return ret + def allocate_global_nonreentrancy_slot(self): + slot = self.allocate_slot(1, GLOBAL_NONREENTRANT_KEY) + assert slot == self._starting_slot + return slot + class Allocators: storage_allocator: SimpleAllocator transient_storage_allocator: SimpleAllocator immutables_allocator: SimpleAllocator + _global_nonreentrancy_key_slot: int + def __init__(self): self.storage_allocator = SimpleAllocator(max_slot=2**256) self.transient_storage_allocator = SimpleAllocator(max_slot=2**256) @@ -82,6 +94,16 @@ def get_allocator(self, location: DataLocation): raise CompilerPanic("unreachable") # pragma: nocover + def allocate_global_nonreentrancy_slot(self): + location = get_reentrancy_key_location() + + allocator = self.get_allocator(location) + slot = allocator.allocate_global_nonreentrancy_slot() + self._global_nonreentrancy_key_slot = slot + + def get_global_nonreentrant_key_slot(self): + return self._global_nonreentrancy_key_slot + class OverridingStorageAllocator: """ @@ -127,7 +149,6 @@ def set_storage_slots_with_overrides( Returns the layout as a dict of variable name -> variable info (Doesn't handle modules, or transient storage) """ - ret: InsertableOnceDict[str, dict] = InsertableOnceDict() reserved_slots = OverridingStorageAllocator() @@ -136,15 +157,13 @@ def set_storage_slots_with_overrides( type_ = node._metadata["func_type"] # Ignore functions without non-reentrant - if type_.nonreentrant is None: + if not type_.nonreentrant: continue - variable_name = f"nonreentrant.{type_.nonreentrant}" + variable_name = GLOBAL_NONREENTRANT_KEY # re-entrant key was already identified if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(VarOffset(_slot)) continue # Expect to find this variable within the storage layout override @@ -210,6 +229,20 @@ def get_reentrancy_key_location() -> DataLocation: } +def _allocate_nonreentrant_keys(vyper_module, allocators): + SLOT = allocators.get_global_nonreentrant_key_slot() + + for node in vyper_module.get_children(vy_ast.FunctionDef): + type_ = node._metadata["func_type"] + if not type_.nonreentrant: + continue + + # a nonreentrant key can appear many times in a module but it + # only takes one slot. after the first time we see it, do not + # increment the storage slot. + type_.set_reentrancy_key_position(VarOffset(SLOT)) + + def _allocate_layout_r( vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False ) -> StorageLayout: @@ -217,42 +250,26 @@ def _allocate_layout_r( Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ + global_ = False if allocators is None: + global_ = True allocators = Allocators() + # always allocate nonreentrancy slot, so that adding or removing + # reentrancy protection from a contract does not change its layout + allocators.allocate_global_nonreentrancy_slot() ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) - for node in vyper_module.get_children(vy_ast.FunctionDef): - if immutables_only: - break - - type_ = node._metadata["func_type"] - if type_.nonreentrant is None: - continue - - variable_name = f"nonreentrant.{type_.nonreentrant}" - reentrancy_key_location = get_reentrancy_key_location() - layout_key = _LAYOUT_KEYS[reentrancy_key_location] - - # a nonreentrant key can appear many times in a module but it - # only takes one slot. after the first time we see it, do not - # increment the storage slot. - if variable_name in ret[layout_key]: - _slot = ret[layout_key][variable_name]["slot"] - type_.set_reentrancy_key_position(VarOffset(_slot)) - continue - - # TODO use one byte - or bit - per reentrancy key - # requires either an extra SLOAD or caching the value of the - # location in memory at entrance - allocator = allocators.get_allocator(reentrancy_key_location) - slot = allocator.allocate_slot(1, variable_name, node) - - type_.set_reentrancy_key_position(VarOffset(slot)) + # tag functions with the global nonreentrant key + if not immutables_only: + _allocate_nonreentrant_keys(vyper_module, allocators) + layout_key = _LAYOUT_KEYS[get_reentrancy_key_location()] # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[layout_key][variable_name] = {"type": "nonreentrant lock", "slot": slot} + if global_ and GLOBAL_NONREENTRANT_KEY not in ret[layout_key]: + slot = allocators.get_global_nonreentrant_key_slot() + ret[layout_key][GLOBAL_NONREENTRANT_KEY] = {"type": "nonreentrant lock", "slot": slot} for node in _get_allocatable(vyper_module): if isinstance(node, vy_ast.InitializesDecl): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 705470a798..43d553288e 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional, Tuple from vyper import ast as vy_ast -from vyper.ast.identifiers import validate_identifier from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ArgumentException, @@ -78,8 +77,8 @@ class ContractFunctionT(VyperType): enum indicating the external visibility of a function. state_mutability : StateMutability enum indicating the authority a function has to mutate it's own state. - nonreentrant : Optional[str] - Re-entrancy lock name. + nonreentrant : bool + Whether this function is marked `@nonreentrant` or not """ _is_callable = True @@ -93,7 +92,7 @@ def __init__( function_visibility: FunctionVisibility, state_mutability: StateMutability, from_interface: bool = False, - nonreentrant: Optional[str] = None, + nonreentrant: bool = False, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: super().__init__() @@ -107,6 +106,9 @@ def __init__( self.nonreentrant = nonreentrant self.from_interface = from_interface + # sanity check, nonreentrant used to be Optional[str] + assert isinstance(self.nonreentrant, bool) + self.ast_def = ast_def self._analysed = False @@ -279,7 +281,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=True, - nonreentrant=None, + nonreentrant=False, ast_def=funcdef, ) @@ -298,12 +300,10 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) - if nonreentrant_key is not None: - raise FunctionDeclarationException( - "nonreentrant key not allowed in interfaces", funcdef - ) + if nonreentrant: + raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef) if funcdef.name == "__init__": raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) @@ -332,7 +332,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=True, - nonreentrant=nonreentrant_key, + nonreentrant=nonreentrant, ast_def=funcdef, ) @@ -350,7 +350,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) positional_args, keyword_args = _parse_args(funcdef) @@ -403,15 +403,16 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=False, - nonreentrant=nonreentrant_key, + nonreentrant=nonreentrant, ast_def=funcdef, ) def set_reentrancy_key_position(self, position: VarOffset) -> None: if hasattr(self, "reentrancy_key_position"): raise CompilerPanic("Position was already assigned") - if self.nonreentrant is None: - raise CompilerPanic(f"No reentrant key {self}") + if not self.nonreentrant: + raise CompilerPanic(f"Not nonreentrant {self}", self.ast_def) + self.reentrancy_key_position = position @classmethod @@ -660,32 +661,30 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, bool]: function_visibility = None state_mutability = None - nonreentrant_key = None + nonreentrant_node = None for decorator in funcdef.decorator_list: if isinstance(decorator, vy_ast.Call): - if nonreentrant_key is not None: - raise StructureException( - "nonreentrant decorator is already set with key: " f"{nonreentrant_key}", - funcdef, - ) - - if decorator.get("func.id") != "nonreentrant": - raise StructureException("Decorator is not callable", decorator) - if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): - raise StructureException( - "@nonreentrant name must be given as a single string literal", decorator - ) + msg = "Decorator is not callable" + hint = None + if decorator.get("func.id") == "nonreentrant": + hint = "use `@nonreentrant` with no arguments. the " + hint += "`@nonreentrant` decorator does not accept any " + hint += "arguments since vyper 0.4.0." + raise StructureException(msg, decorator, hint=hint) + + if decorator.get("id") == "nonreentrant": + if nonreentrant_node is not None: + raise StructureException("nonreentrant decorator is already set", nonreentrant_node) if funcdef.name == "__init__": - msg = "Nonreentrant decorator disallowed on `__init__`" + msg = "`@nonreentrant` decorator disallowed on `__init__`" raise FunctionDeclarationException(msg, decorator) - nonreentrant_key = decorator.args[0].value - validate_identifier(nonreentrant_key, decorator.args[0]) + nonreentrant_node = decorator elif isinstance(decorator, vy_ast.Name): if FunctionVisibility.is_valid_value(decorator.id): @@ -726,12 +725,13 @@ def _parse_decorators( # default to nonpayable state_mutability = StateMutability.NONPAYABLE - if state_mutability == StateMutability.PURE and nonreentrant_key is not None: - raise StructureException("Cannot use reentrancy guard on pure functions", funcdef) + if state_mutability == StateMutability.PURE and nonreentrant_node is not None: + raise StructureException("Cannot use reentrancy guard on pure functions", nonreentrant_node) # assert function_visibility is not None # mypy # assert state_mutability is not None # mypy - return function_visibility, state_mutability, nonreentrant_key + nonreentrant = nonreentrant_node is not None + return function_visibility, state_mutability, nonreentrant def _parse_args(