diff --git a/pkg/agent/svid/rotator.go b/pkg/agent/svid/rotator.go index ca8324e6b6..a608505f4f 100644 --- a/pkg/agent/svid/rotator.go +++ b/pkg/agent/svid/rotator.go @@ -61,9 +61,13 @@ type rotator struct { // Mutex used to prevent rotations when a new connection is being created rotMtx *sync.RWMutex - // Hook that will be called when the SVID rotation finishes - rotationFinishedHook func() + hooks struct { + // Hook that will be called when the SVID rotation finishes + rotationFinishedHook func() + // Hook that is called when the rotator starts running + runRotatorSignal chan struct{} + } tainted bool } @@ -83,6 +87,10 @@ func (r *rotator) Run(ctx context.Context) error { } func (r *rotator) runRotation(ctx context.Context) error { + if r.hooks.runRotatorSignal != nil { + r.hooks.runRotatorSignal <- struct{}{} + } + for { err := r.rotateSVIDIfNeeded(ctx) state, ok := r.state.Value().(State) @@ -179,7 +187,7 @@ func (r *rotator) GetRotationMtx() *sync.RWMutex { } func (r *rotator) SetRotationFinishedHook(f func()) { - r.rotationFinishedHook = f + r.hooks.rotationFinishedHook = f } func (r *rotator) Reattest(ctx context.Context) error { @@ -193,8 +201,8 @@ func (r *rotator) Reattest(ctx context.Context) error { } err := r.reattest(ctx) - if err == nil && r.rotationFinishedHook != nil { - r.rotationFinishedHook() + if err == nil && r.hooks.rotationFinishedHook != nil { + r.hooks.rotationFinishedHook() } return err @@ -213,8 +221,8 @@ func (r *rotator) rotateSVIDIfNeeded(ctx context.Context) (err error) { err = r.rotateSVID(ctx) } - if err == nil && r.rotationFinishedHook != nil { - r.rotationFinishedHook() + if err == nil && r.hooks.rotationFinishedHook != nil { + r.hooks.rotationFinishedHook() } } diff --git a/pkg/agent/svid/rotator_test.go b/pkg/agent/svid/rotator_test.go index cacc69c3b8..ff355e50f2 100644 --- a/pkg/agent/svid/rotator_test.go +++ b/pkg/agent/svid/rotator_test.go @@ -158,6 +158,7 @@ func TestRotator(t *testing.T) { RotationStrategy: rotationutil.NewRotationStrategy(0), }) rotator.client = mockClient + rotator.hooks.runRotatorSignal = make(chan struct{}) // Hook the rotation loop so we can determine when the rotator // has finished a rotation evaluation (does not imply anything @@ -179,6 +180,9 @@ func TestRotator(t *testing.T) { errCh <- rotator.Run(ctx) }() + // Make sure that the rotator is running + <-rotator.hooks.runRotatorSignal + // All tests should get through one rotation loop or error select { case <-clk.WaitForAfterCh(): @@ -513,7 +517,7 @@ func TestTaintedSVIDIsRotated(t *testing.T) { }) rotator.client = mockClient rotationFinishedCh := make(chan struct{}, 1) - rotator.rotationFinishedHook = func() { + rotator.hooks.rotationFinishedHook = func() { close(rotationFinishedCh) }