diff --git a/packages/price-oracle/test/PriceOracle.test.ts b/packages/price-oracle/test/PriceOracle.test.ts index 2f1e9daf..52105270 100644 --- a/packages/price-oracle/test/PriceOracle.test.ts +++ b/packages/price-oracle/test/PriceOracle.test.ts @@ -1551,10 +1551,8 @@ describe('PriceOracle', () => { }) describe('getPrice (off-chain)', () => { - let base: Contract, - quote: Contract, - feed: Contract, - data = '0x' + let data = '0x' + let base: Contract, quote: Contract, feed: Contract const OFF_CHAIN_ORACLE_PRICE = fp(5) const SMART_VAULT_ORACLE_PRICE = fp(10) diff --git a/packages/tasks/contracts/base/BaseTask.sol b/packages/tasks/contracts/base/BaseTask.sol index 6730a41b..566a8f62 100644 --- a/packages/tasks/contracts/base/BaseTask.sol +++ b/packages/tasks/contracts/base/BaseTask.sol @@ -147,8 +147,12 @@ abstract contract BaseTask is IBaseTask, Authorized { */ function _getPrice(address base, address quote) internal view virtual returns (uint256) { address priceOracle = ISmartVault(smartVault).priceOracle(); - if (priceOracle == address(0)) revert TaskSmartVaultPriceOracleNotSet(base, quote); - return IPriceOracle(priceOracle).getPrice(_wrappedIfNative(base), _wrappedIfNative(quote)); + if (priceOracle == address(0)) revert TaskSmartVaultPriceOracleNotSet(smartVault); + bytes memory extraCallData = _decodeExtraCallData(); + return + extraCallData.length == 0 + ? IPriceOracle(priceOracle).getPrice(_wrappedIfNative(base), _wrappedIfNative(quote)) + : IPriceOracle(priceOracle).getPrice(_wrappedIfNative(base), _wrappedIfNative(quote), extraCallData); } /** @@ -172,4 +176,29 @@ abstract contract BaseTask is IBaseTask, Authorized { function _wrappedNativeToken() internal view returns (address) { return ISmartVault(smartVault).wrappedNativeToken(); } + + /** + * @dev Decodes any potential extra calldata stored in the calldata space. Tasks relying on the extra calldata + * pattern, assume that the last word of the calldata stores the extra calldata length so it can be decoded. Note + * that tasks relying on this pattern must contemplate this function may return bogus data if no extra calldata + * was given. + */ + function _decodeExtraCallData() private pure returns (bytes memory data) { + uint256 length = uint256(_decodeLastCallDataWord()); + if (msg.data.length < length) return new bytes(0); + data = new bytes(length); + assembly { + calldatacopy(add(data, 0x20), sub(sub(calldatasize(), length), 0x20), length) + } + } + + /** + * @dev Returns the last calldata word. This function returns zero if the calldata is not long enough. + */ + function _decodeLastCallDataWord() private pure returns (bytes32 result) { + if (msg.data.length < 36) return bytes32(0); + assembly { + result := calldataload(sub(calldatasize(), 0x20)) + } + } } diff --git a/packages/tasks/contracts/interfaces/base/IBaseTask.sol b/packages/tasks/contracts/interfaces/base/IBaseTask.sol index 37b4f9be..ee7b0387 100644 --- a/packages/tasks/contracts/interfaces/base/IBaseTask.sol +++ b/packages/tasks/contracts/interfaces/base/IBaseTask.sol @@ -32,7 +32,7 @@ interface IBaseTask is IAuthorized { /** * @dev The smart vault's price oracle is not set */ - error TaskSmartVaultPriceOracleNotSet(address base, address quote); + error TaskSmartVaultPriceOracleNotSet(address smartVault); /** * @dev Emitted every time a task is executed diff --git a/packages/tasks/test/base/BaseTask.test.ts b/packages/tasks/test/base/BaseTask.test.ts index a24eff7f..2eaf09a2 100644 --- a/packages/tasks/test/base/BaseTask.test.ts +++ b/packages/tasks/test/base/BaseTask.test.ts @@ -221,6 +221,7 @@ describe('BaseTask', () => { context('when there is not enough balance in the connector', () => { it('reverts', async () => { + // TODO: Hardhat does not decode smart vault error properly await expect(task.call(token, amount)).to.be.reverted }) }) diff --git a/packages/tasks/test/swap/OneInchV5Swapper.test.ts b/packages/tasks/test/swap/OneInchV5Swapper.test.ts index c8464d1d..3160e7e9 100644 --- a/packages/tasks/test/swap/OneInchV5Swapper.test.ts +++ b/packages/tasks/test/swap/OneInchV5Swapper.test.ts @@ -1,19 +1,22 @@ import { OP } from '@mimic-fi/v3-authorizer' import { - assertEvent, assertIndirectEvent, + BigNumberish, deploy, deployFeedMock, deployProxy, deployTokenMock, fp, getSigners, + MAX_UINT256, ZERO_ADDRESS, ZERO_BYTES32, } from '@mimic-fi/v3-helpers' import { SignerWithAddress } from '@nomiclabs/hardhat-ethers/dist/src/signer-with-address' import { expect } from 'chai' -import { Contract } from 'ethers' +import { Contract, ContractTransaction } from 'ethers' +import { defaultAbiCoder } from 'ethers/lib/utils' +import { ethers } from 'hardhat' import { buildEmptyTaskConfig, deployEnvironment } from '../../src/setup' import { itBehavesLikeBaseSwapTask } from './BaseSwapTask.behavior' @@ -93,7 +96,8 @@ describe('OneInchV5Swapper', () => { context('when the token in is allowed', () => { context('when there is a token out set', () => { - let tokenOut: Contract + let tokenOut: Contract, + extraCallData = '' beforeEach('set default token out', async () => { tokenOut = await deployTokenMock('TKN') @@ -102,113 +106,246 @@ describe('OneInchV5Swapper', () => { await task.connect(owner).setDefaultTokenOut(tokenOut.address) }) - beforeEach('set price feed', async () => { - const feed = await deployFeedMock(fp(tokenRate), 18) - const setFeedRole = priceOracle.interface.getSighash('setFeed') - await authorizer.connect(owner).authorize(owner.address, priceOracle.address, setFeedRole, []) - await priceOracle.connect(owner).setFeed(tokenIn.address, tokenOut.address, feed.address) - }) - - beforeEach('set threshold', async () => { - const setDefaultTokenThresholdRole = task.interface.getSighash('setDefaultTokenThreshold') - await authorizer.connect(owner).authorize(owner.address, task.address, setDefaultTokenThresholdRole, []) - await task.connect(owner).setDefaultTokenThreshold(tokenOut.address, thresholdAmount, 0) - }) + context('when an off-chain oracle is given', () => { + beforeEach('sign off-chain oracle', async () => { + const setSignerRole = priceOracle.interface.getSighash('setSigner') + await authorizer.connect(owner).authorize(owner.address, priceOracle.address, setSignerRole, []) + await priceOracle.connect(owner).setSigner(owner.address, true) + + type PriceData = { base: string; quote: string; rate: BigNumberish; deadline: BigNumberish } + const pricesData: PriceData[] = [ + { + base: tokenIn.address, + quote: tokenOut.address, + rate: fp(tokenRate), + deadline: MAX_UINT256, + }, + { + base: tokenOut.address, + quote: tokenIn.address, + rate: fp(1).mul(fp(1)).div(fp(tokenRate)), + deadline: MAX_UINT256, + }, + ] + + const PricesDataType = 'PriceData(address base, address quote, uint256 rate, uint256 deadline)[]' + const encodedPrices = await defaultAbiCoder.encode([PricesDataType], [pricesData]) + const message = ethers.utils.solidityKeccak256(['bytes'], [encodedPrices]) + const signature = await owner.signMessage(ethers.utils.arrayify(message)) + const data = defaultAbiCoder.encode([PricesDataType, 'bytes'], [pricesData, signature]).slice(2) + const dataLength = defaultAbiCoder.encode(['uint256'], [data.length / 2]).slice(2) + extraCallData = `${data}${dataLength}` + }) - context('when the smart vault balance passes the threshold', () => { - beforeEach('fund smart vault', async () => { - await tokenIn.mint(smartVault.address, amountIn) + beforeEach('set threshold', async () => { + const setDefaultTokenThresholdRole = task.interface.getSighash('setDefaultTokenThreshold') + await authorizer + .connect(owner) + .authorize(owner.address, task.address, setDefaultTokenThresholdRole, []) + await task.connect(owner).setDefaultTokenThreshold(tokenOut.address, thresholdAmount, 0) }) - context('when the slippage is below the limit', () => { - const data = '0xaabb' - const slippage = fp(0.01) - const expectedAmountOut = amountIn.mul(tokenRate) - const minAmountOut = expectedAmountOut.mul(fp(1).sub(slippage)).div(fp(1)) + const executeTask = async (amountIn, slippage, data): Promise => { + const callTx = await task.populateTransaction.call(tokenIn.address, amountIn, slippage, data) + const callData = `${callTx.data}${extraCallData}` + return owner.sendTransaction({ to: task.address, data: callData }) + } - beforeEach('set max slippage', async () => { - const setDefaultMaxSlippageRole = task.interface.getSighash('setDefaultMaxSlippage') - await authorizer - .connect(owner) - .authorize(owner.address, task.address, setDefaultMaxSlippageRole, []) - await task.connect(owner).setDefaultMaxSlippage(slippage) + context('when the smart vault balance passes the threshold', () => { + beforeEach('fund smart vault', async () => { + await tokenIn.mint(smartVault.address, amountIn) }) - it('executes the expected connector', async () => { - const tx = await task.call(tokenIn.address, amountIn, slippage, data) + context('when the slippage is below the limit', () => { + const data = '0xaabb' + const slippage = fp(0.01) + const expectedAmountOut = amountIn.mul(tokenRate) + const minAmountOut = expectedAmountOut.mul(fp(1).sub(slippage)).div(fp(1)) + + beforeEach('set max slippage', async () => { + const setDefaultMaxSlippageRole = task.interface.getSighash('setDefaultMaxSlippage') + await authorizer + .connect(owner) + .authorize(owner.address, task.address, setDefaultMaxSlippageRole, []) + await task.connect(owner).setDefaultMaxSlippage(slippage) + }) + + it('executes the expected connector', async () => { + const tx = await executeTask(amountIn, slippage, data) + + const connectorData = connector.interface.encodeFunctionData('execute', [ + tokenIn.address, + tokenOut.address, + amountIn, + minAmountOut, + data, + ]) + + await assertIndirectEvent(tx, smartVault.interface, 'Executed', { + connector, + data: connectorData, + }) + + await assertIndirectEvent(tx, connector.interface, 'LogExecute', { + tokenIn, + tokenOut, + amountIn, + minAmountOut, + data, + }) + }) - const connectorData = connector.interface.encodeFunctionData('execute', [ - tokenIn.address, - tokenOut.address, - amountIn, - minAmountOut, - data, - ]) + it('emits an Executed event', async () => { + const tx = await executeTask(amountIn, slippage, data) - await assertIndirectEvent(tx, smartVault.interface, 'Executed', { - connector, - data: connectorData, + await assertIndirectEvent(tx, task.interface, 'Executed') }) - await assertIndirectEvent(tx, connector.interface, 'LogExecute', { - tokenIn, - tokenOut, - amountIn, - minAmountOut, - data, + it('updates the balance connectors properly', async () => { + const nextConnectorId = '0x0000000000000000000000000000000000000000000000000000000000000002' + const setBalanceConnectorsRole = task.interface.getSighash('setBalanceConnectors') + await authorizer + .connect(owner) + .authorize(owner.address, task.address, setBalanceConnectorsRole, []) + await task.connect(owner).setBalanceConnectors(ZERO_BYTES32, nextConnectorId) + + const updateBalanceConnectorRole = smartVault.interface.getSighash('updateBalanceConnector') + await authorizer + .connect(owner) + .authorize(task.address, smartVault.address, updateBalanceConnectorRole, []) + + const tx = await executeTask(amountIn, slippage, data) + + await assertIndirectEvent(tx, smartVault.interface, 'BalanceConnectorUpdated', { + id: nextConnectorId, + token: tokenOut.address, + amount: minAmountOut, + added: true, + }) }) }) - it('emits an Executed event', async () => { - const tx = await task.call(tokenIn.address, amountIn, slippage, data) + context('when the slippage is above the limit', () => { + const slippage = fp(0.01) - await assertEvent(tx, 'Executed') + it('reverts', async () => { + await expect(executeTask(amountIn, slippage, '0x')).to.be.revertedWith('TaskSlippageAboveMax') + }) + }) + }) + + context('when the smart vault balance does not pass the threshold', () => { + const amountIn = thresholdAmountInTokenIn.div(2) + + beforeEach('fund smart vault', async () => { + await tokenIn.mint(smartVault.address, amountIn) + }) + + it('reverts', async () => { + await expect(executeTask(amountIn, 0, '0x')).to.be.revertedWith('TaskTokenThresholdNotMet') }) + }) + }) - it('updates the balance connectors properly', async () => { - const nextConnectorId = '0x0000000000000000000000000000000000000000000000000000000000000002' - const setBalanceConnectorsRole = task.interface.getSighash('setBalanceConnectors') - await authorizer.connect(owner).authorize(owner.address, task.address, setBalanceConnectorsRole, []) - await task.connect(owner).setBalanceConnectors(ZERO_BYTES32, nextConnectorId) + context('when no off-chain oracle is given', () => { + context('when an on-chain oracle is given', () => { + beforeEach('set price feed', async () => { + const feed = await deployFeedMock(fp(tokenRate), 18) + const setFeedRole = priceOracle.interface.getSighash('setFeed') + await authorizer.connect(owner).authorize(owner.address, priceOracle.address, setFeedRole, []) + await priceOracle.connect(owner).setFeed(tokenIn.address, tokenOut.address, feed.address) + }) - const updateBalanceConnectorRole = smartVault.interface.getSighash('updateBalanceConnector') + beforeEach('set threshold', async () => { + const setDefaultTokenThresholdRole = task.interface.getSighash('setDefaultTokenThreshold') await authorizer .connect(owner) - .authorize(task.address, smartVault.address, updateBalanceConnectorRole, []) + .authorize(owner.address, task.address, setDefaultTokenThresholdRole, []) + await task.connect(owner).setDefaultTokenThreshold(tokenOut.address, thresholdAmount, 0) + }) - const tx = await task.call(tokenIn.address, amountIn, slippage, data) + context('when the smart vault balance passes the threshold', () => { + beforeEach('fund smart vault', async () => { + await tokenIn.mint(smartVault.address, amountIn) + }) - await assertIndirectEvent(tx, smartVault.interface, 'BalanceConnectorUpdated', { - id: nextConnectorId, - token: tokenOut.address, - amount: minAmountOut, - added: true, + context('when the slippage is below the limit', () => { + const data = '0xaabb' + const slippage = fp(0.01) + const expectedAmountOut = amountIn.mul(tokenRate) + const minAmountOut = expectedAmountOut.mul(fp(1).sub(slippage)).div(fp(1)) + + beforeEach('set max slippage', async () => { + const setDefaultMaxSlippageRole = task.interface.getSighash('setDefaultMaxSlippage') + await authorizer + .connect(owner) + .authorize(owner.address, task.address, setDefaultMaxSlippageRole, []) + await task.connect(owner).setDefaultMaxSlippage(slippage) + }) + + it('executes the expected connector', async () => { + const tx = await task.call(tokenIn.address, amountIn, slippage, data) + + const connectorData = connector.interface.encodeFunctionData('execute', [ + tokenIn.address, + tokenOut.address, + amountIn, + minAmountOut, + data, + ]) + + await assertIndirectEvent(tx, smartVault.interface, 'Executed', { + connector, + data: connectorData, + }) + + await assertIndirectEvent(tx, connector.interface, 'LogExecute', { + tokenIn, + tokenOut, + amountIn, + minAmountOut, + data, + }) + }) + + it('emits an Executed event', async () => { + const tx = await task.call(tokenIn.address, amountIn, slippage, data) + + await assertIndirectEvent(tx, task.interface, 'Executed') + }) }) - }) - }) - context('when the slippage is above the limit', () => { - const slippage = fp(0.01) + context('when the slippage is above the limit', () => { + const slippage = fp(0.01) - it('reverts', async () => { - await expect(task.call(tokenIn.address, amountIn, slippage, '0x')).to.be.revertedWith( - 'TaskSlippageAboveMax' - ) + it('reverts', async () => { + await expect(task.call(tokenIn.address, amountIn, slippage, '0x')).to.be.revertedWith( + 'TaskSlippageAboveMax' + ) + }) + }) }) - }) - }) - context('when the smart vault balance does not pass the threshold', () => { - const amountIn = thresholdAmountInTokenIn.div(2) + context('when the smart vault balance does not pass the threshold', () => { + const amountIn = thresholdAmountInTokenIn.div(2) + + beforeEach('fund smart vault', async () => { + await tokenIn.mint(smartVault.address, amountIn) + }) - beforeEach('fund smart vault', async () => { - await tokenIn.mint(smartVault.address, amountIn) + it('reverts', async () => { + await expect(task.call(tokenIn.address, amountIn, 0, '0x')).to.be.revertedWith( + 'TaskTokenThresholdNotMet' + ) + }) + }) }) - it('reverts', async () => { - await expect(task.call(tokenIn.address, amountIn, 0, '0x')).to.be.revertedWith( - 'TaskTokenThresholdNotMet' - ) + context('when no on-chain oracle is given', () => { + it('reverts', async () => { + // TODO: Hardhat does not decode price oracle error properly + await expect(task.call(tokenIn.address, amountIn, 0, '0x')).to.be.reverted + }) }) }) })