diff --git a/cmd/staking-expiry-checker/main.go b/cmd/staking-expiry-checker/main.go index 12655a5..b24930f 100644 --- a/cmd/staking-expiry-checker/main.go +++ b/cmd/staking-expiry-checker/main.go @@ -3,6 +3,9 @@ package main import ( "context" "fmt" + "os" + "os/signal" + "syscall" "github.com/joho/godotenv" "github.com/rs/zerolog/log" @@ -23,25 +26,31 @@ func init() { } func main() { - ctx := context.Background() + // Create a context that is cancelled on SIGINT or SIGTERM + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - // setup cli commands and flags + // Setup signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Setup CLI commands and flags if err := cli.Setup(); err != nil { log.Fatal().Err(err).Msg("error while setting up cli") } - // load config + // Load config cfgPath := cli.GetConfigPath() cfg, err := config.New(cfgPath) if err != nil { log.Fatal().Err(err).Msg(fmt.Sprintf("error while loading config file: %s", cfgPath)) } - // initialize metrics with the metrics port from config + // Initialize metrics with the metrics port from config metricsPort := cfg.Metrics.GetMetricsPort() metrics.Init(metricsPort) - // create new db client + // Create new DB client dbClient, err := db.New(ctx, cfg.Db) if err != nil { log.Fatal().Err(err).Msg("error while creating db client") @@ -65,21 +74,17 @@ func main() { log.Fatal().Err(err).Msg("error while creating service") } - // Even though we pass service, it's viewed only through the specific interface - expiryPoller := poller.NewExpiryPoller( - cfg.Pollers.ExpiryChecker, - service, // service implements ExpiryChecker - ) - - btcSubscriberPoller := poller.NewBTCSubscriberPoller( - cfg.Pollers.BtcSubscriber, - service, // service implements BTCSubscriber - ) - - // Start pollers in separate goroutines - go expiryPoller.Start(ctx) - go btcSubscriberPoller.Start(ctx) + // Start pollers + go poller.NewExpiryPoller(cfg.Pollers.ExpiryChecker, service).Start(ctx) + go poller.NewBTCSubscriberPoller(cfg.Pollers.BtcSubscriber, service).Start(ctx) + // Start service handlers go service.HandleUnbondingDelegationChannel(ctx) go service.HandleWithdrawnDelegationChannel(ctx) + + // Wait for a signal to shutdown + <-sigChan + + // Cancel the context to signal all goroutines to stop + cancel() } diff --git a/internal/db/dbclient.go b/internal/db/dbclient.go index da5cb4a..b4b244c 100644 --- a/internal/db/dbclient.go +++ b/internal/db/dbclient.go @@ -38,3 +38,7 @@ func (db *Database) Ping(ctx context.Context) error { } return nil } + +func (db *Database) Shutdown(ctx context.Context) error { + return db.client.Disconnect(ctx) +} diff --git a/internal/poller/poller.go b/internal/poller/poller.go index a7f733b..1a91ae7 100644 --- a/internal/poller/poller.go +++ b/internal/poller/poller.go @@ -10,7 +10,6 @@ import ( "github.com/babylonlabs-io/staking-expiry-checker/internal/types" ) -// Define minimal interfaces for each poller type type ExpiryChecker interface { ProcessExpiredDelegations(ctx context.Context) *types.Error } @@ -34,7 +33,6 @@ type Poller struct { quit chan struct{} } -// Constructors now accept interfaces instead of the full service func NewExpiryPoller(cfg config.PollerConfig, checker ExpiryChecker) *Poller { return &Poller{ pollerType: ExpiryPoller, diff --git a/internal/services/expiry_checker.go b/internal/services/expiry_checker.go index 0b3a4f5..db5fc14 100644 --- a/internal/services/expiry_checker.go +++ b/internal/services/expiry_checker.go @@ -9,22 +9,22 @@ import ( "github.com/rs/zerolog/log" ) -// ProcessExpireCheck checks if the staking delegation has expired and updates the database. -// This method tolerate duplicated calls on the same stakingTxHashHex. -func (s *Service) ProcessExpireCheck( - ctx context.Context, stakingTxHashHex string, - startHeight, timelock uint64, txType types.StakingTxType, -) *types.Error { - expireHeight := startHeight + timelock - err := s.db.SaveTimeLockExpireCheck( - ctx, stakingTxHashHex, expireHeight, txType.ToString(), - ) - if err != nil { - log.Ctx(ctx).Err(err).Msg("Failed to save expire check") - return types.NewInternalServiceError(err) - } - return nil -} +// // ProcessExpireCheck checks if the staking delegation has expired and updates the database. +// // This method tolerate duplicated calls on the same stakingTxHashHex. +// func (s *Service) ProcessExpireCheck( +// ctx context.Context, stakingTxHashHex string, +// startHeight, timelock uint64, txType types.StakingTxType, +// ) *types.Error { +// expireHeight := startHeight + timelock +// err := s.db.SaveTimeLockExpireCheck( +// ctx, stakingTxHashHex, expireHeight, txType.ToString(), +// ) +// if err != nil { +// log.Ctx(ctx).Err(err).Msg("Failed to save expire check") +// return types.NewInternalServiceError(err) +// } +// return nil +// } func (s *Service) ProcessExpiredDelegations(ctx context.Context) *types.Error { btcTip, err := s.btc.GetBlockCount() @@ -33,38 +33,33 @@ func (s *Service) ProcessExpiredDelegations(ctx context.Context) *types.Error { return types.NewInternalServiceError(err) } - for { - expiredDelegations, err := s.db.FindExpiredDelegations(ctx, uint64(btcTip)) - if err != nil { - log.Error().Err(err).Msg("Error finding expired delegations") - return types.NewInternalServiceError(err) - } - if len(expiredDelegations) == 0 { - break - } + // Single batch of expired delegations + expiredDelegations, err := s.db.FindExpiredDelegations(ctx, uint64(btcTip)) + if err != nil { + log.Error().Err(err).Msg("Error finding expired delegations") + return types.NewInternalServiceError(err) + } - for _, delegation := range expiredDelegations { - err := s.ProcessExpiredDelegation(ctx, delegation) - if err != nil { - log.Error().Err(err).Msgf("Error processing expired delegation: %v", delegation.ID) - return err - } + // Process each delegation in the batch + for _, delegation := range expiredDelegations { + if err := s.transitionToUnbondedIfEligible(ctx, delegation); err != nil { + log.Error().Err(err). + Msgf("Error transitioning delegation to unbonded: %v", delegation.ID) + return err + } - // After successfully sending the event, delete the entry from the database. - if err := s.db.DeleteExpiredDelegation(ctx, delegation.ID); err != nil { - log.Error().Err(err).Msg("Error deleting expired delegation") - return types.NewInternalServiceError(err) - } + if err := s.db.DeleteExpiredDelegation(ctx, delegation.ID); err != nil { + log.Error().Err(err).Msg("Error deleting expired delegation") + return types.NewInternalServiceError(err) } } return nil } -// ProcessExpiredDelegation processes an expired delegation by -// transitioning it to unbonded. -// Do nothing if the delegation is not in an eligible state to transition. -func (s *Service) ProcessExpiredDelegation( +// transitionToUnbondedIfEligible attempts to transition a delegation to unbonded state +// if it's in an eligible state. +func (s *Service) transitionToUnbondedIfEligible( ctx context.Context, delegation model.TimeLockDocument, ) *types.Error { // Check what type of the timelock is