Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: account linking with 2FA #4188

Merged
merged 4 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions selfservice/flow/duplicate_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package flow
import (
"encoding/json"

"github.com/gofrs/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

Expand All @@ -20,7 +19,6 @@ type DuplicateCredentialsData struct {
CredentialsType identity.CredentialsType
CredentialsConfig sqlxx.JSONRawMessage
DuplicateIdentifier string
OrganizationID uuid.UUID
}

type InternalContexter interface {
Expand Down
7 changes: 2 additions & 5 deletions selfservice/flow/login/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ import (

"github.com/tidwall/gjson"

"github.com/ory/x/jsonx"
"github.com/ory/x/sqlxx"
"github.com/ory/x/uuidx"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/x/jsonx"
"github.com/ory/x/sqlxx"

"github.com/ory/kratos/internal"

Expand Down Expand Up @@ -225,7 +223,6 @@ func TestDuplicateCredentials(t *testing.T) {
CredentialsType: "foo",
CredentialsConfig: sqlxx.JSONRawMessage(`{"bar":"baz"}`),
DuplicateIdentifier: "bar",
OrganizationID: uuidx.NewV4(),
}

require.NoError(t, flow.SetDuplicateCredentials(f, dc))
Expand Down
20 changes: 19 additions & 1 deletion selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ func WithFlowReturnTo(returnTo string) FlowOption {
}
}

func WithOrganizationID(organizationID uuid.NullUUID) FlowOption {
return func(f *Flow) {
f.OrganizationID = organizationID
}
}

func WithRequestedAAL(aal identity.AuthenticatorAssuranceLevel) FlowOption {
return func(f *Flow) {
f.RequestedAAL = aal
}
}

func WithInternalContext(internalContext []byte) FlowOption {
return func(f *Flow) {
f.InternalContext = internalContext
Expand Down Expand Up @@ -217,7 +229,13 @@ preLoginHook:

if orgID.Valid {
f.OrganizationID = orgID
strategyFilters = []StrategyFilter{func(s Strategy) bool { return s.ID() == identity.CredentialsTypeOIDC }}
if f.RequestedAAL == identity.AuthenticatorAssuranceLevel1 {
// We only apply the filter on AAL1, because the OIDC strategy can only satsify
// AAL1.
strategyFilters = []StrategyFilter{func(s Strategy) bool {
return s.ID() == identity.CredentialsTypeOIDC
}}
}
}

for _, s := range h.d.LoginStrategies(r.Context(), strategyFilters...) {
Expand Down
55 changes: 49 additions & 6 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"time"

"github.com/gofrs/uuid"
Expand Down Expand Up @@ -55,6 +56,7 @@ type (
x.LoggingProvider
x.TracingProvider
sessiontokenexchange.PersistenceProvider
HandlerProvider

FlowPersistenceProvider
HooksProvider
Expand Down Expand Up @@ -273,8 +275,28 @@ func (e *HookExecutor) PostLoginHook(
// If we detect that whoami would require a higher AAL, we redirect!
if err := e.checkAAL(ctx, classified, f); err != nil {
if aalErr := new(session.ErrAALNotSatisfied); errors.As(err, &aalErr) {
span.SetAttributes(attribute.String("return_to", aalErr.RedirectTo), attribute.String("redirect_reason", "requires aal2"))
e.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(aalErr.RedirectTo))
if data, _ := flow.DuplicateCredentials(f); data == nil {
span.SetAttributes(attribute.String("return_to", aalErr.RedirectTo), attribute.String("redirect_reason", "requires aal2"))
e.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(aalErr.RedirectTo))
return nil
}

// Special case: If we are in a flow that wants to link credentials, we create a
// new login flow here that asks for the require AAL, but also copies over the
// internal context and the organization ID.
r.URL, err = url.Parse(aalErr.RedirectTo)
if err != nil {
return errors.WithStack(err)
}
newFlow, _, err := e.d.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser,
WithInternalContext(f.InternalContext),
WithOrganizationID(f.OrganizationID),
)
if err != nil {
return errors.WithStack(err)
}

x.AcceptToRedirectOrJSON(w, r, e.d.Writer(), newFlow, newFlow.AppendTo(e.d.Config().SelfServiceFlowLoginUI(ctx)).String())
return nil
}
return err
Expand Down Expand Up @@ -309,7 +331,27 @@ func (e *HookExecutor) PostLoginHook(
// If we detect that whoami would require a higher AAL, we redirect!
if err := e.checkAAL(ctx, classified, f); err != nil {
if aalErr := new(session.ErrAALNotSatisfied); errors.As(err, &aalErr) {
http.Redirect(w, r, aalErr.RedirectTo, http.StatusSeeOther)
if data, _ := flow.DuplicateCredentials(f); data == nil {
http.Redirect(w, r, aalErr.RedirectTo, http.StatusSeeOther)
return nil
}

// Special case: If we are in a flow that wants to link credentials, we create a
// new login flow here that asks for the require AAL, but also copies over the
// internal context and the organization ID.
r.URL, err = url.Parse(aalErr.RedirectTo)
if err != nil {
return errors.WithStack(err)
}
newFlow, _, err := e.d.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser,
WithInternalContext(f.InternalContext),
WithOrganizationID(f.OrganizationID),
)
if err != nil {
return errors.WithStack(err)
}

x.AcceptToRedirectOrJSON(w, r, e.d.Writer(), newFlow, newFlow.AppendTo(e.d.Config().SelfServiceFlowLoginUI(ctx)).String())
return nil
}
return errors.WithStack(err)
Expand Down Expand Up @@ -362,7 +404,7 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return nil
}

if err := e.checkDuplicateCredentialsIdentifierMatch(ctx, ident.ID, lc.DuplicateIdentifier); err != nil {
if err = e.checkDuplicateCredentialsIdentifierMatch(ctx, ident.ID, lc.DuplicateIdentifier); err != nil {
return err
}
strategy, err := e.d.AllLoginStrategies().Strategy(lc.CredentialsType)
Expand All @@ -380,8 +422,9 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return err
}

method := strategy.CompletedAuthenticationMethod(ctx)
sess.CompletedLoginForMethod(method)
if err = linkableStrategy.CompletedLogin(sess, lc); err != nil {
return err
}

return nil
}
Expand Down
48 changes: 39 additions & 9 deletions selfservice/flow/login/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"
"time"

"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -27,6 +28,7 @@ import (
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/session"
"github.com/ory/kratos/ui/node"
"github.com/ory/kratos/x"
)

Expand All @@ -42,6 +44,7 @@ func TestLoginExecutor(t *testing.T) {
reg.WithHydra(hydra.NewFake())
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/login.schema.json")
conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh/")
_ = testhelpers.NewLoginUIFlowEchoServer(t, reg)

newServer := func(t *testing.T, ft flow.Type, useIdentity *identity.Identity, flowCallback ...func(*login.Flow)) *httptest.Server {
router := httprouter.New()
Expand Down Expand Up @@ -222,7 +225,6 @@ func TestLoginExecutor(t *testing.T) {

t.Run("case=work normally if AAL is satisfied", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, "aal1")
_ = testhelpers.NewLoginUIFlowEchoServer(t, reg)
t.Cleanup(testhelpers.SelfServiceHookConfigReset(t, conf))

useIdentity := &identity.Identity{Credentials: map[identity.CredentialsType]identity.Credentials{
Expand Down Expand Up @@ -255,7 +257,6 @@ func TestLoginExecutor(t *testing.T) {

t.Run("case=redirect to login if AAL is too low", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, "highest_available")
_ = testhelpers.NewLoginUIFlowEchoServer(t, reg)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, "aal1")
})
Expand Down Expand Up @@ -320,6 +321,7 @@ func TestLoginExecutor(t *testing.T) {
t.Run("case=maybe links credential", func(t *testing.T) {
t.Cleanup(testhelpers.SelfServiceHookConfigReset(t, conf))
conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, config.HighestAvailableAAL)
conf.MustSet(ctx, "selfservice.methods.totp.enabled", true)

email1, email2 := testhelpers.RandomEmail(), testhelpers.RandomEmail()
passwordOnlyIdentity := &identity.Identity{Credentials: map[identity.CredentialsType]identity.Credentials{
Expand Down Expand Up @@ -360,15 +362,43 @@ func TestLoginExecutor(t *testing.T) {
require.NoError(t, err)

t.Run("sub-case=does not link after first factor when second factor is available", func(t *testing.T) {
duplicateCredentialsData := flow.DuplicateCredentialsData{
CredentialsType: identity.CredentialsTypeOIDC,
CredentialsConfig: credsOIDC2FA.Config,
DuplicateIdentifier: email2,
}
ts := newServer(t, flow.TypeBrowser, twoFAIdentitiy, func(l *login.Flow) {
require.NoError(t, flow.SetDuplicateCredentials(l, flow.DuplicateCredentialsData{
CredentialsType: identity.CredentialsTypeOIDC,
CredentialsConfig: credsOIDC2FA.Config,
DuplicateIdentifier: email2,
}))
require.NoError(t, flow.SetDuplicateCredentials(l, duplicateCredentialsData))
})
res, body := makeRequestPost(t, ts, false, url.Values{})
assert.Equal(t, res.Request.URL.String(), ts.URL+login.RouteInitBrowserFlow+"?aal=aal2", "%s", body)
res, _ := makeRequestPost(t, ts, false, url.Values{})

assert.Equal(t, reg.Config().SelfServiceFlowLoginUI(ctx).Host, res.Request.URL.Host)
assert.Equal(t, reg.Config().SelfServiceFlowLoginUI(ctx).Path, res.Request.URL.Path)
newFlowID := res.Request.URL.Query().Get("flow")
assert.NotEmpty(t, newFlowID)

newFlow, err := reg.LoginFlowPersister().GetLoginFlow(ctx, uuid.Must(uuid.FromString(newFlowID)))
require.NoError(t, err)
newFlowDuplicateCredentialsData, err := flow.DuplicateCredentials(newFlow)
require.NoError(t, err)

// Duplicate credentials data should have been copied over
assert.Equal(t, duplicateCredentialsData.CredentialsType, newFlowDuplicateCredentialsData.CredentialsType)
assert.Equal(t, duplicateCredentialsData.DuplicateIdentifier, newFlowDuplicateCredentialsData.DuplicateIdentifier)
assert.JSONEq(t, string(duplicateCredentialsData.CredentialsConfig), string(newFlowDuplicateCredentialsData.CredentialsConfig))

// AAL should be AAL2
assert.Equal(t, identity.AuthenticatorAssuranceLevel2, newFlow.RequestedAAL)

// TOTP nodes should be present
found := false
for _, n := range newFlow.UI.Nodes {
if n.Group == node.TOTPGroup {
found = true
break
}
}
assert.True(t, found, "could not find TOTP nodes in %+v", newFlow.UI.Nodes)

ident, err := reg.Persister().GetIdentity(ctx, twoFAIdentitiy.ID, identity.ExpandCredentials)
require.NoError(t, err)
Expand Down
3 changes: 3 additions & 0 deletions selfservice/flow/login/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"

"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/session"
"github.com/ory/kratos/ui/node"
"github.com/ory/kratos/x"
Expand All @@ -28,6 +29,8 @@ type Strategies []Strategy

type LinkableStrategy interface {
Link(ctx context.Context, i *identity.Identity, credentials sqlxx.JSONRawMessage) error
CompletedLogin(sess *session.Session, data *flow.DuplicateCredentialsData) error
SetDuplicateCredentials(f flow.InternalContexter, duplicateIdentifier string, credentials identity.Credentials, provider string) error
}

func (s Strategies) Strategy(id identity.CredentialsType) (Strategy, error) {
Expand Down
17 changes: 7 additions & 10 deletions selfservice/flow/registration/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,18 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
return err
}

if _, ok := strategy.(login.LinkableStrategy); ok {
if strategy, ok := strategy.(login.LinkableStrategy); ok {
duplicateIdentifier, err := e.getDuplicateIdentifier(ctx, i)
if err != nil {
return err
}
registrationDuplicateCredentials := flow.DuplicateCredentialsData{
CredentialsType: ct,
CredentialsConfig: i.Credentials[ct].Config,
DuplicateIdentifier: duplicateIdentifier,
}
if registrationFlow.OrganizationID.Valid {
registrationDuplicateCredentials.OrganizationID = registrationFlow.OrganizationID.UUID
}

if err := flow.SetDuplicateCredentials(registrationFlow, registrationDuplicateCredentials); err != nil {
if err := strategy.SetDuplicateCredentials(
registrationFlow,
duplicateIdentifier,
i.Credentials[ct],
provider,
); err != nil {
return err
}
}
Expand Down
23 changes: 14 additions & 9 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,21 @@ func (s *Strategy) processLogin(ctx context.Context, w http.ResponseWriter, r *h

sess := session.NewInactiveSession()
sess.CompletedLoginForWithProvider(s.ID(), identity.AuthenticatorAssuranceLevel1, provider.Config().ID, provider.Config().OrganizationID)
for _, c := range oidcCredentials.Providers {
if c.Subject == claims.Subject && c.Provider == provider.Config().ID {
if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID); err != nil {
return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err)
}
return nil, nil
}
}

return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject)))
hperl marked this conversation as resolved.
Show resolved Hide resolved
if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID); err != nil {
return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err)
}
return nil, nil

//for _, c := range oidcCredVentials.Providers {
// if c.Subject == claims.Subject && c.Provider == provider.Config().ID {
// if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID); err != nil {
// return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err)
// }
// return nil, nil
// }
//}
//return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject)))
}

func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ *session.Session) (i *identity.Identity, err error) {
Expand Down
Loading
Loading