diff --git a/lib/auth/init.go b/lib/auth/init.go index 491a3629f379..737f1566ac41 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -311,9 +311,9 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e if err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ - Backend: cfg.Backend, - LockName: domainName, - TTL: 30 * time.Second, + Backend: cfg.Backend, + LockNameComponents: []string{domainName}, + TTL: 30 * time.Second, }, RefreshLockInterval: 20 * time.Second, }, func(ctx context.Context) error { diff --git a/lib/backend/key.go b/lib/backend/key.go index 4c7f25c604ed..1dcb213d5c93 100644 --- a/lib/backend/key.go +++ b/lib/backend/key.go @@ -52,6 +52,11 @@ func (k Key) String() string { return string(k) } +// IsZero reports whether k represents the zero key. +func (k Key) IsZero() bool { + return len(k) == 0 +} + // HasPrefix reports whether the key begins with prefix. func (k Key) HasPrefix(prefix Key) bool { return bytes.HasPrefix(k, prefix) diff --git a/lib/backend/key_test.go b/lib/backend/key_test.go index d554fb692235..894c39f49aff 100644 --- a/lib/backend/key_test.go +++ b/lib/backend/key_test.go @@ -473,3 +473,10 @@ func TestKeyCompare(t *testing.T) { }) } } + +func TestKeyIsZero(t *testing.T) { + assert.True(t, backend.Key{}.IsZero()) + assert.True(t, backend.NewKey().IsZero()) + assert.False(t, backend.NewKey("a", "b").IsZero()) + assert.False(t, backend.ExactKey("a", "b").IsZero()) +} diff --git a/lib/backend/lock.go b/lib/backend/lock.go index f9eae9443703..220f1f7a315e 100644 --- a/lib/backend/lock.go +++ b/lib/backend/lock.go @@ -50,8 +50,14 @@ func randomID() ([]byte, error) { } type LockConfiguration struct { - Backend Backend + // Backend to create the lock in. + Backend Backend + // LockName the precomputed lock name. + // TODO(tross) DELETE WHEN teleport.e is updated to use LockNameComponents. LockName string + // LockNameComponents are subcomponents to be used when constructing + // the lock name. + LockNameComponents []string // TTL defines when lock will be released automatically TTL time.Duration // RetryInterval defines interval which is used to retry locking after @@ -63,9 +69,14 @@ func (l *LockConfiguration) CheckAndSetDefaults() error { if l.Backend == nil { return trace.BadParameter("missing Backend") } - if l.LockName == "" { - return trace.BadParameter("missing LockName") + if l.LockName == "" && len(l.LockNameComponents) == 0 { + return trace.BadParameter("missing LockName/LockNameComponents") } + + if len(l.LockNameComponents) == 0 { + l.LockNameComponents = []string{l.LockName} + } + if l.TTL == 0 { return trace.BadParameter("missing TTL") } @@ -81,7 +92,7 @@ func AcquireLock(ctx context.Context, cfg LockConfiguration) (Lock, error) { if err != nil { return Lock{}, trace.Wrap(err) } - key := lockKey(cfg.LockName) + key := lockKey(cfg.LockNameComponents...) id, err := randomID() if err != nil { return Lock{}, trace.Wrap(err) diff --git a/lib/backend/lock_test.go b/lib/backend/lock_test.go index 822ede66236f..a9c807a2ec0f 100644 --- a/lib/backend/lock_test.go +++ b/lib/backend/lock_test.go @@ -51,30 +51,30 @@ func TestLockConfiguration_CheckAndSetDefaults(t *testing.T) { { name: "minimum valid", in: LockConfiguration{ - Backend: mockBackend{}, - LockName: "lock", - TTL: 30 * time.Second, + Backend: mockBackend{}, + LockNameComponents: []string{"lock"}, + TTL: 30 * time.Second, }, want: LockConfiguration{ - Backend: mockBackend{}, - LockName: "lock", - TTL: 30 * time.Second, - RetryInterval: 250 * time.Millisecond, + Backend: mockBackend{}, + LockNameComponents: []string{"lock"}, + TTL: 30 * time.Second, + RetryInterval: 250 * time.Millisecond, }, }, { name: "set RetryAcquireLockTimeout", in: LockConfiguration{ - Backend: mockBackend{}, - LockName: "lock", - TTL: 30 * time.Second, - RetryInterval: 10 * time.Second, + Backend: mockBackend{}, + LockNameComponents: []string{"lock"}, + TTL: 30 * time.Second, + RetryInterval: 10 * time.Second, }, want: LockConfiguration{ - Backend: mockBackend{}, - LockName: "lock", - TTL: 30 * time.Second, - RetryInterval: 10 * time.Second, + Backend: mockBackend{}, + LockNameComponents: []string{"lock"}, + TTL: 30 * time.Second, + RetryInterval: 10 * time.Second, }, }, { @@ -95,9 +95,9 @@ func TestLockConfiguration_CheckAndSetDefaults(t *testing.T) { { name: "missing TTL", in: LockConfiguration{ - Backend: mockBackend{}, - LockName: "lock", - TTL: 0, + Backend: mockBackend{}, + LockNameComponents: []string{"lock"}, + TTL: 0, }, wantErr: "missing TTL", }, @@ -124,9 +124,9 @@ func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) { ttl := 1 * time.Minute minimumValidConfig := RunWhileLockedConfig{ LockConfiguration: LockConfiguration{ - Backend: mockBackend{}, - LockName: lockName, - TTL: ttl, + Backend: mockBackend{}, + LockNameComponents: []string{lockName}, + TTL: ttl, }, } tests := []struct { @@ -142,10 +142,10 @@ func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) { }, want: RunWhileLockedConfig{ LockConfiguration: LockConfiguration{ - Backend: mockBackend{}, - LockName: lockName, - TTL: ttl, - RetryInterval: 250 * time.Millisecond, + Backend: mockBackend{}, + LockNameComponents: []string{lockName}, + TTL: ttl, + RetryInterval: 250 * time.Millisecond, }, ReleaseCtxTimeout: time.Second, // defaults to halft of TTL. @@ -157,6 +157,7 @@ func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) { input: func() RunWhileLockedConfig { cfg := minimumValidConfig cfg.LockName = "" + cfg.LockNameComponents = nil return cfg }, wantErr: "missing LockName", diff --git a/lib/backend/test/suite.go b/lib/backend/test/suite.go index 950e3a025e73..5597ccd1d051 100644 --- a/lib/backend/test/suite.go +++ b/lib/backend/test/suite.go @@ -833,7 +833,7 @@ func testLocking(t *testing.T, newBackend Constructor) { defer requireNoAsyncErrors() // Given a lock named `tok1` on the backend... - lock, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) + lock, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl}) require.NoError(t, err) // When I asynchronously release the lock... @@ -848,7 +848,7 @@ func testLocking(t *testing.T, newBackend Constructor) { }() // ...and simultaneously attempt to create a new lock with the same name - lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) + lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl}) // expect that the asynchronous Release() has executed - we're using the // change in the value of the marker value as a proxy for the Release(). @@ -860,7 +860,7 @@ func testLocking(t *testing.T, newBackend Constructor) { require.NoError(t, lock.Release(ctx, uut)) // Given a lock with the same name as previously-existing, manually-released lock - lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) + lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl}) require.NoError(t, err) atomic.StoreInt32(&marker, 7) @@ -875,7 +875,7 @@ func testLocking(t *testing.T, newBackend Constructor) { }() // ...and simultaneously try to acquire another lock with the same name - lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) + lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl}) // expect that the asynchronous Release() has executed - we're using the // change in the value of the marker value as a proxy for the call to @@ -889,9 +889,9 @@ func testLocking(t *testing.T, newBackend Constructor) { // Given a pair of locks named `tok1` and `tok2` y := int32(0) - lock1, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) + lock1, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl}) require.NoError(t, err) - lock2, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok2, TTL: ttl}) + lock2, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok2}, TTL: ttl}) require.NoError(t, err) // When I asynchronously release the locks... @@ -908,7 +908,7 @@ func testLocking(t *testing.T, newBackend Constructor) { } }() - lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) + lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockNameComponents: []string{tok1}, TTL: ttl}) require.NoError(t, err) require.Equal(t, int32(15), atomic.LoadInt32(&y)) require.NoError(t, lock.Release(ctx, uut)) diff --git a/lib/events/athena/consumer.go b/lib/events/athena/consumer.go index 2332667f4b21..df1a27f909a7 100644 --- a/lib/events/athena/consumer.go +++ b/lib/events/athena/consumer.go @@ -253,8 +253,8 @@ func (c *consumer) runContinuouslyOnSingleAuth(ctx context.Context, eventsProces default: err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ - Backend: c.backend, - LockName: "athena_lock", + Backend: c.backend, + LockNameComponents: []string{"athena_lock"}, // TTL is higher then batchMaxInterval because we want to optimize // for low backend writes. TTL: 5 * c.batchMaxInterval, diff --git a/lib/services/local/access.go b/lib/services/local/access.go index d985d81e17af..2afb23464734 100644 --- a/lib/services/local/access.go +++ b/lib/services/local/access.go @@ -327,9 +327,9 @@ func (s *AccessService) DeleteAllLocks(ctx context.Context) error { func (s *AccessService) ReplaceRemoteLocks(ctx context.Context, clusterName string, newRemoteLocks []types.Lock) error { return backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ - Backend: s.Backend, - LockName: "ReplaceRemoteLocks/" + clusterName, - TTL: time.Minute, + Backend: s.Backend, + LockNameComponents: []string{"ReplaceRemoteLocks", clusterName}, + TTL: time.Minute, }, }, func(ctx context.Context) error { remoteLocksKey := backend.ExactKey(locksPrefix, clusterName) diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index 5a155327c761..19e7c76d4934 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -18,7 +18,6 @@ package local import ( "context" - "strings" "time" "github.com/google/go-cmp/cmp" @@ -196,7 +195,7 @@ func (a *AccessListService) runOpWithLock(ctx context.Context, accessList *acces var err error if feature := modules.GetModules().Features(); !feature.IGSEnabled() { - err = a.service.RunWhileLocked(ctx, "createAccessListLimitLock", accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + err = a.service.RunWhileLocked(ctx, []string{"createAccessListLimitLock"}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { if err := a.VerifyAccessListCreateLimit(ctx, accessList.GetName()); err != nil { return trace.Wrap(err) } @@ -453,7 +452,7 @@ func (a *AccessListService) UpsertAccessListWithMembers(ctx context.Context, acc var err error if feature := modules.GetModules().Features(); !feature.IGSEnabled() { - err = a.service.RunWhileLocked(ctx, "createAccessListWithMembersLimitLock", accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { + err = a.service.RunWhileLocked(ctx, []string{"createAccessListWithMembersLimitLock"}, accessListLockTTL, func(ctx context.Context, _ backend.Backend) error { if err := a.VerifyAccessListCreateLimit(ctx, accessList.GetName()); err != nil { return trace.Wrap(err) } @@ -638,8 +637,8 @@ func (a *AccessListService) DeleteAllAccessListReviews(ctx context.Context) erro return trace.Wrap(a.reviewService.DeleteAllResources(ctx)) } -func lockName(accessListName string) string { - return strings.Join([]string{"access_list", accessListName}, string(backend.Separator)) +func lockName(accessListName string) []string { + return []string{"access_list", accessListName} } // VerifyAccessListCreateLimit ensures creating access list is limited to no more than 1 (updating is allowed). diff --git a/lib/services/local/externalauditstorage.go b/lib/services/local/externalauditstorage.go index f42e0ec33fca..9f6d3bc04921 100644 --- a/lib/services/local/externalauditstorage.go +++ b/lib/services/local/externalauditstorage.go @@ -83,9 +83,9 @@ func (s *ExternalAuditStorageService) CreateDraftExternalAuditStorage(ctx contex var lease *backend.Lease err = backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ - Backend: s.backend, - LockName: externalAuditStorageLockName, - TTL: externalAuditStorageLockTTL, + Backend: s.backend, + LockNameComponents: []string{externalAuditStorageLockName}, + TTL: externalAuditStorageLockTTL, }, }, func(ctx context.Context) error { // Check that the referenced AWS OIDC integration actually exists. @@ -122,9 +122,9 @@ func (s *ExternalAuditStorageService) UpsertDraftExternalAuditStorage(ctx contex var lease *backend.Lease err = backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ - Backend: s.backend, - LockName: externalAuditStorageLockName, - TTL: externalAuditStorageLockTTL, + Backend: s.backend, + LockNameComponents: []string{externalAuditStorageLockName}, + TTL: externalAuditStorageLockTTL, }, }, func(ctx context.Context) error { // Check that the referenced AWS OIDC integration actually exists. @@ -185,9 +185,9 @@ func (s *ExternalAuditStorageService) GetClusterExternalAuditStorage(ctx context func (s *ExternalAuditStorageService) PromoteToClusterExternalAuditStorage(ctx context.Context) error { err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ - Backend: s.backend, - LockName: externalAuditStorageLockName, - TTL: externalAuditStorageLockTTL, + Backend: s.backend, + LockNameComponents: []string{externalAuditStorageLockName}, + TTL: externalAuditStorageLockTTL, }, }, func(ctx context.Context) error { draft, err := s.GetDraftExternalAuditStorage(ctx) diff --git a/lib/services/local/generic/generic.go b/lib/services/local/generic/generic.go index 3682b7946aa2..2f4940fe69e0 100644 --- a/lib/services/local/generic/generic.go +++ b/lib/services/local/generic/generic.go @@ -421,14 +421,14 @@ func (s *Service[T]) MakeKey(name string) backend.Key { } // RunWhileLocked will run the given function in a backend lock. This is a wrapper around the backend.RunWhileLocked function. -func (s *Service[T]) RunWhileLocked(ctx context.Context, lockName string, ttl time.Duration, fn func(context.Context, backend.Backend) error) error { +func (s *Service[T]) RunWhileLocked(ctx context.Context, lockNameComponents []string, ttl time.Duration, fn func(context.Context, backend.Backend) error) error { return trace.Wrap(backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ - Backend: s.backend, - LockName: lockName, - TTL: ttl, - RetryInterval: s.runWhileLockedRetryInterval, + Backend: s.backend, + LockNameComponents: lockNameComponents, + TTL: ttl, + RetryInterval: s.runWhileLockedRetryInterval, }, }, func(ctx context.Context) error { return fn(ctx, s.backend) diff --git a/lib/services/local/generic/generic_test.go b/lib/services/local/generic/generic_test.go index a530dca74661..2b6d267db212 100644 --- a/lib/services/local/generic/generic_test.go +++ b/lib/services/local/generic/generic_test.go @@ -256,7 +256,7 @@ func TestGenericCRUD(t *testing.T) { require.ErrorIs(t, err, trace.NotFound(`generic resource "doesnotexist" doesn't exist`)) // Test running while locked. - err = service.RunWhileLocked(ctx, "test-lock", time.Second*5, func(ctx context.Context, backend backend.Backend) error { + err = service.RunWhileLocked(ctx, []string{"test-lock"}, time.Second*5, func(ctx context.Context, backend backend.Backend) error { item, err := backend.Get(ctx, service.MakeKey(r1.GetName())) require.NoError(t, err) diff --git a/lib/services/local/integrations.go b/lib/services/local/integrations.go index 5d8f340b9bde..9eb7c9099c09 100644 --- a/lib/services/local/integrations.go +++ b/lib/services/local/integrations.go @@ -124,9 +124,9 @@ func (s *IntegrationsService) DeleteIntegration(ctx context.Context, name string // so that no new EAS integrations can be concurrently created. err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ LockConfiguration: backend.LockConfiguration{ - Backend: s.backend, - LockName: externalAuditStorageLockName, - TTL: externalAuditStorageLockTTL, + Backend: s.backend, + LockNameComponents: []string{externalAuditStorageLockName}, + TTL: externalAuditStorageLockTTL, }, }, func(ctx context.Context) error { if err := notReferencedByEAS(ctx, s.backend, name); err != nil { diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go index adeb0f9cecb1..0c7f624d9e60 100644 --- a/lib/services/local/saml_idp_service_provider.go +++ b/lib/services/local/saml_idp_service_provider.go @@ -143,7 +143,7 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context return trace.Wrap(err) } - return trace.Wrap(s.svc.RunWhileLocked(ctx, samlIDPServiceProviderModifyLock, samlIDPServiceProviderModifyLockTTL, + return trace.Wrap(s.svc.RunWhileLocked(ctx, []string{samlIDPServiceProviderModifyLock}, samlIDPServiceProviderModifyLockTTL, func(ctx context.Context, backend backend.Backend) error { if err := s.ensureEntityIDIsUnique(ctx, sp); err != nil { return trace.Wrap(err) @@ -181,7 +181,7 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context return trace.Wrap(err) } - return trace.Wrap(s.svc.RunWhileLocked(ctx, samlIDPServiceProviderModifyLock, samlIDPServiceProviderModifyLockTTL, + return trace.Wrap(s.svc.RunWhileLocked(ctx, []string{samlIDPServiceProviderModifyLock}, samlIDPServiceProviderModifyLockTTL, func(ctx context.Context, backend backend.Backend) error { if err := s.ensureEntityIDIsUnique(ctx, sp); err != nil { return trace.Wrap(err)