From bbea18471b3577368acb711b003cdaf92db0915a Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Thu, 28 Mar 2024 14:03:52 +0100 Subject: [PATCH] fix: webhook transient payload in OIDC login flows (#3857) * fix: transient payload with OIDC login --- internal/client-go/go.sum | 1 + .../hook/hooktest/web_hook_test_server.go | 74 +++++++++++++++++++ selfservice/strategy/code/hook.jsonnet | 1 - .../strategy/code/strategy_recovery_test.go | 27 ++----- selfservice/strategy/oidc/strategy_login.go | 7 +- selfservice/strategy/oidc/strategy_test.go | 45 ++++++++++- 6 files changed, 127 insertions(+), 28 deletions(-) create mode 100644 selfservice/hook/hooktest/web_hook_test_server.go delete mode 100644 selfservice/strategy/code/hook.jsonnet diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/selfservice/hook/hooktest/web_hook_test_server.go b/selfservice/hook/hooktest/web_hook_test_server.go new file mode 100644 index 000000000000..6111b48540a7 --- /dev/null +++ b/selfservice/hook/hooktest/web_hook_test_server.go @@ -0,0 +1,74 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package hooktest + +import ( + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + + "github.com/ory/kratos/driver/config" + "github.com/ory/x/configx" + "github.com/ory/x/ioutilx" +) + +var jsonnet = base64.StdEncoding.EncodeToString([]byte("function(ctx) ctx")) + +type Server struct { + *httptest.Server + LastBody []byte +} + +// NewServer returns a new webhook server for testing. +func NewServer() *Server { + s := new(Server) + httptestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.LastBody = ioutilx.MustReadAll(r.Body) + w.WriteHeader(http.StatusOK) + })) + + s.Server = httptestServer + + return s +} + +// HookConfig returns the hook configuration for calling the webhook server. +func (s *Server) HookConfig() config.SelfServiceHook { + return config.SelfServiceHook{ + Name: "web_hook", + Config: []byte(fmt.Sprintf(` +{ + "method": "POST", + "url": "%s", + "body": "base64://%s" +}`, s.URL, jsonnet)), + } +} + +func (s *Server) AssertTransientPayload(t *testing.T, expected string) { + require.NotEmpty(t, s.LastBody) + actual := gjson.GetBytes(s.LastBody, "flow.transient_payload").String() + assert.JSONEq(t, expected, actual, "%+v", actual) +} + +// SetConfig adds the webhook to the list of hooks for the given key and restores +// the original configuration after the test. +func (s *Server) SetConfig(t *testing.T, conf *configx.Provider, key string) { + var newValue []config.SelfServiceHook + original := conf.Get(key) + if originalHooks, ok := original.([]config.SelfServiceHook); ok { + newValue = slices.Clone(originalHooks) + } + require.NoError(t, conf.Set(key, append(newValue, s.HookConfig()))) + t.Cleanup(func() { + _ = conf.Set(key, original) + }) +} diff --git a/selfservice/strategy/code/hook.jsonnet b/selfservice/strategy/code/hook.jsonnet deleted file mode 100644 index 54223dda2f32..000000000000 --- a/selfservice/strategy/code/hook.jsonnet +++ /dev/null @@ -1 +0,0 @@ -function(ctx) ctx \ No newline at end of file diff --git a/selfservice/strategy/code/strategy_recovery_test.go b/selfservice/strategy/code/strategy_recovery_test.go index a8bea043f91c..98f3d9c7f21c 100644 --- a/selfservice/strategy/code/strategy_recovery_test.go +++ b/selfservice/strategy/code/strategy_recovery_test.go @@ -30,6 +30,7 @@ import ( "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/recovery" + "github.com/ory/kratos/selfservice/hook/hooktest" "github.com/ory/kratos/selfservice/strategy/code" "github.com/ory/kratos/session" "github.com/ory/kratos/text" @@ -294,28 +295,12 @@ func TestRecovery(t *testing.T) { }) t.Run("description=should pass transient data to email template and webhooks", func(t *testing.T) { - var webhookReceivedTransientPayload string - webhookTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - webhookReceivedTransientPayload = gjson.GetBytes(ioutilx.MustReadAll(r.Body), "flow.transient_payload").String() - w.WriteHeader(http.StatusOK) - })) + webhookTS := hooktest.NewServer() t.Cleanup(webhookTS.Close) - conf.MustSet( - ctx, - "selfservice.flows.recovery.after.hooks", - []config.SelfServiceHook{{Name: "web_hook", Config: []byte( - fmt.Sprintf(`{ - "method":"POST", - "url": "%s", - "body":"file://./hook.jsonnet" -}`, webhookTS.URL), - )}}, - ) - - t.Cleanup(func() { - conf.MustSet(ctx, "selfservice.flows.recovery.after.hooks", nil) - }) + conf.MustSet(ctx, "selfservice.flows.recovery.after.hooks", []config.SelfServiceHook{webhookTS.HookConfig()}) + t.Cleanup(func() { conf.MustSet(ctx, "selfservice.flows.recovery.after.hooks", nil) }) + client := testhelpers.NewClientWithCookies(t) email := testhelpers.RandomEmail() createIdentityToRecover(t, reg, email) @@ -347,7 +332,7 @@ func TestRecovery(t *testing.T) { }))) require.NoError(t, err) - assert.JSONEq(t, webhookPayload, webhookReceivedTransientPayload, + assert.JSONEq(t, webhookPayload, gjson.GetBytes(webhookTS.LastBody, "flow.transient_payload").String(), "should pass transient payload to webhook") }) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 09bb8ab27e6e..42b948ec7c11 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -258,9 +258,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } if err := s.d.ContinuityManager().Pause(ctx, w, r, sessionName, continuity.WithPayload(&AuthCodeContainer{ - State: state.String(), - FlowID: f.ID.String(), - Traits: p.Traits, + State: state.String(), + FlowID: f.ID.String(), + Traits: p.Traits, + TransientPayload: f.TransientPayload, }), continuity.WithLifespan(time.Minute*30)); err != nil { return nil, s.handleError(w, r, f, pid, nil, err) diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 951f061995e3..74ea4e0726d6 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "github.com/ory/kratos/selfservice/hook/hooktest" "github.com/ory/x/sqlxx" "github.com/ory/kratos/hydra" @@ -457,38 +458,76 @@ func TestStrategy(t *testing.T) { } t.Run("case=register and then login", func(t *testing.T) { + postRegistrationWebhook := hooktest.NewServer() + t.Cleanup(postRegistrationWebhook.Close) + postLoginWebhook := hooktest.NewServer() + t.Cleanup(postLoginWebhook.Close) + + postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx), + config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String())) + postLoginWebhook.SetConfig(t, conf.GetProvider(ctx), + config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, config.HookGlobal)) + subject = "register-then-login@ory.sh" scope = []string{"openid", "offline"} t.Run("case=should pass registration", func(t *testing.T) { + transientPayload := `{"data": "registration"}` r := newBrowserRegistrationFlow(t, returnTS.URL, time.Minute) action := assertFormValues(t, r.ID, "valid") - res, body := makeRequest(t, "valid", action, url.Values{}) + res, body := makeRequest(t, "valid", action, url.Values{ + "transient_payload": {transientPayload}, + }) assertIdentity(t, res, body) expectTokens(t, "valid", body) assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) + + postRegistrationWebhook.AssertTransientPayload(t, transientPayload) }) t.Run("case=should pass login", func(t *testing.T) { + transientPayload := `{"data": "login"}` r := newBrowserLoginFlow(t, returnTS.URL, time.Minute) action := assertFormValues(t, r.ID, "valid") - res, body := makeRequest(t, "valid", action, url.Values{}) + res, body := makeRequest(t, "valid", action, url.Values{ + "transient_payload": {transientPayload}, + }) assertIdentity(t, res, body) expectTokens(t, "valid", body) assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) + + postLoginWebhook.AssertTransientPayload(t, transientPayload) }) }) t.Run("case=login without registered account", func(t *testing.T) { + postRegistrationWebhook := hooktest.NewServer() + t.Cleanup(postRegistrationWebhook.Close) + postLoginWebhook := hooktest.NewServer() + t.Cleanup(postLoginWebhook.Close) + + postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx), + config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String())) + postLoginWebhook.SetConfig(t, conf.GetProvider(ctx), + config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, config.HookGlobal)) + subject = "login-without-register@ory.sh" scope = []string{"openid"} t.Run("case=should pass login", func(t *testing.T) { + transientPayload := `{"data": "login to registration"}` + r := newBrowserLoginFlow(t, returnTS.URL, time.Minute) action := assertFormValues(t, r.ID, "valid") - res, body := makeRequest(t, "valid", action, url.Values{}) + res, body := makeRequest(t, "valid", action, url.Values{ + "transient_payload": {transientPayload}, + }) assertIdentity(t, res, body) assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) + + assert.Empty(t, postLoginWebhook.LastBody, + "post login hook should not have been called, because this was a registration flow") + postRegistrationWebhook.AssertTransientPayload(t, transientPayload) }) })