diff --git a/.changeset/cool-mangos-compare.md b/.changeset/cool-mangos-compare.md new file mode 100644 index 00000000000..470ee089456 --- /dev/null +++ b/.changeset/cool-mangos-compare.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Math`: add an `invMod` function to get the modular multiplicative inverse of a number in Z/nZ. diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 3a1d5a4b24d..a826dfd9656 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -121,9 +121,10 @@ library Math { } /** - * @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or + * @dev Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or * denominator == 0. - * @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by + * + * Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by * Uniswap Labs also under MIT license. */ function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) { @@ -208,7 +209,7 @@ library Math { } /** - * @notice Calculates x * y / denominator with full precision, following the selected rounding direction. + * @dev Calculates x * y / denominator with full precision, following the selected rounding direction. */ function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) { uint256 result = mulDiv(x, y, denominator); @@ -218,6 +219,62 @@ library Math { return result; } + /** + * @dev Calculate the modular multiplicative inverse of a number in Z/nZ. + * + * If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0. + * If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible. + * + * If the input value is not inversible, 0 is returned. + */ + function invMod(uint256 a, uint256 n) internal pure returns (uint256) { + unchecked { + if (n == 0) return 0; + + // The inverse modulo is calculated using the Extended Euclidean Algorithm (iterative version) + // Used to compute integers x and y such that: ax + ny = gcd(a, n). + // When the gcd is 1, then the inverse of a modulo n exists and it's x. + // ax + ny = 1 + // ax = 1 + (-y)n + // ax ≡ 1 (mod n) # x is the inverse of a modulo n + + // If the remainder is 0 the gcd is n right away. + uint256 remainder = a % n; + uint256 gcd = n; + + // Therefore the initial coefficients are: + // ax + ny = gcd(a, n) = n + // 0a + 1n = n + int256 x = 0; + int256 y = 1; + + while (remainder != 0) { + uint256 quotient = gcd / remainder; + + (gcd, remainder) = ( + // The old remainder is the next gcd to try. + remainder, + // Compute the next remainder. + // Can't overflow given that (a % gcd) * (gcd // (a % gcd)) <= gcd + // where gcd is at most n (capped to type(uint256).max) + gcd - remainder * quotient + ); + + (x, y) = ( + // Increment the coefficient of a. + y, + // Decrement the coefficient of n. + // Can overflow, but the result is casted to uint256 so that the + // next value of y is "wrapped around" to a value between 0 and n - 1. + x - y * int256(quotient) + ); + } + + if (gcd != 1) return 0; // No inverse exists. + return x < 0 ? (n - uint256(-x)) : uint256(x); // Wrap the result if it's negative. + } + } + /** * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded * towards zero. @@ -258,7 +315,7 @@ library Math { } /** - * @notice Calculates sqrt(a), following the selected rounding direction. + * @dev Calculates sqrt(a), following the selected rounding direction. */ function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) { unchecked { diff --git a/test/utils/math/Math.t.sol b/test/utils/math/Math.t.sol index a757833796f..75d28041dc8 100644 --- a/test/utils/math/Math.t.sol +++ b/test/utils/math/Math.t.sol @@ -55,6 +55,41 @@ contract MathTest is Test { return value * value < ref; } + // INV + function testInvMod(uint256 value, uint256 p) public { + _testInvMod(value, p, true); + } + + function testInvMod2(uint256 seed) public { + uint256 p = 2; // prime + _testInvMod(bound(seed, 1, p - 1), p, false); + } + + function testInvMod17(uint256 seed) public { + uint256 p = 17; // prime + _testInvMod(bound(seed, 1, p - 1), p, false); + } + + function testInvMod65537(uint256 seed) public { + uint256 p = 65537; // prime + _testInvMod(bound(seed, 1, p - 1), p, false); + } + + function testInvModP256(uint256 seed) public { + uint256 p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff; // prime + _testInvMod(bound(seed, 1, p - 1), p, false); + } + + function _testInvMod(uint256 value, uint256 p, bool allowZero) private { + uint256 inverse = Math.invMod(value, p); + if (inverse != 0) { + assertEq(mulmod(value, inverse, p), 1); + assertLt(inverse, p); + } else { + assertTrue(allowZero); + } + } + // LOG2 function testLog2(uint256 input, uint8 r) public { Math.Rounding rounding = _asRounding(r); diff --git a/test/utils/math/Math.test.js b/test/utils/math/Math.test.js index cb25db67cdf..abf43f0738d 100644 --- a/test/utils/math/Math.test.js +++ b/test/utils/math/Math.test.js @@ -5,6 +5,7 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic'); const { Rounding } = require('../../helpers/enums'); const { min, max } = require('../../helpers/math'); +const { randomArray, generators } = require('../../helpers/random'); const RoundingDown = [Rounding.Floor, Rounding.Trunc]; const RoundingUp = [Rounding.Ceil, Rounding.Expand]; @@ -298,6 +299,43 @@ describe('Math', function () { }); }); + describe('invMod', function () { + for (const factors of [ + [0n], + [1n], + [2n], + [17n], + [65537n], + [0xffffffff00000001000000000000000000000000ffffffffffffffffffffffffn], + [3n, 5n], + [3n, 7n], + [47n, 53n], + ]) { + const p = factors.reduce((acc, f) => acc * f, 1n); + + describe(`using p=${p} which is ${p > 1 && factors.length > 1 ? 'not ' : ''}a prime`, function () { + it('trying to inverse 0 returns 0', async function () { + expect(await this.mock.$invMod(0, p)).to.equal(0n); + expect(await this.mock.$invMod(p, p)).to.equal(0n); // p is 0 mod p + }); + + if (p != 0) { + for (const value of randomArray(generators.uint256, 16)) { + const isInversible = factors.every(f => value % f); + it(`trying to inverse ${value}`, async function () { + const result = await this.mock.$invMod(value, p); + if (isInversible) { + expect((value * result) % p).to.equal(1n); + } else { + expect(result).to.equal(0n); + } + }); + } + } + }); + } + }); + describe('sqrt', function () { it('rounds down', async function () { for (const rounding of RoundingDown) {