diff --git a/internal/clientcache/internal/cache/refresh.go b/internal/clientcache/internal/cache/refresh.go index a7193ac963..f2961b5e1a 100644 --- a/internal/clientcache/internal/cache/refresh.go +++ b/internal/clientcache/internal/cache/refresh.go @@ -364,6 +364,21 @@ func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error { if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil { retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) } + { + // now, we need to refreshTargetsDuringRefreshWindow, so we don't + // have to wait for the refresh token to expire before refreshing + // the targets. + + // notice that we still have the "targets" cache key locked, + // since we don't want to run this in parallel with the regular + // targets refresh. + semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.refreshTargetsDuringRefreshWindow(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + } + } + } semaphore.(*atomic.Bool).Store(false) } diff --git a/internal/clientcache/internal/cache/repository_targets.go b/internal/clientcache/internal/cache/repository_targets.go index 56ccc16f20..615c713df4 100644 --- a/internal/clientcache/internal/cache/repository_targets.go +++ b/internal/clientcache/internal/cache/repository_targets.go @@ -9,6 +9,7 @@ import ( "encoding/json" stderrors "errors" "fmt" + "time" "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/targets" @@ -278,8 +279,182 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma return nil } +// refreshTargetsDuringRefreshWindow refreshes the targets for the provided user +// using the provided auth tokens. This function is intended to be used during +// the refresh window for the user's current target refresh token. It will store +// the targets in the target_refresh_window table and then swap the resources in +// the target_refresh_window for the user's resources in the target table. +// +// IMPORTANT: the caller is responsible for ensuring that there's no inflight +// refresh into the target_refresh_window. See +// cache.RefreshService.syncSemaphore for how to ensure there's no inflight +// refresh into the target_refresh_window table +// +// if the existing refresh token is not within the refresh window, this function +// will return without doing anything. +// +// if controller list target doesn't support refresh token, we will not refresh +// and we will NOT return an error +func (r *Repository) refreshTargetsDuringRefreshWindow(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error { + const ( + refreshWindowLookBackInDays = -10 + op = "cache.(Repository).refreshTargetsIntoTmpTable" + emptyRefreshToken = "" + deleteRemovedTargets = "delete from %s where id in @ids" + ) + switch { + case util.IsNil(u): + return errors.New(ctx, errors.InvalidParameter, op, "user is nil") + case u.Id == "": + return errors.New(ctx, errors.InvalidParameter, op, "user id is missing") + } + + existingRefreshToken, err := r.lookupRefreshToken(ctx, u, targetResourceType) + if err != nil { + return errors.Wrap(ctx, err, op) + } + // check if existing refresh token will expire within the refresh window + // look back (next 10 days) and if not then just return + if existingRefreshToken != nil { + fmt.Println("*********************") + fmt.Println("existing...: ", existingRefreshToken.CreateTime) + fmt.Println("10 days ago: ", time.Now().AddDate(0, 0, refreshWindowLookBackInDays)) + fmt.Println("expr.......: ", existingRefreshToken.CreateTime.After(time.Now().AddDate(0, 0, refreshWindowLookBackInDays))) + } + if existingRefreshToken != nil && !existingRefreshToken.CreateTime.After(time.Now().AddDate(0, 0, refreshWindowLookBackInDays)) { + return nil + } + + opts, err := getOpts(opt...) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if opts.withTargetRetrievalFunc == nil { + opts.withTargetRetrievalFunc = defaultTargetFunc + } + + var gotResponse bool + var currentPage *targets.TargetListResult + var newRefreshToken RefreshTokenValue + var foundAuthToken string + var retErr error + for at, t := range tokens { + // we want to start from the beginning of the list using an empty refresh token + currentPage, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, emptyRefreshToken, currentPage) + if err != nil { + if err == ErrRefreshNotSupported { + return nil // this particular controller doesn't support refresh tokens and it's not an error, so just return + } else { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) + continue + } + } + foundAuthToken = t + gotResponse = true + break + } + // if we got an error, then we need to save it so it can be retrieved later + // via the status API by users + if retErr != nil { + if saveErr := r.saveError(r.serverCtx, u, targetResourceType, retErr); saveErr != nil { + return stderrors.Join(err, errors.Wrap(ctx, saveErr, op)) + } + } + if !gotResponse { + return retErr + } + + tmpTblName, err := tempTableName(ctx, &Target{}) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + clearTmpTableFn := func(w db.Writer) error { + if _, err := w.Exec(ctx, fmt.Sprintf("delete from %s where fk_user_id = @fk_user_id", tmpTblName), + []any{sql.Named("fk_user_id", u.Id)}); err != nil { + return err + } + return nil + } + + // first, clear out the temporary table + if _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { + return clearTmpTableFn(w) + }); err != nil { + return errors.Wrap(ctx, err, op) + } + + // okay we have the data, let's store it in a temporary table + var numDeleted int + var numUpserted int + for { + if _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { + switch { + case newRefreshToken != "": + if err := upsertTargets(ctx, w, u, currentPage.Items, WithTable(tmpTblName)); err != nil { + return err + } + numUpserted += len(currentPage.Items) + // we need to store the refresh token for the next page of + // targets, it's okay to do this since we are the only inflight + // refresh. Note the caller must ensure there's no inflight + // refresh! See the comment at the top of this function. + if err := upsertRefreshToken(ctx, w, u, targetResourceType, newRefreshToken); err != nil { + return err + } + default: + // controller supports caching, but doesn't have any resources + } + if len(currentPage.RemovedIds) > 0 { + if numDeleted, err = w.Exec(ctx, fmt.Sprintf(deleteRemovedTargets, tmpTblName), + []any{sql.Named("ids", currentPage.RemovedIds)}); err != nil { + return err + } + } + return nil // we're done with this DoTx(...) containing a "page" of targets + }); err != nil { + return errors.Wrap(ctx, err, op) + } + if currentPage.ResponseType == "" || currentPage.ResponseType == "complete" { + break + } + currentPage, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, foundAuthToken, newRefreshToken, currentPage) + if err != nil { + break + } + } + + // now that we have the data in the temporary table, let's swap it with the + // user's target table + _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { + // first, delete the existing targets which were cached for the user + if _, err := w.Exec(ctx, "delete from target where fk_user_id = @fk_user_id", + []any{sql.Named("fk_user_id", u.Id)}); err != nil { + return err + } + // now, insert the targets from the temporary table into the user's + // target table + var numMoved int + if numMoved, err = w.Exec(ctx, fmt.Sprintf("insert into target select * from %s where fk_user_id = @fk_user_id", tmpTblName), + []any{sql.Named("fk_user_id", u.Id)}); err != nil { + return err + } + if numMoved != numUpserted { + return errors.New(ctx, errors.Internal, op, fmt.Sprintf("number of targets moved %d doesn't match number of targets upserted %d", numMoved, numUpserted)) + } + // finally, delete the rows from the temporary table + return clearTmpTableFn(w) + }) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + event.WriteSysEvent(ctx, op, "targets updated", "deleted", numDeleted, "upserted", numUpserted, "user_id", u.Id) + return nil +} + // upsertTargets upserts the provided targets to be stored for the provided user. -func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Target) error { +func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Target, opt ...Option) error { const op = "cache.upsertTargets" switch { case util.IsNil(w): @@ -290,6 +465,11 @@ func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Targ return errors.New(ctx, errors.InvalidParameter, op, "user is nil") } + opts, err := getOpts(opt...) + if err != nil { + return errors.Wrap(ctx, err, op) + } + for _, t := range in { item, err := json.Marshal(t) if err != nil { @@ -309,7 +489,11 @@ func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Targ Target: db.Columns{"fk_user_id", "id"}, Action: db.SetColumns([]string{"name", "description", "address", "scope_id", "type", "item"}), } - if err := w.Create(ctx, newTarget, db.WithOnConflict(&onConflict)); err != nil { + dbOpts := []db.Option{db.WithOnConflict(&onConflict)} + if opts.withTable != "" { + dbOpts = append(dbOpts, db.WithTable(opts.withTable)) + } + if err := w.Create(ctx, newTarget, dbOpts...); err != nil { return errors.Wrap(ctx, err, op) } } diff --git a/internal/clientcache/internal/cache/repository_targets_test.go b/internal/clientcache/internal/cache/repository_targets_test.go index deb87d692b..9ae65431a5 100644 --- a/internal/clientcache/internal/cache/repository_targets_test.go +++ b/internal/clientcache/internal/cache/repository_targets_test.go @@ -6,9 +6,11 @@ package cache import ( "context" "encoding/json" + "fmt" "strconv" "sync" "testing" + "time" "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/authtokens" @@ -164,6 +166,187 @@ func TestRepository_refreshTargets(t *testing.T) { } } +func TestRepository_refreshTargetsDuringRefreshWindow(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + s.Debug(true) + + addr := "address" + u := user{ + Id: "u1", + } + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + fmt.Println("add keyring") + + ts := []*targets.Target{ + { + Id: "ttcp_1", + Name: "name1", + Description: "description1", + Type: "tcp", + ScopeId: "p_123", + SessionMaxSeconds: 111, + }, + { + Id: "ttcp_2", + Name: "name2", + Address: "address2", + Type: "tcp", + ScopeId: "p_123", + SessionMaxSeconds: 222, + }, + { + Id: "ttcp_3", + Name: "name3", + Address: "address3", + Type: "tcp", + ScopeId: "p_123", + SessionMaxSeconds: 333, + }, + } + var want []*Target + for _, tar := range ts { + ti, err := json.Marshal(tar) + require.NoError(t, err) + want = append(want, &Target{ + FkUserId: u.Id, + Id: tar.Id, + Name: tar.Name, + Description: tar.Description, + Address: tar.Address, + ScopeId: tar.ScopeId, + Type: tar.Type, + Item: string(ti), + }) + } + + defaultCleanupFn := func() { + refTok := &refreshToken{ + UserId: at.UserId, + ResourceType: targetResourceType, + } + _, err := r.rw.Delete(ctx, refTok) + require.NoError(t, err) + } + cases := []struct { + name string + u *user + targets []*targets.Target + want []*Target + setup func() + cleanup func() + errorContains string + }{ + { + name: "success-no-token", + u: &user{ + Id: at.UserId, + Address: addr, + }, + targets: ts, + want: want, + cleanup: defaultCleanupFn, + }, + // this test case must run after the above test case (success-no-token) + // so as to exercise running with an existing token which is within the window. + { + name: "success-with-existing-token-within-10-day-window", + u: &user{ + Id: at.UserId, + Address: addr, + }, + targets: ts, + cleanup: defaultCleanupFn, + setup: func() { + // insert refresh token + refTok := &refreshToken{ + UserId: at.UserId, + ResourceType: targetResourceType, + RefreshToken: "1", + UpdateTime: time.Now(), + CreateTime: time.Now(), + } + require.NoError(t, db.New(s).Create(ctx, refTok)) + }, + want: want, + }, + { + name: "no-op-with-existing-token-outside-10-day-window", + u: &user{ + Id: at.UserId, + Address: addr, + }, + targets: ts, + cleanup: defaultCleanupFn, + setup: func() { + n := time.Now().AddDate(0, 0, -11) + // insert refresh token + refTok := &refreshToken{ + UserId: at.UserId, + ResourceType: targetResourceType, + RefreshToken: "1", + UpdateTime: n, + CreateTime: n, + } + fmt.Println("inserting") + require.NoError(t, db.New(s).Create(ctx, refTok)) + fmt.Println("done") + }, + want: want, + }, + { + name: "nil user", + u: nil, + targets: ts, + errorContains: "user is nil", + }, + { + name: "missing user Id", + u: &user{ + Address: addr, + }, + targets: ts, + errorContains: "user id is missing", + }, + } + // u1 / target / + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.cleanup != nil { + t.Cleanup(tc.cleanup) + } + if tc.setup != nil { + tc.setup() + } + err := r.refreshTargetsDuringRefreshWindow(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{tc.targets}, [][]string{nil})))) + if tc.errorContains == "" { + assert.NoError(t, err) + rw := db.New(s) + var got []*Target + require.NoError(t, rw.SearchWhere(ctx, &got, "true", nil)) + assert.ElementsMatch(t, got, tc.want) + + } else { + assert.ErrorContains(t, err, tc.errorContains) + } + }) + } +} + func TestRepository_RefreshTargets_InvalidListTokenError(t *testing.T) { ctx := context.Background() s, err := cachedb.Open(ctx) diff --git a/internal/clientcache/internal/cache/temp_table.go b/internal/clientcache/internal/cache/temp_table.go new file mode 100644 index 0000000000..d51990b686 --- /dev/null +++ b/internal/clientcache/internal/cache/temp_table.go @@ -0,0 +1,36 @@ +package cache + +import ( + "context" + "fmt" + "strings" + + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/util" +) + +type resourceTabler interface { + TableName() string +} + +const ( + tmpTblSuffix = "_refresh_window" + targetTblName = "target" + sessionTblName = "session" + resolvableAliasTblName = "resolvable_alias" +) + +func tempTableName(ctx context.Context, resource resourceTabler) (string, error) { + const op = "cache.tempTableName" + switch { + case util.IsNil(resource): + return "", errors.New(ctx, errors.InvalidParameter, op, "missing resource tabler") + } + baseTableName := strings.ToLower(resource.TableName()) + switch baseTableName { + case targetTblName, sessionTblName, resolvableAliasTblName: + default: + return "", errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unable to create a temp table for %s, it is not a supported base table for creating a temp table", baseTableName)) + } + return baseTableName + tmpTblSuffix, nil +} diff --git a/internal/clientcache/internal/db/schema.sql b/internal/clientcache/internal/db/schema.sql index 4b20cd910f..4b43cacc37 100644 --- a/internal/clientcache/internal/db/schema.sql +++ b/internal/clientcache/internal/db/schema.sql @@ -178,6 +178,9 @@ create table if not exists keyring_token ( -- target contains cached boundary target resource for a specific user and with -- specific fields extracted to facilitate searching over those fields +-- +-- any changes to this table must be reflected in the target_refresh_window +-- table as well. create table if not exists target ( -- the boundary user id of the user who has was able to read/list this target fk_user_id text not null @@ -202,6 +205,37 @@ create table if not exists target ( -- index for implicit scope search create index target_scope_id_ix on target(scope_id); +-- target_refresh_window contains targets being refreshed with a new refresh +-- token before the existing target refresh_token expires. This is used to +-- prevent having no cached targets when the current refresh token expires. +-- NOTE: this table is not used for searching and it's also not used for +-- refreshing the targets using the current unexpired refresh token. Targets +-- will be copied from this table to the target table when a refresh is +-- successfully completed and at that time the refresh token will be updated in +-- the refresh_token table. +-- +-- IMPORTANT: This table should be an exact copy of the target table. +create table if not exists target_refresh_window ( + -- the boundary user id of the user who has was able to read/list this target + fk_user_id text not null + references user(id) + on delete cascade, + -- the boundary id of the target + id text not null + check (length(id) > 0), + -- the following fields are used for searching and are set to the values + -- from the boundary resource + name text, + description text, + type text, + address text, + scope_id text, + -- item is the json representation of this resource from the perspective of + -- the the requesting user. + item text, + primary key (fk_user_id, id) +); + -- session contains cached boundary session resource for a specific user and -- with specific fields extracted to facilitate searching over those fields create table if not exists session (