diff --git a/continuity/manager.go b/continuity/manager.go index 65989704faa0..779b0ae5f185 100644 --- a/continuity/manager.go +++ b/continuity/manager.go @@ -27,19 +27,18 @@ type Manager interface { } type managerOptions struct { - iid uuid.UUID - ttl time.Duration - payload json.RawMessage - payloadRaw interface{} - cleanUp bool + iid uuid.UUID + ttl time.Duration + setExpiresIn time.Duration + payload json.RawMessage + payloadRaw interface{} } type ManagerOption func(*managerOptions) error func newManagerOptions(opts []ManagerOption) (*managerOptions, error) { var o = &managerOptions{ - ttl: time.Minute, - cleanUp: true, + ttl: time.Minute * 10, } for _, opt := range opts { if err := opt(o); err != nil { @@ -49,13 +48,6 @@ func newManagerOptions(opts []ManagerOption) (*managerOptions, error) { return o, nil } -func DontCleanUp() ManagerOption { - return func(o *managerOptions) error { - o.cleanUp = false - return nil - } -} - func WithIdentity(i *identity.Identity) ManagerOption { return func(o *managerOptions) error { if i != nil { @@ -83,3 +75,10 @@ func WithPayload(payload interface{}) ManagerOption { return nil } } + +func WithExpireInsteadOfDelete(duration time.Duration) ManagerOption { + return func(o *managerOptions) error { + o.setExpiresIn = duration + return nil + } +} diff --git a/continuity/manager_cookie.go b/continuity/manager_cookie.go index 495800a87736..7d9b40632df5 100644 --- a/continuity/manager_cookie.go +++ b/continuity/manager_cookie.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "net/http" + "time" "github.com/gofrs/uuid" "github.com/pkg/errors" @@ -93,12 +94,22 @@ func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r * } } - if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil { - return nil, err - } + if o.setExpiresIn > 0 { + if err := m.d.ContinuityPersister().SetContinuitySessionExpiry( + ctx, + container.ID, + time.Now().UTC().Add(o.setExpiresIn).Truncate(time.Second), + ); err != nil && !errors.Is(err, sqlcon.ErrNoRows) { + return nil, err + } + } else { + if err := x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil { + return nil, err + } - if err := m.d.ContinuityPersister().DeleteContinuitySession(ctx, container.ID); err != nil && !errors.Is(err, sqlcon.ErrNoRows) { - return nil, err + if err := m.d.ContinuityPersister().DeleteContinuitySession(ctx, container.ID); err != nil && !errors.Is(err, sqlcon.ErrNoRows) { + return nil, err + } } return container, nil @@ -136,6 +147,9 @@ func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r return nil, errors.WithStack(ErrNotResumable.WithDebugf("Resumable ID from cookie could not be found in the datastore: %+v", err)) } else if err != nil { return nil, err + } else if container.ExpiresAt.Before(time.Now()) { + _ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name) + return nil, errors.WithStack(ErrNotResumable.WithDebugf("Resumable session has expired")) } return container, err diff --git a/continuity/manager_options_test.go b/continuity/manager_options_test.go index be2e9f73d4d0..8286f12e9e77 100644 --- a/continuity/manager_options_test.go +++ b/continuity/manager_options_test.go @@ -20,7 +20,7 @@ func TestManagerOptions(t *testing.T) { }{ { e: func(t *testing.T, actual *managerOptions) { - assert.EqualValues(t, time.Minute, actual.ttl) + assert.EqualValues(t, time.Minute*10, actual.ttl) }, }, { diff --git a/continuity/manager_test.go b/continuity/manager_test.go index 6790137faa80..8e71024d16cd 100644 --- a/continuity/manager_test.go +++ b/continuity/manager_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/ory/kratos/driver/config" @@ -181,6 +182,50 @@ func TestManager(t *testing.T) { assert.Contains(t, href, gjson.GetBytes(body, "name").String(), "%s", body) }) + t.Run("case=pause and use session with expiry", func(t *testing.T) { + cl := newClient() + + tc := &persisterTestCase{ + ro: []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{"bar"}), continuity.WithExpireInsteadOfDelete(time.Minute)}, + wo: []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{}), continuity.WithExpireInsteadOfDelete(time.Minute)}, + } + ts := newServer(t, p, tc) + genid := func() string { + return ts.URL + "/" + x.NewUUID().String() + } + + href := genid() + res, err := cl.Do(testhelpers.NewTestHTTPRequest(t, "PUT", href, nil)) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Equal(t, http.StatusNoContent, res.StatusCode) + + res, err = cl.Do(testhelpers.NewTestHTTPRequest(t, "GET", href, nil)) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Equal(t, http.StatusOK, res.StatusCode) + + res, err = cl.Do(testhelpers.NewTestHTTPRequest(t, "GET", href, nil)) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Equal(t, http.StatusOK, res.StatusCode) + + tc.ro = []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{"bar"}), continuity.WithExpireInsteadOfDelete(-time.Minute)} + tc.wo = []continuity.ManagerOption{continuity.WithPayload(&persisterTestPayload{""}), continuity.WithExpireInsteadOfDelete(-time.Minute)} + + res, err = cl.Do(testhelpers.NewTestHTTPRequest(t, "GET", href, nil)) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + require.Equal(t, http.StatusOK, res.StatusCode) + + res, err = cl.Do(testhelpers.NewTestHTTPRequest(t, "GET", href, nil)) + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, res.StatusCode) + body := ioutilx.MustReadAll(res.Body) + require.NoError(t, res.Body.Close()) + assert.Contains(t, gjson.GetBytes(body, "error.reason").String(), continuity.ErrNotResumable.ReasonField) + }) + for k, tc := range []persisterTestCase{ {}, { diff --git a/continuity/persistence.go b/continuity/persistence.go index 2499abe59e21..cc912731fa87 100644 --- a/continuity/persistence.go +++ b/continuity/persistence.go @@ -18,5 +18,6 @@ type Persister interface { SaveContinuitySession(ctx context.Context, c *Container) error GetContinuitySession(ctx context.Context, id uuid.UUID) (*Container, error) DeleteContinuitySession(ctx context.Context, id uuid.UUID) error + SetContinuitySessionExpiry(ctx context.Context, id uuid.UUID, expiresAt time.Time) error DeleteExpiredContinuitySessions(ctx context.Context, deleteOlder time.Time, pageSize int) error } diff --git a/continuity/test/persistence.go b/continuity/test/persistence.go index cd6f61188c22..752ccca4d542 100644 --- a/continuity/test/persistence.go +++ b/continuity/test/persistence.go @@ -101,6 +101,30 @@ func TestPersister(ctx context.Context, p interface { }) }) + t.Run("case=set expiry", func(t *testing.T) { + // Create a new continuity session + expected := createContainer(t) + require.NoError(t, p.SaveContinuitySession(ctx, &expected)) + + // Set the expiry of the continuity session + newExpiry := time.Now().Add(48 * time.Hour).UTC().Truncate(time.Second) + require.NoError(t, p.SetContinuitySessionExpiry(ctx, expected.ID, newExpiry)) + + // Retrieve the continuity session + actual, err := p.GetContinuitySession(ctx, expected.ID) + require.NoError(t, err) + + // Check if the expiry has been updated + assert.EqualValues(t, newExpiry, actual.ExpiresAt) + + t.Run("can not update on another network", func(t *testing.T) { + _, p := testhelpers.NewNetwork(t, ctx, p) + newExpiry := time.Now().Add(12 * time.Hour).UTC().Truncate(time.Second) + err := p.SetContinuitySessionExpiry(ctx, expected.ID, newExpiry) + require.ErrorIs(t, err, sqlcon.ErrNoRows) + }) + }) + t.Run("case=cleanup", func(t *testing.T) { id := x.NewUUID() yesterday := time.Now().Add(-24 * time.Hour).UTC().Truncate(time.Second) diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/persistence/sql/persister_continuity.go b/persistence/sql/persister_continuity.go index 73078784766c..ee7a4597e469 100644 --- a/persistence/sql/persister_continuity.go +++ b/persistence/sql/persister_continuity.go @@ -28,6 +28,23 @@ func (p *Persister) SaveContinuitySession(ctx context.Context, c *continuity.Con return sqlcon.HandleError(p.GetConnection(ctx).Create(c)) } +func (p *Persister) SetContinuitySessionExpiry(ctx context.Context, id uuid.UUID, expiresAt time.Time) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetContinuitySessionExpiry") + defer otelx.End(span, &err) + + if rows, err := p.GetConnection(ctx). + Where("id = ? AND nid = ?", id, p.NetworkID(ctx)). + UpdateQuery(&continuity.Container{ + ExpiresAt: expiresAt, + }, "expires_at"); err != nil { + return sqlcon.HandleError(err) + } else if rows == 0 { + return errors.WithStack(sqlcon.ErrNoRows) + } + + return nil +} + func (p *Persister) GetContinuitySession(ctx context.Context, id uuid.UUID) (_ *continuity.Container, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetContinuitySession") defer otelx.End(span, &err) diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 449bfece3878..6515d06367ee 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -14,6 +14,7 @@ import ( "net/url" "path/filepath" "strings" + "time" "golang.org/x/exp/maps" @@ -316,7 +317,10 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flo cntnr := AuthCodeContainer{} if f.GetType() == flow.TypeBrowser || !hasSessionTokenCode { - if _, err := s.d.ContinuityManager().Continue(r.Context(), w, r, sessionName, continuity.WithPayload(&cntnr)); err != nil { + if _, err := s.d.ContinuityManager().Continue(r.Context(), w, r, sessionName, + continuity.WithPayload(&cntnr), + continuity.WithExpireInsteadOfDelete(time.Minute), + ); err != nil { return nil, nil, err } if stateParam != cntnr.State { @@ -334,6 +338,7 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flo if errorParam != "" { return f, &cntnr, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to complete OpenID Connect flow because the OpenID Provider returned error "%s": %s`, r.URL.Query().Get("error"), r.URL.Query().Get("error_description"))) } + if codeParam == "" { return f, &cntnr, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to complete OpenID Connect flow because the OpenID Provider did not return the code query parameter.`)) } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index b02e4fc45853..65c8f09b2e06 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -18,6 +18,9 @@ import ( "testing" "time" + "github.com/davecgh/go-spew/spew" + "github.com/samber/lo" + "github.com/ory/kratos/selfservice/hook/hooktest" "github.com/ory/x/sqlxx" @@ -495,6 +498,105 @@ func TestStrategy(t *testing.T) { postLoginWebhook.AssertTransientPayload(t, transientPayload) }) + + t.Run("case=should pass double submit", func(t *testing.T) { + // This test checks that the continuity manager uses a grace period to handle potential double-submit issues. + // + // It addresses issues where Facebook and Apple consent screens on mobile behave in a way that makes it + // easy for users to experience double-submit issues. + j, err := cookiejar.New(nil) + require.NoError(t, err) + + makeInitialRequest := func(t *testing.T, provider, action string, fv url.Values) (*http.Response, []byte, []string) { + fv.Set("provider", provider) + + var lastVia []*http.Request + hc := &http.Client{ + Jar: j, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + lastVia = via + return nil + }, + } + res, err := hc.PostForm(action, fv) + require.NoError(t, err, action) + + body, err := io.ReadAll(res.Body) + require.NoError(t, res.Body.Close()) + require.NoError(t, err) + require.NotEmpty(t, lastVia) + + vias := make([]string, len(lastVia)) + for k, v := range lastVia { + vias[k] = v.URL.String() + } + + return res, body, vias + } + + r := newBrowserLoginFlow(t, returnTS.URL, time.Minute) + action := assertFormValues(t, r.ID, "valid") + + // First login + res, body, via := makeInitialRequest(t, "valid", action, url.Values{}) + assertIdentity(t, res, body) + expectTokens(t, "valid", body) + assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) + + // We fetch the URL which includes the `?code` query parameter. + result := lo.Filter(via, func(s string, _ int) bool { + return strings.Contains(s, "code=") + }) + require.Len(t, result, 1) + + // And call that URL again. What's interesting here is that the whole requets passes because we are already authenticated. + // + // In this scenario, Ory Kratos correctly forwards the user to the return URL, which in our case returns the identity. + // + // We essentially run into this bit: + // + // if authenticated, err := s.alreadyAuthenticated(w, r, req); err != nil { + // s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) + // } else if authenticated { + // return <-- we end up here on the second call + // } + res, err = (&http.Client{Jar: j}).Get(result[0]) + require.NoError(t, err) + body, err = io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + + assertIdentity(t, res, body) + expectTokens(t, "valid", body) + assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) + + // Trying this flow again without the Ory Session cookie will fail as we run into code reuse: + cookies := j.Cookies(urlx.ParseOrPanic(ts.URL)) + t.Logf("Cookies: %s", spew.Sdump(cookies)) + + secondJar, err := cookiejar.New(nil) + require.NoError(t, err) + + secondJar.SetCookies(urlx.ParseOrPanic(ts.URL), lo.Filter(cookies, func(item *http.Cookie, index int) bool { + return item.Name != "ory_kratos_session" + })) + + cookies = secondJar.Cookies(urlx.ParseOrPanic(ts.URL)) + t.Logf("Cookies after: %s", spew.Sdump(cookies)) + + // Doing the request but this time without the Ory Session Cookie. This may be the case in scenarios where we run into race conditions + // where the server sent a response but the client did not process it. + res, err = (&http.Client{Jar: secondJar}).Get(result[0]) + require.NoError(t, err) + body, err = io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + + // The reason for `invalid_client` here is that the code was already used and the session was already authenticated. The invalid_client + // happens because of the way Golang's OAuth2 library is trying out different auth methods when a token request fails, which obfuscates + // the underlying error. + assert.Contains(t, string(body), "invalid_client", "%s", body) + }) }) t.Run("case=login without registered account", func(t *testing.T) { diff --git a/x/cookie.go b/x/cookie.go index 4172d2878cf5..897401183c0e 100644 --- a/x/cookie.go +++ b/x/cookie.go @@ -5,6 +5,7 @@ package x import ( "net/http" + "time" "github.com/gorilla/sessions" "github.com/pkg/errors" @@ -71,6 +72,17 @@ func SessionUnset(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, return errors.WithStack(cookie.Save(r, w)) } +func SessionSetExpiresIn(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, id string, expiresIn time.Duration) error { + cookie, err := s.Get(r, id) + if err == nil && cookie.IsNew { + // No cookie was sent in the request. We have nothing to do. + return nil + } + + cookie.Options.MaxAge = int(expiresIn.Seconds()) + return errors.WithStack(cookie.Save(r, w)) +} + func SessionUnsetKey(w http.ResponseWriter, r *http.Request, s sessions.StoreExact, id, key string) error { cookie, err := s.Get(r, id) if err == nil && cookie.IsNew {