From 8abf1d5952761c12988e45ee272af9db947f0702 Mon Sep 17 00:00:00 2001 From: Guilherme Versiani Date: Fri, 29 Sep 2023 21:38:10 +0000 Subject: [PATCH] Exposed the random sources, so the user can customize according to their needs. Signed-off-by: Guilherme Balena Versiani --- common/hash_utils_test.go | 7 +-- common/random.go | 33 +++++++------- common/random_test.go | 13 +++--- common/safe_prime.go | 63 +++++++++++++-------------- common/safe_prime_test.go | 3 +- crypto/commitments/commitment.go | 5 ++- crypto/commitments/commitment_test.go | 5 ++- crypto/dlnproof/proof.go | 9 ++-- crypto/facproof/proof.go | 19 ++++---- crypto/facproof/proof_test.go | 21 +++++---- crypto/modproof/proof.go | 9 ++-- crypto/modproof/proof_test.go | 7 ++- crypto/mta/proofs.go | 21 ++++----- crypto/mta/range_proof.go | 11 ++--- crypto/mta/range_proof_test.go | 47 ++++++++++---------- crypto/mta/share_protocol.go | 20 +++++---- crypto/mta/share_protocol_test.go | 25 +++++------ crypto/paillier/paillier.go | 13 +++--- crypto/paillier/paillier_test.go | 33 +++++++------- crypto/schnorr/schnorr_proof.go | 9 ++-- crypto/schnorr/schnorr_proof_test.go | 45 ++++++++++--------- crypto/utils.go | 7 +-- crypto/vss/feldman_vss.go | 9 ++-- crypto/vss/feldman_vss_test.go | 23 +++++----- ecdsa/keygen/dln_verifier_test.go | 3 ++ ecdsa/keygen/local_party.go | 6 ++- ecdsa/keygen/prepare.go | 18 ++++++-- ecdsa/keygen/round_1.go | 31 +++++++------ ecdsa/keygen/round_2.go | 16 +++---- ecdsa/resharing/round_1_old_step_1.go | 7 +-- ecdsa/resharing/round_2_new_step_1.go | 13 +++--- ecdsa/resharing/round_4_new_step_2.go | 11 ++--- ecdsa/signing/local_party.go | 9 ++-- ecdsa/signing/local_party_test.go | 3 +- ecdsa/signing/round_1.go | 15 +++---- ecdsa/signing/round_2.go | 8 +++- ecdsa/signing/round_4.go | 2 +- ecdsa/signing/round_5.go | 6 +-- ecdsa/signing/round_6.go | 4 +- ecdsa/signing/round_7.go | 2 +- eddsa/keygen/round_1.go | 13 +++--- eddsa/keygen/round_2.go | 2 +- eddsa/resharing/round_1_old_step_1.go | 7 +-- eddsa/signing/round_1.go | 7 +-- eddsa/signing/round_2.go | 2 +- eddsa/signing/round_3.go | 2 +- eddsa/signing/utils.go | 5 ++- tss/params.go | 22 ++++++++++ tss/party_id.go | 3 +- 49 files changed, 368 insertions(+), 306 deletions(-) diff --git a/common/hash_utils_test.go b/common/hash_utils_test.go index c0fdce2c..b89310e0 100644 --- a/common/hash_utils_test.go +++ b/common/hash_utils_test.go @@ -7,6 +7,7 @@ package common_test import ( + "crypto/rand" "math/big" "reflect" "testing" @@ -15,12 +16,12 @@ import ( ) func TestRejectionSample(t *testing.T) { - curveQ := common.GetRandomPrimeInt(256) - randomQ := common.MustGetRandomInt(64) + curveQ := common.GetRandomPrimeInt(rand.Reader, 256) + randomQ := common.MustGetRandomInt(rand.Reader, 64) hash := common.SHA512_256iOne(big.NewInt(123)) rs1 := common.RejectionSample(curveQ, hash) rs2 := common.RejectionSample(randomQ, hash) - rs3 := common.RejectionSample(common.MustGetRandomInt(64), hash) + rs3 := common.RejectionSample(common.MustGetRandomInt(rand.Reader, 64), hash) type args struct { q *big.Int eHash *big.Int diff --git a/common/random.go b/common/random.go index 333cf549..9396d67d 100644 --- a/common/random.go +++ b/common/random.go @@ -7,8 +7,9 @@ package common import ( - "crypto/rand" + cryptorand "crypto/rand" "fmt" + "io" "math/big" "github.com/pkg/errors" @@ -18,8 +19,8 @@ const ( mustGetRandomIntMaxBits = 5000 ) -// MustGetRandomInt panics if it is unable to gather entropy from `rand.Reader` or when `bits` is <= 0 -func MustGetRandomInt(bits int) *big.Int { +// MustGetRandomInt panics if it is unable to gather entropy from `io.Reader` or when `bits` is <= 0 +func MustGetRandomInt(rand io.Reader, bits int) *big.Int { if bits <= 0 || mustGetRandomIntMaxBits < bits { panic(fmt.Errorf("MustGetRandomInt: bits should be positive, non-zero and less than %d", mustGetRandomIntMaxBits)) } @@ -28,20 +29,20 @@ func MustGetRandomInt(bits int) *big.Int { max = max.Exp(two, big.NewInt(int64(bits)), nil).Sub(max, one) // Generate cryptographically strong pseudo-random int between 0 - max - n, err := rand.Int(rand.Reader, max) + n, err := cryptorand.Int(rand, max) if err != nil { panic(errors.Wrap(err, "rand.Int failure in MustGetRandomInt!")) } return n } -func GetRandomPositiveInt(lessThan *big.Int) *big.Int { +func GetRandomPositiveInt(rand io.Reader, lessThan *big.Int) *big.Int { if lessThan == nil || zero.Cmp(lessThan) != -1 { return nil } var try *big.Int for { - try = MustGetRandomInt(lessThan.BitLen()) + try = MustGetRandomInt(rand, lessThan.BitLen()) if try.Cmp(lessThan) < 0 && try.Cmp(zero) >= 0 { break } @@ -49,16 +50,16 @@ func GetRandomPositiveInt(lessThan *big.Int) *big.Int { return try } -func GetRandomPrimeInt(bits int) *big.Int { +func GetRandomPrimeInt(rand io.Reader, bits int) *big.Int { if bits <= 0 { return nil } - try, err := rand.Prime(rand.Reader, bits) + try, err := cryptorand.Prime(rand, bits) if err != nil || try.Cmp(zero) == 0 { // fallback to older method for { - try = MustGetRandomInt(bits) + try = MustGetRandomInt(rand, bits) if probablyPrime(try) { break } @@ -69,13 +70,13 @@ func GetRandomPrimeInt(bits int) *big.Int { // Generate a random element in the group of all the elements in Z/nZ that // has a multiplicative inverse. -func GetRandomPositiveRelativelyPrimeInt(n *big.Int) *big.Int { +func GetRandomPositiveRelativelyPrimeInt(rand io.Reader, n *big.Int) *big.Int { if n == nil || zero.Cmp(n) != -1 { return nil } var try *big.Int for { - try = MustGetRandomInt(n.BitLen()) + try = MustGetRandomInt(rand, n.BitLen()) if IsNumberInMultiplicativeGroup(n, try) { break } @@ -96,16 +97,16 @@ func IsNumberInMultiplicativeGroup(n, v *big.Int) bool { // THIS METHOD ONLY WORKS IF N IS THE PRODUCT OF TWO SAFE PRIMES! // // https://github.com/didiercrunch/paillier/blob/d03e8850a8e4c53d04e8016a2ce8762af3278b71/utils.go#L39 -func GetRandomGeneratorOfTheQuadraticResidue(n *big.Int) *big.Int { - f := GetRandomPositiveRelativelyPrimeInt(n) +func GetRandomGeneratorOfTheQuadraticResidue(rand io.Reader, n *big.Int) *big.Int { + f := GetRandomPositiveRelativelyPrimeInt(rand, n) fSq := new(big.Int).Mul(f, f) return fSq.Mod(fSq, n) } // GetRandomQuadraticNonResidue returns a quadratic non residue of odd n. -func GetRandomQuadraticNonResidue(n *big.Int) *big.Int { +func GetRandomQuadraticNonResidue(rand io.Reader, n *big.Int) *big.Int { for { - w := GetRandomPositiveInt(n) + w := GetRandomPositiveInt(rand, n) if big.Jacobi(w, n) == -1 { return w } @@ -113,7 +114,7 @@ func GetRandomQuadraticNonResidue(n *big.Int) *big.Int { } // GetRandomBytes returns random bytes of length. -func GetRandomBytes(length int) ([]byte, error) { +func GetRandomBytes(rand io.Reader, length int) ([]byte, error) { // Per [BIP32], the seed must be in range [MinSeedBytes, MaxSeedBytes]. if length <= 0 { return nil, errors.New("invalid length") diff --git a/common/random_test.go b/common/random_test.go index b7ca5f95..ac898c8b 100644 --- a/common/random_test.go +++ b/common/random_test.go @@ -7,6 +7,7 @@ package common_test import ( + "crypto/rand" "math/big" "testing" @@ -20,20 +21,20 @@ const ( ) func TestGetRandomInt(t *testing.T) { - rnd := common.MustGetRandomInt(randomIntBitLen) + rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen) assert.NotZero(t, rnd, "rand int should not be zero") } func TestGetRandomPositiveInt(t *testing.T) { - rnd := common.MustGetRandomInt(randomIntBitLen) - rndPos := common.GetRandomPositiveInt(rnd) + rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen) + rndPos := common.GetRandomPositiveInt(rand.Reader, rnd) assert.NotZero(t, rndPos, "rand int should not be zero") assert.True(t, rndPos.Cmp(big.NewInt(0)) == 1, "rand int should be positive") } func TestGetRandomPositiveRelativelyPrimeInt(t *testing.T) { - rnd := common.MustGetRandomInt(randomIntBitLen) - rndPosRP := common.GetRandomPositiveRelativelyPrimeInt(rnd) + rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen) + rndPosRP := common.GetRandomPositiveRelativelyPrimeInt(rand.Reader, rnd) assert.NotZero(t, rndPosRP, "rand int should not be zero") assert.True(t, common.IsNumberInMultiplicativeGroup(rnd, rndPosRP)) assert.True(t, rndPosRP.Cmp(big.NewInt(0)) == 1, "rand int should be positive") @@ -41,7 +42,7 @@ func TestGetRandomPositiveRelativelyPrimeInt(t *testing.T) { } func TestGetRandomPrimeInt(t *testing.T) { - prime := common.GetRandomPrimeInt(randomIntBitLen) + prime := common.GetRandomPrimeInt(rand.Reader, randomIntBitLen) assert.NotZero(t, prime, "rand prime should not be zero") assert.True(t, prime.ProbablyPrime(50), "rand prime should be prime") } diff --git a/common/safe_prime.go b/common/safe_prime.go index fad2d1f1..51b8adcf 100644 --- a/common/safe_prime.go +++ b/common/safe_prime.go @@ -8,7 +8,6 @@ package common import ( "context" - "crypto/rand" "errors" "fmt" "io" @@ -125,7 +124,7 @@ var ErrGeneratorCancelled = fmt.Errorf("generator work cancelled") // This function generates safe primes of at least 6 `bitLen`. For every // generated safe prime, the two most significant bits are always set to `1` // - we don't want the generated number to be too small. -func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, concurrency int) ([]*GermainSafePrime, error) { +func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, concurrency int, rand io.Reader) ([]*GermainSafePrime, error) { if bitLen < 6 { return nil, errors.New("safe prime size must be at least 6 bits") } @@ -149,7 +148,7 @@ func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, c for i := 0; i < concurrency; i++ { waitGroup.Add(1) runGenPrimeRoutine( - generatorCtx, primeCh, errCh, waitGroup, rand.Reader, bitLen, + generatorCtx, primeCh, errCh, waitGroup, rand, bitLen, ) } @@ -175,35 +174,35 @@ func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, c // a bit length equal to `pBitLen-1`. // // The algorithm is as follows: -// 1. Generate a random odd number `q` of length `pBitLen-1` with two the most -// significant bits set to `1`. -// 2. Execute preliminary primality test on `q` checking whether it is coprime -// to all the elements of `smallPrimes`. It allows to eliminate trivial -// cases quickly, when `q` is obviously no prime, without running an -// expensive final primality tests. -// If `q` is coprime to all of the `smallPrimes`, then go to the point 3. -// If not, add `2` and try again. Do it at most 10 times. -// 3. Check the potentially prime `q`, whether `q = 1 (mod 3)`. This will -// happen for 50% of cases. -// If it is, then `p = 2q+1` will be a multiple of 3, so it will be obviously -// not a prime number. In this case, add `2` and try again. Do it at most 10 -// times. If `q != 1 (mod 3)`, go to the point 4. -// 4. Now we know `q` is potentially prime and `p = 2q+1` is not a multiple of -// 3. We execute a preliminary primality test on `p`, checking whether -// it is coprime to all the elements of `smallPrimes` just like we did for -// `q` in point 2. If `p` is not coprime to at least one element of the -// `smallPrimes`, then go back to point 1. -// If `p` is coprime to all the elements of `smallPrimes`, go to point 5. -// 5. At this point, we know `q` is potentially prime, and `p=q+1` is also -// potentially prime. We need to execute a final primality test for `q`. -// We apply Miller-Rabin and Baillie-PSW tests. If they succeed, it means -// that `q` is prime with a very high probability. Knowing `q` is prime, -// we use Pocklington's criterion to prove the primality of `p=2q+1`, that -// is, we execute Fermat primality test to base 2 checking whether -// `2^{p-1} = 1 (mod p)`. It's significantly faster than running full -// Miller-Rabin and Baillie-PSW for `p`. -// If `q` and `p` are found to be prime, return them as a result. If not, go -// back to the point 1. +// 1. Generate a random odd number `q` of length `pBitLen-1` with two the most +// significant bits set to `1`. +// 2. Execute preliminary primality test on `q` checking whether it is coprime +// to all the elements of `smallPrimes`. It allows to eliminate trivial +// cases quickly, when `q` is obviously no prime, without running an +// expensive final primality tests. +// If `q` is coprime to all of the `smallPrimes`, then go to the point 3. +// If not, add `2` and try again. Do it at most 10 times. +// 3. Check the potentially prime `q`, whether `q = 1 (mod 3)`. This will +// happen for 50% of cases. +// If it is, then `p = 2q+1` will be a multiple of 3, so it will be obviously +// not a prime number. In this case, add `2` and try again. Do it at most 10 +// times. If `q != 1 (mod 3)`, go to the point 4. +// 4. Now we know `q` is potentially prime and `p = 2q+1` is not a multiple of +// 3. We execute a preliminary primality test on `p`, checking whether +// it is coprime to all the elements of `smallPrimes` just like we did for +// `q` in point 2. If `p` is not coprime to at least one element of the +// `smallPrimes`, then go back to point 1. +// If `p` is coprime to all the elements of `smallPrimes`, go to point 5. +// 5. At this point, we know `q` is potentially prime, and `p=q+1` is also +// potentially prime. We need to execute a final primality test for `q`. +// We apply Miller-Rabin and Baillie-PSW tests. If they succeed, it means +// that `q` is prime with a very high probability. Knowing `q` is prime, +// we use Pocklington's criterion to prove the primality of `p=2q+1`, that +// is, we execute Fermat primality test to base 2 checking whether +// `2^{p-1} = 1 (mod p)`. It's significantly faster than running full +// Miller-Rabin and Baillie-PSW for `p`. +// If `q` and `p` are found to be prime, return them as a result. If not, go +// back to the point 1. func runGenPrimeRoutine( ctx context.Context, primeCh chan<- *GermainSafePrime, diff --git a/common/safe_prime_test.go b/common/safe_prime_test.go index 083c736b..2f40d134 100644 --- a/common/safe_prime_test.go +++ b/common/safe_prime_test.go @@ -8,6 +8,7 @@ package common import ( "context" + "crypto/rand" "math/big" "runtime" "testing" @@ -45,7 +46,7 @@ func Test_Validate_Bad(t *testing.T) { func TestGetRandomGermainPrimeConcurrent(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute) defer cancel() - sgps, err := GetRandomSafePrimesConcurrent(ctx, 1024, 2, runtime.NumCPU()) + sgps, err := GetRandomSafePrimesConcurrent(ctx, 1024, 2, runtime.NumCPU(), rand.Reader) assert.NoError(t, err) assert.Equal(t, 2, len(sgps)) for _, sgp := range sgps { diff --git a/crypto/commitments/commitment.go b/crypto/commitments/commitment.go index d273fe74..6d7fa836 100644 --- a/crypto/commitments/commitment.go +++ b/crypto/commitments/commitment.go @@ -10,6 +10,7 @@ package commitments import ( + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -43,8 +44,8 @@ func NewHashCommitmentWithRandomness(r *big.Int, secrets ...*big.Int) *HashCommi return cmt } -func NewHashCommitment(secrets ...*big.Int) *HashCommitDecommit { - r := common.MustGetRandomInt(HashLength) // r +func NewHashCommitment(rand io.Reader, secrets ...*big.Int) *HashCommitDecommit { + r := common.MustGetRandomInt(rand, HashLength) // r return NewHashCommitmentWithRandomness(r, secrets...) } diff --git a/crypto/commitments/commitment_test.go b/crypto/commitments/commitment_test.go index 3dedf214..f400528c 100644 --- a/crypto/commitments/commitment_test.go +++ b/crypto/commitments/commitment_test.go @@ -7,6 +7,7 @@ package commitments_test import ( + "crypto/rand" "math/big" "testing" @@ -19,7 +20,7 @@ func TestCreateVerify(t *testing.T) { one := big.NewInt(1) zero := big.NewInt(0) - commitment := NewHashCommitment(zero, one) + commitment := NewHashCommitment(rand.Reader, zero, one) pass := commitment.Verify() assert.True(t, pass, "must pass") @@ -29,7 +30,7 @@ func TestDeCommit(t *testing.T) { one := big.NewInt(1) zero := big.NewInt(0) - commitment := NewHashCommitment(zero, one) + commitment := NewHashCommitment(rand.Reader, zero, one) pass, secrets := commitment.DeCommit() assert.True(t, pass, "must pass") diff --git a/crypto/dlnproof/proof.go b/crypto/dlnproof/proof.go index 43dd1556..e41dbeff 100644 --- a/crypto/dlnproof/proof.go +++ b/crypto/dlnproof/proof.go @@ -13,6 +13,7 @@ package dlnproof import ( "fmt" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -28,17 +29,15 @@ type ( } ) -var ( - one = big.NewInt(1) -) +var one = big.NewInt(1) -func NewDLNProof(h1, h2, x, p, q, N *big.Int) *Proof { +func NewDLNProof(h1, h2, x, p, q, N *big.Int, rand io.Reader) *Proof { pMulQ := new(big.Int).Mul(p, q) modN, modPQ := common.ModInt(N), common.ModInt(pMulQ) a := make([]*big.Int, Iterations) alpha := [Iterations]*big.Int{} for i := range alpha { - a[i] = common.GetRandomPositiveInt(pMulQ) + a[i] = common.GetRandomPositiveInt(rand, pMulQ) alpha[i] = modN.Exp(h1, a[i]) } msg := append([]*big.Int{h1, h2, N}, alpha[:]...) diff --git a/crypto/facproof/proof.go b/crypto/facproof/proof.go index 32ac00cf..f2c7ac56 100644 --- a/crypto/facproof/proof.go +++ b/crypto/facproof/proof.go @@ -10,6 +10,7 @@ import ( "crypto/elliptic" "errors" "fmt" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -32,7 +33,7 @@ var ( ) // NewProof implements prooffac -func NewProof(Session []byte, ec elliptic.Curve, N0, NCap, s, t, N0p, N0q *big.Int) (*ProofFac, error) { +func NewProof(Session []byte, ec elliptic.Curve, N0, NCap, s, t, N0p, N0q *big.Int, rand io.Reader) (*ProofFac, error) { if ec == nil || N0 == nil || NCap == nil || s == nil || t == nil || N0p == nil || N0q == nil { return nil, errors.New("ProveFac constructor received nil value(s)") } @@ -48,14 +49,14 @@ func NewProof(Session []byte, ec elliptic.Curve, N0, NCap, s, t, N0p, N0q *big.I q3SqrtN0 := new(big.Int).Mul(q3, sqrtN0) // Fig 28.1 sample - alpha := common.GetRandomPositiveInt(q3SqrtN0) - beta := common.GetRandomPositiveInt(q3SqrtN0) - mu := common.GetRandomPositiveInt(qNCap) - nu := common.GetRandomPositiveInt(qNCap) - sigma := common.GetRandomPositiveInt(qN0NCap) - r := common.GetRandomPositiveRelativelyPrimeInt(q3N0NCap) - x := common.GetRandomPositiveInt(q3NCap) - y := common.GetRandomPositiveInt(q3NCap) + alpha := common.GetRandomPositiveInt(rand, q3SqrtN0) + beta := common.GetRandomPositiveInt(rand, q3SqrtN0) + mu := common.GetRandomPositiveInt(rand, qNCap) + nu := common.GetRandomPositiveInt(rand, qNCap) + sigma := common.GetRandomPositiveInt(rand, qN0NCap) + r := common.GetRandomPositiveRelativelyPrimeInt(rand, q3N0NCap) + x := common.GetRandomPositiveInt(rand, q3NCap) + y := common.GetRandomPositiveInt(rand, q3NCap) // Fig 28.1 compute modNCap := common.ModInt(NCap) diff --git a/crypto/facproof/proof_test.go b/crypto/facproof/proof_test.go index 7676f19f..a6af7d61 100644 --- a/crypto/facproof/proof_test.go +++ b/crypto/facproof/proof_test.go @@ -7,6 +7,7 @@ package facproof_test import ( + "crypto/rand" "math/big" "testing" @@ -23,31 +24,29 @@ const ( testSafePrimeBits = 1024 ) -var ( - Session = []byte("session") -) +var Session = []byte("session") func TestFac(test *testing.T) { ec := tss.EC() - N0p := common.GetRandomPrimeInt(testSafePrimeBits) - N0q := common.GetRandomPrimeInt(testSafePrimeBits) + N0p := common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits) + N0q := common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits) N0 := new(big.Int).Mul(N0p, N0q) - primes := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)} - NCap, s, t, err := crypto.GenerateNTildei(primes) + primes := [2]*big.Int{common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits), common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits)} + NCap, s, t, err := crypto.GenerateNTildei(rand.Reader, primes) assert.NoError(test, err) - proof, err := NewProof(Session, ec, N0, NCap, s, t, N0p, N0q) + proof, err := NewProof(Session, ec, N0, NCap, s, t, N0p, N0q, rand.Reader) assert.NoError(test, err) ok := proof.Verify(Session, ec, N0, NCap, s, t) assert.True(test, ok, "proof must verify") - N0p = common.GetRandomPrimeInt(1024) - N0q = common.GetRandomPrimeInt(1024) + N0p = common.GetRandomPrimeInt(rand.Reader, 1024) + N0q = common.GetRandomPrimeInt(rand.Reader, 1024) N0 = new(big.Int).Mul(N0p, N0q) - proof, err = NewProof(Session, ec, N0, NCap, s, t, N0p, N0q) + proof, err = NewProof(Session, ec, N0, NCap, s, t, N0p, N0q, rand.Reader) assert.NoError(test, err) ok = proof.Verify(Session, ec, N0, NCap, s, t) diff --git a/crypto/modproof/proof.go b/crypto/modproof/proof.go index 3277db12..8ec8a270 100644 --- a/crypto/modproof/proof.go +++ b/crypto/modproof/proof.go @@ -8,6 +8,7 @@ package modproof import ( "fmt" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -18,9 +19,7 @@ const ( ProofModBytesParts = Iterations*2 + 3 ) -var ( - one = big.NewInt(1) -) +var one = big.NewInt(1) type ( ProofMod struct { @@ -37,10 +36,10 @@ func isQuadraticResidue(X, N *big.Int) bool { return big.Jacobi(X, N) == 1 } -func NewProof(Session []byte, N, P, Q *big.Int) (*ProofMod, error) { +func NewProof(Session []byte, N, P, Q *big.Int, rand io.Reader) (*ProofMod, error) { Phi := new(big.Int).Mul(new(big.Int).Sub(P, one), new(big.Int).Sub(Q, one)) // Fig 16.1 - W := common.GetRandomQuadraticNonResidue(N) + W := common.GetRandomQuadraticNonResidue(rand, N) // Fig 16.2 Y := [Iterations]*big.Int{} diff --git a/crypto/modproof/proof_test.go b/crypto/modproof/proof_test.go index 5a6dc419..9d800df7 100644 --- a/crypto/modproof/proof_test.go +++ b/crypto/modproof/proof_test.go @@ -7,6 +7,7 @@ package modproof_test import ( + "crypto/rand" "testing" "time" @@ -15,9 +16,7 @@ import ( "github.com/stretchr/testify/assert" ) -var ( - Session = []byte("session") -) +var Session = []byte("session") func TestMod(test *testing.T) { preParams, err := keygen.GeneratePreParams(time.Minute*10, 8) @@ -25,7 +24,7 @@ func TestMod(test *testing.T) { P, Q, N := preParams.PaillierSK.P, preParams.PaillierSK.Q, preParams.PaillierSK.N - proof, err := NewProof(Session, N, P, Q) + proof, err := NewProof(Session, N, P, Q, rand.Reader) assert.NoError(test, err) proofBzs := proof.Bytes() diff --git a/crypto/mta/proofs.go b/crypto/mta/proofs.go index 2bd61174..bf1809d2 100644 --- a/crypto/mta/proofs.go +++ b/crypto/mta/proofs.go @@ -10,6 +10,7 @@ import ( "crypto/elliptic" "errors" "fmt" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -36,7 +37,7 @@ type ( // ProveBobWC implements Bob's proof both with or without check "ProveMtawc_Bob" and "ProveMta_Bob" used in the MtA protocol from GG18Spec (9) Figs. 10 & 11. // an absent `X` generates the proof without the X consistency check X = g^x -func ProveBobWC(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2, x, y, r *big.Int, X *crypto.ECPoint) (*ProofBobWC, error) { +func ProveBobWC(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2, x, y, r *big.Int, X *crypto.ECPoint, rand io.Reader) (*ProofBobWC, error) { if pk == nil || NTilde == nil || h1 == nil || h2 == nil || c1 == nil || c2 == nil || x == nil || y == nil || r == nil { return nil, errors.New("ProveBob() received a nil argument") } @@ -53,20 +54,20 @@ func ProveBobWC(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTild // steps are numbered as shown in Fig. 10, but diverge slightly for Fig. 11 // 1. - alpha := common.GetRandomPositiveInt(q3) + alpha := common.GetRandomPositiveInt(rand, q3) // 2. - rho := common.GetRandomPositiveInt(qNTilde) - sigma := common.GetRandomPositiveInt(qNTilde) - tau := common.GetRandomPositiveInt(q3NTilde) + rho := common.GetRandomPositiveInt(rand, qNTilde) + sigma := common.GetRandomPositiveInt(rand, qNTilde) + tau := common.GetRandomPositiveInt(rand, q3NTilde) // 3. - rhoPrm := common.GetRandomPositiveInt(q3NTilde) + rhoPrm := common.GetRandomPositiveInt(rand, q3NTilde) // 4. - beta := common.GetRandomPositiveRelativelyPrimeInt(pk.N) + beta := common.GetRandomPositiveRelativelyPrimeInt(rand, pk.N) - gamma := common.GetRandomPositiveInt(q7) + gamma := common.GetRandomPositiveInt(rand, q7) // 5. u := crypto.NewECPointNoCurveCheck(ec, zero, zero) // initialization suppresses an IDE warning @@ -139,10 +140,10 @@ func ProveBobWC(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTild } // ProveBob implements Bob's proof "ProveMta_Bob" used in the MtA protocol from GG18Spec (9) Fig. 11. -func ProveBob(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2, x, y, r *big.Int) (*ProofBob, error) { +func ProveBob(Session []byte, ec elliptic.Curve, pk *paillier.PublicKey, NTilde, h1, h2, c1, c2, x, y, r *big.Int, rand io.Reader) (*ProofBob, error) { // the Bob proof ("with check") contains the ProofBob "without check"; this method extracts and returns it // X is supplied as nil to exclude it from the proof hash - pf, err := ProveBobWC(Session, ec, pk, NTilde, h1, h2, c1, c2, x, y, r, nil) + pf, err := ProveBobWC(Session, ec, pk, NTilde, h1, h2, c1, c2, x, y, r, nil, rand) if err != nil { return nil, err } diff --git a/crypto/mta/range_proof.go b/crypto/mta/range_proof.go index 11bb999f..1c5b3a0e 100644 --- a/crypto/mta/range_proof.go +++ b/crypto/mta/range_proof.go @@ -10,6 +10,7 @@ import ( "crypto/elliptic" "errors" "fmt" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -32,7 +33,7 @@ type ( ) // ProveRangeAlice implements Alice's range proof used in the MtA and MtAwc protocols from GG18Spec (9) Fig. 9. -func ProveRangeAlice(ec elliptic.Curve, pk *paillier.PublicKey, c, NTilde, h1, h2, m, r *big.Int) (*RangeProofAlice, error) { +func ProveRangeAlice(ec elliptic.Curve, pk *paillier.PublicKey, c, NTilde, h1, h2, m, r *big.Int, rand io.Reader) (*RangeProofAlice, error) { if pk == nil || NTilde == nil || h1 == nil || h2 == nil || c == nil || m == nil || r == nil { return nil, errors.New("ProveRangeAlice constructor received nil value(s)") } @@ -44,15 +45,15 @@ func ProveRangeAlice(ec elliptic.Curve, pk *paillier.PublicKey, c, NTilde, h1, h q3NTilde := new(big.Int).Mul(q3, NTilde) // 1. - alpha := common.GetRandomPositiveInt(q3) + alpha := common.GetRandomPositiveInt(rand, q3) // 2. - beta := common.GetRandomPositiveRelativelyPrimeInt(pk.N) + beta := common.GetRandomPositiveRelativelyPrimeInt(rand, pk.N) // 3. - gamma := common.GetRandomPositiveInt(q3NTilde) + gamma := common.GetRandomPositiveInt(rand, q3NTilde) // 4. - rho := common.GetRandomPositiveInt(qNTilde) + rho := common.GetRandomPositiveInt(rand, qNTilde) // 5. modNTilde := common.ModInt(NTilde) diff --git a/crypto/mta/range_proof_test.go b/crypto/mta/range_proof_test.go index 318335f1..151126fe 100644 --- a/crypto/mta/range_proof_test.go +++ b/crypto/mta/range_proof_test.go @@ -8,6 +8,7 @@ package mta import ( "context" + "crypto/rand" "fmt" "math/big" "testing" @@ -32,17 +33,17 @@ func TestProveRangeAlice(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - sk, pk, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength) + sk, pk, err := paillier.GenerateKeyPair(ctx, rand.Reader, testPaillierKeyLength) assert.NoError(t, err) - m := common.GetRandomPositiveInt(q) - c, r, err := sk.EncryptAndReturnRandomness(m) + m := common.GetRandomPositiveInt(rand.Reader, q) + c, r, err := sk.EncryptAndReturnRandomness(rand.Reader, m) assert.NoError(t, err) - primes := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)} - NTildei, h1i, h2i, err := crypto.GenerateNTildei(primes) + primes := [2]*big.Int{common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits), common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits)} + NTildei, h1i, h2i, err := crypto.GenerateNTildei(rand.Reader, primes) assert.NoError(t, err) - proof, err := ProveRangeAlice(tss.EC(), pk, c, NTildei, h1i, h2i, m, r) + proof, err := ProveRangeAlice(tss.EC(), pk, c, NTildei, h1i, h2i, m, r, rand.Reader) assert.NoError(t, err) ok := proof.Verify(tss.EC(), pk, NTildei, h1i, h2i, c) @@ -55,34 +56,34 @@ func TestProveRangeAliceBypassed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - sk0, pk0, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength) + sk0, pk0, err := paillier.GenerateKeyPair(ctx, rand.Reader, testPaillierKeyLength) assert.NoError(t, err) - m0 := common.GetRandomPositiveInt(q) - c0, r0, err := sk0.EncryptAndReturnRandomness(m0) + m0 := common.GetRandomPositiveInt(rand.Reader, q) + c0, r0, err := sk0.EncryptAndReturnRandomness(rand.Reader, m0) assert.NoError(t, err) - primes0 := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)} - Ntildei0, h1i0, h2i0, err := crypto.GenerateNTildei(primes0) + primes0 := [2]*big.Int{common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits), common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits)} + Ntildei0, h1i0, h2i0, err := crypto.GenerateNTildei(rand.Reader, primes0) assert.NoError(t, err) - proof0, err := ProveRangeAlice(tss.EC(), pk0, c0, Ntildei0, h1i0, h2i0, m0, r0) + proof0, err := ProveRangeAlice(tss.EC(), pk0, c0, Ntildei0, h1i0, h2i0, m0, r0, rand.Reader) assert.NoError(t, err) ok0 := proof0.Verify(tss.EC(), pk0, Ntildei0, h1i0, h2i0, c0) assert.True(t, ok0, "proof must verify") - //proof 2 - sk1, pk1, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength) + // proof 2 + sk1, pk1, err := paillier.GenerateKeyPair(ctx, rand.Reader, testPaillierKeyLength) assert.NoError(t, err) - m1 := common.GetRandomPositiveInt(q) - c1, r1, err := sk1.EncryptAndReturnRandomness(m1) + m1 := common.GetRandomPositiveInt(rand.Reader, q) + c1, r1, err := sk1.EncryptAndReturnRandomness(rand.Reader, m1) assert.NoError(t, err) - primes1 := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)} - Ntildei1, h1i1, h2i1, err := crypto.GenerateNTildei(primes1) + primes1 := [2]*big.Int{common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits), common.GetRandomPrimeInt(rand.Reader, testSafePrimeBits)} + Ntildei1, h1i1, h2i1, err := crypto.GenerateNTildei(rand.Reader, primes1) assert.NoError(t, err) - proof1, err := ProveRangeAlice(tss.EC(), pk1, c1, Ntildei1, h1i1, h2i1, m1, r1) + proof1, err := ProveRangeAlice(tss.EC(), pk1, c1, Ntildei1, h1i1, h2i1, m1, r1, rand.Reader) assert.NoError(t, err) ok1 := proof1.Verify(tss.EC(), pk1, Ntildei1, h1i1, h2i1, c1) @@ -100,7 +101,7 @@ func TestProveRangeAliceBypassed(t *testing.T) { fmt.Println("Did verify proof 0 with data from 1?", cross0) fmt.Println("Did verify proof 1 with data from 0?", cross1) - //new bypass + // new bypass bypassedproofNew := &RangeProofAlice{ S: big.NewInt(1), S1: big.NewInt(0), @@ -111,13 +112,13 @@ func TestProveRangeAliceBypassed(t *testing.T) { } cBogus := big.NewInt(1) - proofBogus, _ := ProveRangeAlice(tss.EC(), pk1, cBogus, Ntildei1, h1i1, h2i1, m1, r1) + proofBogus, _ := ProveRangeAlice(tss.EC(), pk1, cBogus, Ntildei1, h1i1, h2i1, m1, r1, rand.Reader) ok2 := proofBogus.Verify(tss.EC(), pk1, Ntildei1, h1i1, h2i1, cBogus) bypassresult3 := bypassedproofNew.Verify(tss.EC(), pk1, Ntildei1, h1i1, h2i1, cBogus) - //c = 1 is not valid, even though we can find a range proof for it that passes! - //this also means that the homo mul and add needs to be checked with this! + // c = 1 is not valid, even though we can find a range proof for it that passes! + // this also means that the homo mul and add needs to be checked with this! fmt.Println("Did verify proof bogus with data from bogus?", ok2) fmt.Println("Did we bypass proof 3?", bypassresult3) } diff --git a/crypto/mta/share_protocol.go b/crypto/mta/share_protocol.go index 94c1c877..035f8d92 100644 --- a/crypto/mta/share_protocol.go +++ b/crypto/mta/share_protocol.go @@ -9,6 +9,7 @@ package mta import ( "crypto/elliptic" "errors" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -20,12 +21,13 @@ func AliceInit( ec elliptic.Curve, pkA *paillier.PublicKey, a, NTildeB, h1B, h2B *big.Int, + rand io.Reader, ) (cA *big.Int, pf *RangeProofAlice, err error) { - cA, rA, err := pkA.EncryptAndReturnRandomness(a) + cA, rA, err := pkA.EncryptAndReturnRandomness(rand, a) if err != nil { return nil, nil, err } - pf, err = ProveRangeAlice(ec, pkA, cA, NTildeB, h1B, h2B, a, rA) + pf, err = ProveRangeAlice(ec, pkA, cA, NTildeB, h1B, h2B, a, rA, rand) return cA, pf, err } @@ -35,6 +37,7 @@ func BobMid( pkA *paillier.PublicKey, pf *RangeProofAlice, b, cA, NTildeA, h1A, h2A, NTildeB, h1B, h2B *big.Int, + rand io.Reader, ) (beta, cB, betaPrm *big.Int, piB *ProofBob, err error) { if !pf.Verify(ec, pkA, NTildeB, h1B, h2B, cA) { err = errors.New("RangeProofAlice.Verify() returned false") @@ -44,8 +47,8 @@ func BobMid( q5 := new(big.Int).Mul(q, q) // q^2 q5 = new(big.Int).Mul(q5, q5) // q^4 q5 = new(big.Int).Mul(q5, q) // q^5 - betaPrm = common.GetRandomPositiveInt(q5) - cBetaPrm, cRand, err := pkA.EncryptAndReturnRandomness(betaPrm) + betaPrm = common.GetRandomPositiveInt(rand, q5) + cBetaPrm, cRand, err := pkA.EncryptAndReturnRandomness(rand, betaPrm) if err != nil { return } @@ -58,7 +61,7 @@ func BobMid( return } beta = common.ModInt(q).Sub(zero, betaPrm) - piB, err = ProveBob(Session, ec, pkA, NTildeA, h1A, h2A, cA, cB, b, betaPrm, cRand) + piB, err = ProveBob(Session, ec, pkA, NTildeA, h1A, h2A, cA, cB, b, betaPrm, cRand, rand) return } @@ -69,6 +72,7 @@ func BobMidWC( pf *RangeProofAlice, b, cA, NTildeA, h1A, h2A, NTildeB, h1B, h2B *big.Int, B *crypto.ECPoint, + rand io.Reader, ) (beta, cB, betaPrm *big.Int, piB *ProofBobWC, err error) { if !pf.Verify(ec, pkA, NTildeB, h1B, h2B, cA) { err = errors.New("RangeProofAlice.Verify() returned false") @@ -78,8 +82,8 @@ func BobMidWC( q5 := new(big.Int).Mul(q, q) // q^2 q5 = new(big.Int).Mul(q5, q5) // q^4 q5 = new(big.Int).Mul(q5, q) // q^5 - betaPrm = common.GetRandomPositiveInt(q5) - cBetaPrm, cRand, err := pkA.EncryptAndReturnRandomness(betaPrm) + betaPrm = common.GetRandomPositiveInt(rand, q5) + cBetaPrm, cRand, err := pkA.EncryptAndReturnRandomness(rand, betaPrm) if err != nil { return } @@ -92,7 +96,7 @@ func BobMidWC( return } beta = common.ModInt(q).Sub(zero, betaPrm) - piB, err = ProveBobWC(Session, ec, pkA, NTildeA, h1A, h2A, cA, cB, b, betaPrm, cRand, B) + piB, err = ProveBobWC(Session, ec, pkA, NTildeA, h1A, h2A, cA, cB, b, betaPrm, cRand, B, rand) return } diff --git a/crypto/mta/share_protocol_test.go b/crypto/mta/share_protocol_test.go index aed84f39..fa4f80b8 100644 --- a/crypto/mta/share_protocol_test.go +++ b/crypto/mta/share_protocol_test.go @@ -8,6 +8,7 @@ package mta import ( "context" + "crypto/rand" "math/big" "testing" "time" @@ -26,9 +27,7 @@ const ( testPaillierKeyLength = 2048 ) -var ( - Session = []byte("session") -) +var Session = []byte("session") func TestShareProtocol(t *testing.T) { q := tss.EC().Params().N @@ -36,21 +35,21 @@ func TestShareProtocol(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - sk, pk, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength) + sk, pk, err := paillier.GenerateKeyPair(ctx, rand.Reader, testPaillierKeyLength) assert.NoError(t, err) - a := common.GetRandomPositiveInt(q) - b := common.GetRandomPositiveInt(q) + a := common.GetRandomPositiveInt(rand.Reader, q) + b := common.GetRandomPositiveInt(rand.Reader, q) NTildei, h1i, h2i, err := keygen.LoadNTildeH1H2FromTestFixture(0) assert.NoError(t, err) NTildej, h1j, h2j, err := keygen.LoadNTildeH1H2FromTestFixture(1) assert.NoError(t, err) - cA, pf, err := AliceInit(tss.EC(), pk, a, NTildej, h1j, h2j) + cA, pf, err := AliceInit(tss.EC(), pk, a, NTildej, h1j, h2j, rand.Reader) assert.NoError(t, err) - _, cB, betaPrm, pfB, err := BobMid(Session, tss.EC(), pk, pf, b, cA, NTildei, h1i, h2i, NTildej, h1j, h2j) + _, cB, betaPrm, pfB, err := BobMid(Session, tss.EC(), pk, pf, b, cA, NTildei, h1i, h2i, NTildej, h1j, h2j, rand.Reader) assert.NoError(t, err) alpha, err := AliceEnd(Session, tss.EC(), pk, pfB, h1i, h2i, cA, cB, NTildei, sk) @@ -69,11 +68,11 @@ func TestShareProtocolWC(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() - sk, pk, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength) + sk, pk, err := paillier.GenerateKeyPair(ctx, rand.Reader, testPaillierKeyLength) assert.NoError(t, err) - a := common.GetRandomPositiveInt(q) - b := common.GetRandomPositiveInt(q) + a := common.GetRandomPositiveInt(rand.Reader, q) + b := common.GetRandomPositiveInt(rand.Reader, q) gBX, gBY := tss.EC().ScalarBaseMult(b.Bytes()) NTildei, h1i, h2i, err := keygen.LoadNTildeH1H2FromTestFixture(0) @@ -81,12 +80,12 @@ func TestShareProtocolWC(t *testing.T) { NTildej, h1j, h2j, err := keygen.LoadNTildeH1H2FromTestFixture(1) assert.NoError(t, err) - cA, pf, err := AliceInit(tss.EC(), pk, a, NTildej, h1j, h2j) + cA, pf, err := AliceInit(tss.EC(), pk, a, NTildej, h1j, h2j, rand.Reader) assert.NoError(t, err) gBPoint, err := crypto.NewECPoint(tss.EC(), gBX, gBY) assert.NoError(t, err) - _, cB, betaPrm, pfB, err := BobMidWC(Session, tss.EC(), pk, pf, b, cA, NTildei, h1i, h2i, NTildej, h1j, h2j, gBPoint) + _, cB, betaPrm, pfB, err := BobMidWC(Session, tss.EC(), pk, pf, b, cA, NTildei, h1i, h2i, NTildej, h1j, h2j, gBPoint, rand.Reader) assert.NoError(t, err) alpha, err := AliceEndWC(Session, tss.EC(), pk, pfB, gBPoint, cA, cB, NTildei, h1i, h2i, sk) diff --git a/crypto/paillier/paillier.go b/crypto/paillier/paillier.go index 1b249b12..846d0392 100644 --- a/crypto/paillier/paillier.go +++ b/crypto/paillier/paillier.go @@ -19,6 +19,7 @@ import ( "context" "errors" "fmt" + "io" gmath "math" "math/big" "runtime" @@ -66,7 +67,7 @@ func init() { } // len is the length of the modulus (each prime = len / 2) -func GenerateKeyPair(ctx context.Context, modulusBitLen int, optionalConcurrency ...int) (privateKey *PrivateKey, publicKey *PublicKey, err error) { +func GenerateKeyPair(ctx context.Context, rand io.Reader, modulusBitLen int, optionalConcurrency ...int) (privateKey *PrivateKey, publicKey *PublicKey, err error) { var concurrency int if 0 < len(optionalConcurrency) { if 1 < len(optionalConcurrency) { @@ -82,7 +83,7 @@ func GenerateKeyPair(ctx context.Context, modulusBitLen int, optionalConcurrency { tmp := new(big.Int) for { - sgps, err := common.GetRandomSafePrimesConcurrent(ctx, modulusBitLen/2, 2, concurrency) + sgps, err := common.GetRandomSafePrimesConcurrent(ctx, modulusBitLen/2, 2, concurrency, rand) if err != nil { return nil, nil, err } @@ -110,11 +111,11 @@ func GenerateKeyPair(ctx context.Context, modulusBitLen int, optionalConcurrency // ----- // -func (publicKey *PublicKey) EncryptAndReturnRandomness(m *big.Int) (c *big.Int, x *big.Int, err error) { +func (publicKey *PublicKey) EncryptAndReturnRandomness(rand io.Reader, m *big.Int) (c *big.Int, x *big.Int, err error) { if m.Cmp(zero) == -1 || m.Cmp(publicKey.N) != -1 { // m < 0 || m >= N ? return nil, nil, ErrMessageTooLong } - x = common.GetRandomPositiveRelativelyPrimeInt(publicKey.N) + x = common.GetRandomPositiveRelativelyPrimeInt(rand, publicKey.N) N2 := publicKey.NSquare() // 1. gamma^m mod N2 Gm := new(big.Int).Exp(publicKey.Gamma(), m, N2) @@ -125,8 +126,8 @@ func (publicKey *PublicKey) EncryptAndReturnRandomness(m *big.Int) (c *big.Int, return } -func (publicKey *PublicKey) Encrypt(m *big.Int) (c *big.Int, err error) { - c, _, err = publicKey.EncryptAndReturnRandomness(m) +func (publicKey *PublicKey) Encrypt(rand io.Reader, m *big.Int) (c *big.Int, err error) { + c, _, err = publicKey.EncryptAndReturnRandomness(rand, m) return } diff --git a/crypto/paillier/paillier_test.go b/crypto/paillier/paillier_test.go index a9ebff73..97c4d05d 100644 --- a/crypto/paillier/paillier_test.go +++ b/crypto/paillier/paillier_test.go @@ -8,6 +8,7 @@ package paillier_test import ( "context" + "crypto/rand" "math/big" "testing" "time" @@ -39,7 +40,7 @@ func setUp(t *testing.T) { defer cancel() var err error - privateKey, publicKey, err = GenerateKeyPair(ctx, testPaillierKeyLength) + privateKey, publicKey, err = GenerateKeyPair(ctx, rand.Reader, testPaillierKeyLength) assert.NoError(t, err) } @@ -52,7 +53,7 @@ func TestGenerateKeyPair(t *testing.T) { func TestEncrypt(t *testing.T) { setUp(t) - cipher, err := publicKey.Encrypt(big.NewInt(1)) + cipher, err := publicKey.Encrypt(rand.Reader, big.NewInt(1)) assert.NoError(t, err, "must not error") assert.NotZero(t, cipher) t.Log(cipher) @@ -61,7 +62,7 @@ func TestEncrypt(t *testing.T) { func TestEncryptDecrypt(t *testing.T) { setUp(t) exp := big.NewInt(100) - cypher, err := privateKey.Encrypt(exp) + cypher, err := privateKey.Encrypt(rand.Reader, exp) if err != nil { t.Error(err) } @@ -77,7 +78,7 @@ func TestEncryptDecrypt(t *testing.T) { func TestHomoMul(t *testing.T) { setUp(t) - three, err := privateKey.Encrypt(big.NewInt(3)) + three, err := privateKey.Encrypt(rand.Reader, big.NewInt(3)) assert.NoError(t, err) // for HomoMul, the first argument `m` is not ciphered @@ -98,8 +99,8 @@ func TestHomoAdd(t *testing.T) { num1 := big.NewInt(10) num2 := big.NewInt(32) - one, _ := publicKey.Encrypt(num1) - two, _ := publicKey.Encrypt(num2) + one, _ := publicKey.Encrypt(rand.Reader, num1) + two, _ := publicKey.Encrypt(rand.Reader, num2) ciphered, _ := publicKey.HomoAdd(one, two) @@ -110,9 +111,9 @@ func TestHomoAdd(t *testing.T) { func TestProofVerify(t *testing.T) { setUp(t) - ki := common.MustGetRandomInt(256) // index - ui := common.GetRandomPositiveInt(tss.EC().Params().N) // ECDSA private - yX, yY := tss.EC().ScalarBaseMult(ui.Bytes()) // ECDSA public + ki := common.MustGetRandomInt(rand.Reader, 256) // index + ui := common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N) // ECDSA private + yX, yY := tss.EC().ScalarBaseMult(ui.Bytes()) // ECDSA public proof := privateKey.Proof(ki, crypto.NewECPointNoCurveCheck(tss.EC(), yX, yY)) res, err := proof.Verify(publicKey.N, ki, crypto.NewECPointNoCurveCheck(tss.EC(), yX, yY)) assert.NoError(t, err) @@ -121,9 +122,9 @@ func TestProofVerify(t *testing.T) { func TestProofVerifyFail(t *testing.T) { setUp(t) - ki := common.MustGetRandomInt(256) // index - ui := common.GetRandomPositiveInt(tss.EC().Params().N) // ECDSA private - yX, yY := tss.EC().ScalarBaseMult(ui.Bytes()) // ECDSA public + ki := common.MustGetRandomInt(rand.Reader, 256) // index + ui := common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N) // ECDSA private + yX, yY := tss.EC().ScalarBaseMult(ui.Bytes()) // ECDSA public proof := privateKey.Proof(ki, crypto.NewECPointNoCurveCheck(tss.EC(), yX, yY)) last := proof[len(proof)-1] last.Sub(last, big.NewInt(1)) @@ -143,10 +144,10 @@ func TestComputeL(t *testing.T) { } func TestGenerateXs(t *testing.T) { - k := common.MustGetRandomInt(256) - sX := common.MustGetRandomInt(256) - sY := common.MustGetRandomInt(256) - N := common.GetRandomPrimeInt(2048) + k := common.MustGetRandomInt(rand.Reader, 256) + sX := common.MustGetRandomInt(rand.Reader, 256) + sY := common.MustGetRandomInt(rand.Reader, 256) + N := common.GetRandomPrimeInt(rand.Reader, 2048) xs := GenerateXs(13, k, N, crypto.NewECPointNoCurveCheck(tss.EC(), sX, sY)) assert.Equal(t, 13, len(xs)) diff --git a/crypto/schnorr/schnorr_proof.go b/crypto/schnorr/schnorr_proof.go index 3a4b6aa8..8c1bedcd 100644 --- a/crypto/schnorr/schnorr_proof.go +++ b/crypto/schnorr/schnorr_proof.go @@ -8,6 +8,7 @@ package schnorr import ( "errors" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -27,7 +28,7 @@ type ( ) // NewZKProof constructs a new Schnorr ZK proof of knowledge of the discrete logarithm (GG18Spec Fig. 16) -func NewZKProof(Session []byte, x *big.Int, X *crypto.ECPoint) (*ZKProof, error) { +func NewZKProof(Session []byte, x *big.Int, X *crypto.ECPoint, rand io.Reader) (*ZKProof, error) { if x == nil || X == nil || !X.ValidateBasic() { return nil, errors.New("ZKProof constructor received nil or invalid value(s)") } @@ -36,7 +37,7 @@ func NewZKProof(Session []byte, x *big.Int, X *crypto.ECPoint) (*ZKProof, error) q := ecParams.N g := crypto.NewECPointNoCurveCheck(ec, ecParams.Gx, ecParams.Gy) // already on the curve. - a := common.GetRandomPositiveInt(q) + a := common.GetRandomPositiveInt(rand, q) alpha := crypto.ScalarBaseMult(ec, a) var c *big.Int @@ -79,7 +80,7 @@ func (pf *ZKProof) ValidateBasic() bool { } // NewZKProof constructs a new Schnorr ZK proof of knowledge s_i, l_i such that V_i = R^s_i, g^l_i (GG18Spec Fig. 17) -func NewZKVProof(Session []byte, V, R *crypto.ECPoint, s, l *big.Int) (*ZKVProof, error) { +func NewZKVProof(Session []byte, V, R *crypto.ECPoint, s, l *big.Int, rand io.Reader) (*ZKVProof, error) { if V == nil || R == nil || s == nil || l == nil || !V.ValidateBasic() || !R.ValidateBasic() { return nil, errors.New("ZKVProof constructor received nil value(s)") } @@ -88,7 +89,7 @@ func NewZKVProof(Session []byte, V, R *crypto.ECPoint, s, l *big.Int) (*ZKVProof q := ecParams.N g := crypto.NewECPointNoCurveCheck(ec, ecParams.Gx, ecParams.Gy) - a, b := common.GetRandomPositiveInt(q), common.GetRandomPositiveInt(q) + a, b := common.GetRandomPositiveInt(rand, q), common.GetRandomPositiveInt(rand, q) aR := R.ScalarMult(a) bG := crypto.ScalarBaseMult(ec, b) alpha, _ := aR.Add(bG) // already on the curve. diff --git a/crypto/schnorr/schnorr_proof_test.go b/crypto/schnorr/schnorr_proof_test.go index fbf633ea..89a7e964 100644 --- a/crypto/schnorr/schnorr_proof_test.go +++ b/crypto/schnorr/schnorr_proof_test.go @@ -7,6 +7,7 @@ package schnorr_test import ( + "crypto/rand" "testing" "github.com/stretchr/testify/assert" @@ -17,15 +18,13 @@ import ( "github.com/bnb-chain/tss-lib/v2/tss" ) -var ( - Session = []byte("session") -) +var Session = []byte("session") func TestSchnorrProof(t *testing.T) { q := tss.EC().Params().N - u := common.GetRandomPositiveInt(q) + u := common.GetRandomPositiveInt(rand.Reader, q) uG := crypto.ScalarBaseMult(tss.EC(), u) - proof, _ := NewZKProof(Session, u, uG) + proof, _ := NewZKProof(Session, u, uG, rand.Reader) assert.True(t, proof.Alpha.IsOnCurve()) assert.NotZero(t, proof.Alpha.X()) @@ -35,10 +34,10 @@ func TestSchnorrProof(t *testing.T) { func TestSchnorrProofVerify(t *testing.T) { q := tss.EC().Params().N - u := common.GetRandomPositiveInt(q) + u := common.GetRandomPositiveInt(rand.Reader, q) X := crypto.ScalarBaseMult(tss.EC(), u) - proof, _ := NewZKProof(Session, u, X) + proof, _ := NewZKProof(Session, u, X, rand.Reader) res := proof.Verify(Session, X) assert.True(t, res, "verify result must be true") @@ -46,12 +45,12 @@ func TestSchnorrProofVerify(t *testing.T) { func TestSchnorrProofVerifyBadX(t *testing.T) { q := tss.EC().Params().N - u := common.GetRandomPositiveInt(q) - u2 := common.GetRandomPositiveInt(q) + u := common.GetRandomPositiveInt(rand.Reader, q) + u2 := common.GetRandomPositiveInt(rand.Reader, q) X := crypto.ScalarBaseMult(tss.EC(), u) X2 := crypto.ScalarBaseMult(tss.EC(), u2) - proof, _ := NewZKProof(Session, u2, X2) + proof, _ := NewZKProof(Session, u2, X2, rand.Reader) res := proof.Verify(Session, X) assert.False(t, res, "verify result must be false") @@ -59,15 +58,15 @@ func TestSchnorrProofVerifyBadX(t *testing.T) { func TestSchnorrVProofVerify(t *testing.T) { q := tss.EC().Params().N - k := common.GetRandomPositiveInt(q) - s := common.GetRandomPositiveInt(q) - l := common.GetRandomPositiveInt(q) + k := common.GetRandomPositiveInt(rand.Reader, q) + s := common.GetRandomPositiveInt(rand.Reader, q) + l := common.GetRandomPositiveInt(rand.Reader, q) R := crypto.ScalarBaseMult(tss.EC(), k) // k_-1 * G Rs := R.ScalarMult(s) lG := crypto.ScalarBaseMult(tss.EC(), l) V, _ := Rs.Add(lG) - proof, _ := NewZKVProof(Session, V, R, s, l) + proof, _ := NewZKVProof(Session, V, R, s, l, rand.Reader) res := proof.Verify(Session, V, R) assert.True(t, res, "verify result must be true") @@ -75,14 +74,14 @@ func TestSchnorrVProofVerify(t *testing.T) { func TestSchnorrVProofVerifyBadPartialV(t *testing.T) { q := tss.EC().Params().N - k := common.GetRandomPositiveInt(q) - s := common.GetRandomPositiveInt(q) - l := common.GetRandomPositiveInt(q) + k := common.GetRandomPositiveInt(rand.Reader, q) + s := common.GetRandomPositiveInt(rand.Reader, q) + l := common.GetRandomPositiveInt(rand.Reader, q) R := crypto.ScalarBaseMult(tss.EC(), k) // k_-1 * G Rs := R.ScalarMult(s) V := Rs - proof, _ := NewZKVProof(Session, V, R, s, l) + proof, _ := NewZKVProof(Session, V, R, s, l, rand.Reader) res := proof.Verify(Session, V, R) assert.False(t, res, "verify result must be false") @@ -90,16 +89,16 @@ func TestSchnorrVProofVerifyBadPartialV(t *testing.T) { func TestSchnorrVProofVerifyBadS(t *testing.T) { q := tss.EC().Params().N - k := common.GetRandomPositiveInt(q) - s := common.GetRandomPositiveInt(q) - s2 := common.GetRandomPositiveInt(q) - l := common.GetRandomPositiveInt(q) + k := common.GetRandomPositiveInt(rand.Reader, q) + s := common.GetRandomPositiveInt(rand.Reader, q) + s2 := common.GetRandomPositiveInt(rand.Reader, q) + l := common.GetRandomPositiveInt(rand.Reader, q) R := crypto.ScalarBaseMult(tss.EC(), k) // k_-1 * G Rs := R.ScalarMult(s) lG := crypto.ScalarBaseMult(tss.EC(), l) V, _ := Rs.Add(lG) - proof, _ := NewZKVProof(Session, V, R, s2, l) + proof, _ := NewZKVProof(Session, V, R, s2, l, rand.Reader) res := proof.Verify(Session, V, R) assert.False(t, res, "verify result must be false") diff --git a/crypto/utils.go b/crypto/utils.go index 5448113a..2277a395 100644 --- a/crypto/utils.go +++ b/crypto/utils.go @@ -8,12 +8,13 @@ package crypto import ( "fmt" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" ) -func GenerateNTildei(safePrimes [2]*big.Int) (NTildei, h1i, h2i *big.Int, err error) { +func GenerateNTildei(rand io.Reader, safePrimes [2]*big.Int) (NTildei, h1i, h2i *big.Int, err error) { if safePrimes[0] == nil || safePrimes[1] == nil { return nil, nil, nil, fmt.Errorf("GenerateNTildei: needs two primes, got %v", safePrimes) } @@ -21,7 +22,7 @@ func GenerateNTildei(safePrimes [2]*big.Int) (NTildei, h1i, h2i *big.Int, err er return nil, nil, nil, fmt.Errorf("GenerateNTildei: expected two primes") } NTildei = new(big.Int).Mul(safePrimes[0], safePrimes[1]) - h1 := common.GetRandomGeneratorOfTheQuadraticResidue(NTildei) - h2 := common.GetRandomGeneratorOfTheQuadraticResidue(NTildei) + h1 := common.GetRandomGeneratorOfTheQuadraticResidue(rand, NTildei) + h2 := common.GetRandomGeneratorOfTheQuadraticResidue(rand, NTildei) return NTildei, h1, h2, nil } diff --git a/crypto/vss/feldman_vss.go b/crypto/vss/feldman_vss.go index 21c49714..5e80ccc4 100644 --- a/crypto/vss/feldman_vss.go +++ b/crypto/vss/feldman_vss.go @@ -14,6 +14,7 @@ import ( "crypto/elliptic" "errors" "fmt" + "io" "math/big" "github.com/bnb-chain/tss-lib/v2/common" @@ -58,7 +59,7 @@ func CheckIndexes(ec elliptic.Curve, indexes []*big.Int) ([]*big.Int, error) { // Returns a new array of secret shares created by Shamir's Secret Sharing Algorithm, // requiring a minimum number of shares to recreate, of length shares, from the input secret -func Create(ec elliptic.Curve, threshold int, secret *big.Int, indexes []*big.Int) (Vs, Shares, error) { +func Create(ec elliptic.Curve, threshold int, secret *big.Int, indexes []*big.Int, rand io.Reader) (Vs, Shares, error) { if secret == nil || indexes == nil { return nil, nil, fmt.Errorf("vss secret or indexes == nil: %v %v", secret, indexes) } @@ -76,7 +77,7 @@ func Create(ec elliptic.Curve, threshold int, secret *big.Int, indexes []*big.In return nil, nil, ErrNumSharesBelowThreshold } - poly := samplePolynomial(ec, threshold, secret) + poly := samplePolynomial(ec, threshold, secret, rand) poly[0] = secret // becomes sigma*G in v v := make(Vs, len(poly)) for i, ai := range poly { @@ -144,12 +145,12 @@ func (shares Shares) ReConstruct(ec elliptic.Curve) (secret *big.Int, err error) return secret, nil } -func samplePolynomial(ec elliptic.Curve, threshold int, secret *big.Int) []*big.Int { +func samplePolynomial(ec elliptic.Curve, threshold int, secret *big.Int, rand io.Reader) []*big.Int { q := ec.Params().N v := make([]*big.Int, threshold+1) v[0] = secret for i := 1; i <= threshold; i++ { - ai := common.GetRandomPositiveInt(q) + ai := common.GetRandomPositiveInt(rand, q) v[i] = ai } return v diff --git a/crypto/vss/feldman_vss_test.go b/crypto/vss/feldman_vss_test.go index ef55f685..7e12fff0 100644 --- a/crypto/vss/feldman_vss_test.go +++ b/crypto/vss/feldman_vss_test.go @@ -7,6 +7,7 @@ package vss_test import ( + "crypto/rand" "math/big" "testing" @@ -20,7 +21,7 @@ import ( func TestCheckIndexesDup(t *testing.T) { indexes := make([]*big.Int, 0) for i := 0; i < 1000; i++ { - indexes = append(indexes, common.GetRandomPositiveInt(tss.EC().Params().N)) + indexes = append(indexes, common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N)) } _, e := CheckIndexes(tss.EC(), indexes) assert.NoError(t, e) @@ -33,7 +34,7 @@ func TestCheckIndexesDup(t *testing.T) { func TestCheckIndexesZero(t *testing.T) { indexes := make([]*big.Int, 0) for i := 0; i < 1000; i++ { - indexes = append(indexes, common.GetRandomPositiveInt(tss.EC().Params().N)) + indexes = append(indexes, common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N)) } _, e := CheckIndexes(tss.EC(), indexes) assert.NoError(t, e) @@ -46,14 +47,14 @@ func TestCheckIndexesZero(t *testing.T) { func TestCreate(t *testing.T) { num, threshold := 5, 3 - secret := common.GetRandomPositiveInt(tss.EC().Params().N) + secret := common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N) ids := make([]*big.Int, 0) for i := 0; i < num; i++ { - ids = append(ids, common.GetRandomPositiveInt(tss.EC().Params().N)) + ids = append(ids, common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N)) } - vs, _, err := Create(tss.EC(), threshold, secret, ids) + vs, _, err := Create(tss.EC(), threshold, secret, ids, rand.Reader) assert.Nil(t, err) assert.Equal(t, threshold+1, len(vs)) @@ -74,14 +75,14 @@ func TestCreate(t *testing.T) { func TestVerify(t *testing.T) { num, threshold := 5, 3 - secret := common.GetRandomPositiveInt(tss.EC().Params().N) + secret := common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N) ids := make([]*big.Int, 0) for i := 0; i < num; i++ { - ids = append(ids, common.GetRandomPositiveInt(tss.EC().Params().N)) + ids = append(ids, common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N)) } - vs, shares, err := Create(tss.EC(), threshold, secret, ids) + vs, shares, err := Create(tss.EC(), threshold, secret, ids, rand.Reader) assert.NoError(t, err) for i := 0; i < num; i++ { @@ -92,14 +93,14 @@ func TestVerify(t *testing.T) { func TestReconstruct(t *testing.T) { num, threshold := 5, 3 - secret := common.GetRandomPositiveInt(tss.EC().Params().N) + secret := common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N) ids := make([]*big.Int, 0) for i := 0; i < num; i++ { - ids = append(ids, common.GetRandomPositiveInt(tss.EC().Params().N)) + ids = append(ids, common.GetRandomPositiveInt(rand.Reader, tss.EC().Params().N)) } - _, shares, err := Create(tss.EC(), threshold, secret, ids) + _, shares, err := Create(tss.EC(), threshold, secret, ids, rand.Reader) assert.NoError(t, err) secret2, err2 := shares[:threshold-1].ReConstruct(tss.EC()) diff --git a/ecdsa/keygen/dln_verifier_test.go b/ecdsa/keygen/dln_verifier_test.go index 89d76bba..5b9f65aa 100644 --- a/ecdsa/keygen/dln_verifier_test.go +++ b/ecdsa/keygen/dln_verifier_test.go @@ -7,6 +7,7 @@ package keygen import ( + "crypto/rand" "math/big" "runtime" "testing" @@ -29,6 +30,7 @@ func BenchmarkDlnProof_Verify(b *testing.B) { params.P, params.Q, params.NTildei, + rand.Reader, ) b.ResetTimer() @@ -228,6 +230,7 @@ func prepareProof() (*LocalPreParams, [][]byte, error) { preParams.P, preParams.Q, preParams.NTildei, + rand.Reader, ) serialized, err := proof.Serialize() diff --git a/ecdsa/keygen/local_party.go b/ecdsa/keygen/local_party.go index 68e2b772..93d39f02 100644 --- a/ecdsa/keygen/local_party.go +++ b/ecdsa/keygen/local_party.go @@ -19,8 +19,10 @@ import ( // Implements Party // Implements Stringer -var _ tss.Party = (*LocalParty)(nil) -var _ fmt.Stringer = (*LocalParty)(nil) +var ( + _ tss.Party = (*LocalParty)(nil) + _ fmt.Stringer = (*LocalParty)(nil) +) type ( LocalParty struct { diff --git a/ecdsa/keygen/prepare.go b/ecdsa/keygen/prepare.go index cfbc127f..0f5e6603 100644 --- a/ecdsa/keygen/prepare.go +++ b/ecdsa/keygen/prepare.go @@ -8,7 +8,9 @@ package keygen import ( "context" + "crypto/rand" "errors" + "io" "math/big" "runtime" "time" @@ -43,6 +45,14 @@ func GeneratePreParams(timeout time.Duration, optionalConcurrency ...int) (*Loca // If not specified, a concurrency value equal to the number of available CPU cores will be used. // If pre-parameters could not be generated before the context is done, an error is returned. func GeneratePreParamsWithContext(ctx context.Context, optionalConcurrency ...int) (*LocalPreParams, error) { + return GeneratePreParamsWithContextAndRandom(ctx, rand.Reader, optionalConcurrency...) +} + +// GeneratePreParams finds two safe primes and computes the Paillier secret required for the protocol. +// This can be a time consuming process so it is recommended to do it out-of-band. +// If not specified, a concurrency value equal to the number of available CPU cores will be used. +// If pre-parameters could not be generated before the context is done, an error is returned. +func GeneratePreParamsWithContextAndRandom(ctx context.Context, rand io.Reader, optionalConcurrency ...int) (*LocalPreParams, error) { var concurrency int if 0 < len(optionalConcurrency) { if 1 < len(optionalConcurrency) { @@ -65,7 +75,7 @@ func GeneratePreParamsWithContext(ctx context.Context, optionalConcurrency ...in common.Logger.Info("generating the Paillier modulus, please wait...") start := time.Now() // more concurrency weight is assigned here because the paillier primes have a requirement of having "large" P-Q - PiPaillierSk, _, err := paillier.GenerateKeyPair(ctx, paillierModulusLen, concurrency*2) + PiPaillierSk, _, err := paillier.GenerateKeyPair(ctx, rand, paillierModulusLen, concurrency*2) if err != nil { ch <- nil return @@ -79,7 +89,7 @@ func GeneratePreParamsWithContext(ctx context.Context, optionalConcurrency ...in var err error common.Logger.Info("generating the safe primes for the signing proofs, please wait...") start := time.Now() - sgps, err := common.GetRandomSafePrimesConcurrent(ctx, safePrimeBitLen, 2, concurrency) + sgps, err := common.GetRandomSafePrimesConcurrent(ctx, safePrimeBitLen, 2, concurrency, rand) if err != nil { ch <- nil return @@ -126,8 +136,8 @@ consumer: p, q := sgps[0].Prime(), sgps[1].Prime() modPQ := common.ModInt(new(big.Int).Mul(p, q)) - f1 := common.GetRandomPositiveRelativelyPrimeInt(NTildei) - alpha := common.GetRandomPositiveRelativelyPrimeInt(NTildei) + f1 := common.GetRandomPositiveRelativelyPrimeInt(rand, NTildei) + alpha := common.GetRandomPositiveRelativelyPrimeInt(rand, NTildei) beta := modPQ.ModInverse(alpha) h1i := modNTildeI.Mul(f1, f1) h2i := modNTildeI.Exp(h1i, alpha) diff --git a/ecdsa/keygen/round_1.go b/ecdsa/keygen/round_1.go index 4854e1c4..93b25bfa 100644 --- a/ecdsa/keygen/round_1.go +++ b/ecdsa/keygen/round_1.go @@ -7,6 +7,7 @@ package keygen import ( + "context" "errors" "math/big" @@ -18,14 +19,13 @@ import ( "github.com/bnb-chain/tss-lib/v2/tss" ) -var ( - zero = big.NewInt(0) -) +var zero = big.NewInt(0) // round 1 represents round 1 of the keygen part of the GG18 ECDSA TSS spec (Gennaro, Goldfeder; 2018) func newRound1(params *tss.Parameters, save *LocalPartySaveData, temp *localTempData, out chan<- tss.Message, end chan<- *LocalPartySaveData) tss.Round { return &round1{ - &base{params, save, temp, out, end, make([]bool, len(params.Parties().IDs())), false, 1}} + &base{params, save, temp, out, end, make([]bool, len(params.Parties().IDs())), false, 1}, + } } func (round *round1) Start() *tss.Error { @@ -40,13 +40,13 @@ func (round *round1) Start() *tss.Error { i := Pi.Index // 1. calculate "partial" key share ui - ui := common.GetRandomPositiveInt(round.Params().EC().Params().N) + ui := common.GetRandomPositiveInt(round.PartialKeyRand(), round.EC().Params().N) round.temp.ui = ui // 2. compute the vss shares ids := round.Parties().IDs().Keys() - vs, shares, err := vss.Create(round.Params().EC(), round.Threshold(), ui, ids) + vs, shares, err := vss.Create(round.EC(), round.Threshold(), ui, ids, round.Rand()) if err != nil { return round.WrapError(err, Pi) } @@ -61,7 +61,7 @@ func (round *round1) Start() *tss.Error { if err != nil { return round.WrapError(err, Pi) } - cmt := cmts.NewHashCommitment(pGFlat...) + cmt := cmts.NewHashCommitment(round.Rand(), pGFlat...) // 4. generate Paillier public key E_i, private key and proof // 5-7. generate safe primes for ZKPs used later on @@ -74,9 +74,13 @@ func (round *round1) Start() *tss.Error { } else if round.save.LocalPreParams.ValidateWithProof() { preParams = &round.save.LocalPreParams } else { - preParams, err = GeneratePreParams(round.SafePrimeGenTimeout(), round.Concurrency()) - if err != nil { - return round.WrapError(errors.New("pre-params generation failed"), Pi) + { + ctx, cancel := context.WithTimeout(context.Background(), round.SafePrimeGenTimeout()) + defer cancel() + preParams, err = GeneratePreParamsWithContextAndRandom(ctx, round.Rand(), round.Concurrency()) + if err != nil { + return round.WrapError(errors.New("pre-params generation failed"), Pi) + } } } round.save.LocalPreParams = *preParams @@ -84,16 +88,15 @@ func (round *round1) Start() *tss.Error { round.save.H1j[i], round.save.H2j[i] = preParams.H1i, preParams.H2i // generate the dlnproofs for keygen - h1i, h2i, alpha, beta, p, q, NTildei := - preParams.H1i, + h1i, h2i, alpha, beta, p, q, NTildei := preParams.H1i, preParams.H2i, preParams.Alpha, preParams.Beta, preParams.P, preParams.Q, preParams.NTildei - dlnProof1 := dlnproof.NewDLNProof(h1i, h2i, alpha, p, q, NTildei) - dlnProof2 := dlnproof.NewDLNProof(h2i, h1i, beta, p, q, NTildei) + dlnProof1 := dlnproof.NewDLNProof(h1i, h2i, alpha, p, q, NTildei, round.Rand()) + dlnProof2 := dlnproof.NewDLNProof(h2i, h1i, beta, p, q, NTildei, round.Rand()) // for this P: SAVE // - shareID diff --git a/ecdsa/keygen/round_2.go b/ecdsa/keygen/round_2.go index 3eac8eba..bb6c81b8 100644 --- a/ecdsa/keygen/round_2.go +++ b/ecdsa/keygen/round_2.go @@ -47,8 +47,7 @@ func (round *round2) Start() *tss.Error { wg := new(sync.WaitGroup) for j, msg := range round.temp.kgRound1Messages { r1msg := msg.Content().(*KGRound1Message) - H1j, H2j, NTildej, paillierPKj := - r1msg.UnmarshalH1(), + H1j, H2j, NTildej, paillierPKj := r1msg.UnmarshalH1(), r1msg.UnmarshalH2(), r1msg.UnmarshalNTilde(), r1msg.UnmarshalPaillierPK() @@ -99,8 +98,7 @@ func (round *round2) Start() *tss.Error { continue } r1msg := msg.Content().(*KGRound1Message) - paillierPK, H1j, H2j, NTildej, KGC := - r1msg.UnmarshalPaillierPK(), + paillierPK, H1j, H2j, NTildej, KGC := r1msg.UnmarshalPaillierPK(), r1msg.UnmarshalH1(), r1msg.UnmarshalH2(), r1msg.UnmarshalNTilde(), @@ -116,12 +114,14 @@ func (round *round2) Start() *tss.Error { ContextI := append(round.temp.ssid, big.NewInt(int64(i)).Bytes()...) for j, Pj := range round.Parties().IDs() { - facProof := &facproof.ProofFac{P: zero, Q: zero, A: zero, B: zero, T: zero, Sigma: zero, - Z1: zero, Z2: zero, W1: zero, W2: zero, V: zero} + facProof := &facproof.ProofFac{ + P: zero, Q: zero, A: zero, B: zero, T: zero, Sigma: zero, + Z1: zero, Z2: zero, W1: zero, W2: zero, V: zero, + } if !round.Params().NoProofFac() { var err error facProof, err = facproof.NewProof(ContextI, round.EC(), round.save.PaillierSK.N, round.save.NTildej[j], - round.save.H1j[j], round.save.H2j[j], round.save.PaillierSK.P, round.save.PaillierSK.Q) + round.save.H1j[j], round.save.H2j[j], round.save.PaillierSK.P, round.save.PaillierSK.Q, round.Rand()) if err != nil { return round.WrapError(err, round.PartyID()) } @@ -141,7 +141,7 @@ func (round *round2) Start() *tss.Error { if !round.Parameters.NoProofMod() { var err error modProof, err = modproof.NewProof(ContextI, round.save.PaillierSK.N, - round.save.PaillierSK.P, round.save.PaillierSK.Q) + round.save.PaillierSK.P, round.save.PaillierSK.Q, round.Rand()) if err != nil { return round.WrapError(err, round.PartyID()) } diff --git a/ecdsa/resharing/round_1_old_step_1.go b/ecdsa/resharing/round_1_old_step_1.go index be8a2ec8..167365cf 100644 --- a/ecdsa/resharing/round_1_old_step_1.go +++ b/ecdsa/resharing/round_1_old_step_1.go @@ -22,7 +22,8 @@ import ( // round 1 represents round 1 of the keygen part of the GG18 ECDSA TSS spec (Gennaro, Goldfeder; 2018) func newRound1(params *tss.ReSharingParameters, input, save *keygen.LocalPartySaveData, temp *localTempData, out chan<- tss.Message, end chan<- *keygen.LocalPartySaveData) tss.Round { return &round1{ - &base{params, temp, input, save, out, end, make([]bool, len(params.OldParties().IDs())), make([]bool, len(params.NewParties().IDs())), false, 1}} + &base{params, temp, input, save, out, end, make([]bool, len(params.OldParties().IDs())), make([]bool, len(params.NewParties().IDs())), false, 1}, + } } func (round *round1) Start() *tss.Error { @@ -57,7 +58,7 @@ func (round *round1) Start() *tss.Error { wi, _ := signing.PrepareForSigning(round.Params().EC(), i, len(round.OldParties().IDs()), xi, ks, bigXj) // 2. - vi, shares, err := vss.Create(round.Params().EC(), round.NewThreshold(), wi, newKs) + vi, shares, err := vss.Create(round.Params().EC(), round.NewThreshold(), wi, newKs, round.Rand()) if err != nil { return round.WrapError(err, round.PartyID()) } @@ -67,7 +68,7 @@ func (round *round1) Start() *tss.Error { if err != nil { return round.WrapError(err, round.PartyID()) } - vCmt := commitments.NewHashCommitment(flatVis...) + vCmt := commitments.NewHashCommitment(round.Rand(), flatVis...) // 4. populate temp data round.temp.VD = vCmt.D diff --git a/ecdsa/resharing/round_2_new_step_1.go b/ecdsa/resharing/round_2_new_step_1.go index 8a47283a..2ed876fe 100644 --- a/ecdsa/resharing/round_2_new_step_1.go +++ b/ecdsa/resharing/round_2_new_step_1.go @@ -18,9 +18,7 @@ import ( "github.com/bnb-chain/tss-lib/v2/tss" ) -var ( - zero = big.NewInt(0) -) +var zero = big.NewInt(0) func (round *round2) Start() *tss.Error { if round.started { @@ -82,22 +80,21 @@ func (round *round2) Start() *tss.Error { round.save.H1j[i], round.save.H2j[i] = preParams.H1i, preParams.H2i // generate the dlnproofs for resharing - h1i, h2i, alpha, beta, p, q, NTildei := - preParams.H1i, + h1i, h2i, alpha, beta, p, q, NTildei := preParams.H1i, preParams.H2i, preParams.Alpha, preParams.Beta, preParams.P, preParams.Q, preParams.NTildei - dlnProof1 := dlnproof.NewDLNProof(h1i, h2i, alpha, p, q, NTildei) - dlnProof2 := dlnproof.NewDLNProof(h2i, h1i, beta, p, q, NTildei) + dlnProof1 := dlnproof.NewDLNProof(h1i, h2i, alpha, p, q, NTildei, round.Rand()) + dlnProof2 := dlnproof.NewDLNProof(h2i, h1i, beta, p, q, NTildei, round.Rand()) modProof := &modproof.ProofMod{W: zero, X: *new([80]*big.Int), A: zero, B: zero, Z: *new([80]*big.Int)} ContextI := append(round.temp.ssid, big.NewInt(int64(i)).Bytes()...) if !round.Parameters.NoProofMod() { var err error - modProof, err = modproof.NewProof(ContextI, preParams.PaillierSK.N, preParams.PaillierSK.P, preParams.PaillierSK.Q) + modProof, err = modproof.NewProof(ContextI, preParams.PaillierSK.N, preParams.PaillierSK.P, preParams.PaillierSK.Q, round.Rand()) if err != nil { return round.WrapError(err, Pi) } diff --git a/ecdsa/resharing/round_4_new_step_2.go b/ecdsa/resharing/round_4_new_step_2.go index f3115c59..5453bc62 100644 --- a/ecdsa/resharing/round_4_new_step_2.go +++ b/ecdsa/resharing/round_4_new_step_2.go @@ -58,8 +58,7 @@ func (round *round4) Start() *tss.Error { wg := new(sync.WaitGroup) for j, msg := range round.temp.dgRound2Message1s { r2msg1 := msg.Content().(*DGRound2Message1) - paiPK, NTildej, H1j, H2j := - r2msg1.UnmarshalPaillierPK(), + paiPK, NTildej, H1j, H2j := r2msg1.UnmarshalPaillierPK(), r2msg1.UnmarshalNTilde(), r2msg1.UnmarshalH1(), r2msg1.UnmarshalH2() @@ -218,11 +217,13 @@ func (round *round4) Start() *tss.Error { continue } ContextJ := common.AppendBigIntToBytesSlice(round.temp.ssid, big.NewInt(int64(j))) - facProof := &facproof.ProofFac{P: zero, Q: zero, A: zero, B: zero, T: zero, Sigma: zero, - Z1: zero, Z2: zero, W1: zero, W2: zero, V: zero} + facProof := &facproof.ProofFac{ + P: zero, Q: zero, A: zero, B: zero, T: zero, Sigma: zero, + Z1: zero, Z2: zero, W1: zero, W2: zero, V: zero, + } if !round.Parameters.NoProofFac() { facProof, err = facproof.NewProof(ContextJ, round.EC(), round.save.PaillierSK.N, round.save.NTildej[j], - round.save.H1j[j], round.save.H2j[j], round.save.PaillierSK.P, round.save.PaillierSK.Q) + round.save.H1j[j], round.save.H2j[j], round.save.PaillierSK.P, round.save.PaillierSK.Q, round.Rand()) if err != nil { return round.WrapError(err, Pi) } diff --git a/ecdsa/signing/local_party.go b/ecdsa/signing/local_party.go index ae34ad00..3a9bceeb 100644 --- a/ecdsa/signing/local_party.go +++ b/ecdsa/signing/local_party.go @@ -21,8 +21,10 @@ import ( // Implements Party // Implements Stringer -var _ tss.Party = (*LocalParty)(nil) -var _ fmt.Stringer = (*LocalParty)(nil) +var ( + _ tss.Party = (*LocalParty)(nil) + _ fmt.Stringer = (*LocalParty)(nil) +) type ( LocalParty struct { @@ -102,7 +104,8 @@ func NewLocalParty( params *tss.Parameters, key keygen.LocalPartySaveData, out chan<- tss.Message, - end chan<- *common.SignatureData) tss.Party { + end chan<- *common.SignatureData, +) tss.Party { return NewLocalPartyWithKDD(msg, params, key, nil, out, end) } diff --git a/ecdsa/signing/local_party_test.go b/ecdsa/signing/local_party_test.go index a1680b5f..bb07bd4c 100644 --- a/ecdsa/signing/local_party_test.go +++ b/ecdsa/signing/local_party_test.go @@ -8,6 +8,7 @@ package signing import ( "crypto/ecdsa" + "crypto/rand" "fmt" "math/big" "runtime" @@ -144,7 +145,7 @@ func TestE2EWithHDKeyDerivation(t *testing.T) { chainCode := make([]byte, 32) max32b := new(big.Int).Lsh(new(big.Int).SetUint64(1), 256) max32b = new(big.Int).Sub(max32b, new(big.Int).SetUint64(1)) - fillBytes(common.GetRandomPositiveInt(max32b), chainCode) + fillBytes(common.GetRandomPositiveInt(rand.Reader, max32b), chainCode) il, extendedChildPk, errorDerivation := derivingPubkeyFromPath(keys[0].ECDSAPub, chainCode, []uint32{12, 209, 3}, btcec.S256()) assert.NoErrorf(t, errorDerivation, "there should not be an error deriving the child public key") diff --git a/ecdsa/signing/round_1.go b/ecdsa/signing/round_1.go index 924b080f..7089848f 100644 --- a/ecdsa/signing/round_1.go +++ b/ecdsa/signing/round_1.go @@ -19,14 +19,13 @@ import ( "github.com/bnb-chain/tss-lib/v2/tss" ) -var ( - zero = big.NewInt(0) -) +var zero = big.NewInt(0) // round 1 represents round 1 of the signing part of the GG18 ECDSA TSS spec (Gennaro, Goldfeder; 2018) func newRound1(params *tss.Parameters, key *keygen.LocalPartySaveData, data *common.SignatureData, temp *localTempData, out chan<- tss.Message, end chan<- *common.SignatureData) tss.Round { return &round1{ - &base{params, key, data, temp, out, end, make([]bool, len(params.Parties().IDs())), false, 1}} + &base{params, key, data, temp, out, end, make([]bool, len(params.Parties().IDs())), false, 1}, + } } func (round *round1) Start() *tss.Error { @@ -52,11 +51,11 @@ func (round *round1) Start() *tss.Error { } round.temp.ssid = ssid - k := common.GetRandomPositiveInt(round.Params().EC().Params().N) - gamma := common.GetRandomPositiveInt(round.Params().EC().Params().N) + k := common.GetRandomPositiveInt(round.Rand(), round.EC().Params().N) + gamma := common.GetRandomPositiveInt(round.Rand(), round.EC().Params().N) pointGamma := crypto.ScalarBaseMult(round.Params().EC(), gamma) - cmt := commitments.NewHashCommitment(pointGamma.X(), pointGamma.Y()) + cmt := commitments.NewHashCommitment(round.Rand(), pointGamma.X(), pointGamma.Y()) round.temp.k = k round.temp.gamma = gamma round.temp.pointGamma = pointGamma @@ -69,7 +68,7 @@ func (round *round1) Start() *tss.Error { if j == i { continue } - cA, pi, err := mta.AliceInit(round.Params().EC(), round.key.PaillierPKs[i], k, round.key.NTildej[j], round.key.H1j[j], round.key.H2j[j]) + cA, pi, err := mta.AliceInit(round.Params().EC(), round.key.PaillierPKs[i], k, round.key.NTildej[j], round.key.H1j[j], round.key.H2j[j], round.Rand()) if err != nil { return round.WrapError(fmt.Errorf("failed to init mta: %v", err)) } diff --git a/ecdsa/signing/round_2.go b/ecdsa/signing/round_2.go index 0568f3dd..003359df 100644 --- a/ecdsa/signing/round_2.go +++ b/ecdsa/signing/round_2.go @@ -57,7 +57,9 @@ func (round *round2) Start() *tss.Error { round.key.H2j[j], round.key.NTildej[i], round.key.H1j[i], - round.key.H2j[i]) + round.key.H2j[i], + round.Rand(), + ) // should be thread safe as these are pre-allocated round.temp.betas[j] = beta round.temp.c1jis[j] = c1ji @@ -88,7 +90,9 @@ func (round *round2) Start() *tss.Error { round.key.NTildej[i], round.key.H1j[i], round.key.H2j[i], - round.temp.bigWs[i]) + round.temp.bigWs[i], + round.Rand(), + ) round.temp.vs[j] = v round.temp.c2jis[j] = c2ji round.temp.pi2jis[j] = pi2ji diff --git a/ecdsa/signing/round_4.go b/ecdsa/signing/round_4.go index e7183690..8ee0cff1 100644 --- a/ecdsa/signing/round_4.go +++ b/ecdsa/signing/round_4.go @@ -43,7 +43,7 @@ func (round *round4) Start() *tss.Error { thetaInverse = modN.ModInverse(thetaInverse) i := round.PartyID().Index ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) - piGamma, err := schnorr.NewZKProof(ContextI, round.temp.gamma, round.temp.pointGamma) + piGamma, err := schnorr.NewZKProof(ContextI, round.temp.gamma, round.temp.pointGamma, round.Rand()) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKProof(gamma, bigGamma)")) } diff --git a/ecdsa/signing/round_5.go b/ecdsa/signing/round_5.go index 3cbe19c5..01b5a726 100644 --- a/ecdsa/signing/round_5.go +++ b/ecdsa/signing/round_5.go @@ -69,8 +69,8 @@ func (round *round5) Start() *tss.Error { round.temp.w = zero round.temp.k = zero - li := common.GetRandomPositiveInt(N) // li - roI := common.GetRandomPositiveInt(N) // pi + li := common.GetRandomPositiveInt(round.Rand(), N) // li + roI := common.GetRandomPositiveInt(round.Rand(), N) // pi rToSi := R.ScalarMult(si) liPoint := crypto.ScalarBaseMult(round.Params().EC(), li) bigAi := crypto.ScalarBaseMult(round.Params().EC(), roI) @@ -79,7 +79,7 @@ func (round *round5) Start() *tss.Error { return round.WrapError(errors2.Wrapf(err, "rToSi.Add(li)")) } - cmt := commitments.NewHashCommitment(bigVi.X(), bigVi.Y(), bigAi.X(), bigAi.Y()) + cmt := commitments.NewHashCommitment(round.Rand(), bigVi.X(), bigVi.Y(), bigAi.X(), bigAi.Y()) r5msg := NewSignRound5Message(round.PartyID(), cmt.C) round.temp.signRound5Messages[round.PartyID().Index] = r5msg round.out <- r5msg diff --git a/ecdsa/signing/round_6.go b/ecdsa/signing/round_6.go index 278d19be..228be118 100644 --- a/ecdsa/signing/round_6.go +++ b/ecdsa/signing/round_6.go @@ -26,11 +26,11 @@ func (round *round6) Start() *tss.Error { i := round.PartyID().Index ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) - piAi, err := schnorr.NewZKProof(ContextI, round.temp.roi, round.temp.bigAi) + piAi, err := schnorr.NewZKProof(ContextI, round.temp.roi, round.temp.bigAi, round.Rand()) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKProof(roi, bigAi)")) } - piV, err := schnorr.NewZKVProof(ContextI, round.temp.bigVi, round.temp.bigR, round.temp.si, round.temp.li) + piV, err := schnorr.NewZKVProof(ContextI, round.temp.bigVi, round.temp.bigR, round.temp.si, round.temp.li, round.Rand()) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKVProof(bigVi, bigR, si, li)")) } diff --git a/ecdsa/signing/round_7.go b/ecdsa/signing/round_7.go index a50be65d..072bbf54 100644 --- a/ecdsa/signing/round_7.go +++ b/ecdsa/signing/round_7.go @@ -83,7 +83,7 @@ func (round *round7) Start() *tss.Error { TiX, TiY := round.Params().EC().ScalarMult(AX, AY, round.temp.li.Bytes()) round.temp.Ui = crypto.NewECPointNoCurveCheck(round.Params().EC(), UiX, UiY) round.temp.Ti = crypto.NewECPointNoCurveCheck(round.Params().EC(), TiX, TiY) - cmt := commitments.NewHashCommitment(UiX, UiY, TiX, TiY) + cmt := commitments.NewHashCommitment(round.Rand(), UiX, UiY, TiX, TiY) r7msg := NewSignRound7Message(round.PartyID(), cmt.C) round.temp.signRound7Messages[round.PartyID().Index] = r7msg round.out <- r7msg diff --git a/eddsa/keygen/round_1.go b/eddsa/keygen/round_1.go index 9fc694e0..f8679596 100644 --- a/eddsa/keygen/round_1.go +++ b/eddsa/keygen/round_1.go @@ -17,14 +17,13 @@ import ( "github.com/bnb-chain/tss-lib/v2/tss" ) -var ( - zero = big.NewInt(0) -) +var zero = big.NewInt(0) // round 1 represents round 1 of the keygen part of the EDDSA TSS spec func newRound1(params *tss.Parameters, save *LocalPartySaveData, temp *localTempData, out chan<- tss.Message, end chan<- *LocalPartySaveData) tss.Round { return &round1{ - &base{params, save, temp, out, end, make([]bool, len(params.Parties().IDs())), false, 1}} + &base{params, save, temp, out, end, make([]bool, len(params.Parties().IDs())), false, 1}, + } } func (round *round1) Start() *tss.Error { @@ -46,12 +45,12 @@ func (round *round1) Start() *tss.Error { round.temp.ssid = ssid // 1. calculate "partial" key share ui - ui := common.GetRandomPositiveInt(round.Params().EC().Params().N) + ui := common.GetRandomPositiveInt(round.PartialKeyRand(), round.Params().EC().Params().N) round.temp.ui = ui // 2. compute the vss shares ids := round.Parties().IDs().Keys() - vs, shares, err := vss.Create(round.Params().EC(), round.Threshold(), ui, ids) + vs, shares, err := vss.Create(round.EC(), round.Threshold(), ui, ids, round.Rand()) if err != nil { return round.WrapError(err, Pi) } @@ -66,7 +65,7 @@ func (round *round1) Start() *tss.Error { if err != nil { return round.WrapError(err, Pi) } - cmt := cmts.NewHashCommitment(pGFlat...) + cmt := cmts.NewHashCommitment(round.Rand(), pGFlat...) // for this P: SAVE // - shareID diff --git a/eddsa/keygen/round_2.go b/eddsa/keygen/round_2.go index d1a603b4..0e5ea25e 100644 --- a/eddsa/keygen/round_2.go +++ b/eddsa/keygen/round_2.go @@ -47,7 +47,7 @@ func (round *round2) Start() *tss.Error { // 5. compute Schnorr prove ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) - pii, err := schnorr.NewZKProof(ContextI, round.temp.ui, round.temp.vs[0]) + pii, err := schnorr.NewZKProof(ContextI, round.temp.ui, round.temp.vs[0], round.Rand()) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKProof(ui, vi0)")) } diff --git a/eddsa/resharing/round_1_old_step_1.go b/eddsa/resharing/round_1_old_step_1.go index 9f997d58..d717bfd2 100644 --- a/eddsa/resharing/round_1_old_step_1.go +++ b/eddsa/resharing/round_1_old_step_1.go @@ -21,7 +21,8 @@ import ( // round 1 represents round 1 of the keygen part of the EDDSA TSS spec func newRound1(params *tss.ReSharingParameters, input, save *keygen.LocalPartySaveData, temp *localTempData, out chan<- tss.Message, end chan<- *keygen.LocalPartySaveData) tss.Round { return &round1{ - &base{params, temp, input, save, out, end, make([]bool, len(params.OldParties().IDs())), make([]bool, len(params.NewParties().IDs())), false, 1}} + &base{params, temp, input, save, out, end, make([]bool, len(params.OldParties().IDs())), make([]bool, len(params.NewParties().IDs())), false, 1}, + } } func (round *round1) Start() *tss.Error { @@ -50,7 +51,7 @@ func (round *round1) Start() *tss.Error { wi := signing.PrepareForSigning(round.Params().EC(), i, len(round.OldParties().IDs()), xi, ks) // 2. - vi, shares, err := vss.Create(round.Params().EC(), round.NewThreshold(), wi, newKs) + vi, shares, err := vss.Create(round.Params().EC(), round.NewThreshold(), wi, newKs, round.Rand()) if err != nil { return round.WrapError(err, round.PartyID()) } @@ -60,7 +61,7 @@ func (round *round1) Start() *tss.Error { if err != nil { return round.WrapError(err, round.PartyID()) } - vCmt := commitments.NewHashCommitment(flatVis...) + vCmt := commitments.NewHashCommitment(round.Rand(), flatVis...) // 4. populate temp data round.temp.VD = vCmt.D diff --git a/eddsa/signing/round_1.go b/eddsa/signing/round_1.go index 629d0b4d..82e434fd 100644 --- a/eddsa/signing/round_1.go +++ b/eddsa/signing/round_1.go @@ -21,7 +21,8 @@ import ( // round 1 represents round 1 of the signing part of the EDDSA TSS spec func newRound1(params *tss.Parameters, key *keygen.LocalPartySaveData, data *common.SignatureData, temp *localTempData, out chan<- tss.Message, end chan<- *common.SignatureData) tss.Round { return &round1{ - &base{params, key, data, temp, out, end, make([]bool, len(params.Parties().IDs())), false, 1}} + &base{params, key, data, temp, out, end, make([]bool, len(params.Parties().IDs())), false, 1}, + } } func (round *round1) Start() *tss.Error { @@ -40,11 +41,11 @@ func (round *round1) Start() *tss.Error { return round.WrapError(err) } // 1. select ri - ri := common.GetRandomPositiveInt(round.Params().EC().Params().N) + ri := common.GetRandomPositiveInt(round.Rand(), round.Params().EC().Params().N) // 2. make commitment pointRi := crypto.ScalarBaseMult(round.Params().EC(), ri) - cmt := commitments.NewHashCommitment(pointRi.X(), pointRi.Y()) + cmt := commitments.NewHashCommitment(round.Rand(), pointRi.X(), pointRi.Y()) // 3. store r1 message pieces round.temp.ri = ri diff --git a/eddsa/signing/round_2.go b/eddsa/signing/round_2.go index 82f6c12a..2f1802f5 100644 --- a/eddsa/signing/round_2.go +++ b/eddsa/signing/round_2.go @@ -34,7 +34,7 @@ func (round *round2) Start() *tss.Error { // 2. compute Schnorr prove ContextI := append(round.temp.ssid, new(big.Int).SetUint64(uint64(i)).Bytes()...) - pir, err := schnorr.NewZKProof(ContextI, round.temp.ri, round.temp.pointRi) + pir, err := schnorr.NewZKProof(ContextI, round.temp.ri, round.temp.pointRi, round.Rand()) if err != nil { return round.WrapError(errors2.Wrapf(err, "NewZKProof(ri, pointRi)")) } diff --git a/eddsa/signing/round_3.go b/eddsa/signing/round_3.go index 715c3a73..e63d3dba 100644 --- a/eddsa/signing/round_3.go +++ b/eddsa/signing/round_3.go @@ -66,7 +66,7 @@ func (round *round3) Start() *tss.Error { return round.WrapError(errors.New("failed to prove Rj"), Pj) } - extendedRj := ecPointToExtendedElement(round.Params().EC(), Rj.X(), Rj.Y()) + extendedRj := ecPointToExtendedElement(round.Params().EC(), Rj.X(), Rj.Y(), round.Rand()) R = addExtendedElements(R, extendedRj) } diff --git a/eddsa/signing/utils.go b/eddsa/signing/utils.go index d3b41f8f..1aa44769 100644 --- a/eddsa/signing/utils.go +++ b/eddsa/signing/utils.go @@ -8,6 +8,7 @@ package signing import ( "crypto/elliptic" + "io" "math/big" "github.com/agl/ed25519/edwards25519" @@ -100,11 +101,11 @@ func addExtendedElements(p, q edwards25519.ExtendedGroupElement) edwards25519.Ex return result } -func ecPointToExtendedElement(ec elliptic.Curve, x *big.Int, y *big.Int) edwards25519.ExtendedGroupElement { +func ecPointToExtendedElement(ec elliptic.Curve, x *big.Int, y *big.Int, rand io.Reader) edwards25519.ExtendedGroupElement { encodedXBytes := bigIntToEncodedBytes(x) encodedYBytes := bigIntToEncodedBytes(y) - z := common.GetRandomPositiveInt(ec.Params().N) + z := common.GetRandomPositiveInt(rand, ec.Params().N) encodedZBytes := bigIntToEncodedBytes(z) var fx, fy, fxy edwards25519.FieldElement diff --git a/tss/params.go b/tss/params.go index 114f4848..ee17759f 100644 --- a/tss/params.go +++ b/tss/params.go @@ -8,6 +8,8 @@ package tss import ( "crypto/elliptic" + "crypto/rand" + "io" "runtime" "time" ) @@ -26,6 +28,8 @@ type ( // for keygen noProofMod bool noProofFac bool + // random sources + partialKeyRand, rand io.Reader } ReSharingParameters struct { @@ -50,6 +54,8 @@ func NewParameters(ec elliptic.Curve, ctx *PeerContext, partyID *PartyID, partyC threshold: threshold, concurrency: runtime.GOMAXPROCS(0), safePrimeGenTimeout: defaultSafePrimeGenTimeout, + partialKeyRand: rand.Reader, + rand: rand.Reader, } } @@ -106,6 +112,22 @@ func (params *Parameters) SetNoProofFac() { params.noProofFac = true } +func (params *Parameters) PartialKeyRand() io.Reader { + return params.partialKeyRand +} + +func (params *Parameters) Rand() io.Reader { + return params.rand +} + +func (params *Parameters) SetPartialKeyRand(rand io.Reader) { + params.partialKeyRand = rand +} + +func (params *Parameters) SetRand(rand io.Reader) { + params.rand = rand +} + // ----- // // Exported, used in `tss` client diff --git a/tss/party_id.go b/tss/party_id.go index e21c9310..08ec9a5a 100644 --- a/tss/party_id.go +++ b/tss/party_id.go @@ -7,6 +7,7 @@ package tss import ( + "crypto/rand" "fmt" "math/big" "sort" @@ -80,7 +81,7 @@ func SortPartyIDs(ids UnSortedPartyIDs, startAt ...int) SortedPartyIDs { // GenerateTestPartyIDs generates a list of mock PartyIDs for tests func GenerateTestPartyIDs(count int, startAt ...int) SortedPartyIDs { ids := make(UnSortedPartyIDs, 0, count) - key := common.MustGetRandomInt(256) + key := common.MustGetRandomInt(rand.Reader, 256) frm := 0 i := 0 // default `i` if len(startAt) > 0 {