Skip to content

Commit

Permalink
Identify resourceless responses as coming from supported or unsupport…
Browse files Browse the repository at this point in the history
…ed controllers
  • Loading branch information
talanknight committed Dec 18, 2023
1 parent 509969c commit d4e7683
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 97 deletions.
1 change: 0 additions & 1 deletion internal/clientcache/cmd/search/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ func TestSearch(t *testing.T) {

t.Run("unsupported boundary instance", func(t *testing.T) {
srv.AddUnsupportedCachingData(t, unsupportedAt, boundaryTokenReaderFn)

resp, r, apiErr, err := search(ctx, srv.BaseDotDir(), filterBy{
authTokenId: unsupportedAt.Id,
resource: "targets",
Expand Down
4 changes: 2 additions & 2 deletions internal/clientcache/internal/cache/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,14 @@ func (r *RefreshService) RecheckCachingSupport(ctx context.Context, opt ...Optio
}

if err := r.repo.checkCachingTargets(ctx, u, tokens, opt...); err != nil {
if err == errRefreshNotSupported {
if err == ErrRefreshNotSupported {
// This is expected so no need to propogate the error up
continue
}
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))
}
if err := r.repo.checkCachingSessions(ctx, u, tokens, opt...); err != nil {
if err == errRefreshNotSupported {
if err == ErrRefreshNotSupported {
// This is expected so no need to propogate the error up
continue
}
Expand Down
78 changes: 36 additions & 42 deletions internal/clientcache/internal/cache/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ func testStaticResourceRetrievalFunc[T any](t *testing.T, ret [][]T, removed [][

// testNoRefreshRetrievalFunc simulates a controller that doesn't support refresh
// since it does not return any refresh token.
func testNoRefreshRetrievalFunc[T any](t *testing.T, ret []T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) {
func testNoRefreshRetrievalFunc[T any](t *testing.T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) {
return func(_ context.Context, _, _ string, _ RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) {
return ret, nil, "", nil
return nil, nil, "", ErrRefreshNotSupported
}
}

Expand Down Expand Up @@ -471,9 +471,9 @@ func TestRefreshForSearch(t *testing.T) {

// Get the first set of resources, but no refresh tokens
err = rs.Refresh(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets)))
assert.ErrorContains(t, err, errRefreshNotSupported.Error())
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))
assert.ErrorContains(t, err, ErrRefreshNotSupported.Error())

got, err := r.ListTargets(ctx, at.Id)
assert.NoError(t, err)
Expand All @@ -483,13 +483,13 @@ func TestRefreshForSearch(t *testing.T) {
// wont be refreshed any more, and we wont see the error when refreshing
// any more.
err = rs.Refresh(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets)))
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))
assert.Nil(t, err)

err = rs.RecheckCachingSupport(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets)))
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))
assert.Nil(t, err)

got, err = r.ListTargets(ctx, at.Id)
Expand All @@ -499,9 +499,9 @@ func TestRefreshForSearch(t *testing.T) {
// Now simulate the controller updating to support refresh tokens and
// the resources starting to be cached.
err = rs.RecheckCachingSupport(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)),
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{retTargets}, [][]string{{}})))
assert.Nil(t, err)
assert.Nil(t, err, err)

got, err = r.ListTargets(ctx, at.Id)
assert.NoError(t, err)
Expand Down Expand Up @@ -799,32 +799,29 @@ func TestRecheckCachingSupport(t *testing.T) {
require.NoError(t, err)
require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id}))

retTargets := []*targets.Target{
target("1"),
}
// Since this user doesn't have any resources, the user's data will still
// only get updated with a call to Refresh.
assert.NoError(t, rs.RecheckCachingSupport(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets))))
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))))

got, err := r.ListTargets(ctx, at.Id)
assert.NoError(t, err)
assert.Empty(t, got)

err = rs.Refresh(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets)))
assert.ErrorIs(t, err, errRefreshNotSupported)
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))
assert.ErrorIs(t, err, ErrRefreshNotSupported)

got, err = r.ListTargets(ctx, at.Id)
assert.NoError(t, err)
assert.Empty(t, got)

// now a full fetch will work since the user has resources and no refresh token
assert.NoError(t, rs.RecheckCachingSupport(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, retTargets))))
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))))
})

t.Run("sessions", func(t *testing.T) {
Expand All @@ -836,29 +833,26 @@ func TestRecheckCachingSupport(t *testing.T) {
require.NoError(t, err)
require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id}))

retSess := []*sessions.Session{
session("1"),
}
assert.NoError(t, rs.RecheckCachingSupport(ctx,
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, retSess))))
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))

got, err := r.ListSessions(ctx, at.Id)
assert.NoError(t, err)
assert.Empty(t, got)

err = rs.Refresh(ctx,
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, retSess)))
assert.ErrorIs(t, err, errRefreshNotSupported)
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))
assert.ErrorIs(t, err, ErrRefreshNotSupported)

got, err = r.ListSessions(ctx, at.Id)
assert.NoError(t, err)
assert.Empty(t, got)

assert.NoError(t, rs.RecheckCachingSupport(ctx,
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, retSess))))
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))
got, err = r.ListSessions(ctx, at.Id)
assert.NoError(t, err)
assert.Empty(t, got)
Expand All @@ -874,13 +868,13 @@ func TestRecheckCachingSupport(t *testing.T) {
require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id}))

err = rs.Refresh(ctx,
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, []*targets.Target{target("1")})),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, []*sessions.Session{session("1")})))
assert.ErrorIs(t, err, errRefreshNotSupported)
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))
assert.ErrorIs(t, err, ErrRefreshNotSupported)

innerErr := errors.New("test error")
err = rs.RecheckCachingSupport(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) {
require.Equal(t, boundaryAddr, addr)
require.Equal(t, at.Token, token)
Expand All @@ -889,8 +883,8 @@ func TestRecheckCachingSupport(t *testing.T) {
assert.ErrorContains(t, err, innerErr.Error())

err = rs.RecheckCachingSupport(ctx,
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)),
WithSessionRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) {
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) {
require.Equal(t, boundaryAddr, addr)
require.Equal(t, at.Token, token)
return nil, nil, "", innerErr
Expand All @@ -908,9 +902,9 @@ func TestRecheckCachingSupport(t *testing.T) {
require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id}))

err = rs.Refresh(ctx,
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, []*targets.Target{target("1")})),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, []*sessions.Session{session("1")})))
assert.ErrorIs(t, err, errRefreshNotSupported)
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)),
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))
assert.ErrorIs(t, err, ErrRefreshNotSupported)

// Remove the token from the keyring, see that we can still see the
// token and then user until a Refresh happens which causes them to be
Expand All @@ -926,8 +920,8 @@ func TestRecheckCachingSupport(t *testing.T) {
assert.Len(t, us, 1)

err = rs.RecheckCachingSupport(ctx,
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t, nil)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t, nil)))
WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)),
WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))
assert.NoError(t, err)

ps, err = r.listTokens(ctx, u)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (
"github.com/hashicorp/boundary/internal/util"
)

// errRefreshNotSupported is returned whenever it is determined a boundary
// ErrRefreshNotSupported is returned whenever it is determined a boundary
// instance does not support refresh tokens.
var errRefreshNotSupported = stderrors.New("refresh tokens are not supported for this controller")
var ErrRefreshNotSupported = stderrors.New("refresh tokens are not supported for this controller")

// RefreshTokenValue is the the type for the actual refresh token value handled
// by the client cache.
Expand Down
55 changes: 36 additions & 19 deletions internal/clientcache/internal/cache/repository_sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func defaultSessionFunc(ctx context.Context, addr, authTok string, refreshTok Re
}
return nil, nil, "", errors.Wrap(ctx, err, op)
}
if l.ResponseType == "" {
return nil, nil, "", ErrRefreshNotSupported
}
return l.Items, l.RemovedIds, RefreshTokenValue(l.ListToken), nil
}

Expand Down Expand Up @@ -78,6 +81,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au
var gotResponse bool
var resp []*sessions.Session
var newRefreshToken RefreshTokenValue
var unsupportedCacheRequest bool
var removedIds []string
var retErr error
for at, t := range tokens {
Expand All @@ -91,8 +95,12 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au
resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "")
}
if err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
if err == ErrRefreshNotSupported {
unsupportedCacheRequest = true
} else {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}
}
gotResponse = true
break
Expand Down Expand Up @@ -122,28 +130,28 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au
}
}
switch {
case unsupportedCacheRequest:
if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil {
return err
}
case newRefreshToken != "":
if err := upsertSessions(ctx, w, u, resp); err != nil {
return err
}
if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil {
return err
}
case len(resp) > 0:
if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil {
return err
}
return errRefreshNotSupported
case len(resp) == 0:
if err := deleteRefreshToken(ctx, w, u, resourceType); err != nil {
return err
}
default:
// controller supports caching, but doesn't have any resources
}
return nil
})
if err != nil {
return errors.Wrap(ctx, err, op)
}
if unsupportedCacheRequest {
return ErrRefreshNotSupported
}
return nil
}

Expand Down Expand Up @@ -174,12 +182,17 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m
var gotResponse bool
var resp []*sessions.Session
var newRefreshToken RefreshTokenValue
var unsupportedCacheRequest bool
var retErr error
for at, t := range tokens {
resp, _, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "")
if err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
if err == ErrRefreshNotSupported {
unsupportedCacheRequest = true
} else {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}
}
gotResponse = true
break
Expand All @@ -196,6 +209,10 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m

_, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error {
switch {
case unsupportedCacheRequest:
if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil {
return err
}
case newRefreshToken != "":
if _, err := w.Exec(ctx, "delete from session where fk_user_id = @fk_user_id",
[]any{sql.Named("fk_user_id", u.Id)}); err != nil {
Expand All @@ -207,12 +224,9 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m
if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil {
return err
}
case len(resp) > 0:
if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil {
return err
}
return errRefreshNotSupported
case len(resp) == 0:
default:
// This is no longer flagged as not supported, but we dont have a
// refresh token so clear out any refresh token we have stored.
if err := deleteRefreshToken(ctx, w, u, resourceType); err != nil {
return err
}
Expand All @@ -222,6 +236,9 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m
if err != nil {
return errors.Wrap(ctx, err, op)
}
if unsupportedCacheRequest {
return ErrRefreshNotSupported
}
return nil
}

Expand Down
Loading

0 comments on commit d4e7683

Please sign in to comment.