diff --git a/ecdsa/signing/finalize.go b/ecdsa/signing/finalize.go index 995702fe..91f65c44 100644 --- a/ecdsa/signing/finalize.go +++ b/ecdsa/signing/finalize.go @@ -61,14 +61,21 @@ func (round *finalization) Start() *tss.Error { round.data.S = padToLengthBytesInPlace(sumS.Bytes(), bitSizeInBytes) round.data.Signature = append(round.data.R, round.data.S...) round.data.SignatureRecovery = []byte{byte(recid)} - round.data.M = round.temp.m.Bytes() + if round.temp.fullBytesLen == 0 { + round.data.M = round.temp.m.Bytes() + } else { + var mBytes = make([]byte, round.temp.fullBytesLen) + round.temp.m.FillBytes(mBytes) + round.data.M = mBytes + } pk := ecdsa.PublicKey{ Curve: round.Params().EC(), X: round.key.ECDSAPub.X(), Y: round.key.ECDSAPub.Y(), } - ok := ecdsa.Verify(&pk, round.temp.m.Bytes(), round.temp.rx, sumS) + + ok := ecdsa.Verify(&pk, round.data.M, round.temp.rx, sumS) if !ok { return round.WrapError(fmt.Errorf("signature verification failed")) } diff --git a/ecdsa/signing/local_party.go b/ecdsa/signing/local_party.go index 3a9bceeb..5d343055 100644 --- a/ecdsa/signing/local_party.go +++ b/ecdsa/signing/local_party.go @@ -65,10 +65,11 @@ type ( sigma, keyDerivationDelta, gamma *big.Int - cis []*big.Int - bigWs []*crypto.ECPoint - pointGamma *crypto.ECPoint - deCommit cmt.HashDeCommitment + fullBytesLen int + cis []*big.Int + bigWs []*crypto.ECPoint + pointGamma *crypto.ECPoint + deCommit cmt.HashDeCommitment // round 2 betas, // return value of Bob_mid @@ -105,8 +106,8 @@ func NewLocalParty( key keygen.LocalPartySaveData, out chan<- tss.Message, end chan<- *common.SignatureData, -) tss.Party { - return NewLocalPartyWithKDD(msg, params, key, nil, out, end) + fullBytesLen ...int) tss.Party { + return NewLocalPartyWithKDD(msg, params, key, nil, out, end, fullBytesLen...) } // NewLocalPartyWithKDD returns a party with key derivation delta for HD support @@ -117,6 +118,7 @@ func NewLocalPartyWithKDD( keyDerivationDelta *big.Int, out chan<- tss.Message, end chan<- *common.SignatureData, + fullBytesLen ...int, ) tss.Party { partyCount := len(params.Parties().IDs()) p := &LocalParty{ @@ -142,6 +144,11 @@ func NewLocalPartyWithKDD( // temp data init p.temp.keyDerivationDelta = keyDerivationDelta p.temp.m = msg + if len(fullBytesLen) > 0 { + p.temp.fullBytesLen = fullBytesLen[0] + } else { + p.temp.fullBytesLen = 0 + } p.temp.cis = make([]*big.Int, partyCount) p.temp.bigWs = make([]*crypto.ECPoint, partyCount) p.temp.betas = make([]*big.Int, partyCount) diff --git a/ecdsa/signing/local_party_test.go b/ecdsa/signing/local_party_test.go index bb07bd4c..33f41dc5 100644 --- a/ecdsa/signing/local_party_test.go +++ b/ecdsa/signing/local_party_test.go @@ -9,6 +9,7 @@ package signing import ( "crypto/ecdsa" "crypto/rand" + "encoding/hex" "fmt" "math/big" "runtime" @@ -56,11 +57,9 @@ func TestE2EConcurrent(t *testing.T) { endCh := make(chan *common.SignatureData, len(signPIDs)) updater := test.SharedPartyUpdater - // init the parties for i := 0; i < len(signPIDs); i++ { params := tss.NewParameters(tss.S256(), p2pCtx, signPIDs[i], len(signPIDs), threshold) - P := NewLocalParty(big.NewInt(42), params, keys[i], outCh, endCh).(*LocalParty) parties = append(parties, P) go func(P *LocalParty) { @@ -132,6 +131,101 @@ signing: } } +func TestE2EConcurrentWithLeadingZeroInMSG(t *testing.T) { + setUp("info") + threshold := testThreshold + + // PHASE: load keygen fixtures + keys, signPIDs, err := keygen.LoadKeygenTestFixturesRandomSet(testThreshold+1, testParticipants) + assert.NoError(t, err, "should load keygen fixtures") + assert.Equal(t, testThreshold+1, len(keys)) + assert.Equal(t, testThreshold+1, len(signPIDs)) + + // PHASE: signing + // use a shuffled selection of the list of parties for this test + p2pCtx := tss.NewPeerContext(signPIDs) + parties := make([]*LocalParty, 0, len(signPIDs)) + + errCh := make(chan *tss.Error, len(signPIDs)) + outCh := make(chan tss.Message, len(signPIDs)) + endCh := make(chan *common.SignatureData, len(signPIDs)) + + updater := test.SharedPartyUpdater + msgData, _ := hex.DecodeString("00f163ee51bcaeff9cdff5e0e3c1a646abd19885fffbab0b3b4236e0cf95c9f5") + // init the parties + for i := 0; i < len(signPIDs); i++ { + params := tss.NewParameters(tss.S256(), p2pCtx, signPIDs[i], len(signPIDs), threshold) + P := NewLocalParty(new(big.Int).SetBytes(msgData), params, keys[i], outCh, endCh, len(msgData)).(*LocalParty) + parties = append(parties, P) + go func(P *LocalParty) { + if err := P.Start(); err != nil { + errCh <- err + } + }(P) + } + + var ended int32 +signing: + for { + fmt.Printf("ACTIVE GOROUTINES: %d\n", runtime.NumGoroutine()) + select { + case err := <-errCh: + common.Logger.Errorf("Error: %s", err) + assert.FailNow(t, err.Error()) + break signing + + case msg := <-outCh: + dest := msg.GetTo() + if dest == nil { + for _, P := range parties { + if P.PartyID().Index == msg.GetFrom().Index { + continue + } + go updater(P, msg, errCh) + } + } else { + if dest[0].Index == msg.GetFrom().Index { + t.Fatalf("party %d tried to send a message to itself (%d)", dest[0].Index, msg.GetFrom().Index) + } + go updater(parties[dest[0].Index], msg, errCh) + } + + case <-endCh: + atomic.AddInt32(&ended, 1) + if atomic.LoadInt32(&ended) == int32(len(signPIDs)) { + t.Logf("Done. Received signature data from %d participants", ended) + R := parties[0].temp.bigR + r := parties[0].temp.rx + fmt.Printf("sign result: R(%s, %s), r=%s\n", R.X().String(), R.Y().String(), r.String()) + + modN := common.ModInt(tss.S256().Params().N) + + // BEGIN check s correctness + sumS := big.NewInt(0) + for _, p := range parties { + sumS = modN.Add(sumS, p.temp.si) + } + fmt.Printf("S: %s\n", sumS.String()) + // END check s correctness + + // BEGIN ECDSA verify + pkX, pkY := keys[0].ECDSAPub.X(), keys[0].ECDSAPub.Y() + pk := ecdsa.PublicKey{ + Curve: tss.EC(), + X: pkX, + Y: pkY, + } + ok := ecdsa.Verify(&pk, msgData, R.X(), sumS) + assert.True(t, ok, "ecdsa verify must pass") + t.Log("ECDSA signing test done.") + // END ECDSA verify + + break signing + } + } + } +} + func TestE2EWithHDKeyDerivation(t *testing.T) { setUp("info") threshold := testThreshold @@ -170,7 +264,7 @@ func TestE2EWithHDKeyDerivation(t *testing.T) { for i := 0; i < len(signPIDs); i++ { params := tss.NewParameters(tss.S256(), p2pCtx, signPIDs[i], len(signPIDs), threshold) - P := NewLocalPartyWithKDD(big.NewInt(42), params, keys[i], keyDerivationDelta, outCh, endCh).(*LocalParty) + P := NewLocalPartyWithKDD(big.NewInt(42), params, keys[i], keyDerivationDelta, outCh, endCh, 0).(*LocalParty) parties = append(parties, P) go func(P *LocalParty) { if err := P.Start(); err != nil { diff --git a/eddsa/signing/finalize.go b/eddsa/signing/finalize.go index aaafd255..f01e8890 100644 --- a/eddsa/signing/finalize.go +++ b/eddsa/signing/finalize.go @@ -12,9 +12,8 @@ import ( "math/big" "github.com/agl/ed25519/edwards25519" - "github.com/decred/dcrd/dcrec/edwards/v2" - "github.com/bnb-chain/tss-lib/v2/tss" + "github.com/decred/dcrd/dcrec/edwards/v2" ) func (round *finalization) Start() *tss.Error { @@ -43,7 +42,13 @@ func (round *finalization) Start() *tss.Error { round.data.Signature = append(bigIntToEncodedBytes(round.temp.r)[:], sumS[:]...) round.data.R = round.temp.r.Bytes() round.data.S = s.Bytes() - round.data.M = round.temp.m.Bytes() + if round.temp.fullBytesLen == 0 { + round.data.M = round.temp.m.Bytes() + } else { + var mBytes = make([]byte, round.temp.fullBytesLen) + round.temp.m.FillBytes(mBytes) + round.data.M = mBytes + } pk := edwards.PublicKey{ Curve: round.Params().EC(), @@ -51,7 +56,7 @@ func (round *finalization) Start() *tss.Error { Y: round.key.EDDSAPub.Y(), } - ok := edwards.Verify(&pk, round.temp.m.Bytes(), round.temp.r, s) + ok := edwards.Verify(&pk, round.data.M, round.temp.r, s) if !ok { return round.WrapError(fmt.Errorf("signature verification failed")) } diff --git a/eddsa/signing/local_party.go b/eddsa/signing/local_party.go index d11a29c4..52e474dc 100644 --- a/eddsa/signing/local_party.go +++ b/eddsa/signing/local_party.go @@ -50,8 +50,9 @@ type ( wi, m, ri *big.Int - pointRi *crypto.ECPoint - deCommit cmt.HashDeCommitment + fullBytesLen int + pointRi *crypto.ECPoint + deCommit cmt.HashDeCommitment // round 2 cjs []*big.Int @@ -71,6 +72,7 @@ func NewLocalParty( key keygen.LocalPartySaveData, out chan<- tss.Message, end chan<- *common.SignatureData, + fullBytesLen ...int, ) tss.Party { partyCount := len(params.Parties().IDs()) p := &LocalParty{ @@ -89,6 +91,11 @@ func NewLocalParty( // temp data init p.temp.m = msg + if len(fullBytesLen) > 0 { + p.temp.fullBytesLen = fullBytesLen[0] + } else { + p.temp.fullBytesLen = 0 + } p.temp.cjs = make([]*big.Int, partyCount) return p } diff --git a/eddsa/signing/local_party_test.go b/eddsa/signing/local_party_test.go index 33f6fa4e..a2c573e8 100644 --- a/eddsa/signing/local_party_test.go +++ b/eddsa/signing/local_party_test.go @@ -7,6 +7,7 @@ package signing import ( + "encoding/hex" "fmt" "math/big" "sync/atomic" @@ -142,3 +143,108 @@ signing: } } } + +func TestE2EConcurrentWithLeadingZeroInMSG(t *testing.T) { + setUp("info") + + threshold := testThreshold + + // PHASE: load keygen fixtures + keys, signPIDs, err := keygen.LoadKeygenTestFixturesRandomSet(testThreshold+1, testParticipants) + assert.NoError(t, err, "should load keygen fixtures") + assert.Equal(t, testThreshold+1, len(keys)) + assert.Equal(t, testThreshold+1, len(signPIDs)) + + // PHASE: signing + + p2pCtx := tss.NewPeerContext(signPIDs) + parties := make([]*LocalParty, 0, len(signPIDs)) + + errCh := make(chan *tss.Error, len(signPIDs)) + outCh := make(chan tss.Message, len(signPIDs)) + endCh := make(chan *common.SignatureData, len(signPIDs)) + + updater := test.SharedPartyUpdater + + msg, _ := hex.DecodeString("00f163ee51bcaeff9cdff5e0e3c1a646abd19885fffbab0b3b4236e0cf95c9f5") + // init the parties + for i := 0; i < len(signPIDs); i++ { + params := tss.NewParameters(tss.Edwards(), p2pCtx, signPIDs[i], len(signPIDs), threshold) + P := NewLocalParty(new(big.Int).SetBytes(msg), params, keys[i], outCh, endCh, len(msg)).(*LocalParty) + parties = append(parties, P) + go func(P *LocalParty) { + if err := P.Start(); err != nil { + errCh <- err + } + }(P) + } + + var ended int32 +signing: + for { + select { + case err := <-errCh: + common.Logger.Errorf("Error: %s", err) + assert.FailNow(t, err.Error()) + break signing + + case msg := <-outCh: + dest := msg.GetTo() + if dest == nil { + for _, P := range parties { + if P.PartyID().Index == msg.GetFrom().Index { + continue + } + go updater(P, msg, errCh) + } + } else { + if dest[0].Index == msg.GetFrom().Index { + t.Fatalf("party %d tried to send a message to itself (%d)", dest[0].Index, msg.GetFrom().Index) + } + go updater(parties[dest[0].Index], msg, errCh) + } + + case <-endCh: + atomic.AddInt32(&ended, 1) + if atomic.LoadInt32(&ended) == int32(len(signPIDs)) { + t.Logf("Done. Received signature data from %d participants", ended) + R := parties[0].temp.r + + // BEGIN check s correctness + sumS := parties[0].temp.si + for i, p := range parties { + if i == 0 { + continue + } + + var tmpSumS [32]byte + edwards25519.ScMulAdd(&tmpSumS, sumS, bigIntToEncodedBytes(big.NewInt(1)), p.temp.si) + sumS = &tmpSumS + } + fmt.Printf("S: %s\n", encodedBytesToBigInt(sumS).String()) + fmt.Printf("R: %s\n", R.String()) + // END check s correctness + + // BEGIN EDDSA verify + pkX, pkY := keys[0].EDDSAPub.X(), keys[0].EDDSAPub.Y() + pk := edwards.PublicKey{ + Curve: tss.Edwards(), + X: pkX, + Y: pkY, + } + + newSig, err := edwards.ParseSignature(parties[0].data.Signature) + if err != nil { + println("new sig error, ", err.Error()) + } + + ok := edwards.Verify(&pk, msg, newSig.R, newSig.S) + assert.True(t, ok, "eddsa verify must pass") + t.Log("EDDSA signing test done.") + // END EDDSA verify + + break signing + } + } + } +} diff --git a/eddsa/signing/round_3.go b/eddsa/signing/round_3.go index d967f8ff..6880e6bf 100644 --- a/eddsa/signing/round_3.go +++ b/eddsa/signing/round_3.go @@ -80,7 +80,13 @@ func (round *round3) Start() *tss.Error { h.Reset() h.Write(encodedR[:]) h.Write(encodedPubKey[:]) - h.Write(round.temp.m.Bytes()) + if round.temp.fullBytesLen == 0 { + h.Write(round.temp.m.Bytes()) + } else { + var mBytes = make([]byte, round.temp.fullBytesLen) + round.temp.m.FillBytes(mBytes) + h.Write(mBytes) + } var lambda [64]byte h.Sum(lambda[:0])