diff --git a/das/daser.go b/das/daser.go index 40eee3d316..7d569f7e0b 100644 --- a/das/daser.go +++ b/das/daser.go @@ -151,6 +151,8 @@ func (d *DASer) sample(ctx context.Context, h *header.ExtendedHeader) error { // short-circuit if pruning is enabled and the header is outside the // availability window if !d.isWithinSamplingWindow(h) { + log.Debugw("skipping header outside sampling window", "height", h.Height(), + "time", h.Time()) return nil } diff --git a/das/daser_test.go b/das/daser_test.go index fd1eb39f7d..9eec6392cc 100644 --- a/das/daser_test.go +++ b/das/daser_test.go @@ -2,6 +2,7 @@ package das import ( "context" + "strconv" "testing" "time" @@ -244,6 +245,42 @@ func TestDASerSampleTimeout(t *testing.T) { } } +// TestDASer_SamplingWindow tests the sampling window determination +// for headers. +func TestDASer_SamplingWindow(t *testing.T) { + ds := ds_sync.MutexWrap(datastore.NewMapDatastore()) + sub := new(headertest.Subscriber) + fserv := &fraudtest.DummyService[*header.ExtendedHeader]{} + getter := getterStub{} + avail := mocks.NewMockAvailability(gomock.NewController(t)) + + // create and start DASer + daser, err := NewDASer(avail, sub, getter, ds, fserv, newBroadcastMock(1), + WithSamplingWindow(time.Second)) + require.NoError(t, err) + + var tests = []struct { + timestamp time.Time + withinWindow bool + }{ + {timestamp: time.Now().Add(-(time.Second * 5)), withinWindow: false}, + {timestamp: time.Now().Add(-(time.Millisecond * 800)), withinWindow: true}, + {timestamp: time.Now().Add(-(time.Hour)), withinWindow: false}, + {timestamp: time.Now().Add(-(time.Hour * 24 * 30)), withinWindow: false}, + {timestamp: time.Now(), withinWindow: true}, + } + + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + eh := headertest.RandExtendedHeader(t) + eh.RawHeader.Time = tt.timestamp + + assert.Equal(t, tt.withinWindow, daser.isWithinSamplingWindow(eh)) + }) + } + +} + // createDASerSubcomponents takes numGetter (number of headers // to store in mockGetter) and numSub (number of headers to store // in the mock header.Subscriber), returning a newly instantiated diff --git a/header/headertest/testing.go b/header/headertest/testing.go index 9907fd7eb4..05f325bcbb 100644 --- a/header/headertest/testing.go +++ b/header/headertest/testing.go @@ -42,6 +42,14 @@ func NewStore(t *testing.T) libhead.Store[*header.ExtendedHeader] { return headertest.NewStore[*header.ExtendedHeader](t, NewTestSuite(t, 3), 10) } +func NewCustomStore( + t *testing.T, + generator headertest.Generator[*header.ExtendedHeader], + numHeaders int, +) libhead.Store[*header.ExtendedHeader] { + return headertest.NewStore[*header.ExtendedHeader](t, generator, numHeaders) +} + // NewTestSuite setups a new test suite with a given number of validators. func NewTestSuite(t *testing.T, num int) *TestSuite { valSet, vals := RandValidatorSet(num, 10) @@ -77,8 +85,10 @@ func (s *TestSuite) genesis() *header.ExtendedHeader { return eh } -func MakeCommit(blockID types.BlockID, height int64, round int32, - voteSet *types.VoteSet, validators []types.PrivValidator, now time.Time) (*types.Commit, error) { +func MakeCommit( + blockID types.BlockID, height int64, round int32, + voteSet *types.VoteSet, validators []types.PrivValidator, now time.Time, +) (*types.Commit, error) { // all sign for i := 0; i < len(validators); i++ { @@ -152,7 +162,8 @@ func (s *TestSuite) NextHeader() *header.ExtendedHeader { } func (s *TestSuite) GenRawHeader( - height uint64, lastHeader, lastCommit, dataHash libhead.Hash) *header.RawHeader { + height uint64, lastHeader, lastCommit, dataHash libhead.Hash, +) *header.RawHeader { rh := RandRawHeader(s.t) rh.Height = int64(height) rh.Time = time.Now() @@ -204,6 +215,11 @@ func (s *TestSuite) nextProposer() *types.Validator { // RandExtendedHeader provides an ExtendedHeader fixture. func RandExtendedHeader(t testing.TB) *header.ExtendedHeader { + timestamp := time.Now() + return RandExtendedHeaderAtTimestamp(t, timestamp) +} + +func RandExtendedHeaderAtTimestamp(t testing.TB, timestamp time.Time) *header.ExtendedHeader { dah := share.EmptyRoot() rh := RandRawHeader(t) @@ -214,7 +230,7 @@ func RandExtendedHeader(t testing.TB) *header.ExtendedHeader { voteSet := types.NewVoteSet(rh.ChainID, rh.Height, 0, tmproto.PrecommitType, valSet) blockID := RandBlockID(t) blockID.Hash = rh.Hash() - commit, err := MakeCommit(blockID, rh.Height, 0, voteSet, vals, time.Now()) + commit, err := MakeCommit(blockID, rh.Height, 0, voteSet, vals, timestamp) require.NoError(t, err) return &header.ExtendedHeader{ diff --git a/libs/utils/resetctx.go b/libs/utils/resetctx.go index 3014ba81db..a108cc27b4 100644 --- a/libs/utils/resetctx.go +++ b/libs/utils/resetctx.go @@ -1,6 +1,8 @@ package utils -import "context" +import ( + "context" +) // ResetContextOnError returns a fresh context if the given context has an error. func ResetContextOnError(ctx context.Context) context.Context { diff --git a/nodebuilder/prune/constructors.go b/nodebuilder/prune/constructors.go new file mode 100644 index 0000000000..8cc58aecd9 --- /dev/null +++ b/nodebuilder/prune/constructors.go @@ -0,0 +1,23 @@ +package prune + +import ( + "github.com/ipfs/go-datastore" + + hdr "github.com/celestiaorg/go-header" + + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/nodebuilder/p2p" + "github.com/celestiaorg/celestia-node/pruner" +) + +func newPrunerService( + p pruner.Pruner, + window pruner.AvailabilityWindow, + getter hdr.Store[*header.ExtendedHeader], + ds datastore.Batching, + opts ...pruner.Option, +) *pruner.Service { + // TODO @renaynay: remove this once pruning implementation + opts = append(opts, pruner.WithDisabledGC()) + return pruner.NewService(p, window, getter, ds, p2p.BlockTime, opts...) +} diff --git a/nodebuilder/prune/module.go b/nodebuilder/prune/module.go index 330ef21cdc..9120d196d7 100644 --- a/nodebuilder/prune/module.go +++ b/nodebuilder/prune/module.go @@ -8,13 +8,16 @@ import ( "github.com/celestiaorg/celestia-node/nodebuilder/node" "github.com/celestiaorg/celestia-node/pruner" "github.com/celestiaorg/celestia-node/pruner/archival" + "github.com/celestiaorg/celestia-node/pruner/full" "github.com/celestiaorg/celestia-node/pruner/light" + "github.com/celestiaorg/celestia-node/share/eds" ) func ConstructModule(tp node.Type) fx.Option { + baseComponents := fx.Options( fx.Provide(fx.Annotate( - pruner.NewService, + newPrunerService, fx.OnStart(func(ctx context.Context, p *pruner.Service) error { return p.Start(ctx) }), @@ -22,10 +25,21 @@ func ConstructModule(tp node.Type) fx.Option { return p.Stop(ctx) }), )), + // This is necessary to invoke the pruner service as independent thanks to a + // quirk in FX. + fx.Invoke(func(p *pruner.Service) {}), ) switch tp { - case node.Full, node.Bridge: + case node.Full: + return fx.Module("prune", + baseComponents, + fx.Provide(func(store *eds.Store) pruner.Pruner { + return full.NewPruner(store) + }), + fx.Supply(full.Window), + ) + case node.Bridge: return fx.Module("prune", baseComponents, fx.Provide(func() pruner.Pruner { @@ -39,7 +53,7 @@ func ConstructModule(tp node.Type) fx.Option { fx.Provide(func() pruner.Pruner { return light.NewPruner() }), - fx.Supply(archival.Window), // TODO @renaynay: turn this into light.Window in following PR + fx.Supply(light.Window), ) default: panic("unknown node type") diff --git a/nodebuilder/settings.go b/nodebuilder/settings.go index 298976fda4..136b75e15f 100644 --- a/nodebuilder/settings.go +++ b/nodebuilder/settings.go @@ -30,6 +30,7 @@ import ( "github.com/celestiaorg/celestia-node/nodebuilder/node" "github.com/celestiaorg/celestia-node/nodebuilder/p2p" "github.com/celestiaorg/celestia-node/nodebuilder/share" + "github.com/celestiaorg/celestia-node/pruner" "github.com/celestiaorg/celestia-node/state" ) @@ -93,6 +94,7 @@ func WithMetrics(metricOpts []otlpmetrichttp.Option, nodeType node.Type) fx.Opti fx.Invoke(fraud.WithMetrics[*header.ExtendedHeader]), fx.Invoke(node.WithMetrics), fx.Invoke(share.WithDiscoveryMetrics), + fx.Invoke(pruner.WithPrunerMetrics), ) samplingMetrics := fx.Options( diff --git a/pruner/archival/pruner.go b/pruner/archival/pruner.go index 7b1cb935f3..a1a55db0da 100644 --- a/pruner/archival/pruner.go +++ b/pruner/archival/pruner.go @@ -15,6 +15,6 @@ func NewPruner() *Pruner { return &Pruner{} } -func (p *Pruner) Prune(context.Context, ...*header.ExtendedHeader) error { +func (p *Pruner) Prune(context.Context, *header.ExtendedHeader) error { return nil } diff --git a/pruner/finder.go b/pruner/finder.go new file mode 100644 index 0000000000..482a7478b9 --- /dev/null +++ b/pruner/finder.go @@ -0,0 +1,132 @@ +package pruner + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + "github.com/ipfs/go-datastore" + + "github.com/celestiaorg/celestia-node/header" +) + +var ( + lastPrunedHeaderKey = datastore.NewKey("last_pruned_header") +) + +type checkpoint struct { + ds datastore.Datastore + + lastPrunedHeader atomic.Pointer[header.ExtendedHeader] + + // TODO @renaynay: keep track of failed roots to retry in separate job +} + +func newCheckpoint(ds datastore.Datastore) *checkpoint { + return &checkpoint{ds: ds} +} + +// findPruneableHeaders returns all headers that are eligible for pruning +// (outside the sampling window). +func (s *Service) findPruneableHeaders(ctx context.Context) ([]*header.ExtendedHeader, error) { + lastPruned := s.lastPruned() + + pruneCutoff := time.Now().Add(time.Duration(-s.window)) + estimatedCutoffHeight := lastPruned.Height() + s.numBlocksInWindow + + head, err := s.getter.Head(ctx) + if err != nil { + return nil, err + } + if head.Height() < estimatedCutoffHeight { + estimatedCutoffHeight = head.Height() + } + + headers, err := s.getter.GetRangeByHeight(ctx, lastPruned, estimatedCutoffHeight) + if err != nil { + return nil, err + } + + // if our estimated range didn't cover enough headers, we need to fetch more + // TODO: This is really inefficient in the case that lastPruned is the default value, or if the + // node has been offline for a long time. Instead of increasing the boundary by one in the for + // loop we could increase by a range every iteration + headerCount := len(headers) + for { + if headerCount > int(s.maxPruneablePerGC) { + headers = headers[:s.maxPruneablePerGC] + break + } + lastHeader := headers[len(headers)-1] + if lastHeader.Time().After(pruneCutoff) { + break + } + + nextHeader, err := s.getter.GetByHeight(ctx, lastHeader.Height()+1) + if err != nil { + return nil, err + } + headers = append(headers, nextHeader) + headerCount++ + } + + for i, h := range headers { + if h.Time().After(pruneCutoff) { + if i == 0 { + // we can't prune anything + return nil, nil + } + + // we can ignore the rest of the headers since they are all newer than the cutoff + return headers[:i], nil + } + } + return headers, nil +} + +// initializeCheckpoint initializes the checkpoint, storing the earliest header in the chain. +func (s *Service) initializeCheckpoint(ctx context.Context) error { + firstHeader, err := s.getter.GetByHeight(ctx, 1) + if err != nil { + return fmt.Errorf("failed to initialize checkpoint: %w", err) + } + + return s.updateCheckpoint(ctx, firstHeader) +} + +// loadCheckpoint loads the last checkpoint from disk, initializing it if it does not already exist. +func (s *Service) loadCheckpoint(ctx context.Context) error { + bin, err := s.checkpoint.ds.Get(ctx, lastPrunedHeaderKey) + if err != nil { + if err == datastore.ErrNotFound { + return s.initializeCheckpoint(ctx) + } + return fmt.Errorf("failed to load checkpoint: %w", err) + } + + var lastPruned header.ExtendedHeader + if err := lastPruned.UnmarshalJSON(bin); err != nil { + return fmt.Errorf("failed to load checkpoint: %w", err) + } + + s.checkpoint.lastPrunedHeader.Store(&lastPruned) + return nil +} + +// updateCheckpoint updates the checkpoint with the last pruned header height +// and persists it to disk. +func (s *Service) updateCheckpoint(ctx context.Context, lastPruned *header.ExtendedHeader) error { + s.checkpoint.lastPrunedHeader.Store(lastPruned) + + bin, err := lastPruned.MarshalJSON() + if err != nil { + return err + } + + return s.checkpoint.ds.Put(ctx, lastPrunedHeaderKey, bin) +} + +func (s *Service) lastPruned() *header.ExtendedHeader { + return s.checkpoint.lastPrunedHeader.Load() +} diff --git a/pruner/full/pruner.go b/pruner/full/pruner.go new file mode 100644 index 0000000000..a5fc38e78e --- /dev/null +++ b/pruner/full/pruner.go @@ -0,0 +1,33 @@ +package full + +import ( + "context" + + logging "github.com/ipfs/go-log/v2" + + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds" +) + +var log = logging.Logger("pruner/full") + +type Pruner struct { + store *eds.Store +} + +func NewPruner(store *eds.Store) *Pruner { + return &Pruner{ + store: store, + } +} + +func (p *Pruner) Prune(ctx context.Context, eh *header.ExtendedHeader) error { + // short circuit on empty roots + if eh.DAH.Equals(share.EmptyRoot()) { + return nil + } + + log.Debugf("pruning header %s", eh.DAH.Hash()) + return p.store.Remove(ctx, eh.DAH.Hash()) +} diff --git a/pruner/full/window.go b/pruner/full/window.go new file mode 100644 index 0000000000..c2a64288ce --- /dev/null +++ b/pruner/full/window.go @@ -0,0 +1,11 @@ +package full + +import ( + "time" + + "github.com/celestiaorg/celestia-node/pruner" +) + +// Window is the availability window for light nodes in the Celestia +// network (30 days). +const Window = pruner.AvailabilityWindow(time.Second * 86400 * 30) diff --git a/pruner/light/pruner.go b/pruner/light/pruner.go index 513bfa2b66..61401bae74 100644 --- a/pruner/light/pruner.go +++ b/pruner/light/pruner.go @@ -12,6 +12,6 @@ func NewPruner() *Pruner { return &Pruner{} } -func (p *Pruner) Prune(context.Context, ...*header.ExtendedHeader) error { +func (p *Pruner) Prune(context.Context, *header.ExtendedHeader) error { return nil } diff --git a/pruner/light/window.go b/pruner/light/window.go index 53bfe4a163..dc1a9e4444 100644 --- a/pruner/light/window.go +++ b/pruner/light/window.go @@ -6,4 +6,6 @@ import ( "github.com/celestiaorg/celestia-node/pruner" ) +// Window is the availability window for light nodes in the Celestia +// network (30 days). const Window = pruner.AvailabilityWindow(time.Second * 86400 * 30) diff --git a/pruner/metrics.go b/pruner/metrics.go new file mode 100644 index 0000000000..fd6c50788a --- /dev/null +++ b/pruner/metrics.go @@ -0,0 +1,79 @@ +package pruner + +import ( + "context" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +var ( + meter = otel.Meter("storage_pruner") +) + +type metrics struct { + prunedCounter metric.Int64Counter + + lastPruned metric.Int64ObservableGauge + failedPrunes metric.Int64ObservableGauge +} + +func (s *Service) WithMetrics() error { + prunedCounter, err := meter.Int64Counter("pruner_pruned_counter", + metric.WithDescription("pruner pruned header counter")) + if err != nil { + return err + } + + failedPrunes, err := meter.Int64ObservableGauge("pruner_failed_counter", + metric.WithDescription("pruner failed prunes counter")) + if err != nil { + return err + } + + lastPruned, err := meter.Int64ObservableGauge("pruner_last_pruned", + metric.WithDescription("pruner highest pruned height")) + if err != nil { + return err + } + + callback := func(ctx context.Context, observer metric.Observer) error { + observer.ObserveInt64(failedPrunes, int64(len(s.failedHeaders))) + return nil + } + + if _, err := meter.RegisterCallback(callback, failedPrunes); err != nil { + return err + } + + callback = func(ctx context.Context, observer metric.Observer) error { + lastPrunedHeader := s.checkpoint.lastPrunedHeader.Load() + if lastPrunedHeader != nil { + observer.ObserveInt64(lastPruned, int64(lastPrunedHeader.Height())) + } + return nil + } + + if _, err := meter.RegisterCallback(callback, lastPruned); err != nil { + return err + } + + s.metrics = &metrics{ + prunedCounter: prunedCounter, + lastPruned: lastPruned, + failedPrunes: failedPrunes, + } + return nil +} + +func (m *metrics) observePrune(ctx context.Context, failed bool) { + if m == nil { + return + } + if ctx.Err() != nil { + ctx = context.Background() + } + m.prunedCounter.Add(ctx, 1, metric.WithAttributes( + attribute.Bool("failed", failed))) +} diff --git a/pruner/params.go b/pruner/params.go new file mode 100644 index 0000000000..c699ff9e11 --- /dev/null +++ b/pruner/params.go @@ -0,0 +1,41 @@ +package pruner + +import ( + "time" +) + +type Option func(*Params) + +type Params struct { + // gcCycle is the frequency at which the pruning Service + // runs the ticker. If set to 0, the Service will not run. + gcCycle time.Duration +} + +func DefaultParams() Params { + return Params{ + gcCycle: time.Hour, + } +} + +// WithGCCycle configures how often the pruning Service +// triggers a pruning cycle. +func WithGCCycle(cycle time.Duration) Option { + return func(p *Params) { + p.gcCycle = cycle + } +} + +// WithDisabledGC disables the pruning Service's pruning +// routine. +func WithDisabledGC() Option { + return func(p *Params) { + p.gcCycle = time.Duration(0) + } +} + +// WithPrunerMetrics is a utility function to turn on pruner metrics and that is +// expected to be "invoked" by the fx lifecycle. +func WithPrunerMetrics(s *Service) error { + return s.WithMetrics() +} diff --git a/pruner/pruner.go b/pruner/pruner.go index fae60e483c..a591a65392 100644 --- a/pruner/pruner.go +++ b/pruner/pruner.go @@ -9,5 +9,5 @@ import ( // Pruner contains methods necessary to prune data // from the node's datastore. type Pruner interface { - Prune(context.Context, ...*header.ExtendedHeader) error + Prune(context.Context, *header.ExtendedHeader) error } diff --git a/pruner/service.go b/pruner/service.go index f67265977a..6816d7a6e7 100644 --- a/pruner/service.go +++ b/pruner/service.go @@ -2,24 +2,131 @@ package pruner import ( "context" + "fmt" + "time" + + "github.com/ipfs/go-datastore" + logging "github.com/ipfs/go-log/v2" + + hdr "github.com/celestiaorg/go-header" + + "github.com/celestiaorg/celestia-node/header" ) +var log = logging.Logger("pruner/service") + // Service handles the pruning routine for the node using the // prune Pruner. type Service struct { pruner Pruner + window AvailabilityWindow + + getter hdr.Getter[*header.ExtendedHeader] // TODO @renaynay: expects a header service with access to sync head + + checkpoint *checkpoint + failedHeaders map[uint64]error + maxPruneablePerGC uint64 + numBlocksInWindow uint64 + + ctx context.Context + cancel context.CancelFunc + doneCh chan struct{} + + params Params + metrics *metrics } -func NewService(p Pruner) *Service { +func NewService( + p Pruner, + window AvailabilityWindow, + getter hdr.Getter[*header.ExtendedHeader], + ds datastore.Datastore, + blockTime time.Duration, + opts ...Option, +) *Service { + params := DefaultParams() + for _, opt := range opts { + opt(¶ms) + } + + // TODO @renaynay + numBlocksInWindow := uint64(time.Duration(window) / blockTime) + return &Service{ - pruner: p, + pruner: p, + window: window, + getter: getter, + checkpoint: newCheckpoint(ds), + numBlocksInWindow: numBlocksInWindow, + // TODO @distractedmind: make this configurable? + maxPruneablePerGC: numBlocksInWindow * 2, + doneCh: make(chan struct{}), + params: params, } } func (s *Service) Start(context.Context) error { + s.ctx, s.cancel = context.WithCancel(context.Background()) + + err := s.loadCheckpoint(s.ctx) + if err != nil { + return err + } + + go s.prune() return nil } -func (s *Service) Stop(context.Context) error { - return nil +func (s *Service) Stop(ctx context.Context) error { + s.cancel() + + select { + case <-s.doneCh: + return nil + case <-ctx.Done(): + return fmt.Errorf("pruner unable to exit within context deadline") + } +} + +func (s *Service) prune() { + if s.params.gcCycle == time.Duration(0) { + // Service is disabled, exit + close(s.doneCh) + return + } + + ticker := time.NewTicker(s.params.gcCycle) + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + close(s.doneCh) + return + case <-ticker.C: + headers, err := s.findPruneableHeaders(s.ctx) + if err != nil { + // TODO @renaynay: record + report errors properly + continue + } + // TODO @renaynay: make deadline a param ? / configurable? + pruneCtx, cancel := context.WithDeadline(s.ctx, time.Now().Add(time.Minute)) + for _, eh := range headers { + err = s.pruner.Prune(pruneCtx, eh) + if err != nil { + // TODO: @distractedm1nd: updatecheckpoint should be called on the last NON-ERRORED header + log.Errorf("failed to prune header %d: %s", eh.Height(), err) + s.failedHeaders[eh.Height()] = err + } + s.metrics.observePrune(pruneCtx, err != nil) + } + cancel() + + err = s.updateCheckpoint(s.ctx, headers[len(headers)-1]) + if err != nil { + // TODO @renaynay: record + report errors properly + continue + } + } + } } diff --git a/pruner/service_test.go b/pruner/service_test.go new file mode 100644 index 0000000000..c67133e39f --- /dev/null +++ b/pruner/service_test.go @@ -0,0 +1,186 @@ +package pruner + +import ( + "context" + "testing" + "time" + + "github.com/ipfs/go-datastore" + "github.com/ipfs/go-datastore/sync" + "github.com/stretchr/testify/require" + + hdr "github.com/celestiaorg/go-header" + + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/header/headertest" +) + +/* + | toPrune | availability window | +*/ + +// TODO @renaynay: tweak/document +var ( + availWindow = AvailabilityWindow(time.Millisecond * 200) + blockTime = time.Millisecond * 100 + gcCycle = time.Millisecond * 500 +) + +func TestService(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + store := headertest.NewStore(t) + + mp := &mockPruner{} + + serv := NewService( + mp, + availWindow, + store, + sync.MutexWrap(datastore.NewMapDatastore()), + blockTime, + WithGCCycle(gcCycle), + ) + + gen, err := store.GetByHeight(ctx, 1) + require.NoError(t, err) + + err = serv.updateCheckpoint(ctx, gen) + require.NoError(t, err) + + err = serv.Start(ctx) + require.NoError(t, err) + + time.Sleep(time.Second) + + err = serv.Stop(ctx) + require.NoError(t, err) + + expected := time.Second/blockTime - time.Duration(availWindow)/blockTime + require.Len(t, mp.deletedHeaderHashes, int(expected)) +} + +func TestFindPruneableHeaders(t *testing.T) { + testCases := []struct { + name string + availWindow AvailabilityWindow + blockTime time.Duration + startTime time.Time + headerAmount int + expectedLength int + }{ + { + name: "Estimated range matches expected", + // Availability window is one week + availWindow: AvailabilityWindow(time.Hour * 24 * 7), + blockTime: time.Hour, + // Make two weeks of headers + headerAmount: 2 * (24 * 7), + startTime: time.Now().Add(-2 * time.Hour * 24 * 7), + // One week of headers are pruneable + expectedLength: 24 * 7, + }, + { + name: "Estimated range not sufficient but finds the correct tail", + // Availability window is one week + availWindow: AvailabilityWindow(time.Hour * 24 * 7), + blockTime: time.Hour, + // Make three weeks of headers + headerAmount: 3 * (24 * 7), + startTime: time.Now().Add(-3 * time.Hour * 24 * 7), + // Two weeks of headers are pruneable + expectedLength: 2 * 24 * 7, + }, + { + name: "No pruneable headers", + // Availability window is two weeks + availWindow: AvailabilityWindow(2 * time.Hour * 24 * 7), + blockTime: time.Hour, + // Make one week of headers + headerAmount: 24 * 7, + startTime: time.Now().Add(-time.Hour * 24 * 7), + // No headers are pruneable + expectedLength: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + headerGenerator := NewSpacedHeaderGenerator(t, tc.startTime, tc.blockTime) + store := headertest.NewCustomStore(t, headerGenerator, tc.headerAmount) + + mp := &mockPruner{} + + serv := NewService( + mp, + tc.availWindow, + store, + sync.MutexWrap(datastore.NewMapDatastore()), + tc.blockTime, + ) + + err := serv.Start(ctx) + require.NoError(t, err) + + pruneable, err := serv.findPruneableHeaders(ctx) + require.NoError(t, err) + require.Len(t, pruneable, tc.expectedLength) + + pruneableCutoff := time.Now().Add(-time.Duration(tc.availWindow)) + // All returned headers are older than the availability window + for _, h := range pruneable { + require.WithinRange(t, h.Time(), tc.startTime, pruneableCutoff) + } + + // The next header after the last pruneable header is too new to prune + if len(pruneable) != 0 { + lastPruneable := pruneable[len(pruneable)-1] + if lastPruneable.Height() != store.Height() { + firstUnpruneable, err := store.GetByHeight(ctx, lastPruneable.Height()+1) + require.NoError(t, err) + require.WithinRange(t, firstUnpruneable.Time(), pruneableCutoff, time.Now()) + } + } + }) + } +} + +type mockPruner struct { + deletedHeaderHashes []hdr.Hash +} + +func (mp *mockPruner) Prune(_ context.Context, h *header.ExtendedHeader) error { + mp.deletedHeaderHashes = append(mp.deletedHeaderHashes, h.Hash()) + return nil +} + +type SpacedHeaderGenerator struct { + t *testing.T + TimeBetweenHeaders time.Duration + currentTime time.Time + currentHeight int64 +} + +func NewSpacedHeaderGenerator( + t *testing.T, startTime time.Time, timeBetweenHeaders time.Duration, +) *SpacedHeaderGenerator { + return &SpacedHeaderGenerator{ + t: t, + TimeBetweenHeaders: timeBetweenHeaders, + currentTime: startTime, + currentHeight: 1, + } +} + +func (shg *SpacedHeaderGenerator) NextHeader() *header.ExtendedHeader { + h := headertest.RandExtendedHeaderAtTimestamp(shg.t, shg.currentTime) + h.RawHeader.Height = shg.currentHeight + h.RawHeader.Time = shg.currentTime + shg.currentHeight++ + shg.currentTime = shg.currentTime.Add(shg.TimeBetweenHeaders) + return h +}