diff --git a/signer/cosigner_nonce_cache.go b/signer/cosigner_nonce_cache.go index eac24d92..6425c8c6 100644 --- a/signer/cosigner_nonce_cache.go +++ b/signer/cosigner_nonce_cache.go @@ -3,6 +3,7 @@ package signer import ( "context" "fmt" + "sort" "sync" "time" @@ -33,6 +34,8 @@ type CosignerNonceCache struct { threshold uint8 cache NonceCache + + pruner NonceCachePruner } type lastCount struct { @@ -58,6 +61,10 @@ func (lc *lastCount) Get() int { return lc.count } +type NonceCachePruner interface { + PruneNonces() +} + type NonceCache struct { cache map[uuid.UUID]*CachedNonce mu sync.RWMutex @@ -88,6 +95,19 @@ func (nc *NonceCache) Set(uuid uuid.UUID, cn *CachedNonce) { nc.cache[uuid] = cn } +func (nc *NonceCache) GetSortedByExpiration() []*CachedNonce { + nc.mu.RLock() + defer nc.mu.RUnlock() + cns := make([]*CachedNonce, 0, len(nc.cache)) + for _, cn := range nc.cache { + cns = append(cns, cn) + } + sort.Slice(cns, func(i, j int) bool { + return cns[i].Expiration.Before(cns[j].Expiration) + }) + return cns +} + type CosignerNoncesRel struct { Cosigner Cosigner Nonces CosignerNonces @@ -116,8 +136,9 @@ func NewCosignerNonceCache( getNoncesInterval time.Duration, getNoncesTimeout time.Duration, threshold uint8, + pruner NonceCachePruner, ) *CosignerNonceCache { - return &CosignerNonceCache{ + cnc := &CosignerNonceCache{ logger: logger, cache: NewNonceCache(), cosigners: cosigners, @@ -125,7 +146,13 @@ func NewCosignerNonceCache( getNoncesInterval: getNoncesInterval, getNoncesTimeout: getNoncesTimeout, threshold: threshold, + pruner: pruner, + } + if pruner == nil { + cnc.pruner = cnc } + + return cnc } func (cnc *CosignerNonceCache) getUuids(n int) []uuid.UUID { @@ -138,7 +165,7 @@ func (cnc *CosignerNonceCache) getUuids(n int) []uuid.UUID { func (cnc *CosignerNonceCache) reconcile(ctx context.Context) { // prune expired nonces - cnc.pruneNonces() + cnc.pruner.PruneNonces() if !cnc.leader.IsLeader() { return @@ -161,15 +188,17 @@ func (cnc *CosignerNonceCache) reconcile(ctx context.Context) { } defer func() { + cnc.lastReconcileNonces.Set(cnc.cache.Size()) cnc.lastReconcileTime = time.Now() }() // calculate how many nonces we need to load to keep up with demand // load 120% the number of nonces we need to keep up with demand, + // plus a couple seconds worth of nonces to account for nonce consumption during LoadN // plus 10 for padding - target := int((cnc.noncesPerMinute/60)*cnc.getNoncesInterval.Seconds()*1.2) + 10 + target := int((cnc.noncesPerMinute/60)*cnc.getNoncesInterval.Seconds()*1.2) + int(cnc.noncesPerMinute/30) + 10 additional := target - remainingNonces if additional <= 0 { // we're ahead of demand, don't load any more @@ -277,7 +306,7 @@ func (cnc *CosignerNonceCache) GetNonces(fastestPeers []Cosigner) (*CosignerUUID cnc.cache.mu.RLock() defer cnc.cache.mu.RUnlock() CheckNoncesLoop: - for u, cn := range cnc.cache.cache { + for _, cn := range cnc.cache.GetSortedByExpiration() { var nonces CosignerNonces for _, p := range fastestPeers { found := false @@ -295,12 +324,12 @@ CheckNoncesLoop: } cnc.cache.mu.RUnlock() - cnc.clearNonce(u) + cnc.clearNonce(cn.UUID) cnc.cache.mu.RLock() // all peers found return &CosignerUUIDNonces{ - UUID: u, + UUID: cn.UUID, Nonces: nonces, }, nil } @@ -316,7 +345,7 @@ CheckNoncesLoop: return nil, fmt.Errorf("no nonces found involving cosigners %+v", cosignerInts) } -func (cnc *CosignerNonceCache) pruneNonces() { +func (cnc *CosignerNonceCache) PruneNonces() { cnc.cache.mu.Lock() defer cnc.cache.mu.Unlock() for u, cn := range cnc.cache.cache { diff --git a/signer/cosigner_nonce_cache_test.go b/signer/cosigner_nonce_cache_test.go index c7ac2586..7a2798d3 100644 --- a/signer/cosigner_nonce_cache_test.go +++ b/signer/cosigner_nonce_cache_test.go @@ -10,6 +10,24 @@ import ( "github.com/stretchr/testify/require" ) +type mockPruner struct { + cnc *CosignerNonceCache + count int + pruned int +} + +func (mp *mockPruner) PruneNonces() { + mp.cnc.cache.mu.Lock() + defer mp.cnc.cache.mu.Unlock() + mp.count++ + for u, cn := range mp.cnc.cache.cache { + if time.Now().After(cn.Expiration) { + mp.pruned++ + delete(mp.cnc.cache.cache, u) + } + } +} + func TestNonceCacheDemand(t *testing.T) { lcs, _ := getTestLocalCosigners(t, 2, 3) cosigners := make([]Cosigner, len(lcs)) @@ -17,6 +35,8 @@ func TestNonceCacheDemand(t *testing.T) { cosigners[i] = lc } + mp := &mockPruner{} + nonceCache := NewCosignerNonceCache( cometlog.NewTMLogger(cometlog.NewSyncWriter(os.Stdout)), cosigners, @@ -24,8 +44,11 @@ func TestNonceCacheDemand(t *testing.T) { 500*time.Millisecond, 100*time.Millisecond, 2, + mp, ) + mp.cnc = nonceCache + ctx, cancel := context.WithCancel(context.Background()) nonceCache.LoadN(ctx, 500) @@ -45,6 +68,9 @@ func TestNonceCacheDemand(t *testing.T) { cancel() - target := int(nonceCache.noncesPerMinute*.01) + 10 + target := int(nonceCache.noncesPerMinute*.01) + int(nonceCache.noncesPerMinute/30) + 100 require.LessOrEqual(t, size, target) + + require.Greater(t, mp.count, 0) + require.Equal(t, 0, mp.pruned) } diff --git a/signer/threshold_validator.go b/signer/threshold_validator.go index 5a201d4b..b8bce14c 100644 --- a/signer/threshold_validator.go +++ b/signer/threshold_validator.go @@ -85,6 +85,7 @@ func NewThresholdValidator( defaultGetNoncesInterval, defaultGetNoncesTimeout, uint8(threshold), + nil, ) return &ThresholdValidator{ logger: logger,