Skip to content

Commit

Permalink
Reduce usage of big.Rat
Browse files Browse the repository at this point in the history
  • Loading branch information
BrendanChou committed May 20, 2024
1 parent 722ec24 commit f004af6
Show file tree
Hide file tree
Showing 16 changed files with 221 additions and 115 deletions.
65 changes: 31 additions & 34 deletions protocol/lib/big_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,9 @@ func BigIntMulPpm(input *big.Int, ppm uint32) *big.Int {

// BigIntMulSignedPpm takes a `big.Int` and returns the result of `input * ppm / 1_000_000`.
func BigIntMulSignedPpm(input *big.Int, ppm int32, roundUp bool) *big.Int {
result := new(big.Rat)
result.Mul(
new(big.Rat).SetInt(input),
new(big.Rat).SetInt64(int64(ppm)),
)
result.Quo(result, BigRatOneMillion())
return BigRatRound(result, roundUp)
r := big.NewInt(int64(ppm))
r.Mul(r, input)
return BigIntDivRound(r, BigIntOneMillion(), roundUp)
}

// BigMin takes two `big.Int` as parameters and returns the smaller one.
Expand Down Expand Up @@ -137,22 +133,17 @@ func BigIntClamp(n *big.Int, lowerBound *big.Int, upperBound *big.Int) *big.Int
return bigGenericClamp(n, lowerBound, upperBound)
}

// BigRatRound takes an input and a direction to round (true for up, false for down).
// It returns the result rounded to a `*big.Int` in the specified direction.
func BigRatRound(n *big.Rat, roundUp bool) *big.Int {
numeratorBig := n.Num()
denominatorBig := n.Denom()
resultBig, remainderBig := new(big.Int).DivMod(numeratorBig, denominatorBig, new(big.Int))
// If the remainder is non-zero, then round up by adding 1.
// Note this works for negative numbers due to the following reasons:
// - In euclidean division, the remainder is always positive so the resulting division rounds
// down instead of towards zero.
// - The denominator of `big.Rat` is always positive. Therefore if `n` is negative, that means
// the numerator is negative.
if remainderBig.Sign() > 0 && roundUp {
resultBig.Add(resultBig, big.NewInt(1))
// BigIntDivRound takes a numerator, denominator, and a direction to round (true for up, false for down).
func BigIntDivRound(
numerator *big.Int,
denominator *big.Int,
roundUp bool,
) *big.Int {
result, remainder := new(big.Int).DivMod(numerator, denominator, new(big.Int))
if remainder.Sign() > 0 && roundUp {
result.Add(result, big.NewInt(1))
}
return resultBig
return result
}

// BigIntRoundToMultiple takes an input, a multiple, and a direction to round (true for up,
Expand Down Expand Up @@ -262,26 +253,32 @@ func warmCache() map[uint64]*big.Int {
return bigExponentValues
}

// BigRatRoundToNearestMultiple rounds `value` up/down to the nearest multiple of `base`.
// BigRoundToNearestMultiple rounds `value` up/down to the nearest multiple of `base`.
// Bounds the result between 0 and `math.MaxUint64`.
// Returns 0 if `base` is 0.
func BigRatRoundToNearestMultiple(
value *big.Rat,
func BigRoundToNearestMultiple(
value *big.Int,
base uint32,
up bool,
) uint64 {
// Special-case for zero.
if base == 0 {
return 0
}

quotient := new(big.Rat).Quo(
value,
new(big.Rat).SetUint64(uint64(base)),
)
quotientFloored := new(big.Int).Div(quotient.Num(), quotient.Denom())

if up && quotientFloored.Cmp(quotient.Num()) != 0 {
return (quotientFloored.Uint64() + 1) * uint64(base)
// Set up variables.
baseBig := new(big.Int).SetUint64(uint64(base))
result := new(big.Int).Set(value)
if up {
result.Add(result, new(big.Int).Sub(baseBig, big.NewInt(1)))
}

return quotientFloored.Uint64() * uint64(base)
// Clamp result to prevent overflow.
result = BigIntClamp(result, big.NewInt(0), new(big.Int).SetUint64(math.MaxUint64))

// Do the division, rounding down (since we added `base - 1` if we wanted to round up).
result.Div(result, baseBig)

// Multiply back in.
return result.Mul(result, baseBig).Uint64()
}
111 changes: 76 additions & 35 deletions protocol/lib/big_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,76 +533,89 @@ func TestBigIntClamp(t *testing.T) {
}
}

func TestBigRatRound(t *testing.T) {
func TestBigIntDivRound(t *testing.T) {
tests := map[string]struct {
input *big.Rat
num *big.Int
den *big.Int
roundUp bool
expectedResult *big.Int
}{
"Input is unrounded if it is zero and we round up": {
input: big.NewRat(0, 1),
num: big.NewInt(0),
den: big.NewInt(1),
roundUp: true,
expectedResult: big.NewInt(0),
},
"Input is unrounded if it is zero and we round down": {
input: big.NewRat(0, 1),
num: big.NewInt(0),
den: big.NewInt(1),
roundUp: false,
expectedResult: big.NewInt(0),
},
"Input is unrounded if it is an int and we round up": {
input: big.NewRat(2, 1),
num: big.NewInt(2),
den: big.NewInt(1),
roundUp: true,
expectedResult: big.NewInt(2),
},
"Input is unrounded if it is an int and we round down": {
input: big.NewRat(2, 1),
num: big.NewInt(2),
den: big.NewInt(1),
roundUp: false,
expectedResult: big.NewInt(2),
},
"Input is unrounded if it isn't normalized, it is an int and we round up": {
input: big.NewRat(21, 3),
num: big.NewInt(21),
den: big.NewInt(3),
roundUp: true,
expectedResult: big.NewInt(7),
},
"Input is unrounded if it isn't normalized, it is an int and we round down": {
input: big.NewRat(21, 3),
num: big.NewInt(21),
den: big.NewInt(3),
roundUp: false,
expectedResult: big.NewInt(7),
},
"Input is rounded up if we round up": {
input: big.NewRat(5, 4),
num: big.NewInt(5),
den: big.NewInt(4),
roundUp: true,
expectedResult: big.NewInt(2),
},
"Input is rounded up if it isn't normalized and we round up": {
input: big.NewRat(10, 4),
num: big.NewInt(10),
den: big.NewInt(4),
roundUp: true,
expectedResult: big.NewInt(3),
},
"Input is rounded down if rational number isn't normalized and we round down": {
input: big.NewRat(10, 4),
num: big.NewInt(10),
den: big.NewInt(4),
roundUp: false,
expectedResult: big.NewInt(2),
},
"Input is rounded down if we round down": {
input: big.NewRat(5, 4),
num: big.NewInt(5),
den: big.NewInt(4),
roundUp: false,
expectedResult: big.NewInt(1),
},
"Input is rounded down if input is negative and we round down": {
input: big.NewRat(-22, 7),
num: big.NewInt(-22),
den: big.NewInt(7),
roundUp: false,
expectedResult: big.NewInt(-4),
},
"Input is rounded up if input is negative and we round up": {
input: big.NewRat(-22, 7),
num: big.NewInt(-22),
den: big.NewInt(7),
roundUp: true,
expectedResult: big.NewInt(-3),
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result := lib.BigRatRound(tc.input, tc.roundUp)
result := lib.BigIntDivRound(tc.num, tc.den, tc.roundUp)
require.Equal(t, tc.expectedResult, result)
})
}
Expand Down Expand Up @@ -900,89 +913,117 @@ func TestMustConvertBigIntToInt32(t *testing.T) {
}
}

func TestBigRatRoundToNearestMultiple(t *testing.T) {
func TestBigRoundToNearestMultiple(t *testing.T) {
tests := map[string]struct {
value *big.Rat
value *big.Int
base uint32
up bool
expectedResult uint64
}{
"Round 5 down to a multiple of 2": {
value: big.NewRat(5, 1),
value: big.NewInt(5),
base: 2,
up: false,
expectedResult: 4,
},
"Round 5 up to a multiple of 2": {
value: big.NewRat(5, 1),
value: big.NewInt(5),
base: 2,
up: true,
expectedResult: 6,
},
"Round 7 down to a multiple of 14": {
value: big.NewRat(7, 1),
value: big.NewInt(7),
base: 14,
up: false,
expectedResult: 0,
},
"Round 7 up to a multiple of 14": {
value: big.NewRat(7, 1),
value: big.NewInt(7),
base: 14,
up: true,
expectedResult: 14,
},
"Round 123 down to a multiple of 123": {
value: big.NewRat(123, 1),
value: big.NewInt(123),
base: 123,
up: false,
expectedResult: 123,
},
"Round 123 up to a multiple of 123": {
value: big.NewRat(123, 1),
value: big.NewInt(123),
base: 123,
up: true,
expectedResult: 123,
},
"Round 100/6 down to a multiple of 3": {
value: big.NewRat(100, 6),
"Round 16 down to a multiple of 3": {
value: big.NewInt(16),
base: 3,
up: false,
expectedResult: 15,
},
"Round 100/6 up to a multiple of 3": {
value: big.NewRat(100, 6),
"Round 16 up to a multiple of 3": {
value: big.NewInt(16),
base: 3,
up: true,
expectedResult: 18,
},
"Round 7/2 down to a multiple of 1": {
value: big.NewRat(7, 2),
"Round -16 down to a multiple of 3, is clamped to zero": {
value: big.NewInt(-16),
base: 3,
up: false,
expectedResult: 0,
},
"Round -16 up to a multiple of 3, is clamped to zero": {
value: big.NewInt(-16),
base: 3,
up: true,
expectedResult: 0,
},
"Round 4 down to a multiple of 1": {
value: big.NewInt(4),
base: 1,
up: false,
expectedResult: 3,
expectedResult: 4,
},
"Round 7/2 up to a multiple of 1": {
value: big.NewRat(7, 2),
"Round 4 up to a multiple of 1": {
value: big.NewInt(4),
base: 1,
up: true,
expectedResult: 4,
},
"Round 10 down to a multiple of 0": {
value: big.NewRat(10, 1),
value: big.NewInt(10),
base: 0,
up: false,
expectedResult: 0,
},
"Round 10 up to a multiple of 0": {
value: big.NewRat(10, 1),
value: big.NewInt(10),
base: 0,
up: true,
expectedResult: 0,
},
"Check overflow is clamped when rounding down": {
value: big_testutil.MustFirst(
new(big.Int).SetString("99999999999999999999", 10),
),
base: 100,
up: false,
expectedResult: 18446744073709551600,
},
"Check overflow is clamped when rounding up": {
value: big_testutil.MustFirst(
new(big.Int).SetString("99999999999999999999", 10),
),
base: 100,
up: true,
expectedResult: 18446744073709551600,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result := lib.BigRatRoundToNearestMultiple(
result := lib.BigRoundToNearestMultiple(
tc.value,
tc.base,
tc.up,
Expand Down
7 changes: 5 additions & 2 deletions protocol/lib/quantums.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@ func QuoteToBaseQuantums(
RatPow10(exponent),
)

// Round down.
bigBaseQuantums := BigRatRound(ratBaseQuantums, false)
// Round down using Euclidean division.
bigBaseQuantums := new(big.Int).Div(
ratBaseQuantums.Num(),
ratBaseQuantums.Denom(),
)

// Flip the sign of the return value if necessary.
if !isLong {
Expand Down
6 changes: 5 additions & 1 deletion protocol/x/assets/keeper/asset.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,11 @@ func (k Keeper) ConvertAssetToCoin(
)

// round down to get denom amount that was converted.
bigConvertedDenomAmount := lib.BigRatRound(bigRatDenomAmount, false)
bigConvertedDenomAmount := lib.BigIntDivRound(
bigRatDenomAmount.Num(),
bigRatDenomAmount.Denom(),
false,
)

bigRatConvertedQuantums := lib.BigMulPow10(
bigConvertedDenomAmount,
Expand Down
7 changes: 6 additions & 1 deletion protocol/x/clob/keeper/clob_pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,12 @@ func (k Keeper) validateOrderAgainstClobPairStatus(
// the oracle price. For instance without this check a user could place an ask far below the oracle
// price, thereby preventing any bids at or above the specified price of the ask.
currentOraclePriceSubticksRat := k.GetOraclePriceSubticksRat(ctx, clobPair)
currentOraclePriceSubticks := lib.BigRatRound(currentOraclePriceSubticksRat, false).Uint64()
currentOraclePriceSubticks := lib.BigIntDivRound(
currentOraclePriceSubticksRat.Num(),
currentOraclePriceSubticksRat.Denom(),
false,
).Uint64()

// Throw error if order is a buy and order subticks is greater than oracle price subticks
if order.IsBuy() && order.Subticks > currentOraclePriceSubticks {
return errorsmod.Wrapf(
Expand Down
6 changes: 5 additions & 1 deletion protocol/x/clob/keeper/liquidations.go
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,11 @@ func (k Keeper) ConvertFillablePriceToSubticks(
roundUp := isLiquidatingLong

// Round the subticks to the nearest `big.Int` in the correct direction.
roundedSubticksBig := lib.BigRatRound(subticksRat, roundUp)
roundedSubticksBig := lib.BigIntDivRound(
subticksRat.Num(),
subticksRat.Denom(),
roundUp,
)

// Ensure `roundedSubticksBig % clobPair.SubticksPerTick == 0`, rounding in the correct
// direction if necessary.
Expand Down
Loading

0 comments on commit f004af6

Please sign in to comment.