Skip to content

Commit

Permalink
Exposed the random sources, so the user can customize according to th…
Browse files Browse the repository at this point in the history
…eir needs.

Signed-off-by: Guilherme Balena Versiani <[email protected]>
  • Loading branch information
Guilherme Versiani authored and balena committed Oct 2, 2023
1 parent b8d526d commit 8abf1d5
Show file tree
Hide file tree
Showing 49 changed files with 368 additions and 306 deletions.
7 changes: 4 additions & 3 deletions common/hash_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package common_test

import (
"crypto/rand"
"math/big"
"reflect"
"testing"
Expand All @@ -15,12 +16,12 @@ import (
)

func TestRejectionSample(t *testing.T) {
curveQ := common.GetRandomPrimeInt(256)
randomQ := common.MustGetRandomInt(64)
curveQ := common.GetRandomPrimeInt(rand.Reader, 256)
randomQ := common.MustGetRandomInt(rand.Reader, 64)
hash := common.SHA512_256iOne(big.NewInt(123))
rs1 := common.RejectionSample(curveQ, hash)
rs2 := common.RejectionSample(randomQ, hash)
rs3 := common.RejectionSample(common.MustGetRandomInt(64), hash)
rs3 := common.RejectionSample(common.MustGetRandomInt(rand.Reader, 64), hash)
type args struct {
q *big.Int
eHash *big.Int
Expand Down
33 changes: 17 additions & 16 deletions common/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
package common

import (
"crypto/rand"
cryptorand "crypto/rand"
"fmt"
"io"
"math/big"

"github.com/pkg/errors"
Expand All @@ -18,8 +19,8 @@ const (
mustGetRandomIntMaxBits = 5000
)

// MustGetRandomInt panics if it is unable to gather entropy from `rand.Reader` or when `bits` is <= 0
func MustGetRandomInt(bits int) *big.Int {
// MustGetRandomInt panics if it is unable to gather entropy from `io.Reader` or when `bits` is <= 0
func MustGetRandomInt(rand io.Reader, bits int) *big.Int {
if bits <= 0 || mustGetRandomIntMaxBits < bits {
panic(fmt.Errorf("MustGetRandomInt: bits should be positive, non-zero and less than %d", mustGetRandomIntMaxBits))
}
Expand All @@ -28,37 +29,37 @@ func MustGetRandomInt(bits int) *big.Int {
max = max.Exp(two, big.NewInt(int64(bits)), nil).Sub(max, one)

// Generate cryptographically strong pseudo-random int between 0 - max
n, err := rand.Int(rand.Reader, max)
n, err := cryptorand.Int(rand, max)
if err != nil {
panic(errors.Wrap(err, "rand.Int failure in MustGetRandomInt!"))
}
return n
}

func GetRandomPositiveInt(lessThan *big.Int) *big.Int {
func GetRandomPositiveInt(rand io.Reader, lessThan *big.Int) *big.Int {
if lessThan == nil || zero.Cmp(lessThan) != -1 {
return nil
}
var try *big.Int
for {
try = MustGetRandomInt(lessThan.BitLen())
try = MustGetRandomInt(rand, lessThan.BitLen())
if try.Cmp(lessThan) < 0 && try.Cmp(zero) >= 0 {
break
}
}
return try
}

func GetRandomPrimeInt(bits int) *big.Int {
func GetRandomPrimeInt(rand io.Reader, bits int) *big.Int {
if bits <= 0 {
return nil
}
try, err := rand.Prime(rand.Reader, bits)
try, err := cryptorand.Prime(rand, bits)
if err != nil ||
try.Cmp(zero) == 0 {
// fallback to older method
for {
try = MustGetRandomInt(bits)
try = MustGetRandomInt(rand, bits)
if probablyPrime(try) {
break
}
Expand All @@ -69,13 +70,13 @@ func GetRandomPrimeInt(bits int) *big.Int {

// Generate a random element in the group of all the elements in Z/nZ that
// has a multiplicative inverse.
func GetRandomPositiveRelativelyPrimeInt(n *big.Int) *big.Int {
func GetRandomPositiveRelativelyPrimeInt(rand io.Reader, n *big.Int) *big.Int {
if n == nil || zero.Cmp(n) != -1 {
return nil
}
var try *big.Int
for {
try = MustGetRandomInt(n.BitLen())
try = MustGetRandomInt(rand, n.BitLen())
if IsNumberInMultiplicativeGroup(n, try) {
break
}
Expand All @@ -96,24 +97,24 @@ func IsNumberInMultiplicativeGroup(n, v *big.Int) bool {
// THIS METHOD ONLY WORKS IF N IS THE PRODUCT OF TWO SAFE PRIMES!
//
// https://github.com/didiercrunch/paillier/blob/d03e8850a8e4c53d04e8016a2ce8762af3278b71/utils.go#L39
func GetRandomGeneratorOfTheQuadraticResidue(n *big.Int) *big.Int {
f := GetRandomPositiveRelativelyPrimeInt(n)
func GetRandomGeneratorOfTheQuadraticResidue(rand io.Reader, n *big.Int) *big.Int {
f := GetRandomPositiveRelativelyPrimeInt(rand, n)
fSq := new(big.Int).Mul(f, f)
return fSq.Mod(fSq, n)
}

// GetRandomQuadraticNonResidue returns a quadratic non residue of odd n.
func GetRandomQuadraticNonResidue(n *big.Int) *big.Int {
func GetRandomQuadraticNonResidue(rand io.Reader, n *big.Int) *big.Int {
for {
w := GetRandomPositiveInt(n)
w := GetRandomPositiveInt(rand, n)
if big.Jacobi(w, n) == -1 {
return w
}
}
}

// GetRandomBytes returns random bytes of length.
func GetRandomBytes(length int) ([]byte, error) {
func GetRandomBytes(rand io.Reader, length int) ([]byte, error) {
// Per [BIP32], the seed must be in range [MinSeedBytes, MaxSeedBytes].
if length <= 0 {
return nil, errors.New("invalid length")
Expand Down
13 changes: 7 additions & 6 deletions common/random_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package common_test

import (
"crypto/rand"
"math/big"
"testing"

Expand All @@ -20,28 +21,28 @@ const (
)

func TestGetRandomInt(t *testing.T) {
rnd := common.MustGetRandomInt(randomIntBitLen)
rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen)
assert.NotZero(t, rnd, "rand int should not be zero")
}

func TestGetRandomPositiveInt(t *testing.T) {
rnd := common.MustGetRandomInt(randomIntBitLen)
rndPos := common.GetRandomPositiveInt(rnd)
rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen)
rndPos := common.GetRandomPositiveInt(rand.Reader, rnd)
assert.NotZero(t, rndPos, "rand int should not be zero")
assert.True(t, rndPos.Cmp(big.NewInt(0)) == 1, "rand int should be positive")
}

func TestGetRandomPositiveRelativelyPrimeInt(t *testing.T) {
rnd := common.MustGetRandomInt(randomIntBitLen)
rndPosRP := common.GetRandomPositiveRelativelyPrimeInt(rnd)
rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen)
rndPosRP := common.GetRandomPositiveRelativelyPrimeInt(rand.Reader, rnd)
assert.NotZero(t, rndPosRP, "rand int should not be zero")
assert.True(t, common.IsNumberInMultiplicativeGroup(rnd, rndPosRP))
assert.True(t, rndPosRP.Cmp(big.NewInt(0)) == 1, "rand int should be positive")
// TODO test for relative primeness
}

func TestGetRandomPrimeInt(t *testing.T) {
prime := common.GetRandomPrimeInt(randomIntBitLen)
prime := common.GetRandomPrimeInt(rand.Reader, randomIntBitLen)
assert.NotZero(t, prime, "rand prime should not be zero")
assert.True(t, prime.ProbablyPrime(50), "rand prime should be prime")
}
63 changes: 31 additions & 32 deletions common/safe_prime.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package common

import (
"context"
"crypto/rand"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -125,7 +124,7 @@ var ErrGeneratorCancelled = fmt.Errorf("generator work cancelled")
// This function generates safe primes of at least 6 `bitLen`. For every
// generated safe prime, the two most significant bits are always set to `1`
// - we don't want the generated number to be too small.
func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, concurrency int) ([]*GermainSafePrime, error) {
func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, concurrency int, rand io.Reader) ([]*GermainSafePrime, error) {
if bitLen < 6 {
return nil, errors.New("safe prime size must be at least 6 bits")
}
Expand All @@ -149,7 +148,7 @@ func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, c
for i := 0; i < concurrency; i++ {
waitGroup.Add(1)
runGenPrimeRoutine(
generatorCtx, primeCh, errCh, waitGroup, rand.Reader, bitLen,
generatorCtx, primeCh, errCh, waitGroup, rand, bitLen,
)
}

Expand All @@ -175,35 +174,35 @@ func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, c
// a bit length equal to `pBitLen-1`.
//
// The algorithm is as follows:
// 1. Generate a random odd number `q` of length `pBitLen-1` with two the most
// significant bits set to `1`.
// 2. Execute preliminary primality test on `q` checking whether it is coprime
// to all the elements of `smallPrimes`. It allows to eliminate trivial
// cases quickly, when `q` is obviously no prime, without running an
// expensive final primality tests.
// If `q` is coprime to all of the `smallPrimes`, then go to the point 3.
// If not, add `2` and try again. Do it at most 10 times.
// 3. Check the potentially prime `q`, whether `q = 1 (mod 3)`. This will
// happen for 50% of cases.
// If it is, then `p = 2q+1` will be a multiple of 3, so it will be obviously
// not a prime number. In this case, add `2` and try again. Do it at most 10
// times. If `q != 1 (mod 3)`, go to the point 4.
// 4. Now we know `q` is potentially prime and `p = 2q+1` is not a multiple of
// 3. We execute a preliminary primality test on `p`, checking whether
// it is coprime to all the elements of `smallPrimes` just like we did for
// `q` in point 2. If `p` is not coprime to at least one element of the
// `smallPrimes`, then go back to point 1.
// If `p` is coprime to all the elements of `smallPrimes`, go to point 5.
// 5. At this point, we know `q` is potentially prime, and `p=q+1` is also
// potentially prime. We need to execute a final primality test for `q`.
// We apply Miller-Rabin and Baillie-PSW tests. If they succeed, it means
// that `q` is prime with a very high probability. Knowing `q` is prime,
// we use Pocklington's criterion to prove the primality of `p=2q+1`, that
// is, we execute Fermat primality test to base 2 checking whether
// `2^{p-1} = 1 (mod p)`. It's significantly faster than running full
// Miller-Rabin and Baillie-PSW for `p`.
// If `q` and `p` are found to be prime, return them as a result. If not, go
// back to the point 1.
// 1. Generate a random odd number `q` of length `pBitLen-1` with two the most
// significant bits set to `1`.
// 2. Execute preliminary primality test on `q` checking whether it is coprime
// to all the elements of `smallPrimes`. It allows to eliminate trivial
// cases quickly, when `q` is obviously no prime, without running an
// expensive final primality tests.
// If `q` is coprime to all of the `smallPrimes`, then go to the point 3.
// If not, add `2` and try again. Do it at most 10 times.
// 3. Check the potentially prime `q`, whether `q = 1 (mod 3)`. This will
// happen for 50% of cases.
// If it is, then `p = 2q+1` will be a multiple of 3, so it will be obviously
// not a prime number. In this case, add `2` and try again. Do it at most 10
// times. If `q != 1 (mod 3)`, go to the point 4.
// 4. Now we know `q` is potentially prime and `p = 2q+1` is not a multiple of
// 3. We execute a preliminary primality test on `p`, checking whether
// it is coprime to all the elements of `smallPrimes` just like we did for
// `q` in point 2. If `p` is not coprime to at least one element of the
// `smallPrimes`, then go back to point 1.
// If `p` is coprime to all the elements of `smallPrimes`, go to point 5.
// 5. At this point, we know `q` is potentially prime, and `p=q+1` is also
// potentially prime. We need to execute a final primality test for `q`.
// We apply Miller-Rabin and Baillie-PSW tests. If they succeed, it means
// that `q` is prime with a very high probability. Knowing `q` is prime,
// we use Pocklington's criterion to prove the primality of `p=2q+1`, that
// is, we execute Fermat primality test to base 2 checking whether
// `2^{p-1} = 1 (mod p)`. It's significantly faster than running full
// Miller-Rabin and Baillie-PSW for `p`.
// If `q` and `p` are found to be prime, return them as a result. If not, go
// back to the point 1.
func runGenPrimeRoutine(
ctx context.Context,
primeCh chan<- *GermainSafePrime,
Expand Down
3 changes: 2 additions & 1 deletion common/safe_prime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package common

import (
"context"
"crypto/rand"
"math/big"
"runtime"
"testing"
Expand Down Expand Up @@ -45,7 +46,7 @@ func Test_Validate_Bad(t *testing.T) {
func TestGetRandomGermainPrimeConcurrent(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
defer cancel()
sgps, err := GetRandomSafePrimesConcurrent(ctx, 1024, 2, runtime.NumCPU())
sgps, err := GetRandomSafePrimesConcurrent(ctx, 1024, 2, runtime.NumCPU(), rand.Reader)
assert.NoError(t, err)
assert.Equal(t, 2, len(sgps))
for _, sgp := range sgps {
Expand Down
5 changes: 3 additions & 2 deletions crypto/commitments/commitment.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package commitments

import (
"io"
"math/big"

"github.com/bnb-chain/tss-lib/v2/common"
Expand Down Expand Up @@ -43,8 +44,8 @@ func NewHashCommitmentWithRandomness(r *big.Int, secrets ...*big.Int) *HashCommi
return cmt
}

func NewHashCommitment(secrets ...*big.Int) *HashCommitDecommit {
r := common.MustGetRandomInt(HashLength) // r
func NewHashCommitment(rand io.Reader, secrets ...*big.Int) *HashCommitDecommit {
r := common.MustGetRandomInt(rand, HashLength) // r
return NewHashCommitmentWithRandomness(r, secrets...)
}

Expand Down
5 changes: 3 additions & 2 deletions crypto/commitments/commitment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package commitments_test

import (
"crypto/rand"
"math/big"
"testing"

Expand All @@ -19,7 +20,7 @@ func TestCreateVerify(t *testing.T) {
one := big.NewInt(1)
zero := big.NewInt(0)

commitment := NewHashCommitment(zero, one)
commitment := NewHashCommitment(rand.Reader, zero, one)
pass := commitment.Verify()

assert.True(t, pass, "must pass")
Expand All @@ -29,7 +30,7 @@ func TestDeCommit(t *testing.T) {
one := big.NewInt(1)
zero := big.NewInt(0)

commitment := NewHashCommitment(zero, one)
commitment := NewHashCommitment(rand.Reader, zero, one)
pass, secrets := commitment.DeCommit()

assert.True(t, pass, "must pass")
Expand Down
9 changes: 4 additions & 5 deletions crypto/dlnproof/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package dlnproof

import (
"fmt"
"io"
"math/big"

"github.com/bnb-chain/tss-lib/v2/common"
Expand All @@ -28,17 +29,15 @@ type (
}
)

var (
one = big.NewInt(1)
)
var one = big.NewInt(1)

func NewDLNProof(h1, h2, x, p, q, N *big.Int) *Proof {
func NewDLNProof(h1, h2, x, p, q, N *big.Int, rand io.Reader) *Proof {
pMulQ := new(big.Int).Mul(p, q)
modN, modPQ := common.ModInt(N), common.ModInt(pMulQ)
a := make([]*big.Int, Iterations)
alpha := [Iterations]*big.Int{}
for i := range alpha {
a[i] = common.GetRandomPositiveInt(pMulQ)
a[i] = common.GetRandomPositiveInt(rand, pMulQ)
alpha[i] = modN.Exp(h1, a[i])
}
msg := append([]*big.Int{h1, h2, N}, alpha[:]...)
Expand Down
Loading

0 comments on commit 8abf1d5

Please sign in to comment.