Skip to content
This repository was archived by the owner on May 23, 2023. It is now read-only.

Negative test for a reentrant attack on the core relayer forward mechanism #83

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions ethereum/contracts/mock/AttackForwardIntegration.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.17;

import "@openzeppelin/contracts/token/ERC20/ERC20.sol";

import "../interfaces/IWormhole.sol";
import "../interfaces/IWormholeReceiver.sol";
import "../interfaces/ICoreRelayer.sol";

/**
* This contract is a malicious "integration" that attempts to attack the forward mechanism.
*/
contract AttackForwardIntegration is IWormholeReceiver {
mapping(bytes32 => bool) consumedMessages;
address attackerReward;
IWormhole wormhole;
ICoreRelayer core_relayer;
uint32 nonce = 1;
uint16 targetChainId;

// Capture 30k gas for fees
// This just needs to be enough to pay for the call to the destination address.
uint32 SAFE_DELIVERY_GAS_CAPTURE = 30000;

constructor(IWormhole initWormhole, ICoreRelayer initCoreRelayer, uint16 chainId, address initAttackerReward) {
attackerReward = initAttackerReward;
wormhole = initWormhole;
core_relayer = initCoreRelayer;
targetChainId = chainId;
}

// This is the function which receives all messages from the remote contracts.
function receiveWormholeMessages(bytes[] memory vaas, bytes[] memory additionalData) public payable override {
// Do nothing. The attacker doesn't care about this message; he sends it himself.
}

receive() external payable {
// Request forward from the relayer network
// The core relayer could in principle accept the request due to this being the target of the message at the same time as being the refund address.
// Note that, if succesful, this forward request would be processed after the time for processing forwards is past.
// Thus, the request would "linger" in the forward request cache and be attended to in the next delivery.
requestForward(targetChainId, toWormholeFormat(attackerReward));
}

function requestForward(uint16 targetChain, bytes32 attackerRewardAddress) internal {
uint256 computeBudget = core_relayer.quoteGasDeliveryFee(
targetChain, SAFE_DELIVERY_GAS_CAPTURE, core_relayer.getDefaultRelayProvider()
);

ICoreRelayer.DeliveryRequest memory request = ICoreRelayer.DeliveryRequest({
targetChain: targetChain,
targetAddress: attackerRewardAddress,
// All remaining funds will be returned to the attacker
refundAddress: attackerRewardAddress,
computeBudget: computeBudget,
applicationBudget: 0,
relayParameters: core_relayer.getDefaultRelayParams()
});

core_relayer.requestForward{value: computeBudget}(request, nonce, core_relayer.getDefaultRelayProvider());
}

function toWormholeFormat(address addr) public pure returns (bytes32 whFormat) {
return bytes32(uint256(uint160(addr)));
}
}
137 changes: 127 additions & 10 deletions ethereum/forge-test/CoreRelayer.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {Wormhole} from "../wormhole/ethereum/contracts/Wormhole.sol";
import {IWormhole} from "../contracts/interfaces/IWormhole.sol";
import {WormholeSimulator} from "./WormholeSimulator.sol";
import {IWormholeReceiver} from "../contracts/interfaces/IWormholeReceiver.sol";
import {AttackForwardIntegration} from "../contracts/mock/AttackForwardIntegration.sol";
import {MockRelayerIntegration} from "../contracts/mock/MockRelayerIntegration.sol";
import "../contracts/libraries/external/BytesLib.sol";

Expand Down Expand Up @@ -495,6 +496,106 @@ contract TestCoreRelayer is Test {
assertTrue(keccak256(setup.source.integration.getMessage()) == keccak256(bytes("received!")));
}

function testAttackForwardRequestCache(GasParameters memory gasParams, FeeParameters memory feeParams) public {
// General idea:
// 1. Attacker sets up a malicious integration contract in the target chain.
// 2. Attacker requests a message send to `target` chain.
// The message destination and the refund address are both the malicious integration contract in the target chain.
// 3. The delivery of the message triggers a refund to the malicious integration contract.
// 4. During the refund, the integration contract activates the forwarding mechanism.
// This is allowed due to the integration contract also being the target of the delivery.
// 5. The forward request is left as is in the `CoreRelayer` state.
// 6. The next message (i.e. the victim's message) delivery on `target` chain, from any relayer, using any `RelayProvider` and any integration contract,
// will see the forward request placed by the malicious integration contract and act on it.
// Caveat: the delivery of the victim's message must not invoke the forwarding mechanism for the attack test to be meaningful.
//
// In essence, this tries to attack the shared forwarding request cache present in the contract state.
// This attack doesn't work thanks to the check inside the `requestForward` function that only allows requesting a forward when there is a delivery being processed.

StandardSetupTwoChains memory setup = standardAssumeAndSetupTwoChains(gasParams, feeParams, 1000000);

// Collected funds from the attack are meant to be sent here.
address attackerSourceAddress =
address(uint160(uint256(keccak256(abi.encodePacked(bytes("attackerAddress"), setup.sourceChainId)))));
assertTrue(attackerSourceAddress.balance == 0);

// Borrowed assumes from testForward. They should help since this test is similar.
vm.assume(
uint256(1) * gasParams.targetGasPrice * feeParams.targetNativePrice
> uint256(1) * gasParams.sourceGasPrice * feeParams.sourceNativePrice
);

vm.assume(
setup.source.coreRelayer.quoteGasDeliveryFee(
setup.targetChainId, gasParams.targetGasLimit, setup.source.relayProvider
) < uint256(2) ** 222
);
vm.assume(
setup.target.coreRelayer.quoteGasDeliveryFee(setup.sourceChainId, 500000, setup.target.relayProvider)
< uint256(2) ** 222 / feeParams.targetNativePrice
);

// Estimate the cost based on the initialized values
uint256 computeBudget = setup.source.coreRelayer.quoteGasDeliveryFee(
setup.targetChainId, gasParams.targetGasLimit, setup.source.relayProvider
);

{
AttackForwardIntegration attackerContract =
new AttackForwardIntegration(setup.target.wormhole, setup.target.coreRelayer, setup.targetChainId, attackerSourceAddress);
bytes memory attackMsg = "attack";

vm.recordLogs();

// The attacker requests the message to be sent to the malicious contract.
// It is critical that the refund and destination (aka integrator) addresses are the same.
setup.source.integration.sendMessage{value: computeBudget + 2 * setup.source.wormhole.messageFee()}(
attackMsg, setup.targetChainId, address(attackerContract), address(attackerContract)
);

// The relayer triggers the call to the malicious contract.
genericRelayer(setup.sourceChainId, 2);

// The message delivery should fail
assertTrue(keccak256(setup.target.integration.getMessage()) != keccak256(attackMsg));
}

{
// Now one victim sends their message. It doesn't need to be from the same source chain.
// What's necessary is that a message is delivered to the chain targeted by the attacker.
bytes memory victimMsg = "relay my message";

uint256 victimBalancePreDelivery = setup.target.refundAddress.balance;

// We will reutilize the compute budget estimated for the attacker to simplify the code here.
// The victim requests their message to be sent.
setup.source.integration.sendMessage{value: computeBudget + 2 * setup.source.wormhole.messageFee()}(
victimMsg, setup.targetChainId, address(setup.target.integration), address(setup.target.refundAddress)
);

// The relayer delivers the victim's message.
// During the delivery process, the forward request injected by the malicious contract is acknowledged.
// The victim's refund address is not called due to this.
genericRelayer(setup.sourceChainId, 2);

// Ensures the message was received.
assertTrue(keccak256(setup.target.integration.getMessage()) == keccak256(victimMsg));
// Here we assert that the victim's refund is safe.
assertTrue(victimBalancePreDelivery < setup.target.refundAddress.balance);
}

Vm.Log[] memory entries = relayerWormholeSimulator.fetchWormholeMessageFromLog(vm.getRecordedLogs());
if (entries.length > 0) {
// There was a wormhole message produced.
// If the attack is successful this is a forward.
// We'll invoke the relay simulation here and later assert that the attack wasn't successful.
// Relay from target chain to source chain.
genericRelayerProcessLogs(setup.targetChainId, entries);
}
// Assert that the attack wasn't successful.
assertTrue(attackerSourceAddress.balance == 0);
}

function testRedelivery(GasParameters memory gasParams, FeeParameters memory feeParams, bytes memory message)
public
{
Expand Down Expand Up @@ -1219,18 +1320,34 @@ contract TestCoreRelayer is Test {
mapping(bytes32 => ICoreRelayer.TargetDeliveryParametersSingle) pastDeliveries;

function genericRelayer(uint16 chainId, uint8 num) internal {
bytes[] memory encodedVMs = new bytes[](num);
{
// Filters all events to just the wormhole messages.
Vm.Log[] memory entries = relayerWormholeSimulator.fetchWormholeMessageFromLog(vm.getRecordedLogs());
assertTrue(entries.length >= num);
for (uint256 i = 0; i < num; i++) {
encodedVMs[i] = relayerWormholeSimulator.fetchSignedMessageFromLogs(
entries[i], chainId, address(uint160(uint256(bytes32(entries[i].topics[1]))))
);
}
Vm.Log[] memory entries = truncateRecordedLogs(chainId, num);
genericRelayerProcessLogs(chainId, entries);
}

/**
* Discards wormhole events beyond `num` events.
* Expects at least `num` wormhole events.
*/
function truncateRecordedLogs(uint16 chainId, uint8 num) internal returns (Vm.Log[] memory) {
// Filters all events to just the wormhole messages.
Vm.Log[] memory entries = relayerWormholeSimulator.fetchWormholeMessageFromLog(vm.getRecordedLogs());
// We expect at least `num` events.
assertTrue(entries.length >= num);

Vm.Log[] memory firstEntries = new Vm.Log[](num);
for (uint256 i = 0; i < num; i++) {
firstEntries[i] = entries[i];
}
return firstEntries;
}

function genericRelayerProcessLogs(uint16 chainId, Vm.Log[] memory entries) internal {
bytes[] memory encodedVMs = new bytes[](entries.length);
for (uint256 i = 0; i < encodedVMs.length; i++) {
encodedVMs[i] = relayerWormholeSimulator.fetchSignedMessageFromLogs(
entries[i], chainId, address(uint160(uint256(bytes32(entries[i].topics[1]))))
);
}
IWormhole.VM[] memory parsed = new IWormhole.VM[](encodedVMs.length);
for (uint16 i = 0; i < encodedVMs.length; i++) {
parsed[i] = relayerWormhole.parseVM(encodedVMs[i]);
Expand Down