From d4e7683564cc469e4cb4df2be6cf5d69adf19812 Mon Sep 17 00:00:00 2001 From: Todd Date: Fri, 15 Dec 2023 15:10:07 -0800 Subject: [PATCH] Identify resourceless responses as coming from supported or unsupported controllers --- .../clientcache/cmd/search/search_test.go | 1 - .../clientcache/internal/cache/refresh.go | 4 +- .../internal/cache/refresh_test.go | 78 +++++++++---------- .../cache/repository_refresh_token.go | 4 +- .../internal/cache/repository_sessions.go | 55 ++++++++----- .../internal/cache/repository_targets.go | 64 ++++++++------- .../clientcache/internal/daemon/testing.go | 4 +- 7 files changed, 113 insertions(+), 97 deletions(-) diff --git a/internal/clientcache/cmd/search/search_test.go b/internal/clientcache/cmd/search/search_test.go index 453ba311bc..2536bdc1c4 100644 --- a/internal/clientcache/cmd/search/search_test.go +++ b/internal/clientcache/cmd/search/search_test.go @@ -154,7 +154,6 @@ func TestSearch(t *testing.T) { t.Run("unsupported boundary instance", func(t *testing.T) { srv.AddUnsupportedCachingData(t, unsupportedAt, boundaryTokenReaderFn) - resp, r, apiErr, err := search(ctx, srv.BaseDotDir(), filterBy{ authTokenId: unsupportedAt.Id, resource: "targets", diff --git a/internal/clientcache/internal/cache/refresh.go b/internal/clientcache/internal/cache/refresh.go index 3140d34d36..e19f087a9a 100644 --- a/internal/clientcache/internal/cache/refresh.go +++ b/internal/clientcache/internal/cache/refresh.go @@ -301,14 +301,14 @@ func (r *RefreshService) RecheckCachingSupport(ctx context.Context, opt ...Optio } if err := r.repo.checkCachingTargets(ctx, u, tokens, opt...); err != nil { - if err == errRefreshNotSupported { + if err == ErrRefreshNotSupported { // This is expected so no need to propogate the error up continue } retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) } if err := r.repo.checkCachingSessions(ctx, u, tokens, opt...); err != nil { - if err == errRefreshNotSupported { + if err == ErrRefreshNotSupported { // This is expected so no need to propogate the error up continue } diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index 3211db4613..44e4c242d2 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -54,9 +54,9 @@ func testStaticResourceRetrievalFunc[T any](t *testing.T, ret [][]T, removed [][ // testNoRefreshRetrievalFunc simulates a controller that doesn't support refresh // since it does not return any refresh token. -func testNoRefreshRetrievalFunc[T any](t *testing.T, ret []T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { +func testNoRefreshRetrievalFunc[T any](t *testing.T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { return func(_ context.Context, _, _ string, _ RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { - return ret, nil, "", nil + return nil, nil, "", ErrRefreshNotSupported } } @@ -471,9 +471,9 @@ func TestRefreshForSearch(t *testing.T) { // Get the first set of resources, but no refresh tokens err = rs.Refresh(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets))) - assert.ErrorContains(t, err, errRefreshNotSupported.Error()) + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) + assert.ErrorContains(t, err, ErrRefreshNotSupported.Error()) got, err := r.ListTargets(ctx, at.Id) assert.NoError(t, err) @@ -483,13 +483,13 @@ func TestRefreshForSearch(t *testing.T) { // wont be refreshed any more, and we wont see the error when refreshing // any more. err = rs.Refresh(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets))) + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) assert.Nil(t, err) err = rs.RecheckCachingSupport(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets))) + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) assert.Nil(t, err) got, err = r.ListTargets(ctx, at.Id) @@ -499,9 +499,9 @@ func TestRefreshForSearch(t *testing.T) { // Now simulate the controller updating to support refresh tokens and // the resources starting to be cached. err = rs.RecheckCachingSupport(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{retTargets}, [][]string{{}}))) - assert.Nil(t, err) + assert.Nil(t, err, err) got, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) @@ -799,23 +799,20 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, err) require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) - retTargets := []*targets.Target{ - target("1"), - } // Since this user doesn't have any resources, the user's data will still // only get updated with a call to Refresh. assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets)))) + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) got, err := r.ListTargets(ctx, at.Id) assert.NoError(t, err) assert.Empty(t, got) err = rs.Refresh(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets))) - assert.ErrorIs(t, err, errRefreshNotSupported) + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) + assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) @@ -823,8 +820,8 @@ func TestRecheckCachingSupport(t *testing.T) { // now a full fetch will work since the user has resources and no refresh token assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets)))) + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) }) t.Run("sessions", func(t *testing.T) { @@ -836,29 +833,26 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, err) require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) - retSess := []*sessions.Session{ - session("1"), - } assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, retSess)))) + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err := r.ListSessions(ctx, at.Id) assert.NoError(t, err) assert.Empty(t, got) err = rs.Refresh(ctx, - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, retSess))) - assert.ErrorIs(t, err, errRefreshNotSupported) + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) assert.Empty(t, got) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, retSess)))) + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) assert.Empty(t, got) @@ -874,13 +868,13 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) err = rs.Refresh(ctx, - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, []*targets.Target{target("1")})), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, []*sessions.Session{session("1")}))) - assert.ErrorIs(t, err, errRefreshNotSupported) + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + assert.ErrorIs(t, err, ErrRefreshNotSupported) innerErr := errors.New("test error") err = rs.RecheckCachingSupport(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) @@ -889,8 +883,8 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorContains(t, err, innerErr.Error()) err = rs.RecheckCachingSupport(ctx, - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)), - WithSessionRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) return nil, nil, "", innerErr @@ -908,9 +902,9 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) err = rs.Refresh(ctx, - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, []*targets.Target{target("1")})), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, []*sessions.Session{session("1")}))) - assert.ErrorIs(t, err, errRefreshNotSupported) + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + assert.ErrorIs(t, err, ErrRefreshNotSupported) // Remove the token from the keyring, see that we can still see the // token and then user until a Refresh happens which causes them to be @@ -926,8 +920,8 @@ func TestRecheckCachingSupport(t *testing.T) { assert.Len(t, us, 1) err = rs.RecheckCachingSupport(ctx, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil))) + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) assert.NoError(t, err) ps, err = r.listTokens(ctx, u) diff --git a/internal/clientcache/internal/cache/repository_refresh_token.go b/internal/clientcache/internal/cache/repository_refresh_token.go index 0570093795..b80ce9b747 100644 --- a/internal/clientcache/internal/cache/repository_refresh_token.go +++ b/internal/clientcache/internal/cache/repository_refresh_token.go @@ -14,9 +14,9 @@ import ( "github.com/hashicorp/boundary/internal/util" ) -// errRefreshNotSupported is returned whenever it is determined a boundary +// ErrRefreshNotSupported is returned whenever it is determined a boundary // instance does not support refresh tokens. -var errRefreshNotSupported = stderrors.New("refresh tokens are not supported for this controller") +var ErrRefreshNotSupported = stderrors.New("refresh tokens are not supported for this controller") // RefreshTokenValue is the the type for the actual refresh token value handled // by the client cache. diff --git a/internal/clientcache/internal/cache/repository_sessions.go b/internal/clientcache/internal/cache/repository_sessions.go index a4a82e54f2..04fbb411d5 100644 --- a/internal/clientcache/internal/cache/repository_sessions.go +++ b/internal/clientcache/internal/cache/repository_sessions.go @@ -40,6 +40,9 @@ func defaultSessionFunc(ctx context.Context, addr, authTok string, refreshTok Re } return nil, nil, "", errors.Wrap(ctx, err, op) } + if l.ResponseType == "" { + return nil, nil, "", ErrRefreshNotSupported + } return l.Items, l.RemovedIds, RefreshTokenValue(l.ListToken), nil } @@ -78,6 +81,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au var gotResponse bool var resp []*sessions.Session var newRefreshToken RefreshTokenValue + var unsupportedCacheRequest bool var removedIds []string var retErr error for at, t := range tokens { @@ -91,8 +95,12 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "") } if err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) - continue + if err == ErrRefreshNotSupported { + unsupportedCacheRequest = true + } else { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) + continue + } } gotResponse = true break @@ -122,6 +130,10 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au } } switch { + case unsupportedCacheRequest: + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } case newRefreshToken != "": if err := upsertSessions(ctx, w, u, resp); err != nil { return err @@ -129,21 +141,17 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { return err } - case len(resp) > 0: - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err - } - return errRefreshNotSupported - case len(resp) == 0: - if err := deleteRefreshToken(ctx, w, u, resourceType); err != nil { - return err - } + default: + // controller supports caching, but doesn't have any resources } return nil }) if err != nil { return errors.Wrap(ctx, err, op) } + if unsupportedCacheRequest { + return ErrRefreshNotSupported + } return nil } @@ -174,12 +182,17 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m var gotResponse bool var resp []*sessions.Session var newRefreshToken RefreshTokenValue + var unsupportedCacheRequest bool var retErr error for at, t := range tokens { resp, _, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "") if err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) - continue + if err == ErrRefreshNotSupported { + unsupportedCacheRequest = true + } else { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) + continue + } } gotResponse = true break @@ -196,6 +209,10 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { switch { + case unsupportedCacheRequest: + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } case newRefreshToken != "": if _, err := w.Exec(ctx, "delete from session where fk_user_id = @fk_user_id", []any{sql.Named("fk_user_id", u.Id)}); err != nil { @@ -207,12 +224,9 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { return err } - case len(resp) > 0: - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err - } - return errRefreshNotSupported - case len(resp) == 0: + default: + // This is no longer flagged as not supported, but we dont have a + // refresh token so clear out any refresh token we have stored. if err := deleteRefreshToken(ctx, w, u, resourceType); err != nil { return err } @@ -222,6 +236,9 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m if err != nil { return errors.Wrap(ctx, err, op) } + if unsupportedCacheRequest { + return ErrRefreshNotSupported + } return nil } diff --git a/internal/clientcache/internal/cache/repository_targets.go b/internal/clientcache/internal/cache/repository_targets.go index 4cea6e6bfc..e042f5d532 100644 --- a/internal/clientcache/internal/cache/repository_targets.go +++ b/internal/clientcache/internal/cache/repository_targets.go @@ -40,6 +40,9 @@ func defaultTargetFunc(ctx context.Context, addr, authTok string, refreshTok Ref } return nil, nil, "", errors.Wrap(ctx, err, op) } + if l.ResponseType == "" { + return nil, nil, "", ErrRefreshNotSupported + } return l.Items, l.RemovedIds, RefreshTokenValue(l.ListToken), nil } @@ -78,6 +81,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut var resp []*targets.Target var removedIds []string var newRefreshToken RefreshTokenValue + var unsupportedCacheRequest bool var retErr error for at, t := range tokens { resp, removedIds, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, oldRefreshTokenVal) @@ -90,8 +94,12 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut resp, removedIds, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, "") } if err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) - continue + if err == ErrRefreshNotSupported { + unsupportedCacheRequest = true + } else { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) + continue + } } gotResponse = true break @@ -105,11 +113,10 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut return retErr } - var unsupportedCacheRequest bool event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d targets for user %v", len(resp), u)) _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { switch { - case oldRefreshToken == nil: + case oldRefreshToken == nil || unsupportedCacheRequest: if _, err := w.Exec(ctx, "delete from target where fk_user_id = @fk_user_id", []any{sql.Named("fk_user_id", u.Id)}); err != nil { return err @@ -121,6 +128,10 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut } } switch { + case unsupportedCacheRequest: + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } case newRefreshToken != "": if err := upsertTargets(ctx, w, u, resp); err != nil { return err @@ -128,15 +139,8 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { return err } - case len(resp) > 0: - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err - } - unsupportedCacheRequest = true - case len(resp) == 0: - if err := deleteRefreshToken(ctx, w, u, resourceType); err != nil { - return err - } + default: + // controller supports caching, but doesn't have any resources } return nil }) @@ -144,7 +148,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut return errors.Wrap(ctx, err, op) } if unsupportedCacheRequest { - return errRefreshNotSupported + return ErrRefreshNotSupported } return nil } @@ -177,12 +181,17 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma var gotResponse bool var resp []*targets.Target var newRefreshToken RefreshTokenValue + var unsupportedCacheRequest bool var retErr error for at, t := range tokens { resp, _, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, "") if err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) - continue + if err == ErrRefreshNotSupported { + unsupportedCacheRequest = true + } else { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) + continue + } } gotResponse = true break @@ -196,9 +205,14 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma return retErr } - var unsupportedCacheRequest bool _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { switch { + case unsupportedCacheRequest: + // Since we know the controller doesn't support caching, we mark the + // user as unable to cache the data. + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } case newRefreshToken != "": // Now that there is a refresh token, the data can be cached, so // cache it and store the refresh token for future refreshes. @@ -212,17 +226,9 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { return err } - case len(resp) > 0: - // There is no refresh token but there is data, so we add the - // sentinel refresh token which marks this user as unable to cache - // this data. - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err - } - unsupportedCacheRequest = true - case len(resp) == 0: - // removing all refresh tokens for this resource is equivalent to - // saying we do not know if the data can be cached. + default: + // We know the controller supports caching, but doesn't have a + // refresh token so clear out any refresh token we have for this resource. if err := deleteRefreshToken(ctx, w, u, resourceType); err != nil { return err } @@ -233,7 +239,7 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma return errors.Wrap(ctx, err, op) } if unsupportedCacheRequest { - return errRefreshNotSupported + return ErrRefreshNotSupported } return nil } diff --git a/internal/clientcache/internal/daemon/testing.go b/internal/clientcache/internal/daemon/testing.go index 799347b784..649bde1b65 100644 --- a/internal/clientcache/internal/daemon/testing.go +++ b/internal/clientcache/internal/daemon/testing.go @@ -122,13 +122,13 @@ func (s *TestServer) AddUnsupportedCachingData(t *testing.T, p *authtokens.AuthT } return []*targets.Target{ {Id: "ttcp_unsupported", Name: "unsupported", Description: "not supported"}, - }, nil, "", nil + }, nil, "", cache.ErrRefreshNotSupported } sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) { if tok != p.Token { return nil, nil, "", nil } - return []*sessions.Session{}, nil, "", nil + return []*sessions.Session{}, nil, "", cache.ErrRefreshNotSupported } rs, err := cache.NewRefreshService(ctx, r, 0, 0) require.NoError(t, err)