Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional big_math helper functions #1563

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions protocol/lib/big_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@ import (
"math/big"
)

// BigU returns a new big.Int from the input unsigned integer.
func BigU[T uint | uint32 | uint64](u T) *big.Int {
return new(big.Int).SetUint64(uint64(u))
}

// BigI returns a new big.Int from the input signed integer.
func BigI[T int | int32 | int64](i T) *big.Int {
return big.NewInt(int64(i))
}

// BigMulPpm returns the result of `val * ppm / 1_000_000`, rounding in the direction indicated.
func BigMulPpm(val *big.Int, ppm *big.Int, roundUp bool) *big.Int {
result := new(big.Int).Mul(val, ppm)
oneMillion := BigIntOneMillion()
if roundUp {
return BigDivCeil(result, oneMillion)
} else {
return result.Div(result, oneMillion)
}
}

// BigMulPow10 returns the result of `val * 10^exponent`, in *big.Rat.
func BigMulPow10(
val *big.Int,
Expand Down Expand Up @@ -137,6 +158,19 @@ func BigIntClamp(n *big.Int, lowerBound *big.Int, upperBound *big.Int) *big.Int
return bigGenericClamp(n, lowerBound, upperBound)
}

// BigDivCeil returns the ceiling of `a / b`.
func BigDivCeil(a *big.Int, b *big.Int) *big.Int {
result, remainder := new(big.Int).QuoRem(a, b, new(big.Int))

// If the value was rounded (i.e. there is a remainder), and the exact result would be positive,
// then add 1 to the result.
if remainder.Sign() != 0 && (a.Sign() == b.Sign()) {
result.Add(result, big.NewInt(1))
}

return result
}

// BigRatRound takes an input and a direction to round (true for up, false for down).
// It returns the result rounded to a `*big.Int` in the specified direction.
func BigRatRound(n *big.Rat, roundUp bool) *big.Int {
Expand Down
215 changes: 215 additions & 0 deletions protocol/lib/big_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,138 @@ import (
"github.com/stretchr/testify/require"
)

func BenchmarkBigI(b *testing.B) {
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigI(int64(i))
}
require.Equal(b, result, result)
}
BrendanChou marked this conversation as resolved.
Show resolved Hide resolved

func BenchmarkBigU(b *testing.B) {
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigU(uint32(i))
}
require.Equal(b, result, result)
}
BrendanChou marked this conversation as resolved.
Show resolved Hide resolved

func TestBigI(t *testing.T) {
require.Equal(t, big.NewInt(-123), lib.BigI(int(-123)))
require.Equal(t, big.NewInt(-123), lib.BigI(int32(-123)))
require.Equal(t, big.NewInt(-123), lib.BigI(int64(-123)))
require.Equal(t, big.NewInt(math.MaxInt64), lib.BigI(math.MaxInt64))
}

func TestBigU(t *testing.T) {
require.Equal(t, big.NewInt(123), lib.BigU(uint(123)))
require.Equal(t, big.NewInt(123), lib.BigU(uint32(123)))
require.Equal(t, big.NewInt(123), lib.BigU(uint64(123)))
require.Equal(t, new(big.Int).SetUint64(math.MaxUint64), lib.BigU(uint64(math.MaxUint64)))
}

func BenchmarkBigMulPpm_RoundDown(b *testing.B) {
val := big.NewInt(543_211)
ppm := big.NewInt(876_543)
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigMulPpm(val, ppm, false)
}
require.Equal(b, big.NewInt(476147), result)
}

func BenchmarkBigMulPpm_RoundUp(b *testing.B) {
val := big.NewInt(543_211)
ppm := big.NewInt(876_543)
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigMulPpm(val, ppm, true)
}
require.Equal(b, big.NewInt(476148), result)
}

func TestBigMulPpm(t *testing.T) {
tests := map[string]struct {
val *big.Int
ppm *big.Int
roundUp bool
expectedResult *big.Int
}{
"Positive round down": {
val: big.NewInt(543_211),
ppm: big.NewInt(876_543),
roundUp: false,
expectedResult: big.NewInt(476147),
},
"Negative round down": {
val: big.NewInt(-543_211),
ppm: big.NewInt(876_543),
roundUp: false,
expectedResult: big.NewInt(-476148),
},
"Positive round up": {
val: big.NewInt(543_211),
ppm: big.NewInt(876_543),
roundUp: true,
expectedResult: big.NewInt(476148),
},
"Negative round up": {
val: big.NewInt(-543_211),
ppm: big.NewInt(876_543),
roundUp: true,
expectedResult: big.NewInt(-476147),
},
"Zero val": {
val: big.NewInt(0),
ppm: big.NewInt(876_543),
roundUp: true,
expectedResult: big.NewInt(0),
},
"Zero ppm": {
val: big.NewInt(543_211),
ppm: big.NewInt(0),
roundUp: true,
expectedResult: big.NewInt(0),
},
"Zero val and ppm": {
val: big.NewInt(0),
ppm: big.NewInt(0),
roundUp: true,
expectedResult: big.NewInt(0),
},
"Negative val": {
val: big.NewInt(-543_211),
ppm: big.NewInt(876_543),
roundUp: true,
expectedResult: big.NewInt(-476147),
},
"Negative ppm": {
val: big.NewInt(543_211),
ppm: big.NewInt(-876_543),
roundUp: true,
expectedResult: big.NewInt(-476147),
},
"Negative val and ppm": {
val: big.NewInt(-543_211),
ppm: big.NewInt(-876_543),
roundUp: true,
expectedResult: big.NewInt(476148),
},
"Greater than max int64": {
val: big_testutil.MustFirst(new(big.Int).SetString("1000000000000000000000000", 10)),
ppm: big.NewInt(10_000),
roundUp: true,
expectedResult: big_testutil.MustFirst(new(big.Int).SetString("10000000000000000000000", 10)),
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result := lib.BigMulPpm(tc.val, tc.ppm, tc.roundUp)
require.Equal(t, tc.expectedResult, result)
})
}
}

func TestBigPow10(t *testing.T) {
tests := map[string]struct {
exponent uint64
Expand Down Expand Up @@ -533,6 +665,89 @@ func TestBigIntClamp(t *testing.T) {
}
}

func BenchmarkBigDivCeil(b *testing.B) {
numerator := big.NewInt(10)
denominator := big.NewInt(3)
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigDivCeil(numerator, denominator)
}
require.Equal(b, big.NewInt(4), result)
}

func TestBigDivCeil(t *testing.T) {
tests := map[string]struct {
numerator *big.Int
denominator *big.Int
expectedResult *big.Int
}{
"Divides evenly": {
numerator: big.NewInt(10),
denominator: big.NewInt(5),
expectedResult: big.NewInt(2),
},
"Doesn't divide evenly": {
numerator: big.NewInt(10),
denominator: big.NewInt(3),
expectedResult: big.NewInt(4),
},
"Negative numerator": {
numerator: big.NewInt(-10),
denominator: big.NewInt(3),
expectedResult: big.NewInt(-3),
},
"Negative numerator 2": {
numerator: big.NewInt(-1),
denominator: big.NewInt(2),
expectedResult: big.NewInt(0),
},
"Negative denominator": {
numerator: big.NewInt(10),
denominator: big.NewInt(-3),
expectedResult: big.NewInt(-3),
},
"Negative denominator 2": {
numerator: big.NewInt(1),
denominator: big.NewInt(-2),
expectedResult: big.NewInt(0),
},
"Negative numerator and denominator": {
numerator: big.NewInt(-10),
denominator: big.NewInt(-3),
expectedResult: big.NewInt(4),
},
"Negative numerator and denominator 2": {
numerator: big.NewInt(-1),
denominator: big.NewInt(-2),
expectedResult: big.NewInt(1),
},
"Zero numerator": {
numerator: big.NewInt(0),
denominator: big.NewInt(3),
expectedResult: big.NewInt(0),
},
"Zero denominator": {
numerator: big.NewInt(10),
denominator: big.NewInt(0),
expectedResult: nil,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
// Panics if the expected result is nil
if tc.expectedResult == nil {
require.Panics(t, func() {
lib.BigDivCeil(tc.numerator, tc.denominator)
})
return
}
// Otherwise test the result
result := lib.BigDivCeil(tc.numerator, tc.denominator)
require.Equal(t, tc.expectedResult, result)
})
}
}

func TestBigRatRound(t *testing.T) {
tests := map[string]struct {
input *big.Rat
Expand Down
Loading