Skip to content

Commit

Permalink
fix: webhook transient payload in OIDC login flows (ory#3857)
Browse files Browse the repository at this point in the history
* fix: transient payload with OIDC login
  • Loading branch information
hperl authored and mpauly-exnaton committed Feb 12, 2025
1 parent 421ee23 commit bbea184
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 28 deletions.
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
74 changes: 74 additions & 0 deletions selfservice/hook/hooktest/web_hook_test_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package hooktest

import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"slices"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/ory/kratos/driver/config"
"github.com/ory/x/configx"
"github.com/ory/x/ioutilx"
)

var jsonnet = base64.StdEncoding.EncodeToString([]byte("function(ctx) ctx"))

type Server struct {
*httptest.Server
LastBody []byte
}

// NewServer returns a new webhook server for testing.
func NewServer() *Server {
s := new(Server)
httptestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.LastBody = ioutilx.MustReadAll(r.Body)
w.WriteHeader(http.StatusOK)
}))

s.Server = httptestServer

return s
}

// HookConfig returns the hook configuration for calling the webhook server.
func (s *Server) HookConfig() config.SelfServiceHook {
return config.SelfServiceHook{
Name: "web_hook",
Config: []byte(fmt.Sprintf(`
{
"method": "POST",
"url": "%s",
"body": "base64://%s"
}`, s.URL, jsonnet)),
}
}

func (s *Server) AssertTransientPayload(t *testing.T, expected string) {
require.NotEmpty(t, s.LastBody)
actual := gjson.GetBytes(s.LastBody, "flow.transient_payload").String()
assert.JSONEq(t, expected, actual, "%+v", actual)
}

// SetConfig adds the webhook to the list of hooks for the given key and restores
// the original configuration after the test.
func (s *Server) SetConfig(t *testing.T, conf *configx.Provider, key string) {
var newValue []config.SelfServiceHook
original := conf.Get(key)
if originalHooks, ok := original.([]config.SelfServiceHook); ok {
newValue = slices.Clone(originalHooks)
}
require.NoError(t, conf.Set(key, append(newValue, s.HookConfig())))
t.Cleanup(func() {
_ = conf.Set(key, original)
})
}
1 change: 0 additions & 1 deletion selfservice/strategy/code/hook.jsonnet

This file was deleted.

27 changes: 6 additions & 21 deletions selfservice/strategy/code/strategy_recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/ory/kratos/internal/testhelpers"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/recovery"
"github.com/ory/kratos/selfservice/hook/hooktest"
"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/kratos/session"
"github.com/ory/kratos/text"
Expand Down Expand Up @@ -294,28 +295,12 @@ func TestRecovery(t *testing.T) {
})

t.Run("description=should pass transient data to email template and webhooks", func(t *testing.T) {
var webhookReceivedTransientPayload string
webhookTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
webhookReceivedTransientPayload = gjson.GetBytes(ioutilx.MustReadAll(r.Body), "flow.transient_payload").String()
w.WriteHeader(http.StatusOK)
}))
webhookTS := hooktest.NewServer()
t.Cleanup(webhookTS.Close)

conf.MustSet(
ctx,
"selfservice.flows.recovery.after.hooks",
[]config.SelfServiceHook{{Name: "web_hook", Config: []byte(
fmt.Sprintf(`{
"method":"POST",
"url": "%s",
"body":"file://./hook.jsonnet"
}`, webhookTS.URL),
)}},
)

t.Cleanup(func() {
conf.MustSet(ctx, "selfservice.flows.recovery.after.hooks", nil)
})
conf.MustSet(ctx, "selfservice.flows.recovery.after.hooks", []config.SelfServiceHook{webhookTS.HookConfig()})
t.Cleanup(func() { conf.MustSet(ctx, "selfservice.flows.recovery.after.hooks", nil) })

client := testhelpers.NewClientWithCookies(t)
email := testhelpers.RandomEmail()
createIdentityToRecover(t, reg, email)
Expand Down Expand Up @@ -347,7 +332,7 @@ func TestRecovery(t *testing.T) {
})))
require.NoError(t, err)

assert.JSONEq(t, webhookPayload, webhookReceivedTransientPayload,
assert.JSONEq(t, webhookPayload, gjson.GetBytes(webhookTS.LastBody, "flow.transient_payload").String(),
"should pass transient payload to webhook")
})

Expand Down
7 changes: 4 additions & 3 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,10 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}
if err := s.d.ContinuityManager().Pause(ctx, w, r, sessionName,
continuity.WithPayload(&AuthCodeContainer{
State: state.String(),
FlowID: f.ID.String(),
Traits: p.Traits,
State: state.String(),
FlowID: f.ID.String(),
Traits: p.Traits,
TransientPayload: f.TransientPayload,
}),
continuity.WithLifespan(time.Minute*30)); err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
Expand Down
45 changes: 42 additions & 3 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"testing"
"time"

"github.com/ory/kratos/selfservice/hook/hooktest"
"github.com/ory/x/sqlxx"

"github.com/ory/kratos/hydra"
Expand Down Expand Up @@ -457,38 +458,76 @@ func TestStrategy(t *testing.T) {
}

t.Run("case=register and then login", func(t *testing.T) {
postRegistrationWebhook := hooktest.NewServer()
t.Cleanup(postRegistrationWebhook.Close)
postLoginWebhook := hooktest.NewServer()
t.Cleanup(postLoginWebhook.Close)

postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx),
config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String()))
postLoginWebhook.SetConfig(t, conf.GetProvider(ctx),
config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, config.HookGlobal))

subject = "[email protected]"
scope = []string{"openid", "offline"}

t.Run("case=should pass registration", func(t *testing.T) {
transientPayload := `{"data": "registration"}`
r := newBrowserRegistrationFlow(t, returnTS.URL, time.Minute)
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
res, body := makeRequest(t, "valid", action, url.Values{
"transient_payload": {transientPayload},
})
assertIdentity(t, res, body)
expectTokens(t, "valid", body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body)

postRegistrationWebhook.AssertTransientPayload(t, transientPayload)
})

t.Run("case=should pass login", func(t *testing.T) {
transientPayload := `{"data": "login"}`
r := newBrowserLoginFlow(t, returnTS.URL, time.Minute)
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
res, body := makeRequest(t, "valid", action, url.Values{
"transient_payload": {transientPayload},
})
assertIdentity(t, res, body)
expectTokens(t, "valid", body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body)

postLoginWebhook.AssertTransientPayload(t, transientPayload)
})
})

t.Run("case=login without registered account", func(t *testing.T) {
postRegistrationWebhook := hooktest.NewServer()
t.Cleanup(postRegistrationWebhook.Close)
postLoginWebhook := hooktest.NewServer()
t.Cleanup(postLoginWebhook.Close)

postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx),
config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String()))
postLoginWebhook.SetConfig(t, conf.GetProvider(ctx),
config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, config.HookGlobal))

subject = "[email protected]"
scope = []string{"openid"}

t.Run("case=should pass login", func(t *testing.T) {
transientPayload := `{"data": "login to registration"}`

r := newBrowserLoginFlow(t, returnTS.URL, time.Minute)
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
res, body := makeRequest(t, "valid", action, url.Values{
"transient_payload": {transientPayload},
})
assertIdentity(t, res, body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body)

assert.Empty(t, postLoginWebhook.LastBody,
"post login hook should not have been called, because this was a registration flow")
postRegistrationWebhook.AssertTransientPayload(t, transientPayload)
})
})

Expand Down

0 comments on commit bbea184

Please sign in to comment.