Skip to content

Commit

Permalink
ValidationModule&AuthorizationModule: support of custom errors with t…
Browse files Browse the repository at this point in the history
…est through Hardhat + update SnapshotModule test to support Hardhat and upgrade to the latest OpenZeppelin version
  • Loading branch information
rya-sge committed Aug 16, 2023
1 parent d90a906 commit 125bb89
Show file tree
Hide file tree
Showing 24 changed files with 532 additions and 395 deletions.
22 changes: 12 additions & 10 deletions contracts/libraries/Errors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
pragma solidity ^0.8.17;

library Errors {
error InvalidTransfer(address from, address to, uint256 amount);
// CMTAT
error CMTAT_InvalidTransfer(address from, address to, uint256 amount);

// SnapshotModule
error SnapshotScheduledInThePast(uint256 time, uint256 timestamp);
error SnapshotTimestampBeforeLastSnapshot(uint256 time, uint256 lastSnapshotTimestamp);
error SnapshotTimestampAfterNextSnapshot(uint256 time, uint256 nextSnapshotTimestamp);
error SnapshotTimestampBeforePreviousSnapshot(uint256 time, uint256 previousSnapshotTimestamp);
error SnapshotAlreadyExists();
error SnapshotAlreadyDone();
error SnapshotNotScheduled();
error SnapshotNotFound();
error SnapshotNeverScheduled();
error CMTAT_SnapshotModule_SnapshotScheduledInThePast(uint256 time, uint256 timestamp);
error CMTAT_SnapshotModule_SnapshotTimestampBeforeLastSnapshot(uint256 time, uint256 lastSnapshotTimestamp);
error CMTAT_SnapshotModule_SnapshotTimestampAfterNextSnapshot(uint256 time, uint256 nextSnapshotTimestamp);
error CMTAT_SnapshotModule_SnapshotTimestampBeforePreviousSnapshot(uint256 time, uint256 previousSnapshotTimestamp);
error CMTAT_SnapshotModule_SnapshotAlreadyExists();
error CMTAT_SnapshotModule_SnapshotAlreadyDone();
error CMTAT_SnapshotModule_SnapshotNotScheduled();
error CMTAT_SnapshotModule_SnapshotNotFound();
error CMTAT_SnapshotModule_SnapshotNeverScheduled();

// Generic
error AddressZeroNotAllowed();
Expand Down
2 changes: 1 addition & 1 deletion contracts/modules/CMTAT_BASE.sol
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ abstract contract CMTAT_BASE is
uint256 amount
) internal override(ERC20Upgradeable) {
if(!ValidationModule.validateTransfer(from, to, amount)) {
revert Errors.InvalidTransfer(from, to, amount);
revert Errors.CMTAT_InvalidTransfer(from, to, amount);
}
ERC20Upgradeable._update(from, to, amount);
// We call the SnapshotModule only if the transfer is valid
Expand Down
59 changes: 40 additions & 19 deletions contracts/modules/internal/SnapshotModuleInternal.sol
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ abstract contract SnapshotModuleInternal is ERC20Upgradeable {
*/
function _scheduleSnapshot(uint256 time) internal {
// Check the time firstly to avoid an useless read of storage
if(time <= block.timestamp) revert Errors.SnapshotScheduledInThePast(time, block.timestamp);
if(time <= block.timestamp){
revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast(time, block.timestamp);
}

if (_scheduledSnapshots.length > 0) {
// We check the last snapshot on the list
if(time <= _scheduledSnapshots[_scheduledSnapshots.length - 1]) {
revert Errors.SnapshotTimestampBeforeLastSnapshot(time, _scheduledSnapshots[_scheduledSnapshots.length - 1]);
revert Errors.CMTAT_SnapshotModule_SnapshotTimestampBeforeLastSnapshot(time, _scheduledSnapshots[_scheduledSnapshots.length - 1]);
}
}
_scheduledSnapshots.push(time);
Expand All @@ -100,10 +102,14 @@ abstract contract SnapshotModuleInternal is ERC20Upgradeable {
@dev schedule a snapshot at the specified time
*/
function _scheduleSnapshotNotOptimized(uint256 time) internal {
if(time <= block.timestamp) revert Errors.SnapshotScheduledInThePast(time, block.timestamp);
if(time <= block.timestamp) {
revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast(time, block.timestamp);
}
(bool isFound, uint256 index) = _findScheduledSnapshotIndex(time);
// Perfect match
if(isFound) revert Errors.SnapshotAlreadyExists();
if(isFound){
revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists();
}
// if no upper bound match found, we push the snapshot at the end of the list
if (index == _scheduledSnapshots.length) {
_scheduledSnapshots.push(time);
Expand All @@ -127,19 +133,25 @@ abstract contract SnapshotModuleInternal is ERC20Upgradeable {
*/
function _rescheduleSnapshot(uint256 oldTime, uint256 newTime) internal {
// Check the time firstly to avoid an useless read of storage
if(oldTime <= block.timestamp) revert Errors.SnapshotAlreadyDone();
if(newTime <= block.timestamp) revert Errors.SnapshotScheduledInThePast(newTime, block.timestamp);
if(_scheduledSnapshots.length == 0) revert Errors.SnapshotNotScheduled();

if(oldTime <= block.timestamp){
revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone();
}
if(newTime <= block.timestamp){
revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast(newTime, block.timestamp);
}
if(_scheduledSnapshots.length == 0){
revert Errors.CMTAT_SnapshotModule_SnapshotNotScheduled();
}
(bool foundOld, uint256 index) = _findScheduledSnapshotIndex(oldTime);
if(!foundOld) revert Errors.SnapshotNotFound();

if(!foundOld){
revert Errors.CMTAT_SnapshotModule_SnapshotNotFound();
}
if (index + 1 < _scheduledSnapshots.length) {
if(newTime >= _scheduledSnapshots[index + 1]) revert Errors.SnapshotTimestampAfterNextSnapshot(newTime, _scheduledSnapshots[index + 1]);
if(newTime >= _scheduledSnapshots[index + 1]) revert Errors.CMTAT_SnapshotModule_SnapshotTimestampAfterNextSnapshot(newTime, _scheduledSnapshots[index + 1]);
}

if (index > 0) {
if(newTime <= _scheduledSnapshots[index - 1]) revert Errors.SnapshotTimestampBeforePreviousSnapshot(newTime, _scheduledSnapshots[index - 1]);
if(newTime <= _scheduledSnapshots[index - 1]) revert Errors.CMTAT_SnapshotModule_SnapshotTimestampBeforePreviousSnapshot(newTime, _scheduledSnapshots[index - 1]);
}

_scheduledSnapshots[index] = newTime;
Expand All @@ -152,10 +164,16 @@ abstract contract SnapshotModuleInternal is ERC20Upgradeable {
*/
function _unscheduleLastSnapshot(uint256 time) internal {
// Check the time firstly to avoid an useless read of storage
if(time <= block.timestamp) revert Errors.SnapshotAlreadyDone();
if(_scheduledSnapshots.length == 0) revert Errors.SnapshotNotScheduled();
if(time <= block.timestamp){
revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone();
}
if(_scheduledSnapshots.length == 0){
revert Errors.CMTAT_SnapshotModule_SnapshotNotScheduled();
}
// All snapshot time are unique, so we do not check the indice
if(time != _scheduledSnapshots[_scheduledSnapshots.length - 1]) revert Errors.SnapshotNeverScheduled();
if(time != _scheduledSnapshots[_scheduledSnapshots.length - 1]){
revert Errors.CMTAT_SnapshotModule_SnapshotNeverScheduled();
}
_scheduledSnapshots.pop();
emit SnapshotUnschedule(time);
}
Expand All @@ -167,9 +185,13 @@ abstract contract SnapshotModuleInternal is ERC20Upgradeable {
- Reduce the array size by deleting the last snapshot
*/
function _unscheduleSnapshotNotOptimized(uint256 time) internal {
if(time <= block.timestamp) revert Errors.SnapshotAlreadyDone();
if(time <= block.timestamp){
revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone();
}
(bool isFound, uint256 index) = _findScheduledSnapshotIndex(time);
if(!isFound) revert Errors.SnapshotNotFound();
if(!isFound){
revert Errors.CMTAT_SnapshotModule_SnapshotNotFound();
}
for (uint256 i = index; i + 1 < _scheduledSnapshots.length;) {
_scheduledSnapshots[i] = _scheduledSnapshots[i + 1];
unchecked {
Expand Down Expand Up @@ -260,8 +282,6 @@ abstract contract SnapshotModuleInternal is ERC20Upgradeable {
address to,
uint256 amount
) internal virtual override {
super._update(from, to, amount);

_setCurrentSnapshot();
if (from != address(0)) {
// for both burn and transfer
Expand All @@ -278,6 +298,7 @@ abstract contract SnapshotModuleInternal is ERC20Upgradeable {
_updateAccountSnapshot(to);
_updateTotalSupplySnapshot();
}
ERC20Upgradeable._update(from, to, amount);
}

/**
Expand Down
5 changes: 4 additions & 1 deletion contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ abstract contract CMTAT_BASE_SnapshotTest is
address to,
uint256 amount
) internal override(SnapshotModuleInternal, ERC20Upgradeable) {
if(!ValidationModule.validateTransfer(from, to, amount)) revert Errors.InvalidTransfer(from, to, amount);
// We call the SnapshotModule only if the transfer is valid
if(!ValidationModule.validateTransfer(from, to, amount)) revert Errors.CMTAT_InvalidTransfer(from, to, amount);
/*
We do not call ERC20Upgradeable._update(from, to, amount) here because it is called inside the SnapshotModule
*/
/*
SnapshotModule:
Add this call in case you add the SnapshotModule
Expand Down
2 changes: 1 addition & 1 deletion contracts/test/killTest/CMTATKillTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ contract CMTAT_KILL_TEST is
address to,
uint256 amount
) internal view override(ERC20Upgradeable) {
if(!ValidationModule.validateTransfer(from, to, amount)) revert Errors.InvalidTransfer(from, to, amount);
if(!ValidationModule.validateTransfer(from, to, amount)) revert Errors.CMTAT_InvalidTransfer(from, to, amount);
// We call the SnapshotModule only if the transfer is valid
/*
SnapshotModule:
Expand Down
8 changes: 6 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"test:burn": "npx truffle test test/standard/modules/BurnModule.test.js test/proxy/modules/BurnModule.test.js",
"test:validation": "npx truffle test test/standard/modules/ValidationModule/ValidationModule.test.js test/proxy/modules/ValidationModule/ValidationModule.test.js test/standard/modules/ValidationModule/ValidationModuleConstructor.test.js test/proxy/modules/ValidationModule/ValidationModuleConstructor.test.js test/standard/modules/ValidationModule/ValidationModuleSetRuleEngine.test.js test/proxy/modules/ValidationModule/ValidationModuleSetRuleEngine.test.js",
"test:enforcement": "npx truffle test test/standard/modules/EnforcementModule.test.js test/proxy/modules/EnforcementModule.test.js",
"test:authorization": "npx truffle test test/standard/modules/AuthorizationModule/AuthorizationModule.test.js test/proxy/modules/AuthorizationModule/AuthorizationModule.test.js test/standard/modules/AuthorizationModule/TransferAdminship.test.js test/proxy/modules/AuthorizationModule/TransferAdminship.test.js",
"test:authorization": "npx truffle test test/standard/modules/AuthorizationModule/AuthorizationModule.test.js test/proxy/modules/AuthorizationModule/AuthorizationModule.test.js",
"test:base": "npx truffle test test/standard/modules/BaseModule.test.js test/proxy/modules/BaseModule.test.js",
"test:erc20Base": "npx truffle test test/standard/modules/ERC20BaseModule.test.js test/proxy/modules/ERC20BaseModule.test.js",
"test:debt": "npx truffle test test/standard/modules/DebtModule.test.js test/proxy/modules/DebtModule.test.js",
Expand All @@ -47,7 +47,10 @@
"test:hardhat:debt": "npx hardhat test test/standard/modules/DebtModule.test.js test/proxy/modules/DebtModule.test.js",
"test:hardhat:base": "npx hardhat test test/standard/modules/BaseModule.test.js test/proxy/modules/BaseModule.test.js",
"test:hardhat:pause": "npx hardhat test test/standard/modules/PauseModule.test.js test/proxy/modules/PauseModule.test.js",
"test:hardhat:creditEvents": "npx hardhat test test/standard/modules/CreditEventsModule.test.js test/proxy/modules/CreditEventsModule.test.js"
"test:hardhat:creditEvents": "npx hardhat test test/standard/modules/CreditEventsModule.test.js test/proxy/modules/CreditEventsModule.test.js",
"test:hardhat:validation": "npx hardhat test test/standard/modules/ValidationModule/ValidationModule.test.js test/proxy/modules/ValidationModule/ValidationModule.test.js test/standard/modules/ValidationModule/ValidationModuleConstructor.test.js test/proxy/modules/ValidationModule/ValidationModuleConstructor.test.js test/standard/modules/ValidationModule/ValidationModuleSetRuleEngine.test.js test/proxy/modules/ValidationModule/ValidationModuleSetRuleEngine.test.js",
"test:hardhat:authorization": "npx hardhat test test/standard/modules/AuthorizationModule/AuthorizationModule.test.js test/proxy/modules/AuthorizationModule/AuthorizationModule.test.js",
"test:hardhat:snapshot": "npx hardhat test test/standard/modules/SnapshotModule.test.js test/proxy/modules/SnapshotModule.test.js"
},
"repository": {
"type": "git",
Expand All @@ -69,6 +72,7 @@
"homepage": "https://github.com/CMTA/CMTAT",
"devDependencies": {
"@nomicfoundation/hardhat-ethers": "^3.0.4",
"@nomicfoundation/hardhat-network-helpers": "^1.0.8",
"@nomiclabs/hardhat-truffle5": "^2.0.7",
"@nomiclabs/hardhat-web3": "^2.0.0",
"@openzeppelin/hardhat-upgrades": "^2.1.1",
Expand Down
29 changes: 14 additions & 15 deletions test/common/AuthorizationModule/AuthorizationModuleCommon.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const { expectEvent, expectRevert } = require('@openzeppelin/test-helpers')
const { PAUSER_ROLE } = require('../../utils')
const { expectRevertCustomError } = require('../../../openzeppelin-contracts-upgradeable/test/helpers/customError')
const { PAUSER_ROLE, DEFAULT_ADMIN_ROLE } = require('../../utils')
const chai = require('chai')
const expect = chai.expect
const should = chai.should()
Expand All @@ -8,15 +9,15 @@ function AuthorizationModuleCommon (owner, address1, address2) {
context('Authorization', function () {
it('testAdminCanGrantRole', async function () {
// Act
({ logs: this.logs } = await this.cmtat.grantRole(
this.logs = await this.cmtat.grantRole(
PAUSER_ROLE,
address1,
{ from: owner }
));
);
// Assert
(await this.cmtat.hasRole(PAUSER_ROLE, address1)).should.equal(true)
// emits a RoleGranted event
expectEvent.inLogs(this.logs, 'RoleGranted', {
expectEvent(this.logs, 'RoleGranted', {
role: PAUSER_ROLE,
account: address1,
sender: owner
Expand All @@ -29,15 +30,15 @@ function AuthorizationModuleCommon (owner, address1, address2) {
// Arrange - Assert
(await this.cmtat.hasRole(PAUSER_ROLE, address1)).should.equal(true);
// Act
({ logs: this.logs } = await this.cmtat.revokeRole(
this.logs = await this.cmtat.revokeRole(
PAUSER_ROLE,
address1,
{ from: owner }
));
);
// Assert
(await this.cmtat.hasRole(PAUSER_ROLE, address1)).should.equal(false)
// emits a RoleRevoked event
expectEvent.inLogs(this.logs, 'RoleRevoked', {
expectEvent(this.logs, 'RoleRevoked', {
role: PAUSER_ROLE,
account: address1,
sender: owner
Expand All @@ -48,11 +49,10 @@ function AuthorizationModuleCommon (owner, address1, address2) {
// Arrange - Assert
(await this.cmtat.hasRole(PAUSER_ROLE, address1)).should.equal(false)
// Act
await expectRevert(
await expectRevertCustomError(
this.cmtat.grantRole(PAUSER_ROLE, address1, { from: address2 }),
'AccessControl: account ' +
address2.toLowerCase() +
' is missing role 0x0000000000000000000000000000000000000000000000000000000000000000'
'AccessControlUnauthorizedAccount',
[address2, DEFAULT_ADMIN_ROLE]
);
// Assert
(await this.cmtat.hasRole(PAUSER_ROLE, address1)).should.equal(false)
Expand All @@ -65,11 +65,10 @@ function AuthorizationModuleCommon (owner, address1, address2) {
// Arrange - Assert
(await this.cmtat.hasRole(PAUSER_ROLE, address1)).should.equal(true)
// Act
await expectRevert(
await expectRevertCustomError(
this.cmtat.revokeRole(PAUSER_ROLE, address1, { from: address2 }),
'AccessControl: account ' +
address2.toLowerCase() +
' is missing role 0x0000000000000000000000000000000000000000000000000000000000000000'
'AccessControlUnauthorizedAccount',
[address2, DEFAULT_ADMIN_ROLE]
);
// Assert
(await this.cmtat.hasRole(PAUSER_ROLE, address1)).should.equal(true)
Expand Down
Loading

0 comments on commit 125bb89

Please sign in to comment.