Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt committed Dec 19, 2024
1 parent 8f73ffd commit 7b4029e
Show file tree
Hide file tree
Showing 5 changed files with 454 additions and 2 deletions.
15 changes: 15 additions & 0 deletions internal/clientcache/internal/cache/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,21 @@ func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error {
if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))
}
{
// now, we need to refreshTargetsDuringRefreshWindow, so we don't
// have to wait for the refresh token to expire before refreshing
// the targets.

// notice that we still have the "targets" cache key locked,
// since we don't want to run this in parallel with the regular
// targets refresh.
semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool))
if semaphore.(*atomic.Bool).CompareAndSwap(false, true) {
if err := r.repo.refreshTargetsDuringRefreshWindow(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))
}
}
}
semaphore.(*atomic.Bool).Store(false)
}

Expand Down
188 changes: 186 additions & 2 deletions internal/clientcache/internal/cache/repository_targets.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/json"
stderrors "errors"
"fmt"
"time"

"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/api/targets"
Expand Down Expand Up @@ -278,8 +279,182 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma
return nil
}

// refreshTargetsDuringRefreshWindow refreshes the targets for the provided user
// using the provided auth tokens. This function is intended to be used during
// the refresh window for the user's current target refresh token. It will store
// the targets in the target_refresh_window table and then swap the resources in
// the target_refresh_window for the user's resources in the target table.
//
// IMPORTANT: the caller is responsible for ensuring that there's no inflight
// refresh into the target_refresh_window. See
// cache.RefreshService.syncSemaphore for how to ensure there's no inflight
// refresh into the target_refresh_window table
//
// if the existing refresh token is not within the refresh window, this function
// will return without doing anything.
//
// if controller list target doesn't support refresh token, we will not refresh
// and we will NOT return an error
func (r *Repository) refreshTargetsDuringRefreshWindow(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error {
const (
refreshWindowLookBackInDays = -10
op = "cache.(Repository).refreshTargetsIntoTmpTable"
emptyRefreshToken = ""
deleteRemovedTargets = "delete from %s where id in @ids"
)
switch {
case util.IsNil(u):
return errors.New(ctx, errors.InvalidParameter, op, "user is nil")
case u.Id == "":
return errors.New(ctx, errors.InvalidParameter, op, "user id is missing")
}

existingRefreshToken, err := r.lookupRefreshToken(ctx, u, targetResourceType)
if err != nil {
return errors.Wrap(ctx, err, op)
}
// check if existing refresh token will expire within the refresh window
// look back (next 10 days) and if not then just return
if existingRefreshToken != nil {
fmt.Println("*********************")
fmt.Println("existing...: ", existingRefreshToken.CreateTime)
fmt.Println("10 days ago: ", time.Now().AddDate(0, 0, refreshWindowLookBackInDays))
fmt.Println("expr.......: ", existingRefreshToken.CreateTime.After(time.Now().AddDate(0, 0, refreshWindowLookBackInDays)))
}
if existingRefreshToken != nil && !existingRefreshToken.CreateTime.After(time.Now().AddDate(0, 0, refreshWindowLookBackInDays)) {
return nil
}

opts, err := getOpts(opt...)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if opts.withTargetRetrievalFunc == nil {
opts.withTargetRetrievalFunc = defaultTargetFunc
}

var gotResponse bool
var currentPage *targets.TargetListResult
var newRefreshToken RefreshTokenValue
var foundAuthToken string
var retErr error
for at, t := range tokens {
// we want to start from the beginning of the list using an empty refresh token
currentPage, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, emptyRefreshToken, currentPage)
if err != nil {
if err == ErrRefreshNotSupported {
return nil // this particular controller doesn't support refresh tokens and it's not an error, so just return
} else {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}
}
foundAuthToken = t
gotResponse = true
break
}
// if we got an error, then we need to save it so it can be retrieved later
// via the status API by users
if retErr != nil {
if saveErr := r.saveError(r.serverCtx, u, targetResourceType, retErr); saveErr != nil {
return stderrors.Join(err, errors.Wrap(ctx, saveErr, op))
}
}
if !gotResponse {
return retErr
}

tmpTblName, err := tempTableName(ctx, &Target{})
if err != nil {
return errors.Wrap(ctx, err, op)
}

clearTmpTableFn := func(w db.Writer) error {
if _, err := w.Exec(ctx, fmt.Sprintf("delete from %s where fk_user_id = @fk_user_id", tmpTblName),
[]any{sql.Named("fk_user_id", u.Id)}); err != nil {
return err
}
return nil
}

// first, clear out the temporary table
if _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error {
return clearTmpTableFn(w)
}); err != nil {
return errors.Wrap(ctx, err, op)
}

// okay we have the data, let's store it in a temporary table
var numDeleted int
var numUpserted int
for {
if _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error {
switch {
case newRefreshToken != "":
if err := upsertTargets(ctx, w, u, currentPage.Items, WithTable(tmpTblName)); err != nil {
return err
}
numUpserted += len(currentPage.Items)
// we need to store the refresh token for the next page of
// targets, it's okay to do this since we are the only inflight
// refresh. Note the caller must ensure there's no inflight
// refresh! See the comment at the top of this function.
if err := upsertRefreshToken(ctx, w, u, targetResourceType, newRefreshToken); err != nil {
return err
}
default:
// controller supports caching, but doesn't have any resources
}
if len(currentPage.RemovedIds) > 0 {
if numDeleted, err = w.Exec(ctx, fmt.Sprintf(deleteRemovedTargets, tmpTblName),
[]any{sql.Named("ids", currentPage.RemovedIds)}); err != nil {
return err
}
}
return nil // we're done with this DoTx(...) containing a "page" of targets
}); err != nil {
return errors.Wrap(ctx, err, op)
}
if currentPage.ResponseType == "" || currentPage.ResponseType == "complete" {
break
}
currentPage, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, foundAuthToken, newRefreshToken, currentPage)
if err != nil {
break
}
}

// now that we have the data in the temporary table, let's swap it with the
// user's target table
_, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error {
// first, delete the existing targets which were cached for the user
if _, err := w.Exec(ctx, "delete from target where fk_user_id = @fk_user_id",
[]any{sql.Named("fk_user_id", u.Id)}); err != nil {
return err
}
// now, insert the targets from the temporary table into the user's
// target table
var numMoved int
if numMoved, err = w.Exec(ctx, fmt.Sprintf("insert into target select * from %s where fk_user_id = @fk_user_id", tmpTblName),
[]any{sql.Named("fk_user_id", u.Id)}); err != nil {
return err
}
if numMoved != numUpserted {
return errors.New(ctx, errors.Internal, op, fmt.Sprintf("number of targets moved %d doesn't match number of targets upserted %d", numMoved, numUpserted))
}
// finally, delete the rows from the temporary table
return clearTmpTableFn(w)
})
if err != nil {
return errors.Wrap(ctx, err, op)
}

event.WriteSysEvent(ctx, op, "targets updated", "deleted", numDeleted, "upserted", numUpserted, "user_id", u.Id)
return nil
}

// upsertTargets upserts the provided targets to be stored for the provided user.
func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Target) error {
func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Target, opt ...Option) error {
const op = "cache.upsertTargets"
switch {
case util.IsNil(w):
Expand All @@ -290,6 +465,11 @@ func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Targ
return errors.New(ctx, errors.InvalidParameter, op, "user is nil")
}

opts, err := getOpts(opt...)
if err != nil {
return errors.Wrap(ctx, err, op)
}

for _, t := range in {
item, err := json.Marshal(t)
if err != nil {
Expand All @@ -309,7 +489,11 @@ func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Targ
Target: db.Columns{"fk_user_id", "id"},
Action: db.SetColumns([]string{"name", "description", "address", "scope_id", "type", "item"}),
}
if err := w.Create(ctx, newTarget, db.WithOnConflict(&onConflict)); err != nil {
dbOpts := []db.Option{db.WithOnConflict(&onConflict)}
if opts.withTable != "" {
dbOpts = append(dbOpts, db.WithTable(opts.withTable))
}
if err := w.Create(ctx, newTarget, dbOpts...); err != nil {
return errors.Wrap(ctx, err, op)
}
}
Expand Down
Loading

0 comments on commit 7b4029e

Please sign in to comment.