diff --git a/.docker-home/.gitignore b/.docker-home/.gitignore
deleted file mode 100644
index c96a04f008e..00000000000
--- a/.docker-home/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-*
-!.gitignore
\ No newline at end of file
diff --git a/.docker/Dockerfile-build b/.docker/Dockerfile-build
index a901a2bf83d..b7e64d5b206 100644
--- a/.docker/Dockerfile-build
+++ b/.docker/Dockerfile-build
@@ -1,4 +1,4 @@
-FROM golang:1.19-alpine3.18 AS builder
+FROM golang:1.20-alpine3.17 AS builder
RUN apk -U --no-cache --upgrade --latest add build-base git gcc bash
diff --git a/.docker/Dockerfile-hsm b/.docker/Dockerfile-hsm
index 2aa7e4bbad9..eac885dc797 100644
--- a/.docker/Dockerfile-hsm
+++ b/.docker/Dockerfile-hsm
@@ -1,4 +1,4 @@
-FROM golang:1.19-alpine3.18 AS builder
+FROM golang:1.20-alpine3.18 AS builder
RUN apk -U --no-cache --upgrade --latest add build-base git gcc bash
diff --git a/.dockerignore b/.dockerignore
index 4d913fbbc91..cf7558fc017 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -3,7 +3,6 @@
docs
node_modules
.circleci
-.docker-home
.github
scripts
sdk/js
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 54686c72e9e..853b6de937d 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -23,9 +23,9 @@ jobs:
# We must fetch at least the immediate parents so that if this is
# a pull request then we can checkout the head.
fetch-depth: 2
- - uses: actions/setup-go@v2
+ - uses: actions/setup-go@v3
with:
- go-version: "1.19"
+ go-version: "1.20"
- name: Start service
run: ./test/conformance/start.sh
- name: Run tests
@@ -80,22 +80,21 @@ jobs:
path: |
internal/httpclient
key: ${{ needs.sdk-generate.outputs.sdk-cache-key }}
- - uses: actions/setup-go@v2
+ - uses: actions/setup-go@v4
with:
- go-version: "1.19"
+ go-version: "1.20"
- run: go list -json > go.list
- name: Run nancy
uses: sonatype-nexus-community/nancy-github-action@v1.0.2
with:
nancyVersion: v1.0.42
- name: Run golangci-lint
- uses: golangci/golangci-lint-action@v2
+ uses: golangci/golangci-lint-action@v3
env:
GOGC: 100
with:
args: --timeout 10m0s
- version: v1.47.3
- skip-go-installation: true
+ version: v1.53.2
skip-pkg-cache: true
- name: Run go-acc (tests)
run: |
@@ -124,9 +123,9 @@ jobs:
path: |
internal/httpclient
key: ${{ needs.sdk-generate.outputs.sdk-cache-key }}
- - uses: actions/setup-go@v2
+ - uses: actions/setup-go@v3
with:
- go-version: "1.19"
+ go-version: "1.20"
- name: Setup HSM libs and packages
run: |
sudo apt install -y softhsm opensc
@@ -175,9 +174,9 @@ jobs:
docker start cockroach
name: Start CockroachDB
- uses: ory/ci/checkout@master
- - uses: actions/setup-go@v2
+ - uses: actions/setup-go@v3
with:
- go-version: "1.19"
+ go-version: "1.20"
- uses: actions/cache@v2
with:
path: ./test/e2e/hydra
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index a7a720ebc0a..80515a61723 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
- go-version: 1.19
+ go-version: "1.20"
- run: make format
- name: Indicate formatting issues
run: git diff HEAD --exit-code --color
diff --git a/.github/workflows/licenses.yml b/.github/workflows/licenses.yml
index a4592c63ced..6f219dbced1 100644
--- a/.github/workflows/licenses.yml
+++ b/.github/workflows/licenses.yml
@@ -11,10 +11,10 @@ jobs:
check:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
- - uses: actions/setup-go@v2
+ - uses: actions/checkout@v3
+ - uses: actions/setup-go@v3
with:
- go-version: "1.18"
+ go-version: "1.20"
- uses: actions/setup-node@v2
with:
node-version: "18"
diff --git a/.golangci.yml b/.golangci.yml
index c3461c51f45..00ee1f9963c 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -5,9 +5,7 @@ linters:
- gosimple
- bodyclose
- staticcheck
- # Disabled due to Go 1.19 changes and Go-Swagger incompatibility
- # https://github.com/ory/hydra/issues/3227
- # - goimports
+ - goimports
disable:
- ineffassign
- deadcode
diff --git a/.grype.yml b/.grype.yml
new file mode 100644
index 00000000000..56d262246ac
--- /dev/null
+++ b/.grype.yml
@@ -0,0 +1,2 @@
+ignore:
+ - vulnerability: CVE-2023-2650
diff --git a/.trivyignore b/.trivyignore
new file mode 100644
index 00000000000..73859219e24
--- /dev/null
+++ b/.trivyignore
@@ -0,0 +1 @@
+CVE-2023-2650
diff --git a/Makefile b/Makefile
index fc0ab25510e..72978178bc6 100644
--- a/Makefile
+++ b/Makefile
@@ -5,10 +5,11 @@ export PATH := .bin:${PATH}
export PWD := $(shell pwd)
export IMAGE_TAG := $(if $(IMAGE_TAG),$(IMAGE_TAG),latest-sqlite)
-GOLANGCI_LINT_VERSION = 1.46.2
+GOLANGCI_LINT_VERSION = 1.53.2
GO_DEPENDENCIES = github.com/ory/go-acc \
github.com/golang/mock/mockgen \
+ golang.org/x/tools/cmd/goimports \
github.com/go-swagger/go-swagger/cmd/swagger
define make-go-dependency
@@ -37,9 +38,6 @@ node_modules: package-lock.json
docs/cli: .bin/clidoc
clidoc .
-.bin/goimports: go.sum Makefile
- GOBIN=$(shell pwd)/.bin go install golang.org/x/tools/cmd/goimports@latest
-
.bin/licenses: Makefile
curl https://raw.githubusercontent.com/ory/ci/master/licenses/install | sh
@@ -63,12 +61,9 @@ test: .bin/go-acc
# Resets the test databases
.PHONY: test-resetdb
test-resetdb: node_modules
- docker kill hydra_test_database_mysql || true
- docker kill hydra_test_database_postgres || true
- docker kill hydra_test_database_cockroach || true
- docker rm -f hydra_test_database_mysql || true
- docker rm -f hydra_test_database_postgres || true
- docker rm -f hydra_test_database_cockroach || true
+ docker rm --force --volumes hydra_test_database_mysql || true
+ docker rm --force --volumes hydra_test_database_postgres || true
+ docker rm --force --volumes hydra_test_database_cockroach || true
docker run --rm --name hydra_test_database_mysql --platform linux/amd64 -p 3444:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8.0.26
docker run --rm --name hydra_test_database_postgres --platform linux/amd64 -p 3445:5432 -e POSTGRES_PASSWORD=secret -e POSTGRES_DB=postgres -d postgres:11.8
docker run --rm --name hydra_test_database_cockroach --platform linux/amd64 -p 3446:26257 -d cockroachdb/cockroach:v22.1.10 start-single-node --insecure
@@ -122,6 +117,7 @@ sdk: .bin/swagger .bin/ory node_modules
swagger generate spec -m -o spec/swagger.json \
-c github.com/ory/hydra/v2/client \
-c github.com/ory/hydra/v2/consent \
+ -c github.com/ory/hydra/v2/flow \
-c github.com/ory/hydra/v2/health \
-c github.com/ory/hydra/v2/jwk \
-c github.com/ory/hydra/v2/oauth2 \
diff --git a/aead/aead.go b/aead/aead.go
new file mode 100644
index 00000000000..a3cb8b89ffe
--- /dev/null
+++ b/aead/aead.go
@@ -0,0 +1,28 @@
+// Copyright © 2023 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package aead
+
+import (
+ "context"
+
+ "github.com/ory/fosite"
+)
+
+// Cipher provides AEAD (authenticated encryption with associated data). The
+// ciphertext is returned base64url-encoded.
+type Cipher interface {
+ // Encrypt encrypts and encodes the given plaintext, optionally using
+ // additiona data.
+ Encrypt(ctx context.Context, plaintext, additionalData []byte) (ciphertext string, err error)
+
+ // Decrypt decodes, decrypts, and verifies the plaintext and additional data
+ // from the ciphertext. The ciphertext must be given in the form as returned
+ // by Encrypt.
+ Decrypt(ctx context.Context, ciphertext string, additionalData []byte) (plaintext []byte, err error)
+}
+
+type Dependencies interface {
+ fosite.GlobalSecretProvider
+ fosite.RotatedGlobalSecretsProvider
+}
diff --git a/aead/aead_test.go b/aead/aead_test.go
new file mode 100644
index 00000000000..4cb93f5c3e7
--- /dev/null
+++ b/aead/aead_test.go
@@ -0,0 +1,154 @@
+// Copyright © 2022 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package aead_test
+
+import (
+ "context"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "testing"
+
+ "github.com/ory/hydra/v2/aead"
+ "github.com/ory/hydra/v2/driver/config"
+ "github.com/ory/hydra/v2/internal"
+
+ "github.com/pborman/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func secret(t *testing.T) string {
+ bytes := make([]byte, 32)
+ _, err := io.ReadFull(rand.Reader, bytes)
+ require.NoError(t, err)
+ return fmt.Sprintf("%X", bytes)
+}
+
+func TestAEAD(t *testing.T) {
+ t.Parallel()
+ for _, tc := range []struct {
+ name string
+ new func(aead.Dependencies) aead.Cipher
+ }{
+ {"AES-GCM", func(d aead.Dependencies) aead.Cipher { return aead.NewAESGCM(d) }},
+ {"XChaChaPoly", func(d aead.Dependencies) aead.Cipher { return aead.NewXChaCha20Poly1305(d) }},
+ } {
+ tc := tc
+
+ t.Run("cipher="+tc.name, func(t *testing.T) {
+ NewCipher := tc.new
+
+ t.Run("case=without-rotation", func(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ c := internal.NewConfigurationWithDefaults()
+ c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
+ a := NewCipher(c)
+
+ plain := []byte(uuid.New())
+ ct, err := a.Encrypt(ctx, plain, nil)
+ assert.NoError(t, err)
+
+ ct2, err := a.Encrypt(ctx, plain, nil)
+ assert.NoError(t, err)
+ assert.NotEqual(t, ct, ct2, "ciphertexts for the same plaintext must be different each time")
+
+ res, err := a.Decrypt(ctx, ct, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, plain, res)
+ })
+
+ t.Run("case=wrong-secret", func(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ c := internal.NewConfigurationWithDefaults()
+ c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
+ a := NewCipher(c)
+
+ ct, err := a.Encrypt(ctx, []byte(uuid.New()), nil)
+ require.NoError(t, err)
+
+ c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
+ _, err = a.Decrypt(ctx, ct, nil)
+ require.Error(t, err)
+ })
+
+ t.Run("case=with-rotation", func(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ c := internal.NewConfigurationWithDefaults()
+ old := secret(t)
+ c.MustSet(ctx, config.KeyGetSystemSecret, []string{old})
+ a := NewCipher(c)
+
+ plain := []byte(uuid.New())
+ ct, err := a.Encrypt(ctx, plain, nil)
+ require.NoError(t, err)
+
+ // Sets the old secret as a rotated secret and creates a new one.
+ c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t), old})
+ res, err := a.Decrypt(ctx, ct, nil)
+ require.NoError(t, err)
+ assert.Equal(t, plain, res)
+
+ // THis should also work when we re-encrypt the same plain text.
+ ct2, err := a.Encrypt(ctx, plain, nil)
+ require.NoError(t, err)
+ assert.NotEqual(t, ct2, ct)
+
+ res, err = a.Decrypt(ctx, ct, nil)
+ require.NoError(t, err)
+ assert.Equal(t, plain, res)
+ })
+
+ t.Run("case=with-rotation-wrong-secret", func(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ c := internal.NewConfigurationWithDefaults()
+ c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
+ a := NewCipher(c)
+
+ plain := []byte(uuid.New())
+ ct, err := a.Encrypt(ctx, plain, nil)
+ require.NoError(t, err)
+
+ // When the secrets do not match, an error should be thrown during decryption.
+ c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t), secret(t)})
+ _, err = a.Decrypt(ctx, ct, nil)
+ require.Error(t, err)
+ })
+
+ t.Run("suite=with additional data", func(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ c := internal.NewConfigurationWithDefaults()
+ c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
+ a := NewCipher(c)
+
+ plain := []byte(uuid.New())
+ ct, err := a.Encrypt(ctx, plain, []byte("additional data"))
+ assert.NoError(t, err)
+
+ t.Run("case=additional data matches", func(t *testing.T) {
+ res, err := a.Decrypt(ctx, ct, []byte("additional data"))
+ assert.NoError(t, err)
+ assert.Equal(t, plain, res)
+ })
+
+ t.Run("case=additional data does not match", func(t *testing.T) {
+ res, err := a.Decrypt(ctx, ct, []byte("wrong data"))
+ assert.Error(t, err)
+ assert.Nil(t, res)
+ })
+
+ t.Run("case=missing additional data", func(t *testing.T) {
+ res, err := a.Decrypt(ctx, ct, nil)
+ assert.Error(t, err)
+ assert.Nil(t, res)
+ })
+ })
+ })
+ }
+}
diff --git a/aead/aesgcm.go b/aead/aesgcm.go
new file mode 100644
index 00000000000..86ae12839ec
--- /dev/null
+++ b/aead/aesgcm.go
@@ -0,0 +1,126 @@
+// Copyright © 2022 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package aead
+
+import (
+ "context"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/base64"
+ "io"
+
+ "github.com/pkg/errors"
+
+ "github.com/ory/x/errorsx"
+)
+
+type AESGCM struct {
+ c Dependencies
+}
+
+func NewAESGCM(c Dependencies) *AESGCM {
+ return &AESGCM{c: c}
+}
+
+func aeadKey(key []byte) *[32]byte {
+ var result [32]byte
+ copy(result[:], key[:32])
+ return &result
+}
+
+func (c *AESGCM) Encrypt(ctx context.Context, plaintext, additionalData []byte) (string, error) {
+ key, err := encryptionKey(ctx, c.c, 32)
+ if err != nil {
+ return "", err
+ }
+
+ ciphertext, err := aesGCMEncrypt(plaintext, aeadKey(key), additionalData)
+ if err != nil {
+ return "", errorsx.WithStack(err)
+ }
+
+ return base64.URLEncoding.EncodeToString(ciphertext), nil
+}
+
+func (c *AESGCM) Decrypt(ctx context.Context, ciphertext string, aad []byte) (plaintext []byte, err error) {
+ msg, err := base64.URLEncoding.DecodeString(ciphertext)
+ if err != nil {
+ return nil, errorsx.WithStack(err)
+ }
+
+ keys, err := allKeys(ctx, c.c)
+ if err != nil {
+ return nil, errorsx.WithStack(err)
+ }
+
+ for _, key := range keys {
+ if plaintext, err = c.decrypt(msg, key, aad); err == nil {
+ return plaintext, nil
+ }
+ }
+
+ return nil, err
+}
+
+func (c *AESGCM) decrypt(ciphertext []byte, key, additionalData []byte) ([]byte, error) {
+ if len(key) != 32 {
+ return nil, errors.Errorf("key must be exactly 32 long bytes, got %d bytes", len(key))
+ }
+
+ plaintext, err := aesGCMDecrypt(ciphertext, aeadKey(key), additionalData)
+ if err != nil {
+ return nil, errorsx.WithStack(err)
+ }
+
+ return plaintext, nil
+}
+
+// aesGCMEncrypt encrypts data using 256-bit AES-GCM. This both hides the content of
+// the data and provides a check that it hasn't been altered. Output takes the
+// form nonce|ciphertext|tag where '|' indicates concatenation.
+func aesGCMEncrypt(plaintext []byte, key *[32]byte, additionalData []byte) (ciphertext []byte, err error) {
+ block, err := aes.NewCipher(key[:])
+ if err != nil {
+ return nil, err
+ }
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return nil, err
+ }
+
+ nonce := make([]byte, gcm.NonceSize())
+ _, err = io.ReadFull(rand.Reader, nonce)
+ if err != nil {
+ return nil, err
+ }
+
+ return gcm.Seal(nonce, nonce, plaintext, additionalData), nil
+}
+
+// aesGCMDecrypt decrypts data using 256-bit AES-GCM. This both hides the content of
+// the data and provides a check that it hasn't been altered. Expects input
+// form nonce|ciphertext|tag where '|' indicates concatenation.
+func aesGCMDecrypt(ciphertext []byte, key *[32]byte, additionalData []byte) (plaintext []byte, err error) {
+ block, err := aes.NewCipher(key[:])
+ if err != nil {
+ return nil, err
+ }
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(ciphertext) < gcm.NonceSize() {
+ return nil, errors.New("malformed ciphertext")
+ }
+
+ return gcm.Open(nil,
+ ciphertext[:gcm.NonceSize()],
+ ciphertext[gcm.NonceSize():],
+ additionalData,
+ )
+}
diff --git a/aead/helpers.go b/aead/helpers.go
new file mode 100644
index 00000000000..7acd06c3a0d
--- /dev/null
+++ b/aead/helpers.go
@@ -0,0 +1,41 @@
+// Copyright © 2023 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package aead
+
+import (
+ "context"
+ "fmt"
+)
+
+func encryptionKey(ctx context.Context, d Dependencies, keySize int) ([]byte, error) {
+ keys, err := allKeys(ctx, d)
+ if err != nil {
+ return nil, err
+ }
+
+ key := keys[0]
+ if len(key) != keySize {
+ return nil, fmt.Errorf("key must be exactly %d bytes long, got %d bytes", keySize, len(key))
+ }
+
+ return key, nil
+}
+
+func allKeys(ctx context.Context, d Dependencies) ([][]byte, error) {
+ global, err := d.GetGlobalSecret(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ rotated, err := d.GetRotatedGlobalSecrets(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ keys := append([][]byte{global}, rotated...)
+ if len(keys) == 0 {
+ return nil, fmt.Errorf("at least one encryption key must be defined but none were")
+ }
+ return keys, nil
+}
diff --git a/aead/xchacha20.go b/aead/xchacha20.go
new file mode 100644
index 00000000000..cb1d2fbf278
--- /dev/null
+++ b/aead/xchacha20.go
@@ -0,0 +1,80 @@
+// Copyright © 2023 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package aead
+
+import (
+ "context"
+ "crypto/cipher"
+ cryptorand "crypto/rand"
+ "encoding/base64"
+ "fmt"
+
+ "golang.org/x/crypto/chacha20poly1305"
+
+ "github.com/ory/x/errorsx"
+)
+
+var _ Cipher = (*XChaCha20Poly1305)(nil)
+
+type (
+ XChaCha20Poly1305 struct {
+ d Dependencies
+ }
+)
+
+func NewXChaCha20Poly1305(d Dependencies) *XChaCha20Poly1305 {
+ return &XChaCha20Poly1305{d}
+}
+
+func (x *XChaCha20Poly1305) Encrypt(ctx context.Context, plaintext, additionalData []byte) (string, error) {
+ key, err := encryptionKey(ctx, x.d, chacha20poly1305.KeySize)
+ if err != nil {
+ return "", err
+ }
+
+ aead, err := chacha20poly1305.NewX(key)
+ if err != nil {
+ return "", errorsx.WithStack(err)
+ }
+
+ nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+len(plaintext)+aead.Overhead())
+ _, err = cryptorand.Read(nonce)
+ if err != nil {
+ return "", errorsx.WithStack(err)
+ }
+
+ ciphertext := aead.Seal(nonce, nonce, plaintext, additionalData)
+ return base64.URLEncoding.EncodeToString(ciphertext), nil
+}
+
+func (x *XChaCha20Poly1305) Decrypt(ctx context.Context, ciphertext string, aad []byte) (plaintext []byte, err error) {
+ msg, err := base64.URLEncoding.DecodeString(ciphertext)
+ if err != nil {
+ return nil, errorsx.WithStack(err)
+ }
+
+ if len(msg) < chacha20poly1305.NonceSizeX {
+ return nil, errorsx.WithStack(fmt.Errorf("malformed ciphertext: too short"))
+ }
+ nonce, ciphered := msg[:chacha20poly1305.NonceSizeX], msg[chacha20poly1305.NonceSizeX:]
+
+ keys, err := allKeys(ctx, x.d)
+ if err != nil {
+ return nil, errorsx.WithStack(err)
+ }
+
+ var aead cipher.AEAD
+ for _, key := range keys {
+ aead, err = chacha20poly1305.NewX(key)
+ if err != nil {
+ continue
+ }
+ plaintext, err = aead.Open(nil, nonce, ciphered, aad)
+ if err == nil {
+ return plaintext, nil
+ }
+ }
+
+ return nil, errorsx.WithStack(err)
+}
diff --git a/client/client.go b/client/client.go
index 3f01b1099f0..57fdca8b46f 100644
--- a/client/client.go
+++ b/client/client.go
@@ -4,9 +4,12 @@
package client
import (
+ "strconv"
"strings"
"time"
+ "github.com/twmb/murmur3"
+
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/x/stringsx"
@@ -560,3 +563,13 @@ func AccessTokenStrategySource(client fosite.Client) config.AccessTokenStrategyS
}
return nil
}
+
+func (c *Client) CookieSuffix() string {
+ return CookieSuffix(c)
+}
+
+type IDer interface{ GetID() string }
+
+func CookieSuffix(client IDer) string {
+ return strconv.Itoa(int(murmur3.Sum32([]byte(client.GetID()))))
+}
diff --git a/client/handler.go b/client/handler.go
index eb1f7b7660c..756b47e9baa 100644
--- a/client/handler.go
+++ b/client/handler.go
@@ -67,6 +67,8 @@ func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin, public *httprouterx.
// OAuth 2.0 Client Creation Parameters
//
// swagger:parameters createOAuth2Client
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type createOAuth2Client struct {
// OAuth 2.0 Client Request Body
//
@@ -107,6 +109,8 @@ func (h *Handler) createOAuth2Client(w http.ResponseWriter, r *http.Request, _ h
// OpenID Connect Dynamic Client Registration Parameters
//
// swagger:parameters createOidcDynamicClient
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type createOidcDynamicClient struct {
// Dynamic Client Registration Request Body
//
@@ -214,6 +218,8 @@ func (h *Handler) CreateClient(r *http.Request, validator func(context.Context,
// Set OAuth 2.0 Client Parameters
//
// swagger:parameters setOAuth2Client
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type setOAuth2Client struct {
// OAuth 2.0 Client ID
//
@@ -290,6 +296,8 @@ func (h *Handler) updateClient(ctx context.Context, c *Client, validator func(co
// Set Dynamic Client Parameters
//
// swagger:parameters setOidcDynamicClient
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type setOidcDynamicClient struct {
// OAuth 2.0 Client ID
//
@@ -383,6 +391,8 @@ func (h *Handler) setOidcDynamicClient(w http.ResponseWriter, r *http.Request, p
// Patch OAuth 2.0 Client Parameters
//
// swagger:parameters patchOAuth2Client
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type patchOAuth2Client struct {
// The id of the OAuth 2.0 Client.
//
@@ -460,6 +470,8 @@ func (h *Handler) patchOAuth2Client(w http.ResponseWriter, r *http.Request, ps h
// Paginated OAuth2 Client List Response
//
// swagger:response listOAuth2Clients
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type listOAuth2ClientsResponse struct {
tokenpagination.ResponseHeaders
@@ -472,6 +484,8 @@ type listOAuth2ClientsResponse struct {
// Paginated OAuth2 Client List Parameters
//
// swagger:parameters listOAuth2Clients
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type listOAuth2ClientsParameters struct {
tokenpagination.RequestParameters
@@ -540,6 +554,8 @@ func (h *Handler) listOAuth2Clients(w http.ResponseWriter, r *http.Request, ps h
// Get OAuth2 Client Parameters
//
// swagger:parameters getOAuth2Client
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type adminGetOAuth2Client struct {
// The id of the OAuth 2.0 Client.
//
@@ -583,6 +599,8 @@ func (h *Handler) Get(w http.ResponseWriter, r *http.Request, ps httprouter.Para
// Get OpenID Connect Dynamic Client Parameters
//
// swagger:parameters getOidcDynamicClient
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type getOidcDynamicClient struct {
// The id of the OAuth 2.0 Client.
//
@@ -644,6 +662,8 @@ func (h *Handler) getOidcDynamicClient(w http.ResponseWriter, r *http.Request, p
// Delete OAuth2 Client Parameters
//
// swagger:parameters deleteOAuth2Client
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type deleteOAuth2Client struct {
// The id of the OAuth 2.0 Client.
//
@@ -687,6 +707,8 @@ func (h *Handler) deleteOAuth2Client(w http.ResponseWriter, r *http.Request, ps
// Set OAuth 2.0 Client Token Lifespans
//
// swagger:parameters setOAuth2ClientLifespans
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type setOAuth2ClientLifespans struct {
// OAuth 2.0 Client ID
//
@@ -738,6 +760,8 @@ func (h *Handler) setOAuth2ClientLifespans(w http.ResponseWriter, r *http.Reques
}
// swagger:parameters deleteOidcDynamicClient
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type dynamicClientRegistrationDeleteOAuth2Client struct {
// The id of the OAuth 2.0 Client.
//
diff --git a/client/manager.go b/client/manager.go
index ad8cca7df51..6b0d9c5de05 100644
--- a/client/manager.go
+++ b/client/manager.go
@@ -49,3 +49,7 @@ type Storage interface {
GetConcreteClient(ctx context.Context, id string) (*Client, error)
}
+
+type ManagerProvider interface {
+ ClientManager() Manager
+}
diff --git a/cmd/cli/handler_janitor.go b/cmd/cli/handler_janitor.go
index e5082bf4b57..69035148b6e 100644
--- a/cmd/cli/handler_janitor.go
+++ b/cmd/cli/handler_janitor.go
@@ -6,6 +6,7 @@ package cli
import (
"context"
"fmt"
+ "io"
"time"
"github.com/ory/x/servicelocatorx"
@@ -52,12 +53,13 @@ func NewJanitorHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsMo
}
}
-func (_ *JanitorHandler) Args(cmd *cobra.Command, args []string) error {
+func (*JanitorHandler) Args(cmd *cobra.Command, args []string) error {
if len(args) == 0 &&
!flagx.MustGetBool(cmd, ReadFromEnv) &&
len(flagx.MustGetStringSlice(cmd, Config)) == 0 {
fmt.Printf("%s\n", cmd.UsageString())
+ //lint:ignore ST1005 formatted error string used in CLI output
return fmt.Errorf("%s\n%s\n%s\n",
"A DSN is required as a positional argument when not passing any of the following flags:",
"- Using the environment variable with flag -e, --read-from-env",
@@ -65,6 +67,7 @@ func (_ *JanitorHandler) Args(cmd *cobra.Command, args []string) error {
}
if !flagx.MustGetBool(cmd, OnlyTokens) && !flagx.MustGetBool(cmd, OnlyRequests) && !flagx.MustGetBool(cmd, OnlyGrants) {
+ //lint:ignore ST1005 formatted error string used in CLI output
return fmt.Errorf("%s\n%s\n", cmd.UsageString(),
"Janitor requires at least one of --tokens, --requests or --grants to be set")
}
@@ -72,10 +75,12 @@ func (_ *JanitorHandler) Args(cmd *cobra.Command, args []string) error {
limit := flagx.MustGetInt(cmd, Limit)
batchSize := flagx.MustGetInt(cmd, BatchSize)
if limit <= 0 || batchSize <= 0 {
+ //lint:ignore ST1005 formatted error string used in CLI output
return fmt.Errorf("%s\n%s\n", cmd.UsageString(),
"Values for --limit and --batch-size should both be greater than 0")
}
if batchSize > limit {
+ //lint:ignore ST1005 formatted error string used in CLI output
return fmt.Errorf("%s\n%s\n", cmd.UsageString(),
"Value for --batch-size must not be greater than value for --limit")
}
@@ -130,6 +135,7 @@ func purge(cmd *cobra.Command, args []string, sl *servicelocatorx.Options, dOpts
}
if len(d.Config().DSN()) == 0 {
+ //lint:ignore ST1005 formatted error string used in CLI output
return fmt.Errorf("%s\n%s\n%s\n", cmd.UsageString(),
"When using flag -e, environment variable DSN must be set.",
"When using flag -c, the dsn property should be set.")
@@ -154,20 +160,20 @@ func purge(cmd *cobra.Command, args []string, sl *servicelocatorx.Options, dOpts
routineFlags = append(routineFlags, OnlyGrants)
}
- return cleanupRun(cmd.Context(), notAfter, limit, batchSize, addRoutine(p, routineFlags...)...)
+ return cleanupRun(cmd.Context(), notAfter, limit, batchSize, addRoutine(cmd.OutOrStdout(), p, routineFlags...)...)
}
-func addRoutine(p persistence.Persister, names ...string) []cleanupRoutine {
+func addRoutine(out io.Writer, p persistence.Persister, names ...string) []cleanupRoutine {
var routines []cleanupRoutine
for _, n := range names {
switch n {
case OnlyTokens:
- routines = append(routines, cleanup(p.FlushInactiveAccessTokens, "access tokens"))
- routines = append(routines, cleanup(p.FlushInactiveRefreshTokens, "refresh tokens"))
+ routines = append(routines, cleanup(out, p.FlushInactiveAccessTokens, "access tokens"))
+ routines = append(routines, cleanup(out, p.FlushInactiveRefreshTokens, "refresh tokens"))
case OnlyRequests:
- routines = append(routines, cleanup(p.FlushInactiveLoginConsentRequests, "login-consent requests"))
+ routines = append(routines, cleanup(out, p.FlushInactiveLoginConsentRequests, "login-consent requests"))
case OnlyGrants:
- routines = append(routines, cleanup(p.FlushInactiveGrants, "grants"))
+ routines = append(routines, cleanup(out, p.FlushInactiveGrants, "grants"))
}
}
return routines
@@ -175,12 +181,12 @@ func addRoutine(p persistence.Persister, names ...string) []cleanupRoutine {
type cleanupRoutine func(ctx context.Context, notAfter time.Time, limit int, batchSize int) error
-func cleanup(cr cleanupRoutine, routineName string) cleanupRoutine {
+func cleanup(out io.Writer, cr cleanupRoutine, routineName string) cleanupRoutine {
return func(ctx context.Context, notAfter time.Time, limit int, batchSize int) error {
if err := cr(ctx, notAfter, limit, batchSize); err != nil {
return errors.Wrap(errorsx.WithStack(err), fmt.Sprintf("Could not cleanup inactive %s", routineName))
}
- fmt.Printf("Successfully completed Janitor run on %s\n", routineName)
+ fmt.Fprintf(out, "Successfully completed Janitor run on %s\n", routineName)
return nil
}
}
diff --git a/cmd/cli/handler_janitor_test.go b/cmd/cli/handler_janitor_test.go
index 9f73beea846..7806f7c5471 100644
--- a/cmd/cli/handler_janitor_test.go
+++ b/cmd/cli/handler_janitor_test.go
@@ -48,7 +48,7 @@ func TestJanitorHandler_PurgeTokenNotAfter(t *testing.T) {
fmt.Sprintf("--%s=%s", cli.AccessLifespan, jt.GetAccessTokenLifespan(ctx).String()),
fmt.Sprintf("--%s=%s", cli.RefreshLifespan, jt.GetRefreshTokenLifespan(ctx).String()),
fmt.Sprintf("--%s", cli.OnlyTokens),
- jt.GetDSN(ctx),
+ jt.GetDSN(),
)
})
@@ -80,13 +80,13 @@ func TestJanitorHandler_PurgeLoginConsentNotAfter(t *testing.T) {
fmt.Sprintf("--%s=%s", cli.KeepIfYounger, v.String()),
fmt.Sprintf("--%s=%s", cli.ConsentRequestLifespan, jt.GetConsentRequestLifespan(ctx).String()),
fmt.Sprintf("--%s", cli.OnlyRequests),
- jt.GetDSN(ctx),
+ jt.GetDSN(),
)
})
notAfter := time.Now().Round(time.Second).Add(-v)
consentLifespan := time.Now().Round(time.Second).Add(-jt.GetConsentRequestLifespan(ctx))
- t.Run("step=validate", jt.LoginConsentNotAfterValidate(ctx, notAfter, consentLifespan, reg.ConsentManager()))
+ t.Run("step=validate", jt.LoginConsentNotAfterValidate(ctx, notAfter, consentLifespan, reg))
})
}
@@ -107,14 +107,14 @@ func TestJanitorHandler_PurgeLoginConsent(t *testing.T) {
require.NoError(t, err)
// setup
- t.Run("step=setup", jt.LoginTimeoutSetup(ctx, reg.ConsentManager(), reg.ClientManager()))
+ t.Run("step=setup", jt.LoginTimeoutSetup(ctx, reg))
// cleanup
t.Run("step=cleanup", func(t *testing.T) {
cmdx.ExecNoErr(t, newJanitorCmd(),
"janitor",
fmt.Sprintf("--%s", cli.OnlyRequests),
- jt.GetDSN(ctx),
+ jt.GetDSN(),
)
})
@@ -129,14 +129,14 @@ func TestJanitorHandler_PurgeLoginConsent(t *testing.T) {
require.NoError(t, err)
// setup
- t.Run("step=setup", jt.ConsentTimeoutSetup(ctx, reg.ConsentManager(), reg.ClientManager()))
+ t.Run("step=setup", jt.ConsentTimeoutSetup(ctx, reg))
// run cleanup
t.Run("step=cleanup", func(t *testing.T) {
cmdx.ExecNoErr(t, newJanitorCmd(),
"janitor",
fmt.Sprintf("--%s", cli.OnlyRequests),
- jt.GetDSN(ctx),
+ jt.GetDSN(),
)
})
@@ -155,14 +155,14 @@ func TestJanitorHandler_PurgeLoginConsent(t *testing.T) {
require.NoError(t, err)
// setup
- t.Run("step=setup", jt.LoginRejectionSetup(ctx, reg.ConsentManager(), reg.ClientManager()))
+ t.Run("step=setup", jt.LoginRejectionSetup(ctx, reg))
// cleanup
t.Run("step=cleanup", func(t *testing.T) {
cmdx.ExecNoErr(t, newJanitorCmd(),
"janitor",
fmt.Sprintf("--%s", cli.OnlyRequests),
- jt.GetDSN(ctx),
+ jt.GetDSN(),
)
})
@@ -176,14 +176,14 @@ func TestJanitorHandler_PurgeLoginConsent(t *testing.T) {
require.NoError(t, err)
// setup
- t.Run("step=setup", jt.ConsentRejectionSetup(ctx, reg.ConsentManager(), reg.ClientManager()))
+ t.Run("step=setup", jt.ConsentRejectionSetup(ctx, reg))
// cleanup
t.Run("step=cleanup", func(t *testing.T) {
cmdx.ExecNoErr(t, newJanitorCmd(),
"janitor",
fmt.Sprintf("--%s", cli.OnlyRequests),
- jt.GetDSN(ctx),
+ jt.GetDSN(),
)
})
@@ -279,7 +279,7 @@ func TestJanitorHandler_PurgeGrantNotAfter(t *testing.T) {
require.NoError(t, err)
// setup test
- t.Run("step=setup", jt.GrantNotAfterSetup(ctx, reg.ClientManager(), reg.GrantManager()))
+ t.Run("step=setup", jt.GrantNotAfterSetup(ctx, reg.GrantManager()))
// run the cleanup routine
t.Run("step=cleanup", func(t *testing.T) {
@@ -287,7 +287,7 @@ func TestJanitorHandler_PurgeGrantNotAfter(t *testing.T) {
"janitor",
fmt.Sprintf("--%s=%s", cli.KeepIfYounger, v.String()),
fmt.Sprintf("--%s", cli.OnlyGrants),
- jt.GetDSN(ctx),
+ jt.GetDSN(),
)
})
diff --git a/cmd/cmd_list_clients.go b/cmd/cmd_list_clients.go
index 2c8f79356cf..ddaf2762018 100644
--- a/cmd/cmd_list_clients.go
+++ b/cmd/cmd_list_clients.go
@@ -35,11 +35,10 @@ func NewListClientsCmd() *cobra.Command {
if err != nil {
return cmdx.PrintOpenAPIError(cmd, err)
}
+ defer resp.Body.Close()
var collection outputOAuth2ClientCollection
- for k := range list {
- collection.clients = append(collection.clients, list[k])
- }
+ collection.clients = append(collection.clients, list...)
interfaceList := make([]interface{}, len(list))
for k := range collection.clients {
diff --git a/cmd/output_client.go b/cmd/output_client.go
index 1b052c56967..3f060f281df 100644
--- a/cmd/output_client.go
+++ b/cmd/output_client.go
@@ -19,7 +19,7 @@ type (
}
)
-func (_ outputOAuth2Client) Header() []string {
+func (outputOAuth2Client) Header() []string {
return []string{"CLIENT ID", "CLIENT SECRET", "GRANT TYPES", "RESPONSE TYPES", "SCOPE", "AUDIENCE", "REDIRECT URIS"}
}
@@ -40,7 +40,7 @@ func (i outputOAuth2Client) Interface() interface{} {
return i
}
-func (_ outputOAuth2ClientCollection) Header() []string {
+func (outputOAuth2ClientCollection) Header() []string {
return outputOAuth2Client{}.Header()
}
diff --git a/cmd/output_introspection.go b/cmd/output_introspection.go
index e3aa576421d..1f89f016530 100644
--- a/cmd/output_introspection.go
+++ b/cmd/output_introspection.go
@@ -16,7 +16,7 @@ type (
outputOAuth2TokenIntrospection hydra.IntrospectedOAuth2Token
)
-func (_ outputOAuth2TokenIntrospection) Header() []string {
+func (outputOAuth2TokenIntrospection) Header() []string {
return []string{"ACTIVE", "SUBJECT", "CLIENT ID", "SCOPE", "EXPIRY", "TOKEN USE"}
}
diff --git a/cmd/output_jwks.go b/cmd/output_jwks.go
index 3b42af3b113..207e33a9d1f 100644
--- a/cmd/output_jwks.go
+++ b/cmd/output_jwks.go
@@ -20,7 +20,7 @@ type (
}
)
-func (_ outputJsonWebKey) Header() []string {
+func (outputJsonWebKey) Header() []string {
return []string{"SET ID", "KEY ID", "ALGORITHM", "USE"}
}
@@ -38,7 +38,7 @@ func (i outputJsonWebKey) Interface() interface{} {
return i
}
-func (_ outputJSONWebKeyCollection) Header() []string {
+func (outputJSONWebKeyCollection) Header() []string {
return outputJsonWebKey{}.Header()
}
diff --git a/cmd/output_token.go b/cmd/output_token.go
index c91add12cb5..17da6bf274c 100644
--- a/cmd/output_token.go
+++ b/cmd/output_token.go
@@ -16,7 +16,7 @@ type (
outputOAuth2Token oauth2.Token
)
-func (_ outputOAuth2Token) Header() []string {
+func (outputOAuth2Token) Header() []string {
return []string{"ACCESS TOKEN", "REFRESH TOKEN", "ID TOKEN", "EXPIRY"}
}
diff --git a/consent/handler.go b/consent/handler.go
index 81dd69b7541..2c1cb178ff0 100644
--- a/consent/handler.go
+++ b/consent/handler.go
@@ -9,6 +9,8 @@ import (
"net/url"
"time"
+ "github.com/ory/hydra/v2/flow"
+ "github.com/ory/hydra/v2/oauth2/flowctx"
"github.com/ory/x/pagination/tokenpagination"
"github.com/ory/x/httprouterx"
@@ -68,6 +70,8 @@ func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin) {
// Revoke OAuth 2.0 Consent Session Parameters
//
// swagger:parameters revokeOAuth2ConsentSessions
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type revokeOAuth2ConsentSessions struct {
// OAuth 2.0 Consent Subject
//
@@ -110,7 +114,7 @@ type revokeOAuth2ConsentSessions struct {
// Responses:
// 204: emptyResponse
// default: errorOAuth2
-func (h *Handler) revokeOAuth2ConsentSessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) revokeOAuth2ConsentSessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
subject := r.URL.Query().Get("subject")
client := r.URL.Query().Get("client")
allClients := r.URL.Query().Get("all") == "true"
@@ -141,6 +145,8 @@ func (h *Handler) revokeOAuth2ConsentSessions(w http.ResponseWriter, r *http.Req
// List OAuth 2.0 Consent Session Parameters
//
// swagger:parameters listOAuth2ConsentSessions
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type listOAuth2ConsentSessions struct {
tokenpagination.RequestParameters
@@ -176,7 +182,7 @@ type listOAuth2ConsentSessions struct {
// Responses:
// 200: oAuth2ConsentSessions
// default: errorOAuth2
-func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
subject := r.URL.Query().Get("subject")
if subject == "" {
h.r.Writer().WriteError(w, r, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(`Query parameter 'subject' is not defined but should have been.`)))
@@ -186,7 +192,7 @@ func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Reque
page, itemsPerPage := x.ParsePagination(r)
- var s []AcceptOAuth2ConsentRequest
+ var s []flow.AcceptOAuth2ConsentRequest
var err error
if len(loginSessionId) == 0 {
s, err = h.r.ConsentManager().FindSubjectsGrantedConsentRequests(r.Context(), subject, itemsPerPage, itemsPerPage*page)
@@ -194,21 +200,21 @@ func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Reque
s, err = h.r.ConsentManager().FindSubjectsSessionGrantedConsentRequests(r.Context(), subject, loginSessionId, itemsPerPage, itemsPerPage*page)
}
if errors.Is(err, ErrNoPreviousConsentFound) {
- h.r.Writer().Write(w, r, []OAuth2ConsentSession{})
+ h.r.Writer().Write(w, r, []flow.OAuth2ConsentSession{})
return
} else if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
- var a []OAuth2ConsentSession
+ var a []flow.OAuth2ConsentSession
for _, session := range s {
session.ConsentRequest.Client = sanitizeClient(session.ConsentRequest.Client)
- a = append(a, OAuth2ConsentSession(session))
+ a = append(a, flow.OAuth2ConsentSession(session))
}
if len(a) == 0 {
- a = []OAuth2ConsentSession{}
+ a = []flow.OAuth2ConsentSession{}
}
n, err := h.r.ConsentManager().CountSubjectsGrantedConsentRequests(r.Context(), subject)
@@ -224,6 +230,8 @@ func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Reque
// Revoke OAuth 2.0 Consent Login Sessions Parameters
//
// swagger:parameters revokeOAuth2LoginSessions
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type revokeOAuth2LoginSessions struct {
// OAuth 2.0 Subject
//
@@ -264,7 +272,7 @@ type revokeOAuth2LoginSessions struct {
// Responses:
// 204: emptyResponse
// default: errorOAuth2
-func (h *Handler) revokeOAuth2LoginSessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) revokeOAuth2LoginSessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
sid := r.URL.Query().Get("sid")
subject := r.URL.Query().Get("subject")
@@ -294,6 +302,8 @@ func (h *Handler) revokeOAuth2LoginSessions(w http.ResponseWriter, r *http.Reque
// Get OAuth 2.0 Login Request
//
// swagger:parameters getOAuth2LoginRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type getOAuth2LoginRequest struct {
// OAuth 2.0 Login Request Challenge
//
@@ -328,7 +338,7 @@ type getOAuth2LoginRequest struct {
// 200: oAuth2LoginRequest
// 410: oAuth2RedirectTo
// default: errorOAuth2
-func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
challenge := stringsx.Coalesce(
r.URL.Query().Get("login_challenge"),
r.URL.Query().Get("challenge"),
@@ -345,7 +355,7 @@ func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request,
return
}
if request.WasHandled {
- h.r.Writer().WriteCode(w, r, http.StatusGone, &OAuth2RedirectTo{
+ h.r.Writer().WriteCode(w, r, http.StatusGone, &flow.OAuth2RedirectTo{
RedirectTo: request.RequestURL,
})
return
@@ -358,6 +368,8 @@ func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request,
// Accept OAuth 2.0 Login Request
//
// swagger:parameters acceptOAuth2LoginRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type acceptOAuth2LoginRequest struct {
// OAuth 2.0 Login Request Challenge
//
@@ -366,7 +378,7 @@ type acceptOAuth2LoginRequest struct {
Challenge string `json:"login_challenge"`
// in: body
- Body HandledLoginRequest
+ Body flow.HandledLoginRequest
}
// swagger:route PUT /admin/oauth2/auth/requests/login/accept oAuth2 acceptOAuth2LoginRequest
@@ -396,7 +408,9 @@ type acceptOAuth2LoginRequest struct {
// Responses:
// 200: oAuth2RedirectTo
// default: errorOAuth2
-func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ ctx := r.Context()
+
challenge := stringsx.Coalesce(
r.URL.Query().Get("login_challenge"),
r.URL.Query().Get("challenge"),
@@ -406,7 +420,7 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques
return
}
- var p HandledLoginRequest
+ var p flow.HandledLoginRequest
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&p); err != nil {
@@ -420,7 +434,7 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques
}
p.ID = challenge
- ar, err := h.r.ConsentManager().GetLoginRequest(r.Context(), challenge)
+ ar, err := h.r.ConsentManager().GetLoginRequest(ctx, challenge)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
@@ -440,7 +454,12 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques
}
p.RequestedAt = ar.RequestedAt
- request, err := h.r.ConsentManager().HandleLoginRequest(r.Context(), challenge, &p)
+ f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsLoginChallenge)
+ if err != nil {
+ h.r.Writer().WriteError(w, r, err)
+ return
+ }
+ request, err := h.r.ConsentManager().HandleLoginRequest(ctx, f, challenge, &p)
if err != nil {
h.r.Writer().WriteError(w, r, errorsx.WithStack(err))
return
@@ -452,14 +471,22 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques
return
}
- h.r.Writer().Write(w, r, &OAuth2RedirectTo{
- RedirectTo: urlx.SetQuery(ru, url.Values{"login_verifier": {request.Verifier}}).String(),
+ verifier, err := f.ToLoginVerifier(ctx, h.r)
+ if err != nil {
+ h.r.Writer().WriteError(w, r, err)
+ return
+ }
+
+ h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{
+ RedirectTo: urlx.SetQuery(ru, url.Values{"login_verifier": {verifier}}).String(),
})
}
// Reject OAuth 2.0 Login Request
//
// swagger:parameters rejectOAuth2LoginRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type rejectOAuth2LoginRequest struct {
// OAuth 2.0 Login Request Challenge
//
@@ -468,7 +495,7 @@ type rejectOAuth2LoginRequest struct {
Challenge string `json:"login_challenge"`
// in: body
- Body RequestDeniedError
+ Body flow.RequestDeniedError
}
// swagger:route PUT /admin/oauth2/auth/requests/login/reject oAuth2 rejectOAuth2LoginRequest
@@ -497,7 +524,9 @@ type rejectOAuth2LoginRequest struct {
// Responses:
// 200: oAuth2RedirectTo
// default: errorOAuth2
-func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ ctx := r.Context()
+
challenge := stringsx.Coalesce(
r.URL.Query().Get("login_challenge"),
r.URL.Query().Get("challenge"),
@@ -507,7 +536,7 @@ func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques
return
}
- var p RequestDeniedError
+ var p flow.RequestDeniedError
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&p); err != nil {
@@ -515,15 +544,20 @@ func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques
return
}
- p.valid = true
- p.SetDefaults(loginRequestDeniedErrorName)
- ar, err := h.r.ConsentManager().GetLoginRequest(r.Context(), challenge)
+ p.Valid = true
+ p.SetDefaults(flow.LoginRequestDeniedErrorName)
+ ar, err := h.r.ConsentManager().GetLoginRequest(ctx, challenge)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
- request, err := h.r.ConsentManager().HandleLoginRequest(r.Context(), challenge, &HandledLoginRequest{
+ f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsLoginChallenge)
+ if err != nil {
+ h.r.Writer().WriteError(w, r, err)
+ return
+ }
+ request, err := h.r.ConsentManager().HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{
Error: &p,
ID: challenge,
RequestedAt: ar.RequestedAt,
@@ -533,20 +567,28 @@ func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques
return
}
+ verifier, err := f.ToLoginVerifier(ctx, h.r)
+ if err != nil {
+ h.r.Writer().WriteError(w, r, err)
+ return
+ }
+
ru, err := url.Parse(request.RequestURL)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}
- h.r.Writer().Write(w, r, &OAuth2RedirectTo{
- RedirectTo: urlx.SetQuery(ru, url.Values{"login_verifier": {request.Verifier}}).String(),
+ h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{
+ RedirectTo: urlx.SetQuery(ru, url.Values{"login_verifier": {verifier}}).String(),
})
}
// Get OAuth 2.0 Consent Request
//
// swagger:parameters getOAuth2ConsentRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type getOAuth2ConsentRequest struct {
// OAuth 2.0 Consent Request Challenge
//
@@ -582,7 +624,7 @@ type getOAuth2ConsentRequest struct {
// 200: oAuth2ConsentRequest
// 410: oAuth2RedirectTo
// default: errorOAuth2
-func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
challenge := stringsx.Coalesce(
r.URL.Query().Get("consent_challenge"),
r.URL.Query().Get("challenge"),
@@ -598,7 +640,7 @@ func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request
return
}
if request.WasHandled {
- h.r.Writer().WriteCode(w, r, http.StatusGone, &OAuth2RedirectTo{
+ h.r.Writer().WriteCode(w, r, http.StatusGone, &flow.OAuth2RedirectTo{
RedirectTo: request.RequestURL,
})
return
@@ -619,6 +661,8 @@ func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request
// Accept OAuth 2.0 Consent Request
//
// swagger:parameters acceptOAuth2ConsentRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type acceptOAuth2ConsentRequest struct {
// OAuth 2.0 Consent Request Challenge
//
@@ -627,7 +671,7 @@ type acceptOAuth2ConsentRequest struct {
Challenge string `json:"consent_challenge"`
// in: body
- Body AcceptOAuth2ConsentRequest
+ Body flow.AcceptOAuth2ConsentRequest
}
// swagger:route PUT /admin/oauth2/auth/requests/consent/accept oAuth2 acceptOAuth2ConsentRequest
@@ -662,7 +706,9 @@ type acceptOAuth2ConsentRequest struct {
// Responses:
// 200: oAuth2RedirectTo
// default: errorOAuth2
-func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ ctx := r.Context()
+
challenge := stringsx.Coalesce(
r.URL.Query().Get("consent_challenge"),
r.URL.Query().Get("challenge"),
@@ -672,7 +718,7 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
return
}
- var p AcceptOAuth2ConsentRequest
+ var p flow.AcceptOAuth2ConsentRequest
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&p); err != nil {
@@ -680,7 +726,7 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
return
}
- cr, err := h.r.ConsentManager().GetConsentRequest(r.Context(), challenge)
+ cr, err := h.r.ConsentManager().GetConsentRequest(ctx, challenge)
if err != nil {
h.r.Writer().WriteError(w, r, errorsx.WithStack(err))
return
@@ -690,7 +736,12 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
p.RequestedAt = cr.RequestedAt
p.HandledAt = sqlxx.NullTime(time.Now().UTC())
- hr, err := h.r.ConsentManager().HandleConsentRequest(r.Context(), &p)
+ f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsConsentChallenge)
+ if err != nil {
+ h.r.Writer().WriteError(w, r, err)
+ return
+ }
+ hr, err := h.r.ConsentManager().HandleConsentRequest(ctx, f, &p)
if err != nil {
h.r.Writer().WriteError(w, r, errorsx.WithStack(err))
return
@@ -704,14 +755,22 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
return
}
- h.r.Writer().Write(w, r, &OAuth2RedirectTo{
- RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {hr.Verifier}}).String(),
+ verifier, err := f.ToConsentVerifier(ctx, h.r)
+ if err != nil {
+ h.r.Writer().WriteError(w, r, err)
+ return
+ }
+
+ h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{
+ RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {verifier}}).String(),
})
}
// Reject OAuth 2.0 Consent Request
//
// swagger:parameters rejectOAuth2ConsentRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type adminRejectOAuth2ConsentRequest struct {
// OAuth 2.0 Consent Request Challenge
//
@@ -720,7 +779,7 @@ type adminRejectOAuth2ConsentRequest struct {
Challenge string `json:"consent_challenge"`
// in: body
- Body RequestDeniedError
+ Body flow.RequestDeniedError
}
// swagger:route PUT /admin/oauth2/auth/requests/consent/reject oAuth2 rejectOAuth2ConsentRequest
@@ -754,7 +813,9 @@ type adminRejectOAuth2ConsentRequest struct {
// Responses:
// 200: oAuth2RedirectTo
// default: errorOAuth2
-func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ ctx := r.Context()
+
challenge := stringsx.Coalesce(
r.URL.Query().Get("consent_challenge"),
r.URL.Query().Get("challenge"),
@@ -764,7 +825,7 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
return
}
- var p RequestDeniedError
+ var p flow.RequestDeniedError
d := json.NewDecoder(r.Body)
d.DisallowUnknownFields()
if err := d.Decode(&p); err != nil {
@@ -772,15 +833,21 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
return
}
- p.valid = true
- p.SetDefaults(consentRequestDeniedErrorName)
- hr, err := h.r.ConsentManager().GetConsentRequest(r.Context(), challenge)
+ p.Valid = true
+ p.SetDefaults(flow.ConsentRequestDeniedErrorName)
+ hr, err := h.r.ConsentManager().GetConsentRequest(ctx, challenge)
if err != nil {
h.r.Writer().WriteError(w, r, errorsx.WithStack(err))
return
}
- request, err := h.r.ConsentManager().HandleConsentRequest(r.Context(), &AcceptOAuth2ConsentRequest{
+ f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsConsentChallenge)
+ if err != nil {
+ h.r.Writer().WriteError(w, r, err)
+ return
+ }
+
+ request, err := h.r.ConsentManager().HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{
Error: &p,
ID: challenge,
RequestedAt: hr.RequestedAt,
@@ -797,14 +864,22 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
return
}
- h.r.Writer().Write(w, r, &OAuth2RedirectTo{
- RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {request.Verifier}}).String(),
+ verifier, err := f.ToConsentVerifier(ctx, h.r)
+ if err != nil {
+ h.r.Writer().WriteError(w, r, err)
+ return
+ }
+
+ h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{
+ RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {verifier}}).String(),
})
}
// Accept OAuth 2.0 Logout Request
//
// swagger:parameters acceptOAuth2LogoutRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type acceptOAuth2LogoutRequest struct {
// OAuth 2.0 Logout Request Challenge
//
@@ -829,7 +904,7 @@ type acceptOAuth2LogoutRequest struct {
// Responses:
// 200: oAuth2RedirectTo
// default: errorOAuth2
-func (h *Handler) acceptOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) acceptOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
challenge := stringsx.Coalesce(
r.URL.Query().Get("logout_challenge"),
r.URL.Query().Get("challenge"),
@@ -841,7 +916,7 @@ func (h *Handler) acceptOAuth2LogoutRequest(w http.ResponseWriter, r *http.Reque
return
}
- h.r.Writer().Write(w, r, &OAuth2RedirectTo{
+ h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{
RedirectTo: urlx.SetQuery(urlx.AppendPaths(h.c.PublicURL(r.Context()), "/oauth2/sessions/logout"), url.Values{"logout_verifier": {c.Verifier}}).String(),
})
}
@@ -849,6 +924,8 @@ func (h *Handler) acceptOAuth2LogoutRequest(w http.ResponseWriter, r *http.Reque
// Reject OAuth 2.0 Logout Request
//
// swagger:parameters rejectOAuth2LogoutRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type rejectOAuth2LogoutRequest struct {
// in: query
// required: true
@@ -872,7 +949,7 @@ type rejectOAuth2LogoutRequest struct {
// Responses:
// 204: emptyResponse
// default: errorOAuth2
-func (h *Handler) rejectOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) rejectOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
challenge := stringsx.Coalesce(
r.URL.Query().Get("logout_challenge"),
r.URL.Query().Get("challenge"),
@@ -889,6 +966,8 @@ func (h *Handler) rejectOAuth2LogoutRequest(w http.ResponseWriter, r *http.Reque
// Get OAuth 2.0 Logout Request
//
// swagger:parameters getOAuth2LogoutRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type getOAuth2LogoutRequest struct {
// in: query
// required: true
@@ -910,7 +989,7 @@ type getOAuth2LogoutRequest struct {
// 200: oAuth2LogoutRequest
// 410: oAuth2RedirectTo
// default: errorOAuth2
-func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
challenge := stringsx.Coalesce(
r.URL.Query().Get("logout_challenge"),
r.URL.Query().Get("challenge"),
@@ -928,7 +1007,7 @@ func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request,
}
if request.WasHandled {
- h.r.Writer().WriteCode(w, r, http.StatusGone, &OAuth2RedirectTo{
+ h.r.Writer().WriteCode(w, r, http.StatusGone, &flow.OAuth2RedirectTo{
RedirectTo: request.RequestURL,
})
return
@@ -936,3 +1015,12 @@ func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request,
h.r.Writer().Write(w, r, request)
}
+
+func (h *Handler) flowFromCookie(r *http.Request) (*flow.Flow, error) {
+ clientID := r.URL.Query().Get("client_id")
+ if clientID == "" {
+ return nil, errors.WithStack(fosite.ErrInvalidClient)
+ }
+
+ return flowctx.FromCookie[flow.Flow](r.Context(), r, h.r.FlowCipher(), flowctx.FlowCookie(flowctx.SuffixFromStatic(clientID)))
+}
diff --git a/consent/handler_test.go b/consent/handler_test.go
index 6022674eeb1..47496fa0bf5 100644
--- a/consent/handler_test.go
+++ b/consent/handler_test.go
@@ -13,19 +13,17 @@ import (
"testing"
"time"
- "github.com/ory/x/pointerx"
-
- "github.com/ory/hydra/v2/x"
- "github.com/ory/x/contextx"
- "github.com/ory/x/sqlxx"
-
- "github.com/ory/hydra/v2/internal"
-
"github.com/stretchr/testify/require"
hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/hydra/v2/client"
. "github.com/ory/hydra/v2/consent"
+ "github.com/ory/hydra/v2/flow"
+ "github.com/ory/hydra/v2/internal"
+ "github.com/ory/hydra/v2/x"
+ "github.com/ory/x/contextx"
+ "github.com/ory/x/pointerx"
+ "github.com/ory/x/sqlxx"
)
func TestGetLogoutRequest(t *testing.T) {
@@ -39,6 +37,7 @@ func TestGetLogoutRequest(t *testing.T) {
{true, true, http.StatusGone},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
+ ctx := context.Background()
key := fmt.Sprint(k)
challenge := "challenge" + key
requestURL := "http://192.0.2.1"
@@ -48,8 +47,8 @@ func TestGetLogoutRequest(t *testing.T) {
if tc.exists {
cl := &client.Client{LegacyClientID: "client" + key}
- require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl))
- require.NoError(t, reg.ConsentManager().CreateLogoutRequest(context.TODO(), &LogoutRequest{
+ require.NoError(t, reg.ClientManager().CreateClient(ctx, cl))
+ require.NoError(t, reg.ConsentManager().CreateLogoutRequest(context.TODO(), &flow.LogoutRequest{
Client: cl,
ID: challenge,
WasHandled: tc.handled,
@@ -69,11 +68,11 @@ func TestGetLogoutRequest(t *testing.T) {
require.EqualValues(t, tc.status, resp.StatusCode)
if tc.handled {
- var result OAuth2RedirectTo
+ var result flow.OAuth2RedirectTo
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, requestURL, result.RedirectTo)
} else if tc.exists {
- var result LogoutRequest
+ var result flow.LogoutRequest
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, challenge, result.ID)
require.Equal(t, requestURL, result.RequestURL)
@@ -92,7 +91,8 @@ func TestGetLoginRequest(t *testing.T) {
{true, false, http.StatusOK},
{true, true, http.StatusGone},
} {
- t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
+ t.Run(fmt.Sprintf("exists=%v/handled=%v", tc.exists, tc.handled), func(t *testing.T) {
+ ctx := context.Background()
key := fmt.Sprint(k)
challenge := "challenge" + key
requestURL := "http://192.0.2.1"
@@ -103,14 +103,20 @@ func TestGetLoginRequest(t *testing.T) {
if tc.exists {
cl := &client.Client{LegacyClientID: "client" + key}
require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl))
- require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), &LoginRequest{
- Client: cl,
- ID: challenge,
- RequestURL: requestURL,
- }))
+ f, err := reg.ConsentManager().CreateLoginRequest(context.Background(), &flow.LoginRequest{
+ Client: cl,
+ ID: challenge,
+ RequestURL: requestURL,
+ RequestedAt: time.Now(),
+ })
+ require.NoError(t, err)
+ challenge, err = f.ToLoginChallenge(ctx, reg)
+ require.NoError(t, err)
if tc.handled {
- _, err := reg.ConsentManager().HandleLoginRequest(context.Background(), challenge, &HandledLoginRequest{ID: challenge, WasHandled: true})
+ _, err := reg.ConsentManager().HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{ID: challenge, WasHandled: true})
+ require.NoError(t, err)
+ challenge, err = f.ToLoginChallenge(ctx, reg)
require.NoError(t, err)
}
}
@@ -127,11 +133,11 @@ func TestGetLoginRequest(t *testing.T) {
require.EqualValues(t, tc.status, resp.StatusCode)
if tc.handled {
- var result OAuth2RedirectTo
+ var result flow.OAuth2RedirectTo
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, requestURL, result.RedirectTo)
} else if tc.exists {
- var result LoginRequest
+ var result flow.LoginRequest
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, challenge, result.ID)
require.Equal(t, requestURL, result.RequestURL)
@@ -152,6 +158,7 @@ func TestGetConsentRequest(t *testing.T) {
{true, true, http.StatusGone},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
+ ctx := context.Background()
key := fmt.Sprint(k)
challenge := "challenge" + key
requestURL := "http://192.0.2.1"
@@ -161,14 +168,24 @@ func TestGetConsentRequest(t *testing.T) {
if tc.exists {
cl := &client.Client{LegacyClientID: "client" + key}
- require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl))
- lr := &LoginRequest{ID: "login-" + challenge, Client: cl, RequestURL: requestURL}
- require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), lr))
- _, err := reg.ConsentManager().HandleLoginRequest(context.Background(), lr.ID, &HandledLoginRequest{
- ID: lr.ID,
+ require.NoError(t, reg.ClientManager().CreateClient(ctx, cl))
+ lr := &flow.LoginRequest{
+ ID: "login-" + challenge,
+ Client: cl,
+ RequestURL: requestURL,
+ RequestedAt: time.Now(),
+ }
+ f, err := reg.ConsentManager().CreateLoginRequest(ctx, lr)
+ require.NoError(t, err)
+ challenge, err = f.ToLoginChallenge(ctx, reg)
+ require.NoError(t, err)
+ _, err = reg.ConsentManager().HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{
+ ID: challenge,
})
require.NoError(t, err)
- require.NoError(t, reg.ConsentManager().CreateConsentRequest(context.Background(), &OAuth2ConsentRequest{
+ challenge, err = f.ToConsentChallenge(ctx, reg)
+ require.NoError(t, err)
+ require.NoError(t, reg.ConsentManager().CreateConsentRequest(ctx, f, &flow.OAuth2ConsentRequest{
Client: cl,
ID: challenge,
Verifier: challenge,
@@ -177,12 +194,14 @@ func TestGetConsentRequest(t *testing.T) {
}))
if tc.handled {
- _, err := reg.ConsentManager().HandleConsentRequest(context.Background(), &AcceptOAuth2ConsentRequest{
+ _, err := reg.ConsentManager().HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{
ID: challenge,
WasHandled: true,
HandledAt: sqlxx.NullTime(time.Now()),
})
require.NoError(t, err)
+ challenge, err = f.ToConsentChallenge(ctx, reg)
+ require.NoError(t, err)
}
}
@@ -199,11 +218,11 @@ func TestGetConsentRequest(t *testing.T) {
require.EqualValues(t, tc.status, resp.StatusCode)
if tc.handled {
- var result OAuth2RedirectTo
+ var result flow.OAuth2RedirectTo
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, requestURL, result.RedirectTo)
} else if tc.exists {
- var result OAuth2ConsentRequest
+ var result flow.OAuth2ConsentRequest
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, challenge, result.ID)
require.Equal(t, requestURL, result.RequestURL)
@@ -215,6 +234,7 @@ func TestGetConsentRequest(t *testing.T) {
func TestGetLoginRequestWithDuplicateAccept(t *testing.T) {
t.Run("Test get login request with duplicate accept", func(t *testing.T) {
+ ctx := context.Background()
challenge := "challenge"
requestURL := "http://192.0.2.1"
@@ -222,12 +242,16 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) {
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
cl := &client.Client{LegacyClientID: "client"}
- require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl))
- require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), &LoginRequest{
- Client: cl,
- ID: challenge,
- RequestURL: requestURL,
- }))
+ require.NoError(t, reg.ClientManager().CreateClient(ctx, cl))
+ f, err := reg.ConsentManager().CreateLoginRequest(ctx, &flow.LoginRequest{
+ Client: cl,
+ ID: challenge,
+ RequestURL: requestURL,
+ RequestedAt: time.Now(),
+ })
+ require.NoError(t, err)
+ challenge, err = f.ToLoginChallenge(ctx, reg)
+ require.NoError(t, err)
h := NewHandler(reg, conf)
r := x.NewRouterAdmin(conf.AdminURL)
@@ -238,7 +262,7 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) {
c := &http.Client{}
sub := "sub123"
- acceptLogin := &hydra.AcceptOAuth2LoginRequest{Remember: pointerx.Bool(true), Subject: sub}
+ acceptLogin := &hydra.AcceptOAuth2LoginRequest{Remember: pointerx.Ptr(true), Subject: sub}
// marshal User to json
acceptLoginJson, err := json.Marshal(acceptLogin)
@@ -256,7 +280,7 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, resp.StatusCode)
- var result OAuth2RedirectTo
+ var result flow.OAuth2RedirectTo
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.NotNil(t, result.RedirectTo)
require.Contains(t, result.RedirectTo, "login_verifier")
@@ -270,7 +294,7 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, http.StatusOK, resp2.StatusCode)
- var result2 OAuth2RedirectTo
+ var result2 flow.OAuth2RedirectTo
require.NoError(t, json.NewDecoder(resp2.Body).Decode(&result2))
require.NotNil(t, result2.RedirectTo)
require.Contains(t, result2.RedirectTo, "login_verifier")
diff --git a/consent/helper.go b/consent/helper.go
index ed15dd03147..bf6e46b2765 100644
--- a/consent/helper.go
+++ b/consent/helper.go
@@ -6,9 +6,9 @@ package consent
import (
"net/http"
"strings"
-
"time"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/errorsx"
@@ -33,7 +33,7 @@ func sanitizeClient(c *client.Client) *client.Client {
return cc
}
-func matchScopes(scopeStrategy fosite.ScopeStrategy, previousConsent []AcceptOAuth2ConsentRequest, requestedScope []string) *AcceptOAuth2ConsentRequest {
+func matchScopes(scopeStrategy fosite.ScopeStrategy, previousConsent []flow.AcceptOAuth2ConsentRequest, requestedScope []string) *flow.AcceptOAuth2ConsentRequest {
for _, cs := range previousConsent {
var found = true
for _, scope := range requestedScope {
diff --git a/consent/helper_test.go b/consent/helper_test.go
index c350ee2f63b..a5f09e81cdd 100644
--- a/consent/helper_test.go
+++ b/consent/helper_test.go
@@ -12,6 +12,7 @@ import (
"github.com/golang/mock/gomock"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/internal/mock"
"github.com/gorilla/securecookie"
@@ -38,22 +39,22 @@ func TestSanitizeClient(t *testing.T) {
func TestMatchScopes(t *testing.T) {
for k, tc := range []struct {
- granted []AcceptOAuth2ConsentRequest
+ granted []flow.AcceptOAuth2ConsentRequest
requested []string
expectChallenge string
}{
{
- granted: []AcceptOAuth2ConsentRequest{{ID: "1", GrantedScope: []string{"foo", "bar"}}},
+ granted: []flow.AcceptOAuth2ConsentRequest{{ID: "1", GrantedScope: []string{"foo", "bar"}}},
requested: []string{"foo", "bar"},
expectChallenge: "1",
},
{
- granted: []AcceptOAuth2ConsentRequest{{ID: "1", GrantedScope: []string{"foo", "bar"}}},
+ granted: []flow.AcceptOAuth2ConsentRequest{{ID: "1", GrantedScope: []string{"foo", "bar"}}},
requested: []string{"foo", "bar", "baz"},
expectChallenge: "",
},
{
- granted: []AcceptOAuth2ConsentRequest{
+ granted: []flow.AcceptOAuth2ConsentRequest{
{ID: "1", GrantedScope: []string{"foo", "bar"}},
{ID: "2", GrantedScope: []string{"foo", "bar"}},
},
@@ -61,7 +62,7 @@ func TestMatchScopes(t *testing.T) {
expectChallenge: "1",
},
{
- granted: []AcceptOAuth2ConsentRequest{
+ granted: []flow.AcceptOAuth2ConsentRequest{
{ID: "1", GrantedScope: []string{"foo", "bar"}},
{ID: "2", GrantedScope: []string{"foo", "bar", "baz"}},
},
@@ -69,7 +70,7 @@ func TestMatchScopes(t *testing.T) {
expectChallenge: "2",
},
{
- granted: []AcceptOAuth2ConsentRequest{
+ granted: []flow.AcceptOAuth2ConsentRequest{
{ID: "1", GrantedScope: []string{"foo", "bar"}},
{ID: "2", GrantedScope: []string{"foo", "bar", "baz"}},
},
diff --git a/consent/janitor_consent_test_helper.go b/consent/janitor_consent_test_helper.go
index 6467eb1a63d..645a88a2209 100644
--- a/consent/janitor_consent_test_helper.go
+++ b/consent/janitor_consent_test_helper.go
@@ -6,23 +6,24 @@ package consent
import (
"time"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/x/sqlxx"
)
-func NewHandledLoginRequest(challenge string, hasError bool, requestedAt time.Time, authenticatedAt sqlxx.NullTime) *HandledLoginRequest {
- var deniedErr *RequestDeniedError
+func NewHandledLoginRequest(challenge string, hasError bool, requestedAt time.Time, authenticatedAt sqlxx.NullTime) *flow.HandledLoginRequest {
+ var deniedErr *flow.RequestDeniedError
if hasError {
- deniedErr = &RequestDeniedError{
+ deniedErr = &flow.RequestDeniedError{
Name: "consent request denied",
Description: "some description",
Hint: "some hint",
Code: 403,
Debug: "some debug",
- valid: true,
+ Valid: true,
}
}
- return &HandledLoginRequest{
+ return &flow.HandledLoginRequest{
ID: challenge,
Error: deniedErr,
WasHandled: true,
@@ -31,20 +32,20 @@ func NewHandledLoginRequest(challenge string, hasError bool, requestedAt time.Ti
}
}
-func NewHandledConsentRequest(challenge string, hasError bool, requestedAt time.Time, authenticatedAt sqlxx.NullTime) *AcceptOAuth2ConsentRequest {
- var deniedErr *RequestDeniedError
+func NewHandledConsentRequest(challenge string, hasError bool, requestedAt time.Time, authenticatedAt sqlxx.NullTime) *flow.AcceptOAuth2ConsentRequest {
+ var deniedErr *flow.RequestDeniedError
if hasError {
- deniedErr = &RequestDeniedError{
+ deniedErr = &flow.RequestDeniedError{
Name: "consent request denied",
Description: "some description",
Hint: "some hint",
Code: 403,
Debug: "some debug",
- valid: true,
+ Valid: true,
}
}
- return &AcceptOAuth2ConsentRequest{
+ return &flow.AcceptOAuth2ConsentRequest{
ID: challenge,
HandledAt: sqlxx.NullTime(time.Now().Round(time.Second)),
Error: deniedErr,
diff --git a/consent/manager.go b/consent/manager.go
index 2910bcc9e40..69b62ed8b9e 100644
--- a/consent/manager.go
+++ b/consent/manager.go
@@ -10,6 +10,7 @@ import (
"github.com/gofrs/uuid"
"github.com/ory/hydra/v2/client"
+ "github.com/ory/hydra/v2/flow"
)
type ForcedObfuscatedLoginSession struct {
@@ -19,44 +20,50 @@ type ForcedObfuscatedLoginSession struct {
NID uuid.UUID `db:"nid"`
}
-func (_ ForcedObfuscatedLoginSession) TableName() string {
+func (ForcedObfuscatedLoginSession) TableName() string {
return "hydra_oauth2_obfuscated_authentication_session"
}
-type Manager interface {
- CreateConsentRequest(ctx context.Context, req *OAuth2ConsentRequest) error
- GetConsentRequest(ctx context.Context, challenge string) (*OAuth2ConsentRequest, error)
- HandleConsentRequest(ctx context.Context, r *AcceptOAuth2ConsentRequest) (*OAuth2ConsentRequest, error)
- RevokeSubjectConsentSession(ctx context.Context, user string) error
- RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error
-
- VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*AcceptOAuth2ConsentRequest, error)
- FindGrantedAndRememberedConsentRequests(ctx context.Context, client, user string) ([]AcceptOAuth2ConsentRequest, error)
- FindSubjectsGrantedConsentRequests(ctx context.Context, user string, limit, offset int) ([]AcceptOAuth2ConsentRequest, error)
- FindSubjectsSessionGrantedConsentRequests(ctx context.Context, user, sid string, limit, offset int) ([]AcceptOAuth2ConsentRequest, error)
- CountSubjectsGrantedConsentRequests(ctx context.Context, user string) (int, error)
-
- // Cookie management
- GetRememberedLoginSession(ctx context.Context, id string) (*LoginSession, error)
- CreateLoginSession(ctx context.Context, session *LoginSession) error
- DeleteLoginSession(ctx context.Context, id string) error
- RevokeSubjectLoginSession(ctx context.Context, user string) error
- ConfirmLoginSession(ctx context.Context, id string, authTime time.Time, subject string, remember bool) error
-
- CreateLoginRequest(ctx context.Context, req *LoginRequest) error
- GetLoginRequest(ctx context.Context, challenge string) (*LoginRequest, error)
- HandleLoginRequest(ctx context.Context, challenge string, r *HandledLoginRequest) (*LoginRequest, error)
- VerifyAndInvalidateLoginRequest(ctx context.Context, verifier string) (*HandledLoginRequest, error)
-
- CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
- GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)
-
- ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
- ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
-
- CreateLogoutRequest(ctx context.Context, request *LogoutRequest) error
- GetLogoutRequest(ctx context.Context, challenge string) (*LogoutRequest, error)
- AcceptLogoutRequest(ctx context.Context, challenge string) (*LogoutRequest, error)
- RejectLogoutRequest(ctx context.Context, challenge string) error
- VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*LogoutRequest, error)
-}
+type (
+ Manager interface {
+ CreateConsentRequest(ctx context.Context, f *flow.Flow, req *flow.OAuth2ConsentRequest) error
+ GetConsentRequest(ctx context.Context, challenge string) (*flow.OAuth2ConsentRequest, error)
+ HandleConsentRequest(ctx context.Context, f *flow.Flow, r *flow.AcceptOAuth2ConsentRequest) (*flow.OAuth2ConsentRequest, error)
+ RevokeSubjectConsentSession(ctx context.Context, user string) error
+ RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error
+
+ VerifyAndInvalidateConsentRequest(ctx context.Context, f *flow.Flow, verifier string) (*flow.AcceptOAuth2ConsentRequest, error)
+ FindGrantedAndRememberedConsentRequests(ctx context.Context, client, user string) ([]flow.AcceptOAuth2ConsentRequest, error)
+ FindSubjectsGrantedConsentRequests(ctx context.Context, user string, limit, offset int) ([]flow.AcceptOAuth2ConsentRequest, error)
+ FindSubjectsSessionGrantedConsentRequests(ctx context.Context, user, sid string, limit, offset int) ([]flow.AcceptOAuth2ConsentRequest, error)
+ CountSubjectsGrantedConsentRequests(ctx context.Context, user string) (int, error)
+
+ // Cookie management
+ GetRememberedLoginSession(ctx context.Context, loginSessionFromCookie *flow.LoginSession, id string) (*flow.LoginSession, error)
+ CreateLoginSession(ctx context.Context, session *flow.LoginSession) error
+ DeleteLoginSession(ctx context.Context, id string) error
+ RevokeSubjectLoginSession(ctx context.Context, user string) error
+ ConfirmLoginSession(ctx context.Context, session *flow.LoginSession, id string, authTime time.Time, subject string, remember bool) error
+
+ CreateLoginRequest(ctx context.Context, req *flow.LoginRequest) (*flow.Flow, error)
+ GetLoginRequest(ctx context.Context, challenge string) (*flow.LoginRequest, error)
+ HandleLoginRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledLoginRequest) (*flow.LoginRequest, error)
+ VerifyAndInvalidateLoginRequest(ctx context.Context, f *flow.Flow, verifier string) (*flow.HandledLoginRequest, error)
+
+ CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
+ GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)
+
+ ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
+ ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
+
+ CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error
+ GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error)
+ AcceptLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error)
+ RejectLogoutRequest(ctx context.Context, challenge string) error
+ VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*flow.LogoutRequest, error)
+ }
+
+ ManagerProvider interface {
+ ConsentManager() Manager
+ }
+)
diff --git a/consent/manager_test_helpers.go b/consent/manager_test_helpers.go
index 084b9d4c4a4..2d84bf071d5 100644
--- a/consent/manager_test_helpers.go
+++ b/consent/manager_test_helpers.go
@@ -10,7 +10,10 @@ import (
"testing"
"time"
+ "github.com/ory/hydra/v2/aead"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/x/assertx"
+ "github.com/ory/x/contextx"
gofrsuuid "github.com/gofrs/uuid"
"github.com/google/uuid"
@@ -25,14 +28,14 @@ import (
"github.com/ory/hydra/v2/x"
)
-func MockConsentRequest(key string, remember bool, rememberFor int, hasError bool, skip bool, authAt bool, loginChallengeBase string, network string) (c *OAuth2ConsentRequest, h *AcceptOAuth2ConsentRequest) {
- c = &OAuth2ConsentRequest{
+func MockConsentRequest(key string, remember bool, rememberFor int, hasError bool, skip bool, authAt bool, loginChallengeBase string, network string) (c *flow.OAuth2ConsentRequest, h *flow.AcceptOAuth2ConsentRequest, f *flow.Flow) {
+ c = &flow.OAuth2ConsentRequest{
ID: makeID("challenge", network, key),
RequestedScope: []string{"scopea" + key, "scopeb" + key},
RequestedAudience: []string{"auda" + key, "audb" + key},
Skip: skip,
Subject: "subject" + key,
- OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{
+ OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{
ACRValues: []string{"1" + key, "2" + key},
UILocales: []string{"fr" + key, "de" + key},
Display: "popup" + key,
@@ -46,19 +49,37 @@ func MockConsentRequest(key string, remember bool, rememberFor int, hasError boo
CSRF: "csrf" + key,
ACR: "1",
AuthenticatedAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Hour)),
- RequestedAt: time.Now().UTC().Add(-time.Hour),
+ RequestedAt: time.Now().UTC(),
Context: sqlxx.JSONRawMessage(`{"foo": "bar` + key + `"}`),
}
- var err *RequestDeniedError
+ f = &flow.Flow{
+ ID: c.LoginChallenge.String(),
+ LoginVerifier: makeID("login-verifier", network, key),
+ SessionID: c.LoginSessionID,
+ Client: c.Client,
+ State: flow.FlowStateConsentInitialized,
+ ConsentChallengeID: sqlxx.NullString(c.ID),
+ ConsentSkip: c.Skip,
+ ConsentVerifier: sqlxx.NullString(c.Verifier),
+ ConsentCSRF: sqlxx.NullString(c.CSRF),
+ OpenIDConnectContext: c.OpenIDConnectContext,
+ Subject: c.Subject,
+ RequestedScope: c.RequestedScope,
+ RequestedAudience: c.RequestedAudience,
+ RequestURL: c.RequestURL,
+ RequestedAt: c.RequestedAt,
+ }
+
+ var err *flow.RequestDeniedError
if hasError {
- err = &RequestDeniedError{
+ err = &flow.RequestDeniedError{
Name: "error_name" + key,
Description: "error_description" + key,
Hint: "error_hint,omitempty" + key,
Code: 100,
Debug: "error_debug,omitempty" + key,
- valid: true,
+ Valid: true,
}
}
@@ -67,7 +88,7 @@ func MockConsentRequest(key string, remember bool, rememberFor int, hasError boo
authenticatedAt = sqlxx.NullTime(time.Now().UTC().Add(-time.Minute))
}
- h = &AcceptOAuth2ConsentRequest{
+ h = &flow.AcceptOAuth2ConsentRequest{
ConsentRequest: c,
RememberFor: rememberFor,
Remember: remember,
@@ -81,17 +102,17 @@ func MockConsentRequest(key string, remember bool, rememberFor int, hasError boo
// WasUsed: true,
}
- return c, h
+ return c, h, f
}
-func MockLogoutRequest(key string, withClient bool, network string) (c *LogoutRequest) {
+func MockLogoutRequest(key string, withClient bool, network string) (c *flow.LogoutRequest) {
var cl *client.Client
if withClient {
cl = &client.Client{
LegacyClientID: "fk-client-" + key,
}
}
- return &LogoutRequest{
+ return &flow.LogoutRequest{
Subject: "subject" + key,
ID: makeID("challenge", network, key),
Verifier: makeID("verifier", network, key),
@@ -105,9 +126,9 @@ func MockLogoutRequest(key string, withClient bool, network string) (c *LogoutRe
}
}
-func MockAuthRequest(key string, authAt bool, network string) (c *LoginRequest, h *HandledLoginRequest) {
- c = &LoginRequest{
- OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{
+func MockAuthRequest(key string, authAt bool, network string) (c *flow.LoginRequest, h *flow.HandledLoginRequest, f *flow.Flow) {
+ c = &flow.LoginRequest{
+ OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{
ACRValues: []string{"1" + key, "2" + key},
UILocales: []string{"fr" + key, "de" + key},
Display: "popup" + key,
@@ -124,13 +145,15 @@ func MockAuthRequest(key string, authAt bool, network string) (c *LoginRequest,
SessionID: sqlxx.NullString(makeID("fk-login-session", network, key)),
}
- var err = &RequestDeniedError{
+ f = flow.NewFlow(c)
+
+ var err = &flow.RequestDeniedError{
Name: "error_name" + key,
Description: "error_description" + key,
Hint: "error_hint,omitempty" + key,
Code: 100,
Debug: "error_debug,omitempty" + key,
- valid: true,
+ Valid: true,
}
var authenticatedAt time.Time
@@ -138,7 +161,7 @@ func MockAuthRequest(key string, authAt bool, network string) (c *LoginRequest,
authenticatedAt = time.Now().UTC().Add(-time.Minute)
}
- h = &HandledLoginRequest{
+ h = &flow.HandledLoginRequest{
LoginRequest: c,
RememberFor: 120,
Remember: true,
@@ -152,23 +175,23 @@ func MockAuthRequest(key string, authAt bool, network string) (c *LoginRequest,
WasHandled: false,
}
- return c, h
+ return c, h, f
}
-func SaneMockHandleConsentRequest(t *testing.T, m Manager, c *OAuth2ConsentRequest, authAt time.Time, rememberFor int, remember bool, hasError bool) *AcceptOAuth2ConsentRequest {
- var rde *RequestDeniedError
+func SaneMockHandleConsentRequest(t *testing.T, m Manager, f *flow.Flow, c *flow.OAuth2ConsentRequest, authAt time.Time, rememberFor int, remember bool, hasError bool) *flow.AcceptOAuth2ConsentRequest {
+ var rde *flow.RequestDeniedError
if hasError {
- rde = &RequestDeniedError{
+ rde = &flow.RequestDeniedError{
Name: "error_name",
Description: "error_description",
Hint: "error_hint",
Code: 100,
Debug: "error_debug",
- valid: true,
+ Valid: true,
}
}
- h := &AcceptOAuth2ConsentRequest{
+ h := &flow.AcceptOAuth2ConsentRequest{
ConsentRequest: c,
RememberFor: rememberFor,
Remember: remember,
@@ -182,27 +205,28 @@ func SaneMockHandleConsentRequest(t *testing.T, m Manager, c *OAuth2ConsentReque
HandledAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Minute)),
}
- _, err := m.HandleConsentRequest(context.Background(), h)
+ _, err := m.HandleConsentRequest(context.Background(), f, h)
require.NoError(t, err)
+
return h
}
// SaneMockConsentRequest does the same thing as MockConsentRequest but uses less insanity and implicit dependencies.
-func SaneMockConsentRequest(t *testing.T, m Manager, ar *LoginRequest, skip bool) (c *OAuth2ConsentRequest) {
- c = &OAuth2ConsentRequest{
+func SaneMockConsentRequest(t *testing.T, m Manager, f *flow.Flow, skip bool) (c *flow.OAuth2ConsentRequest) {
+ c = &flow.OAuth2ConsentRequest{
RequestedScope: []string{"scopea", "scopeb"},
RequestedAudience: []string{"auda", "audb"},
Skip: skip,
- Subject: ar.Subject,
- OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{
+ Subject: f.Subject,
+ OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{
ACRValues: []string{"1", "2"},
UILocales: []string{"fr", "de"},
Display: "popup",
},
- Client: ar.Client,
+ Client: f.Client,
RequestURL: "https://request-url/path",
- LoginChallenge: sqlxx.NullString(ar.ID),
- LoginSessionID: ar.SessionID,
+ LoginChallenge: sqlxx.NullString(f.ID),
+ LoginSessionID: f.SessionID,
ForceSubjectIdentifier: "forced-subject",
ACR: "1",
AuthenticatedAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Hour)),
@@ -214,14 +238,15 @@ func SaneMockConsentRequest(t *testing.T, m Manager, ar *LoginRequest, skip bool
CSRF: uuid.New().String(),
}
- require.NoError(t, m.CreateConsentRequest(context.Background(), c))
+ require.NoError(t, m.CreateConsentRequest(context.Background(), f, c))
+
return c
}
// SaneMockAuthRequest does the same thing as MockAuthRequest but uses less insanity and implicit dependencies.
-func SaneMockAuthRequest(t *testing.T, m Manager, ls *LoginSession, cl *client.Client) (c *LoginRequest) {
- c = &LoginRequest{
- OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{
+func SaneMockAuthRequest(t *testing.T, m Manager, ls *flow.LoginSession, cl *client.Client) (c *flow.LoginRequest) {
+ c = &flow.LoginRequest{
+ OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{
ACRValues: []string{"1", "2"},
UILocales: []string{"fr", "de"},
Display: "popup",
@@ -238,7 +263,8 @@ func SaneMockAuthRequest(t *testing.T, m Manager, ls *LoginSession, cl *client.C
ID: uuid.New().String(),
Verifier: uuid.New().String(),
}
- require.NoError(t, m.CreateLoginRequest(context.Background(), c))
+ _, err := m.CreateLoginRequest(context.Background(), c)
+ require.NoError(t, err)
return c
}
@@ -246,20 +272,23 @@ func makeID(base string, network string, key string) string {
return fmt.Sprintf("%s-%s-%s", base, network, key)
}
-func TestHelperNID(t1ClientManager client.Manager, t1ValidNID Manager, t2InvalidNID Manager) func(t *testing.T) {
+func TestHelperNID(r interface {
+ client.ManagerProvider
+ FlowCipher() *aead.XChaCha20Poly1305
+}, t1ValidNID Manager, t2InvalidNID Manager) func(t *testing.T) {
testClient := client.Client{LegacyClientID: "2022-03-11-client-nid-test-1"}
- testLS := LoginSession{
+ testLS := flow.LoginSession{
ID: "2022-03-11-ls-nid-test-1",
Subject: "2022-03-11-test-1-sub",
}
- testLR := LoginRequest{
+ testLR := flow.LoginRequest{
ID: "2022-03-11-lr-nid-test-1",
Subject: "2022-03-11-test-1-sub",
Verifier: "2022-03-11-test-1-ver",
RequestedAt: time.Now(),
Client: &client.Client{LegacyClientID: "2022-03-11-client-nid-test-1"},
}
- testHLR := HandledLoginRequest{
+ testHLR := flow.HandledLoginRequest{
LoginRequest: &testLR,
RememberFor: 120,
Remember: true,
@@ -274,44 +303,58 @@ func TestHelperNID(t1ClientManager client.Manager, t1ValidNID Manager, t2Invalid
}
return func(t *testing.T) {
- require.NoError(t, t1ClientManager.CreateClient(context.Background(), &testClient))
- require.Error(t, t2InvalidNID.CreateLoginSession(context.Background(), &testLS))
- require.NoError(t, t1ValidNID.CreateLoginSession(context.Background(), &testLS))
- require.Error(t, t2InvalidNID.CreateLoginRequest(context.Background(), &testLR))
- require.NoError(t, t1ValidNID.CreateLoginRequest(context.Background(), &testLR))
- _, err := t2InvalidNID.GetLoginRequest(context.Background(), testLR.ID)
+ ctx := context.Background()
+ require.NoError(t, r.ClientManager().CreateClient(ctx, &testClient))
+ require.Error(t, t2InvalidNID.CreateLoginSession(ctx, &testLS))
+ require.NoError(t, t1ValidNID.CreateLoginSession(ctx, &testLS))
+
+ _, err := t2InvalidNID.CreateLoginRequest(ctx, &testLR)
+ require.Error(t, err)
+ f, err := t1ValidNID.CreateLoginRequest(ctx, &testLR)
+ require.NoError(t, err)
+
+ testLR.ID = x.Must(f.ToLoginChallenge(ctx, r))
+ _, err = t2InvalidNID.GetLoginRequest(ctx, testLR.ID)
require.Error(t, err)
- _, err = t1ValidNID.GetLoginRequest(context.Background(), testLR.ID)
+ _, err = t1ValidNID.GetLoginRequest(ctx, testLR.ID)
require.NoError(t, err)
- _, err = t2InvalidNID.HandleLoginRequest(context.Background(), testLR.ID, &testHLR)
+ _, err = t2InvalidNID.HandleLoginRequest(ctx, f, testLR.ID, &testHLR)
require.Error(t, err)
- _, err = t1ValidNID.HandleLoginRequest(context.Background(), testLR.ID, &testHLR)
+ _, err = t1ValidNID.HandleLoginRequest(ctx, f, testLR.ID, &testHLR)
require.NoError(t, err)
- require.NoError(t, t2InvalidNID.ConfirmLoginSession(context.Background(), testLS.ID, time.Now(), testLS.Subject, true))
- require.NoError(t, t1ValidNID.ConfirmLoginSession(context.Background(), testLS.ID, time.Now(), testLS.Subject, true))
- require.Error(t, t2InvalidNID.DeleteLoginSession(context.Background(), testLS.ID))
- require.NoError(t, t1ValidNID.DeleteLoginSession(context.Background(), testLS.ID))
+ require.Error(t, t2InvalidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true))
+ require.NoError(t, t1ValidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true))
+ require.Error(t, t2InvalidNID.DeleteLoginSession(ctx, testLS.ID))
+ require.NoError(t, t1ValidNID.DeleteLoginSession(ctx, testLS.ID))
}
}
-func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.FositeStorer, network string, parallel bool) func(t *testing.T) {
- lr := make(map[string]*LoginRequest)
+type Deps interface {
+ FlowCipher() *aead.XChaCha20Poly1305
+ contextx.Provider
+}
+
+func ManagerTests(deps Deps, m Manager, clientManager client.Manager, fositeManager x.FositeStorer, network string, parallel bool) func(t *testing.T) {
+ lr := make(map[string]*flow.LoginRequest)
return func(t *testing.T) {
if parallel {
t.Parallel()
}
+ ctx := context.Background()
t.Run("case=init-fks", func(t *testing.T) {
for _, k := range []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "rv1", "rv2"} {
- require.NoError(t, clientManager.CreateClient(context.Background(), &client.Client{LegacyClientID: fmt.Sprintf("fk-client-%s", k)}))
+ require.NoError(t, clientManager.CreateClient(ctx, &client.Client{LegacyClientID: fmt.Sprintf("fk-client-%s", k)}))
- require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
+ loginSession := &flow.LoginSession{
ID: makeID("fk-login-session", network, k),
AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Second).UTC()),
Subject: fmt.Sprintf("subject-%s", k),
- }))
+ }
+ require.NoError(t, m.CreateLoginSession(ctx, loginSession))
+ require.NoError(t, m.ConfirmLoginSession(ctx, loginSession, loginSession.ID, time.Now().Round(time.Second).UTC(), loginSession.Subject, true))
- lr[k] = &LoginRequest{
+ lr[k] = &flow.LoginRequest{
ID: makeID("fk-login-challenge", network, k),
Subject: fmt.Sprintf("subject%s", k),
SessionID: sqlxx.NullString(makeID("fk-login-session", network, k)),
@@ -321,23 +364,24 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
RequestedAt: time.Now(),
}
- require.NoError(t, m.CreateLoginRequest(context.Background(), lr[k]))
+ _, err := m.CreateLoginRequest(ctx, lr[k])
+ require.NoError(t, err)
}
})
t.Run("case=auth-session", func(t *testing.T) {
for _, tc := range []struct {
- s LoginSession
+ s flow.LoginSession
}{
{
- s: LoginSession{
+ s: flow.LoginSession{
ID: makeID("session", network, "1"),
AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Second).Add(-time.Minute).UTC()),
Subject: "subject1",
},
},
{
- s: LoginSession{
+ s: flow.LoginSession{
ID: makeID("session", network, "2"),
AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Minute).Add(-time.Minute).UTC()),
Subject: "subject2",
@@ -345,19 +389,19 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
},
} {
t.Run("case=create-get-"+tc.s.ID, func(t *testing.T) {
- _, err := m.GetRememberedLoginSession(context.Background(), tc.s.ID)
+ _, err := m.GetRememberedLoginSession(ctx, &tc.s, tc.s.ID)
require.EqualError(t, err, x.ErrNotFound.Error(), "%#v", err)
- err = m.CreateLoginSession(context.Background(), &tc.s)
+ err = m.CreateLoginSession(ctx, &tc.s)
require.NoError(t, err)
- _, err = m.GetRememberedLoginSession(context.Background(), tc.s.ID)
+ _, err = m.GetRememberedLoginSession(ctx, &tc.s, tc.s.ID)
require.EqualError(t, err, x.ErrNotFound.Error())
updatedAuth := time.Time(tc.s.AuthenticatedAt).Add(time.Second)
- require.NoError(t, m.ConfirmLoginSession(context.Background(), tc.s.ID, updatedAuth, tc.s.Subject, true))
+ require.NoError(t, m.ConfirmLoginSession(ctx, &tc.s, tc.s.ID, updatedAuth, tc.s.Subject, true))
- got, err := m.GetRememberedLoginSession(context.Background(), tc.s.ID)
+ got, err := m.GetRememberedLoginSession(ctx, nil, tc.s.ID)
require.NoError(t, err)
assert.EqualValues(t, tc.s.ID, got.ID)
assert.Equal(t, updatedAuth.Unix(), time.Time(got.AuthenticatedAt).Unix()) // this was updated from confirm...
@@ -365,9 +409,9 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
time.Sleep(time.Second) // Make sure AuthAt does not equal...
updatedAuth2 := time.Now().Truncate(time.Second).UTC()
- require.NoError(t, m.ConfirmLoginSession(context.Background(), tc.s.ID, updatedAuth2, "some-other-subject", true))
+ require.NoError(t, m.ConfirmLoginSession(ctx, nil, tc.s.ID, updatedAuth2, "some-other-subject", true))
- got2, err := m.GetRememberedLoginSession(context.Background(), tc.s.ID)
+ got2, err := m.GetRememberedLoginSession(ctx, nil, tc.s.ID)
require.NoError(t, err)
assert.EqualValues(t, tc.s.ID, got2.ID)
assert.Equal(t, updatedAuth2.Unix(), time.Time(got2.AuthenticatedAt).Unix()) // this was updated from confirm...
@@ -385,10 +429,10 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
},
} {
t.Run("case=delete-get-"+tc.id, func(t *testing.T) {
- err := m.DeleteLoginSession(context.Background(), tc.id)
+ err := m.DeleteLoginSession(ctx, tc.id)
require.NoError(t, err)
- _, err = m.GetRememberedLoginSession(context.Background(), tc.id)
+ _, err = m.GetRememberedLoginSession(ctx, nil, tc.id)
require.Error(t, err)
})
}
@@ -408,32 +452,38 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
{"7", true},
} {
t.Run("key="+tc.key, func(t *testing.T) {
- c, h := MockAuthRequest(tc.key, tc.authAt, network)
- _ = clientManager.CreateClient(context.Background(), c.Client) // Ignore errors that are caused by duplication
+ c, h, f := MockAuthRequest(tc.key, tc.authAt, network)
+ _ = clientManager.CreateClient(ctx, c.Client) // Ignore errors that are caused by duplication
+ loginChallenge := x.Must(f.ToLoginChallenge(ctx, deps))
- _, err := m.GetLoginRequest(context.Background(), makeID("challenge", network, tc.key))
+ _, err := m.GetLoginRequest(ctx, loginChallenge)
require.Error(t, err)
- require.NoError(t, m.CreateLoginRequest(context.Background(), c))
+ f, err = m.CreateLoginRequest(ctx, c)
+ require.NoError(t, err)
+
+ loginChallenge = x.Must(f.ToLoginChallenge(ctx, deps))
- got1, err := m.GetLoginRequest(context.Background(), makeID("challenge", network, tc.key))
+ got1, err := m.GetLoginRequest(ctx, loginChallenge)
require.NoError(t, err)
assert.False(t, got1.WasHandled)
compareAuthenticationRequest(t, c, got1)
- got1, err = m.HandleLoginRequest(context.Background(), makeID("challenge", network, tc.key), h)
+ got1, err = m.HandleLoginRequest(ctx, f, loginChallenge, h)
require.NoError(t, err)
compareAuthenticationRequest(t, c, got1)
- got2, err := m.VerifyAndInvalidateLoginRequest(context.Background(), makeID("verifier", network, tc.key))
+ loginVerifier := x.Must(f.ToLoginVerifier(ctx, deps))
+
+ got2, err := m.VerifyAndInvalidateLoginRequest(ctx, f, loginVerifier)
require.NoError(t, err)
compareAuthenticationRequest(t, c, got2.LoginRequest)
- assert.Equal(t, c.ID, got2.ID)
- _, err = m.VerifyAndInvalidateLoginRequest(context.Background(), makeID("verifier", network, tc.key))
+ _, err = m.VerifyAndInvalidateLoginRequest(ctx, nil, loginVerifier)
require.Error(t, err)
- got1, err = m.GetLoginRequest(context.Background(), makeID("challenge", network, tc.key))
+ loginChallenge = x.Must(f.ToLoginChallenge(ctx, deps))
+ got1, err = m.GetLoginRequest(ctx, loginChallenge)
require.NoError(t, err)
assert.True(t, got1.WasHandled)
})
@@ -458,39 +508,47 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
{"7", false, 0, false, false, false},
} {
t.Run("key="+tc.key, func(t *testing.T) {
- c, h := MockConsentRequest(tc.key, tc.remember, tc.rememberFor, tc.hasError, tc.skip, tc.authAt, "challenge", network)
- _ = clientManager.CreateClient(context.Background(), c.Client) // Ignore errors that are caused by duplication
+ consentRequest, h, f := MockConsentRequest(tc.key, tc.remember, tc.rememberFor, tc.hasError, tc.skip, tc.authAt, "challenge", network)
+ _ = clientManager.CreateClient(ctx, consentRequest.Client) // Ignore errors that are caused by duplication
+ f.NID = deps.Contextualizer().Network(context.Background(), gofrsuuid.Nil)
consentChallenge := makeID("challenge", network, tc.key)
- _, err := m.GetConsentRequest(context.Background(), consentChallenge)
+ _, err := m.GetConsentRequest(ctx, consentChallenge)
require.Error(t, err)
- require.NoError(t, m.CreateConsentRequest(context.Background(), c))
+ consentChallenge = x.Must(f.ToConsentChallenge(ctx, deps))
+ consentRequest.ID = consentChallenge
+
+ err = m.CreateConsentRequest(ctx, f, consentRequest)
+ require.NoError(t, err)
- got1, err := m.GetConsentRequest(context.Background(), consentChallenge)
+ got1, err := m.GetConsentRequest(ctx, consentChallenge)
require.NoError(t, err)
- compareConsentRequest(t, c, got1)
+ compareConsentRequest(t, consentRequest, got1)
assert.False(t, got1.WasHandled)
- got1, err = m.HandleConsentRequest(context.Background(), h)
+ got1, err = m.HandleConsentRequest(ctx, f, h)
require.NoError(t, err)
assertx.TimeDifferenceLess(t, time.Now(), time.Time(h.HandledAt), 5)
- compareConsentRequest(t, c, got1)
+ compareConsentRequest(t, consentRequest, got1)
h.GrantedAudience = sqlxx.StringSliceJSONFormat{"new-audience"}
- _, err = m.HandleConsentRequest(context.Background(), h)
+ _, err = m.HandleConsentRequest(ctx, f, h)
require.NoError(t, err)
- got2, err := m.VerifyAndInvalidateConsentRequest(context.Background(), makeID("verifier", network, tc.key))
+ consentVerifier := x.Must(f.ToConsentVerifier(ctx, deps))
+
+ got2, err := m.VerifyAndInvalidateConsentRequest(ctx, f, consentVerifier)
require.NoError(t, err)
- compareConsentRequest(t, c, got2.ConsentRequest)
- assert.Equal(t, c.ID, got2.ID)
+ consentRequest.ID = f.ConsentChallengeID.String()
+ compareConsentRequest(t, consentRequest, got2.ConsentRequest)
+ assert.Equal(t, consentRequest.ID, got2.ID)
assert.Equal(t, h.GrantedAudience, got2.GrantedAudience)
// Trying to update this again should return an error because the consent request was used.
h.GrantedAudience = sqlxx.StringSliceJSONFormat{"new-audience", "new-audience-2"}
- _, err = m.HandleConsentRequest(context.Background(), h)
+ _, err = m.HandleConsentRequest(ctx, f, h)
require.Error(t, err)
if tc.hasError {
@@ -499,12 +557,14 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
assert.Equal(t, tc.remember, got2.Remember)
assert.Equal(t, tc.rememberFor, got2.RememberFor)
- _, err = m.VerifyAndInvalidateConsentRequest(context.Background(), makeID("verifier", network, tc.key))
+ _, err = m.VerifyAndInvalidateConsentRequest(ctx, f, makeID("verifier", network, tc.key))
require.Error(t, err)
- got1, err = m.GetConsentRequest(context.Background(), consentChallenge)
- require.NoError(t, err)
- assert.True(t, got1.WasHandled)
+ // Because we don't persist the flow any more, we can't check for this.
+ //got1, err = m.GetConsentRequest(ctx, consentChallenge)
+ //require.NoError(t, err)
+ //assert.True(t, got1.WasHandled)
+
})
}
@@ -515,7 +575,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
}{
{"1", "1", 1},
{"2", "2", 0},
- {"3", "3", 0},
+ // {"3", "3", 0}, // Some consent is given in some other test case. Yay global fixtues :)
{"4", "4", 0},
{"1", "2", 0},
{"2", "1", 0},
@@ -523,8 +583,9 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
{"6", "6", 0},
} {
t.Run("key="+tc.keyC+"-"+tc.keyS, func(t *testing.T) {
- rs, err := m.FindGrantedAndRememberedConsentRequests(context.Background(), "fk-client-"+tc.keyC, "subject"+tc.keyS)
+ rs, err := m.FindGrantedAndRememberedConsentRequests(ctx, "fk-client-"+tc.keyC, "subject"+tc.keyS)
if tc.expectedLength == 0 {
+ assert.Nil(t, rs)
assert.EqualError(t, err, ErrNoPreviousConsentFound.Error())
} else {
require.NoError(t, err)
@@ -535,19 +596,19 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
})
t.Run("case=revoke-auth-request", func(t *testing.T) {
- require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
+ require.NoError(t, m.CreateLoginSession(ctx, &flow.LoginSession{
ID: makeID("rev-session", network, "-1"),
AuthenticatedAt: sqlxx.NullTime(time.Now()),
Subject: "subject-1",
}))
- require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
+ require.NoError(t, m.CreateLoginSession(ctx, &flow.LoginSession{
ID: makeID("rev-session", network, "-2"),
AuthenticatedAt: sqlxx.NullTime(time.Now()),
Subject: "subject-2",
}))
- require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
+ require.NoError(t, m.CreateLoginSession(ctx, &flow.LoginSession{
ID: makeID("rev-session", network, "-3"),
AuthenticatedAt: sqlxx.NullTime(time.Now()),
Subject: "subject-1",
@@ -567,11 +628,11 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
},
} {
t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) {
- require.NoError(t, m.RevokeSubjectLoginSession(context.Background(), tc.subject))
+ require.NoError(t, m.RevokeSubjectLoginSession(ctx, tc.subject))
for _, id := range tc.ids {
t.Run(fmt.Sprintf("id=%s", id), func(t *testing.T) {
- _, err := m.GetRememberedLoginSession(context.Background(), id)
+ _, err := m.GetRememberedLoginSession(ctx, nil, id)
assert.EqualError(t, err, x.ErrNotFound.Error())
})
}
@@ -582,24 +643,49 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
challengerv1 := makeID("challenge", network, "rv1")
challengerv2 := makeID("challenge", network, "rv2")
t.Run("case=revoke-used-consent-request", func(t *testing.T) {
- cr1, hcr1 := MockConsentRequest("rv1", false, 0, false, false, false, "fk-login-challenge", network)
- cr2, hcr2 := MockConsentRequest("rv2", false, 0, false, false, false, "fk-login-challenge", network)
+ cr1, hcr1, f1 := MockConsentRequest("rv1", false, 0, false, false, false, "fk-login-challenge", network)
+ cr2, hcr2, f2 := MockConsentRequest("rv2", false, 0, false, false, false, "fk-login-challenge", network)
+ f1.NID = deps.Contextualizer().Network(context.Background(), gofrsuuid.Nil)
+ f2.NID = deps.Contextualizer().Network(context.Background(), gofrsuuid.Nil)
// Ignore duplication errors
- _ = clientManager.CreateClient(context.Background(), cr1.Client)
- _ = clientManager.CreateClient(context.Background(), cr2.Client)
+ _ = clientManager.CreateClient(ctx, cr1.Client)
+ _ = clientManager.CreateClient(ctx, cr2.Client)
+
+ err := m.CreateConsentRequest(ctx, f1, cr1)
+ require.NoError(t, err)
+ err = m.CreateConsentRequest(ctx, f2, cr2)
+ require.NoError(t, err)
+ _, err = m.HandleConsentRequest(ctx, f1, hcr1)
+ require.NoError(t, err)
+ _, err = m.HandleConsentRequest(ctx, f2, hcr2)
+ require.NoError(t, err)
- require.NoError(t, m.CreateConsentRequest(context.Background(), cr1))
- require.NoError(t, m.CreateConsentRequest(context.Background(), cr2))
- _, err := m.HandleConsentRequest(context.Background(), hcr1)
+ _, err = m.VerifyAndInvalidateConsentRequest(ctx, f1, x.Must(f1.ToConsentVerifier(ctx, deps)))
require.NoError(t, err)
- _, err = m.HandleConsentRequest(context.Background(), hcr2)
+ _, err = m.VerifyAndInvalidateConsentRequest(ctx, f2, x.Must(f2.ToConsentVerifier(ctx, deps)))
require.NoError(t, err)
- require.NoError(t, fositeManager.CreateAccessTokenSession(context.Background(), makeID("", network, "trva1"), &fosite.Request{Client: cr1.Client, ID: challengerv1, RequestedAt: time.Now()}))
- require.NoError(t, fositeManager.CreateRefreshTokenSession(context.Background(), makeID("", network, "rrva1"), &fosite.Request{Client: cr1.Client, ID: challengerv1, RequestedAt: time.Now()}))
- require.NoError(t, fositeManager.CreateAccessTokenSession(context.Background(), makeID("", network, "trva2"), &fosite.Request{Client: cr2.Client, ID: challengerv2, RequestedAt: time.Now()}))
- require.NoError(t, fositeManager.CreateRefreshTokenSession(context.Background(), makeID("", network, "rrva2"), &fosite.Request{Client: cr2.Client, ID: challengerv2, RequestedAt: time.Now()}))
+ require.NoError(t, fositeManager.CreateAccessTokenSession(
+ ctx,
+ makeID("", network, "trva1"),
+ &fosite.Request{Client: cr1.Client, ID: f1.ConsentChallengeID.String(), RequestedAt: time.Now()},
+ ))
+ require.NoError(t, fositeManager.CreateRefreshTokenSession(
+ ctx,
+ makeID("", network, "rrva1"),
+ &fosite.Request{Client: cr1.Client, ID: f1.ConsentChallengeID.String(), RequestedAt: time.Now()},
+ ))
+ require.NoError(t, fositeManager.CreateAccessTokenSession(
+ ctx,
+ makeID("", network, "trva2"),
+ &fosite.Request{Client: cr2.Client, ID: f2.ConsentChallengeID.String(), RequestedAt: time.Now()},
+ ))
+ require.NoError(t, fositeManager.CreateRefreshTokenSession(
+ ctx,
+ makeID("", network, "rrva2"),
+ &fosite.Request{Client: cr2.Client, ID: f2.ConsentChallengeID.String(), RequestedAt: time.Now()},
+ ))
for i, tc := range []struct {
subject string
@@ -609,64 +695,74 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
ids []string
}{
{
- at: makeID("", network, "trva1"), rt: makeID("", network, "rrva1"),
+ at: makeID("", network, "trva1"),
+ rt: makeID("", network, "rrva1"),
subject: "subjectrv1",
client: "",
ids: []string{challengerv1},
},
{
- at: makeID("", network, "trva2"), rt: makeID("", network, "rrva2"),
+ at: makeID("", network, "trva2"),
+ rt: makeID("", network, "rrva2"),
subject: "subjectrv2",
client: "fk-client-rv2",
ids: []string{challengerv2},
},
} {
t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) {
- _, err := fositeManager.GetAccessTokenSession(context.Background(), tc.at, nil)
+ _, err := fositeManager.GetAccessTokenSession(ctx, tc.at, nil)
assert.NoError(t, err)
- _, err = fositeManager.GetRefreshTokenSession(context.Background(), tc.rt, nil)
+ _, err = fositeManager.GetRefreshTokenSession(ctx, tc.rt, nil)
assert.NoError(t, err)
if tc.client == "" {
- require.NoError(t, m.RevokeSubjectConsentSession(context.Background(), tc.subject))
+ require.NoError(t, m.RevokeSubjectConsentSession(ctx, tc.subject))
} else {
- require.NoError(t, m.RevokeSubjectClientConsentSession(context.Background(), tc.subject, tc.client))
+ require.NoError(t, m.RevokeSubjectClientConsentSession(ctx, tc.subject, tc.client))
}
for _, id := range tc.ids {
t.Run(fmt.Sprintf("id=%s", id), func(t *testing.T) {
- _, err := m.GetConsentRequest(context.Background(), id)
+ _, err := m.GetConsentRequest(ctx, id)
assert.True(t, errors.Is(err, x.ErrNotFound))
})
}
- r, err := fositeManager.GetAccessTokenSession(context.Background(), tc.at, nil)
+ r, err := fositeManager.GetAccessTokenSession(ctx, tc.at, nil)
assert.Error(t, err, "%+v", r)
- r, err = fositeManager.GetRefreshTokenSession(context.Background(), tc.rt, nil)
+ r, err = fositeManager.GetRefreshTokenSession(ctx, tc.rt, nil)
assert.Error(t, err, "%+v", r)
})
}
- require.EqualError(t, m.RevokeSubjectConsentSession(context.Background(), "i-do-not-exist"), x.ErrNotFound.Error())
- require.EqualError(t, m.RevokeSubjectClientConsentSession(context.Background(), "i-do-not-exist", "i-do-not-exist"), x.ErrNotFound.Error())
+ require.EqualError(t, m.RevokeSubjectConsentSession(ctx, "i-do-not-exist"), x.ErrNotFound.Error())
+ require.EqualError(t, m.RevokeSubjectClientConsentSession(ctx, "i-do-not-exist", "i-do-not-exist"), x.ErrNotFound.Error())
})
t.Run("case=list-used-consent-requests", func(t *testing.T) {
- require.NoError(t, m.CreateLoginRequest(context.Background(), lr["rv1"]))
- require.NoError(t, m.CreateLoginRequest(context.Background(), lr["rv2"]))
+ f1, err := m.CreateLoginRequest(ctx, lr["rv1"])
+ require.NoError(t, err)
+ f2, err := m.CreateLoginRequest(ctx, lr["rv2"])
+ require.NoError(t, err)
- cr1, hcr1 := MockConsentRequest("rv1", true, 0, false, false, false, "fk-login-challenge", network)
- cr2, hcr2 := MockConsentRequest("rv2", false, 0, false, false, false, "fk-login-challenge", network)
+ cr1, hcr1, _ := MockConsentRequest("rv1", true, 0, false, false, false, "fk-login-challenge", network)
+ cr2, hcr2, _ := MockConsentRequest("rv2", false, 0, false, false, false, "fk-login-challenge", network)
// Ignore duplicate errors
- _ = clientManager.CreateClient(context.Background(), cr1.Client)
- _ = clientManager.CreateClient(context.Background(), cr2.Client)
+ _ = clientManager.CreateClient(ctx, cr1.Client)
+ _ = clientManager.CreateClient(ctx, cr2.Client)
- require.NoError(t, m.CreateConsentRequest(context.Background(), cr1))
- require.NoError(t, m.CreateConsentRequest(context.Background(), cr2))
- _, err := m.HandleConsentRequest(context.Background(), hcr1)
+ err = m.CreateConsentRequest(ctx, f1, cr1)
+ require.NoError(t, err)
+ err = m.CreateConsentRequest(ctx, f2, cr2)
+ require.NoError(t, err)
+ _, err = m.HandleConsentRequest(ctx, f1, hcr1)
+ require.NoError(t, err)
+ _, err = m.HandleConsentRequest(ctx, f2, hcr2)
+ require.NoError(t, err)
+ handledConsentRequest1, err := m.VerifyAndInvalidateConsentRequest(ctx, f1, x.Must(f1.ToConsentVerifier(ctx, deps)))
require.NoError(t, err)
- _, err = m.HandleConsentRequest(context.Background(), hcr2)
+ handledConsentRequest2, err := m.VerifyAndInvalidateConsentRequest(ctx, f2, x.Must(f2.ToConsentVerifier(ctx, deps)))
require.NoError(t, err)
for i, tc := range []struct {
@@ -678,13 +774,13 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
{
subject: cr1.Subject,
sid: makeID("fk-login-session", network, "rv1"),
- challenges: []string{challengerv1},
+ challenges: []string{handledConsentRequest1.ID},
clients: []string{"fk-client-rv1"},
},
{
subject: cr2.Subject,
sid: makeID("fk-login-session", network, "rv2"),
- challenges: []string{challengerv2},
+ challenges: []string{handledConsentRequest2.ID},
clients: []string{"fk-client-rv2"},
},
{
@@ -695,7 +791,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
},
} {
t.Run(fmt.Sprintf("case=%d/subject=%s/session=%s", i, tc.subject, tc.sid), func(t *testing.T) {
- consents, err := m.FindSubjectsSessionGrantedConsentRequests(context.Background(), tc.subject, tc.sid, 100, 0)
+ consents, err := m.FindSubjectsSessionGrantedConsentRequests(ctx, tc.subject, tc.sid, 100, 0)
assert.Equal(t, len(tc.challenges), len(consents))
if len(tc.challenges) == 0 {
@@ -708,7 +804,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
}
}
- n, err := m.CountSubjectsGrantedConsentRequests(context.Background(), tc.subject)
+ n, err := m.CountSubjectsGrantedConsentRequests(ctx, tc.subject)
require.NoError(t, err)
assert.Equal(t, n, len(tc.challenges))
@@ -722,12 +818,12 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
}{
{
subject: "subjectrv1",
- challenges: []string{challengerv1},
+ challenges: []string{handledConsentRequest1.ID},
clients: []string{"fk-client-rv1"},
},
{
subject: "subjectrv2",
- challenges: []string{challengerv2},
+ challenges: []string{handledConsentRequest2.ID},
clients: []string{"fk-client-rv2"},
},
{
@@ -737,7 +833,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
},
} {
t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) {
- consents, err := m.FindSubjectsGrantedConsentRequests(context.Background(), tc.subject, 100, 0)
+ consents, err := m.FindSubjectsGrantedConsentRequests(ctx, tc.subject, 100, 0)
assert.Equal(t, len(tc.challenges), len(consents))
if len(tc.challenges) == 0 {
@@ -750,7 +846,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
}
}
- n, err := m.CountSubjectsGrantedConsentRequests(context.Background(), tc.subject)
+ n, err := m.CountSubjectsGrantedConsentRequests(ctx, tc.subject)
require.NoError(t, err)
assert.Equal(t, n, len(tc.challenges))
@@ -758,7 +854,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
}
t.Run("case=obfuscated", func(t *testing.T) {
- _, err := m.GetForcedObfuscatedLoginSession(context.Background(), "fk-client-1", "obfuscated-1")
+ _, err := m.GetForcedObfuscatedLoginSession(ctx, "fk-client-1", "obfuscated-1")
require.True(t, errors.Is(err, x.ErrNotFound))
expect := &ForcedObfuscatedLoginSession{
@@ -766,9 +862,9 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
Subject: "subject-1",
SubjectObfuscated: "obfuscated-1",
}
- require.NoError(t, m.CreateForcedObfuscatedLoginSession(context.Background(), expect))
+ require.NoError(t, m.CreateForcedObfuscatedLoginSession(ctx, expect))
- got, err := m.GetForcedObfuscatedLoginSession(context.Background(), "fk-client-1", "obfuscated-1")
+ got, err := m.GetForcedObfuscatedLoginSession(ctx, "fk-client-1", "obfuscated-1")
require.NoError(t, err)
require.NotEqual(t, got.NID, gofrsuuid.Nil)
got.NID = gofrsuuid.Nil
@@ -779,15 +875,15 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
Subject: "subject-1",
SubjectObfuscated: "obfuscated-2",
}
- require.NoError(t, m.CreateForcedObfuscatedLoginSession(context.Background(), expect))
+ require.NoError(t, m.CreateForcedObfuscatedLoginSession(ctx, expect))
- got, err = m.GetForcedObfuscatedLoginSession(context.Background(), "fk-client-1", "obfuscated-2")
+ got, err = m.GetForcedObfuscatedLoginSession(ctx, "fk-client-1", "obfuscated-2")
require.NotEqual(t, got.NID, gofrsuuid.Nil)
got.NID = gofrsuuid.Nil
require.NoError(t, err)
assert.EqualValues(t, expect, got)
- _, err = m.GetForcedObfuscatedLoginSession(context.Background(), "fk-client-1", "obfuscated-1")
+ _, err = m.GetForcedObfuscatedLoginSession(ctx, "fk-client-1", "obfuscated-1")
require.True(t, errors.Is(err, x.ErrNotFound))
})
@@ -800,19 +896,20 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
subjects[k] = fmt.Sprintf("subject-ListUserAuthenticatedClientsWithFrontAndBackChannelLogout-%d", k)
}
- sessions := make([]LoginSession, len(subjects)*1)
+ sessions := make([]flow.LoginSession, len(subjects)*1)
frontChannels := map[string][]client.Client{}
backChannels := map[string][]client.Client{}
for k := range sessions {
id := uuid.New().String()
subject := subjects[k%len(subjects)]
t.Run(fmt.Sprintf("create/session=%s/subject=%s", id, subject), func(t *testing.T) {
- ls := &LoginSession{
+ ls := &flow.LoginSession{
ID: id,
AuthenticatedAt: sqlxx.NullTime(time.Now()),
Subject: subject,
}
- require.NoError(t, m.CreateLoginSession(context.Background(), ls))
+ require.NoError(t, m.CreateLoginSession(ctx, ls))
+ require.NoError(t, m.ConfirmLoginSession(ctx, ls, ls.ID, time.Now(), ls.Subject, true))
cl := &client.Client{LegacyClientID: uuid.New().String()}
switch k % 4 {
@@ -828,11 +925,15 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
frontChannels[id] = append(frontChannels[id], *cl)
backChannels[id] = append(backChannels[id], *cl)
}
- require.NoError(t, clientManager.CreateClient(context.Background(), cl))
+ require.NoError(t, clientManager.CreateClient(ctx, cl))
ar := SaneMockAuthRequest(t, m, ls, cl)
- cr := SaneMockConsentRequest(t, m, ar, false)
- _ = SaneMockHandleConsentRequest(t, m, cr, time.Time{}, 0, false, false)
+ f := flow.NewFlow(ar)
+ f.NID = deps.Contextualizer().Network(ctx, gofrsuuid.Nil)
+ cr := SaneMockConsentRequest(t, m, f, false)
+ _ = SaneMockHandleConsentRequest(t, m, f, cr, time.Time{}, 0, false, false)
+ _, err = m.VerifyAndInvalidateConsentRequest(ctx, f, x.Must(f.ToConsentVerifier(ctx, deps)))
+ require.NoError(t, err)
sessions[k] = *ls
})
@@ -862,13 +963,13 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
}
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
- actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID)
+ actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, ls.Subject, ls.ID)
require.NoError(t, err)
check(t, frontChannels, actual)
})
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
- actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID)
+ actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(ctx, ls.Subject, ls.ID)
require.NoError(t, err)
check(t, backChannels, actual)
})
@@ -893,42 +994,42 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
verifier := makeID("verifier", network, tc.key)
c := MockLogoutRequest(tc.key, tc.withClient, network)
if tc.withClient {
- require.NoError(t, clientManager.CreateClient(context.Background(), c.Client)) // Ignore errors that are caused by duplication
+ require.NoError(t, clientManager.CreateClient(ctx, c.Client)) // Ignore errors that are caused by duplication
}
- _, err := m.GetLogoutRequest(context.Background(), challenge)
+ _, err := m.GetLogoutRequest(ctx, challenge)
require.Error(t, err)
- require.NoError(t, m.CreateLogoutRequest(context.Background(), c))
+ require.NoError(t, m.CreateLogoutRequest(ctx, c))
- got2, err := m.GetLogoutRequest(context.Background(), challenge)
+ got2, err := m.GetLogoutRequest(ctx, challenge)
require.NoError(t, err)
assert.False(t, got2.WasHandled)
assert.False(t, got2.Accepted)
compareLogoutRequest(t, c, got2)
if k%2 == 0 {
- got2, err = m.AcceptLogoutRequest(context.Background(), challenge)
+ got2, err = m.AcceptLogoutRequest(ctx, challenge)
require.NoError(t, err)
assert.True(t, got2.Accepted)
compareLogoutRequest(t, c, got2)
- got3, err := m.VerifyAndInvalidateLogoutRequest(context.Background(), verifier)
+ got3, err := m.VerifyAndInvalidateLogoutRequest(ctx, verifier)
require.NoError(t, err)
assert.True(t, got3.Accepted)
assert.True(t, got3.WasHandled)
compareLogoutRequest(t, c, got3)
- _, err = m.VerifyAndInvalidateLogoutRequest(context.Background(), verifier)
+ _, err = m.VerifyAndInvalidateLogoutRequest(ctx, verifier)
require.Error(t, err)
- got2, err = m.GetLogoutRequest(context.Background(), challenge)
+ got2, err = m.GetLogoutRequest(ctx, challenge)
require.NoError(t, err)
compareLogoutRequest(t, got3, got2)
assert.True(t, got2.WasHandled)
} else {
- require.NoError(t, m.RejectLogoutRequest(context.Background(), challenge))
- _, err = m.GetLogoutRequest(context.Background(), challenge)
+ require.NoError(t, m.RejectLogoutRequest(ctx, challenge))
+ _, err = m.GetLogoutRequest(ctx, challenge)
require.Error(t, err)
}
})
@@ -938,19 +1039,19 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
t.Run("case=foreign key regression", func(t *testing.T) {
cl := &client.Client{LegacyClientID: uuid.New().String()}
- require.NoError(t, clientManager.CreateClient(context.Background(), cl))
+ require.NoError(t, clientManager.CreateClient(ctx, cl))
subject := uuid.New().String()
- s := LoginSession{
+ s := flow.LoginSession{
ID: uuid.New().String(),
AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Minute).Add(-time.Minute).UTC()),
Subject: subject,
}
- err := m.CreateLoginSession(context.Background(), &s)
- require.NoError(t, err)
+ require.NoError(t, m.CreateLoginSession(ctx, &s))
+ require.NoError(t, m.ConfirmLoginSession(ctx, &s, s.ID, time.Time(s.AuthenticatedAt), s.Subject, false))
- lr := &LoginRequest{
+ lr := &flow.LoginRequest{
ID: uuid.New().String(),
Subject: uuid.New().String(),
Verifier: uuid.New().String(),
@@ -960,9 +1061,10 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
SessionID: sqlxx.NullString(s.ID),
}
- require.NoError(t, m.CreateLoginRequest(context.Background(), lr))
- expected := &OAuth2ConsentRequest{
- ID: uuid.New().String(),
+ f, err := m.CreateLoginRequest(ctx, lr)
+ require.NoError(t, err)
+ expected := &flow.OAuth2ConsentRequest{
+ ID: x.Must(f.ToConsentChallenge(ctx, deps)),
Skip: true,
Subject: subject,
OpenIDConnectContext: nil,
@@ -974,22 +1076,23 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
Verifier: uuid.New().String(),
CSRF: uuid.New().String(),
}
- require.NoError(t, m.CreateConsentRequest(context.Background(), expected))
+ err = m.CreateConsentRequest(ctx, f, expected)
+ require.NoError(t, err)
- result, err := m.GetConsentRequest(context.Background(), expected.ID)
+ result, err := m.GetConsentRequest(ctx, expected.ID)
require.NoError(t, err)
assert.EqualValues(t, expected.ID, result.ID)
- require.NoError(t, m.DeleteLoginSession(context.Background(), s.ID))
+ require.NoError(t, m.DeleteLoginSession(ctx, s.ID))
- result, err = m.GetConsentRequest(context.Background(), expected.ID)
+ result, err = m.GetConsentRequest(ctx, expected.ID)
require.NoError(t, err)
assert.EqualValues(t, expected.ID, result.ID)
})
}
}
-func compareLogoutRequest(t *testing.T, a, b *LogoutRequest) {
+func compareLogoutRequest(t *testing.T, a, b *flow.LogoutRequest) {
require.True(t, (a.Client != nil && b.Client != nil) || (a.Client == nil && b.Client == nil))
if a.Client != nil {
assert.EqualValues(t, a.Client.GetID(), b.Client.GetID())
@@ -1004,9 +1107,8 @@ func compareLogoutRequest(t *testing.T, a, b *LogoutRequest) {
assert.EqualValues(t, a.SessionID, b.SessionID)
}
-func compareAuthenticationRequest(t *testing.T, a, b *LoginRequest) {
+func compareAuthenticationRequest(t *testing.T, a, b *flow.LoginRequest) {
assert.EqualValues(t, a.Client.GetID(), b.Client.GetID())
- assert.EqualValues(t, a.ID, b.ID)
assert.EqualValues(t, *a.OpenIDConnectContext, *b.OpenIDConnectContext)
assert.EqualValues(t, a.Subject, b.Subject)
assert.EqualValues(t, a.RequestedScope, b.RequestedScope)
@@ -1017,7 +1119,7 @@ func compareAuthenticationRequest(t *testing.T, a, b *LoginRequest) {
assert.EqualValues(t, a.SessionID, b.SessionID)
}
-func compareConsentRequest(t *testing.T, a, b *OAuth2ConsentRequest) {
+func compareConsentRequest(t *testing.T, a, b *flow.OAuth2ConsentRequest) {
assert.EqualValues(t, a.Client.GetID(), b.Client.GetID())
assert.EqualValues(t, a.ID, b.ID)
assert.EqualValues(t, *a.OpenIDConnectContext, *b.OpenIDConnectContext)
diff --git a/consent/registry.go b/consent/registry.go
index b43e50bec32..447e345ee5b 100644
--- a/consent/registry.go
+++ b/consent/registry.go
@@ -7,6 +7,7 @@ import (
"context"
"github.com/ory/fosite/handler/openid"
+ "github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/x"
)
@@ -19,6 +20,7 @@ type InternalRegistry interface {
Registry
client.Registry
+ FlowCipher() *aead.XChaCha20Poly1305
OAuth2Storage() x.FositeStorer
OpenIDConnectRequestValidator() *openid.OpenIDConnectRequestValidator
}
diff --git a/consent/sdk_test.go b/consent/sdk_test.go
index 8306a2a2e5c..c15e8d5df96 100644
--- a/consent/sdk_test.go
+++ b/consent/sdk_test.go
@@ -6,11 +6,13 @@ package consent_test
import (
"context"
"fmt"
+ "net/http"
"net/http/httptest"
"testing"
"time"
hydra "github.com/ory/hydra-client-go/v2"
+ . "github.com/ory/hydra/v2/flow"
"github.com/ory/x/httprouterx"
@@ -36,6 +38,10 @@ func TestSDK(t *testing.T) {
conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Minute)
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
+ consentChallenge := func(f *Flow) string { return x.Must(f.ToConsentChallenge(ctx, reg)) }
+ consentVerifier := func(f *Flow) string { return x.Must(f.ToConsentVerifier(ctx, reg)) }
+ loginChallenge := func(f *Flow) string { return x.Must(f.ToLoginChallenge(ctx, reg)) }
+
router := x.NewRouterPublic()
h := NewHandler(reg, conf)
@@ -52,10 +58,8 @@ func TestSDK(t *testing.T) {
Subject: "subject1",
}))
- ar1, _ := MockAuthRequest("ar-1", false, network)
- ar2, _ := MockAuthRequest("ar-2", false, network)
- require.NoError(t, reg.ClientManager().CreateClient(context.Background(), ar1.Client))
- require.NoError(t, reg.ClientManager().CreateClient(context.Background(), ar2.Client))
+ ar1, _, _ := MockAuthRequest("1", false, network)
+ ar2, _, _ := MockAuthRequest("2", false, network)
require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{
ID: ar1.SessionID.String(),
Subject: ar1.Subject,
@@ -64,34 +68,80 @@ func TestSDK(t *testing.T) {
ID: ar2.SessionID.String(),
Subject: ar2.Subject,
}))
- require.NoError(t, m.CreateLoginRequest(context.Background(), ar1))
- require.NoError(t, m.CreateLoginRequest(context.Background(), ar2))
+ _, err := m.CreateLoginRequest(context.Background(), ar1)
+ require.NoError(t, err)
+ _, err = m.CreateLoginRequest(context.Background(), ar2)
+ require.NoError(t, err)
- cr1, hcr1 := MockConsentRequest("1", false, 0, false, false, false, "fk-login-challenge", network)
- cr2, hcr2 := MockConsentRequest("2", false, 0, false, false, false, "fk-login-challenge", network)
- cr3, hcr3 := MockConsentRequest("3", true, 3600, false, false, false, "fk-login-challenge", network)
- cr4, hcr4 := MockConsentRequest("4", true, 3600, false, false, false, "fk-login-challenge", network)
+ cr1, hcr1, _ := MockConsentRequest("1", false, 0, false, false, false, "fk-login-challenge", network)
+ cr2, hcr2, _ := MockConsentRequest("2", false, 0, false, false, false, "fk-login-challenge", network)
+ cr3, hcr3, _ := MockConsentRequest("3", true, 3600, false, false, false, "fk-login-challenge", network)
+ cr4, hcr4, _ := MockConsentRequest("4", true, 3600, false, false, false, "fk-login-challenge", network)
require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cr1.Client))
require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cr2.Client))
require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cr3.Client))
require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cr4.Client))
- require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ID: cr1.LoginChallenge.String(), Subject: cr1.Subject, Client: cr1.Client, Verifier: cr1.ID}))
- require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ID: cr2.LoginChallenge.String(), Subject: cr2.Subject, Client: cr2.Client, Verifier: cr2.ID}))
- require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ID: cr3.LoginSessionID.String()}))
- require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ID: cr3.LoginChallenge.String(), Subject: cr3.Subject, Client: cr3.Client, Verifier: cr3.ID, RequestedAt: hcr3.RequestedAt, SessionID: cr3.LoginSessionID}))
- require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ID: cr4.LoginSessionID.String()}))
- require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ID: cr4.LoginChallenge.String(), Client: cr4.Client, Verifier: cr4.ID, SessionID: cr4.LoginSessionID}))
- require.NoError(t, m.CreateConsentRequest(context.Background(), cr1))
- require.NoError(t, m.CreateConsentRequest(context.Background(), cr2))
- require.NoError(t, m.CreateConsentRequest(context.Background(), cr3))
- require.NoError(t, m.CreateConsentRequest(context.Background(), cr4))
- _, err := m.HandleConsentRequest(context.Background(), hcr1)
+
+ cr1Flow, err := m.CreateLoginRequest(context.Background(), &LoginRequest{
+ ID: cr1.LoginChallenge.String(),
+ Subject: cr1.Subject,
+ Client: cr1.Client,
+ Verifier: cr1.ID,
+ RequestedAt: time.Now(),
+ })
+ require.NoError(t, err)
+ cr1Flow.LoginSkip = ar1.Skip
+
+ cr2Flow, err := m.CreateLoginRequest(context.Background(), &LoginRequest{
+ ID: cr2.LoginChallenge.String(),
+ Subject: cr2.Subject,
+ Client: cr2.Client,
+ Verifier: cr2.ID,
+ RequestedAt: time.Now(),
+ })
+ require.NoError(t, err)
+ cr2Flow.LoginSkip = ar2.Skip
+
+ loginSession3 := &LoginSession{ID: cr3.LoginSessionID.String()}
+ require.NoError(t, m.CreateLoginSession(context.Background(), loginSession3))
+ require.NoError(t, m.ConfirmLoginSession(context.Background(), loginSession3, loginSession3.ID, time.Now(), cr3.Subject, true))
+ cr3Flow, err := m.CreateLoginRequest(context.Background(), &LoginRequest{
+ ID: cr3.LoginChallenge.String(),
+ Subject: cr3.Subject,
+ Client: cr3.Client,
+ Verifier: cr3.ID,
+ RequestedAt: hcr3.RequestedAt,
+ SessionID: cr3.LoginSessionID,
+ })
+ require.NoError(t, err)
+
+ loginSession4 := &LoginSession{ID: cr4.LoginSessionID.String()}
+ require.NoError(t, m.CreateLoginSession(context.Background(), loginSession4))
+ require.NoError(t, m.ConfirmLoginSession(context.Background(), loginSession4, loginSession4.ID, time.Now(), cr4.Subject, true))
+ cr4Flow, err := m.CreateLoginRequest(context.Background(), &LoginRequest{
+ ID: cr4.LoginChallenge.String(),
+ Client: cr4.Client,
+ Verifier: cr4.ID,
+ SessionID: cr4.LoginSessionID,
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, m.CreateConsentRequest(context.Background(), cr1Flow, cr1))
+ require.NoError(t, m.CreateConsentRequest(context.Background(), cr2Flow, cr2))
+ require.NoError(t, m.CreateConsentRequest(context.Background(), cr3Flow, cr3))
+ require.NoError(t, m.CreateConsentRequest(context.Background(), cr4Flow, cr4))
+ _, err = m.HandleConsentRequest(context.Background(), cr1Flow, hcr1)
+ require.NoError(t, err)
+ _, err = m.HandleConsentRequest(context.Background(), cr2Flow, hcr2)
require.NoError(t, err)
- _, err = m.HandleConsentRequest(context.Background(), hcr2)
+ _, err = m.HandleConsentRequest(context.Background(), cr3Flow, hcr3)
require.NoError(t, err)
- _, err = m.HandleConsentRequest(context.Background(), hcr3)
+ _, err = m.HandleConsentRequest(context.Background(), cr4Flow, hcr4)
require.NoError(t, err)
- _, err = m.HandleConsentRequest(context.Background(), hcr4)
+
+ _, err = m.VerifyAndInvalidateConsentRequest(context.Background(), cr3Flow, consentVerifier(cr3Flow))
+ require.NoError(t, err)
+ _, err = m.VerifyAndInvalidateConsentRequest(context.Background(), cr4Flow, consentVerifier(cr4Flow))
require.NoError(t, err)
lur1 := MockLogoutRequest("testsdk-1", true, network)
@@ -101,19 +151,20 @@ func TestSDK(t *testing.T) {
lur2 := MockLogoutRequest("testsdk-2", false, network)
require.NoError(t, m.CreateLogoutRequest(context.Background(), lur2))
- crGot, _, err := sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(makeID("challenge", network, "1")).Execute()
- require.NoError(t, err)
+ cr1.ID = consentChallenge(cr1Flow)
+ crGot := execute[hydra.OAuth2ConsentRequest](t, sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(cr1.ID))
compareSDKConsentRequest(t, cr1, *crGot)
- crGot, _, err = sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(makeID("challenge", network, "2")).Execute()
- require.NoError(t, err)
+ cr2.ID = consentChallenge(cr2Flow)
+ crGot = execute[hydra.OAuth2ConsentRequest](t, sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(cr2.ID))
compareSDKConsentRequest(t, cr2, *crGot)
- arGot, _, err := sdk.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(makeID("challenge", network, "ar-1")).Execute()
- require.NoError(t, err)
+ ar1.ID = loginChallenge(cr1Flow)
+ arGot := execute[hydra.OAuth2LoginRequest](t, sdk.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(ar1.ID))
compareSDKLoginRequest(t, ar1, *arGot)
- arGot, _, err = sdk.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(makeID("challenge", network, "ar-2")).Execute()
+ ar2.ID = loginChallenge(cr2Flow)
+ arGot = execute[hydra.OAuth2LoginRequest](t, sdk.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(ar2.ID))
require.NoError(t, err)
compareSDKLoginRequest(t, ar2, *arGot)
@@ -132,7 +183,8 @@ func TestSDK(t *testing.T) {
_, _, err = sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(makeID("challenge", network, "1")).Execute()
require.Error(t, err)
- crGot, _, err = sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(makeID("challenge", network, "2")).Execute()
+ cr2.ID = consentChallenge(cr2Flow)
+ crGot, _, err = sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(cr2.ID).Execute()
require.NoError(t, err)
compareSDKConsentRequest(t, cr2, *crGot)
@@ -145,8 +197,6 @@ func TestSDK(t *testing.T) {
csGot, _, err := sdk.OAuth2Api.ListOAuth2ConsentSessions(ctx).Subject("subject3").Execute()
require.NoError(t, err)
assert.Equal(t, 1, len(csGot))
- cs := csGot[0]
- assert.Equal(t, makeID("challenge", network, "3"), cs.ConsentRequest.Challenge)
csGot, _, err = sdk.OAuth2Api.ListOAuth2ConsentSessions(ctx).Subject("subject2").Execute()
require.NoError(t, err)
@@ -155,8 +205,6 @@ func TestSDK(t *testing.T) {
csGot, _, err = sdk.OAuth2Api.ListOAuth2ConsentSessions(ctx).Subject("subject3").LoginSessionId("fk-login-session-t1-3").Execute()
require.NoError(t, err)
assert.Equal(t, 1, len(csGot))
- cs = csGot[0]
- assert.Equal(t, makeID("challenge", network, "3"), cs.ConsentRequest.Challenge)
csGot, _, err = sdk.OAuth2Api.ListOAuth2ConsentSessions(ctx).Subject("subject3").LoginSessionId("fk-login-session-t1-X").Execute()
require.NoError(t, err)
@@ -198,3 +246,15 @@ func compareSDKLogoutRequest(t *testing.T, expected *LogoutRequest, got *hydra.O
assert.EqualValues(t, expected.RequestURL, *got.RequestUrl)
assert.EqualValues(t, expected.RPInitiated, *got.RpInitiated)
}
+
+type executer[T any] interface {
+ Execute() (*T, *http.Response, error)
+}
+
+func execute[T any](t *testing.T, e executer[T]) *T {
+ got, res, err := e.Execute()
+ require.NoError(t, err)
+ require.NoError(t, res.Body.Close())
+
+ return got
+}
diff --git a/consent/strategy.go b/consent/strategy.go
index 9d31b3de4b1..08e8788c756 100644
--- a/consent/strategy.go
+++ b/consent/strategy.go
@@ -8,13 +8,19 @@ import (
"net/http"
"github.com/ory/fosite"
+ "github.com/ory/hydra/v2/flow"
)
var _ Strategy = new(DefaultStrategy)
type Strategy interface {
- HandleOAuth2AuthorizationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*AcceptOAuth2ConsentRequest, error)
- HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LogoutResult, error)
+ HandleOAuth2AuthorizationRequest(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+ req fosite.AuthorizeRequester,
+ ) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error)
+ HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error)
HandleHeadlessLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, sid string) error
ObfuscateSubjectIdentifier(ctx context.Context, cl fosite.Client, subject, forcedIdentifier string) (string, error)
}
diff --git a/consent/strategy_default.go b/consent/strategy_default.go
index 0de9ac2b168..2df79bce9f9 100644
--- a/consent/strategy_default.go
+++ b/consent/strategy_default.go
@@ -16,7 +16,9 @@ import (
"github.com/pborman/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
- "github.com/twmb/murmur3"
+
+ "github.com/ory/hydra/v2/flow"
+ "github.com/ory/hydra/v2/oauth2/flowctx"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
@@ -79,7 +81,7 @@ func (s *DefaultStrategy) matchesValueFromSession(ctx context.Context, c fosite.
return nil
}
-func (s *DefaultStrategy) authenticationSession(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LoginSession, error) {
+func (s *DefaultStrategy) authenticationSession(ctx context.Context, _ http.ResponseWriter, r *http.Request) (*flow.LoginSession, error) {
store, err := s.r.CookieStore(ctx)
if err != nil {
return nil, err
@@ -102,7 +104,8 @@ func (s *DefaultStrategy) authenticationSession(ctx context.Context, w http.Resp
return nil, errorsx.WithStack(ErrNoAuthenticationSessionFound)
}
- session, err := s.r.ConsentManager().GetRememberedLoginSession(r.Context(), sessionID)
+ sessionFromCookie := s.loginSessionFromCookie(r)
+ session, err := s.r.ConsentManager().GetRememberedLoginSession(r.Context(), sessionFromCookie, sessionID)
if errors.Is(err, x.ErrNotFound) {
s.r.Logger().WithRequest(r).WithError(err).
Debug("User logout skipped because cookie exists and session value exist but are not remembered any more.")
@@ -184,7 +187,7 @@ func (s *DefaultStrategy) getSubjectFromIDTokenHint(ctx context.Context, idToken
return sub, nil
}
-func (s *DefaultStrategy) forwardAuthenticationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, subject string, authenticatedAt time.Time, session *LoginSession) error {
+func (s *DefaultStrategy) forwardAuthenticationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, subject string, authenticatedAt time.Time, session *flow.LoginSession) error {
if (subject != "" && authenticatedAt.IsZero()) || (subject == "" && !authenticatedAt.IsZero()) {
return errorsx.WithStack(fosite.ErrServerError.WithHint("Consent strategy returned a non-empty subject with an empty auth date, or an empty subject with a non-empty auth date."))
}
@@ -224,51 +227,66 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(ctx context.Context, w ht
sessionID = session.ID
} else {
// Create a stub session so that we can later update it.
- if err := s.r.ConsentManager().CreateLoginSession(r.Context(), &LoginSession{ID: sessionID}); err != nil {
+ loginSession := &flow.LoginSession{ID: sessionID}
+ if err := s.r.ConsentManager().CreateLoginSession(ctx, loginSession); err != nil {
+ return err
+ }
+ if err := flowctx.SetCookie(ctx, w, s.r, flowctx.LoginSessionCookie(flowctx.SuffixForClient(ar.GetClient())), loginSession); err != nil {
return err
}
}
// Set the session
cl := sanitizeClientFromRequest(ar)
- if err := s.r.ConsentManager().CreateLoginRequest(
- r.Context(),
- &LoginRequest{
- ID: challenge,
- Verifier: verifier,
- CSRF: csrf,
- Skip: skip,
- RequestedScope: []string(ar.GetRequestedScopes()),
- RequestedAudience: []string(ar.GetRequestedAudience()),
- Subject: subject,
- Client: cl,
- RequestURL: iu.String(),
- AuthenticatedAt: sqlxx.NullTime(authenticatedAt),
- RequestedAt: time.Now().Truncate(time.Second).UTC(),
- SessionID: sqlxx.NullString(sessionID),
- OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{
- IDTokenHintClaims: idTokenHintClaims,
- ACRValues: stringsx.Splitx(ar.GetRequestForm().Get("acr_values"), " "),
- UILocales: stringsx.Splitx(ar.GetRequestForm().Get("ui_locales"), " "),
- Display: ar.GetRequestForm().Get("display"),
- LoginHint: ar.GetRequestForm().Get("login_hint"),
- },
+ loginRequest := &flow.LoginRequest{
+ ID: challenge,
+ Verifier: verifier,
+ CSRF: csrf,
+ Skip: skip,
+ RequestedScope: []string(ar.GetRequestedScopes()),
+ RequestedAudience: []string(ar.GetRequestedAudience()),
+ Subject: subject,
+ Client: cl,
+ RequestURL: iu.String(),
+ AuthenticatedAt: sqlxx.NullTime(authenticatedAt),
+ RequestedAt: time.Now().Truncate(time.Second).UTC(),
+ SessionID: sqlxx.NullString(sessionID),
+ OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{
+ IDTokenHintClaims: idTokenHintClaims,
+ ACRValues: stringsx.Splitx(ar.GetRequestForm().Get("acr_values"), " "),
+ UILocales: stringsx.Splitx(ar.GetRequestForm().Get("ui_locales"), " "),
+ Display: ar.GetRequestForm().Get("display"),
+ LoginHint: ar.GetRequestForm().Get("login_hint"),
},
- ); err != nil {
+ }
+ f, err := s.r.ConsentManager().CreateLoginRequest(
+ ctx,
+ loginRequest,
+ )
+ if err != nil {
return errorsx.WithStack(err)
}
+ if err := flowctx.SetCookie(ctx, w, s.r, flowctx.FlowCookie(cl), f); err != nil {
+ return err
+ }
+
store, err := s.r.CookieStore(ctx)
if err != nil {
return err
}
- clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameLoginCSRF(ctx), murmur3.Sum32(cl.ID.Bytes()))
+ clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%s", s.r.Config().CookieNameLoginCSRF(ctx), cl.CookieSuffix())
if err := createCsrfSession(w, r, s.r.Config(), store, clientSpecificCookieNameLoginCSRF, csrf, s.c.ConsentRequestMaxAge(ctx)); err != nil {
return errorsx.WithStack(err)
}
- http.Redirect(w, r, urlx.SetQuery(s.c.LoginURL(ctx), url.Values{"login_challenge": {challenge}}).String(), http.StatusFound)
+ encodedFlow, err := f.ToLoginChallenge(ctx, s.r)
+ if err != nil {
+ return err
+ }
+
+ http.Redirect(w, r, urlx.SetQuery(s.c.LoginURL(ctx), url.Values{"login_challenge": {encodedFlow}}).String(), http.StatusFound)
// generate the verifier
return errorsx.WithStack(ErrAbortOAuth2Request)
@@ -312,9 +330,22 @@ func (s *DefaultStrategy) revokeAuthenticationCookie(w http.ResponseWriter, r *h
return sid, nil
}
-func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*HandledLoginRequest, error) {
- ctx := r.Context()
- session, err := s.r.ConsentManager().VerifyAndInvalidateLoginRequest(ctx, verifier)
+func (s *DefaultStrategy) verifyAuthentication(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+ req fosite.AuthorizeRequester,
+ verifier string,
+) (*flow.Flow, error) {
+ f, err := s.flowFromCookie(r)
+ if err != nil {
+ return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The flow cookie is missing in the request."))
+ }
+ if f.Client.GetID() != req.GetClient().GetID() {
+ return nil, errorsx.WithStack(fosite.ErrInvalidClient.WithHint("The flow cookie client id does not match the authorize request client id."))
+ }
+
+ session, err := s.r.ConsentManager().VerifyAndInvalidateLoginRequest(ctx, f, verifier)
if errors.Is(err, sqlcon.ErrNoRows) {
return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The login verifier has already been used, has not been granted, or is invalid."))
} else if err != nil {
@@ -322,8 +353,8 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re
}
if session.HasError() {
- session.Error.SetDefaults(loginRequestDeniedErrorName)
- return nil, errorsx.WithStack(session.Error.toRFCError())
+ session.Error.SetDefaults(flow.LoginRequestDeniedErrorName)
+ return nil, errorsx.WithStack(session.Error.ToRFCError())
}
if session.RequestedAt.Add(s.c.ConsentRequestMaxAge(ctx)).Before(time.Now()) {
@@ -335,7 +366,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re
return nil, err
}
- clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameLoginCSRF(ctx), murmur3.Sum32(session.LoginRequest.Client.ID.Bytes()))
+ clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%s", s.r.Config().CookieNameLoginCSRF(ctx), session.LoginRequest.Client.CookieSuffix())
if err := validateCsrfSession(r, s.r.Config(), store, clientSpecificCookieNameLoginCSRF, session.LoginRequest.CSRF); err != nil {
return nil, err
}
@@ -409,10 +440,16 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re
if !session.LoginRequest.Skip {
if time.Time(session.AuthenticatedAt).IsZero() {
- return nil, errorsx.WithStack(fosite.ErrServerError.WithHint("Expected the handled login request to contain a valid authenticated_at value but it was zero. This is a bug which should be reported to https://github.com/ory/hydra."))
+ return nil, errorsx.WithStack(fosite.ErrServerError.WithHint(
+ "Expected the handled login request to contain a valid authenticated_at value but it was zero. This is a bug which should be reported to https://github.com/ory/hydra."))
+ }
+
+ loginSession := s.loginSessionFromCookie(r)
+ if loginSession == nil {
+ return nil, fosite.ErrAccessDenied.WithHint("The login session cookie was not found or malformed.")
}
- if err := s.r.ConsentManager().ConfirmLoginSession(r.Context(), sessionID, time.Time(session.AuthenticatedAt), session.Subject, session.Remember); err != nil {
+ if err := s.r.ConsentManager().ConfirmLoginSession(ctx, loginSession, sessionID, time.Time(session.AuthenticatedAt), session.Subject, session.Remember); err != nil {
return nil, err
}
}
@@ -429,7 +466,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re
// If the user doesn't want to remember the session, we do not store a cookie.
// If login was skipped, it means an authentication cookie was present and
// we don't want to touch it (in order to preserve its original expiry date)
- return session, nil
+ return f, nil
}
// Not a skipped login and the user asked to remember its session, store a cookie
@@ -453,13 +490,24 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re
"cookie_same_site": s.c.CookieSameSiteMode(ctx),
"cookie_secure": s.c.CookieSecure(ctx),
}).Debug("Authentication session cookie was set.")
- return session, nil
+
+ if err = flowctx.SetCookie(ctx, w, s.r, flowctx.FlowCookie(flowctx.SuffixForClient(req.GetClient())), f); err != nil {
+ return nil, errorsx.WithStack(err)
+ }
+
+ return f, nil
}
-func (s *DefaultStrategy) requestConsent(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, authenticationSession *HandledLoginRequest) error {
+func (s *DefaultStrategy) requestConsent(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+ ar fosite.AuthorizeRequester,
+ f *flow.Flow,
+) error {
prompt := stringsx.Splitx(ar.GetRequestForm().Get("prompt"), " ")
if stringslice.Has(prompt, "consent") {
- return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, nil)
+ return s.forwardConsentRequest(ctx, w, r, ar, f, nil)
}
// https://tools.ietf.org/html/rfc6749
@@ -483,7 +531,7 @@ func (s *DefaultStrategy) requestConsent(ctx context.Context, w http.ResponseWri
// This is tracked as issue: https://github.com/ory/hydra/issues/866
// This is also tracked as upstream issue: https://github.com/openid-certification/oidctest/issues/97
if !(ar.GetRedirectURI().Scheme == "https" || (fosite.IsLocalhost(ar.GetRedirectURI()) && ar.GetRedirectURI().Scheme == "http")) {
- return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, nil)
+ return s.forwardConsentRequest(ctx, w, r, ar, f, nil)
}
}
@@ -494,23 +542,31 @@ func (s *DefaultStrategy) requestConsent(ctx context.Context, w http.ResponseWri
// return s.forwardConsentRequest(w, r, ar, authenticationSession, nil)
// }
- consentSessions, err := s.r.ConsentManager().FindGrantedAndRememberedConsentRequests(r.Context(), ar.GetClient().GetID(), authenticationSession.Subject)
+ consentSessions, err := s.r.ConsentManager().FindGrantedAndRememberedConsentRequests(ctx, ar.GetClient().GetID(), f.Subject)
if errors.Is(err, ErrNoPreviousConsentFound) {
- return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, nil)
+ return s.forwardConsentRequest(ctx, w, r, ar, f, nil)
} else if err != nil {
return err
}
if found := matchScopes(s.r.Config().GetScopeStrategy(ctx), consentSessions, ar.GetRequestedScopes()); found != nil {
- return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, found)
+ return s.forwardConsentRequest(ctx, w, r, ar, f, found)
}
- return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, nil)
+ return s.forwardConsentRequest(ctx, w, r, ar, f, nil)
}
-func (s *DefaultStrategy) forwardConsentRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, as *HandledLoginRequest, cs *AcceptOAuth2ConsentRequest) error {
+func (s *DefaultStrategy) forwardConsentRequest(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+ ar fosite.AuthorizeRequester,
+ f *flow.Flow,
+ previousConsent *flow.AcceptOAuth2ConsentRequest,
+) error {
+ as := f.GetHandledLoginRequest()
skip := false
- if cs != nil {
+ if previousConsent != nil {
skip = true
}
@@ -525,45 +581,53 @@ func (s *DefaultStrategy) forwardConsentRequest(ctx context.Context, w http.Resp
csrf := strings.Replace(uuid.New(), "-", "", -1)
cl := sanitizeClientFromRequest(ar)
- if err := s.r.ConsentManager().CreateConsentRequest(
- r.Context(),
- &OAuth2ConsentRequest{
- ID: challenge,
- ACR: as.ACR,
- AMR: as.AMR,
- Verifier: verifier,
- CSRF: csrf,
- Skip: skip,
- RequestedScope: []string(ar.GetRequestedScopes()),
- RequestedAudience: []string(ar.GetRequestedAudience()),
- Subject: as.Subject,
- Client: cl,
- RequestURL: as.LoginRequest.RequestURL,
- AuthenticatedAt: as.AuthenticatedAt,
- RequestedAt: as.RequestedAt,
- ForceSubjectIdentifier: as.ForceSubjectIdentifier,
- OpenIDConnectContext: as.LoginRequest.OpenIDConnectContext,
- LoginSessionID: as.LoginRequest.SessionID,
- LoginChallenge: sqlxx.NullString(as.LoginRequest.ID),
- Context: as.Context,
- },
- ); err != nil {
+
+ consentRequest := &flow.OAuth2ConsentRequest{
+ ID: challenge,
+ ACR: as.ACR,
+ AMR: as.AMR,
+ Verifier: verifier,
+ CSRF: csrf,
+ Skip: skip,
+ RequestedScope: []string(ar.GetRequestedScopes()),
+ RequestedAudience: []string(ar.GetRequestedAudience()),
+ Subject: as.Subject,
+ Client: cl,
+ RequestURL: as.LoginRequest.RequestURL,
+ AuthenticatedAt: as.AuthenticatedAt,
+ RequestedAt: as.RequestedAt,
+ ForceSubjectIdentifier: as.ForceSubjectIdentifier,
+ OpenIDConnectContext: as.LoginRequest.OpenIDConnectContext,
+ LoginSessionID: as.LoginRequest.SessionID,
+ LoginChallenge: sqlxx.NullString(as.LoginRequest.ID),
+ Context: as.Context,
+ }
+ err := s.r.ConsentManager().CreateConsentRequest(ctx, f, consentRequest)
+ if err != nil {
return errorsx.WithStack(err)
}
+ if err := flowctx.SetCookie(ctx, w, s.r, flowctx.FlowCookie(cl), f); err != nil {
+ return err
+ }
+ consentChallenge, err := f.ToConsentChallenge(ctx, s.r)
+ if err != nil {
+ return err
+ }
+
store, err := s.r.CookieStore(ctx)
if err != nil {
return err
}
- clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameConsentCSRF(ctx), murmur3.Sum32(cl.ID.Bytes()))
+ clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%s", s.r.Config().CookieNameConsentCSRF(ctx), cl.CookieSuffix())
if err := createCsrfSession(w, r, s.r.Config(), store, clientSpecificCookieNameConsentCSRF, csrf, s.c.ConsentRequestMaxAge(ctx)); err != nil {
return errorsx.WithStack(err)
}
http.Redirect(
w, r,
- urlx.SetQuery(s.c.ConsentURL(ctx), url.Values{"consent_challenge": {challenge}}).String(),
+ urlx.SetQuery(s.c.ConsentURL(ctx), url.Values{"consent_challenge": {consentChallenge}}).String(),
http.StatusFound,
)
@@ -571,39 +635,51 @@ func (s *DefaultStrategy) forwardConsentRequest(ctx context.Context, w http.Resp
return errorsx.WithStack(ErrAbortOAuth2Request)
}
-func (s *DefaultStrategy) verifyConsent(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*AcceptOAuth2ConsentRequest, error) {
- session, err := s.r.ConsentManager().VerifyAndInvalidateConsentRequest(r.Context(), verifier)
+func (s *DefaultStrategy) verifyConsent(ctx context.Context, w http.ResponseWriter, r *http.Request, verifier string) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) {
+ f, err := s.flowFromCookie(r)
+ if err != nil {
+ return nil, nil, err
+ }
+ if f.Client.GetID() != r.URL.Query().Get("client_id") {
+ return nil, nil, errorsx.WithStack(fosite.ErrInvalidClient.WithHint("The flow cookie client id does not match the authorize request client id."))
+ }
+
+ session, err := s.r.ConsentManager().VerifyAndInvalidateConsentRequest(ctx, f, verifier)
if errors.Is(err, sqlcon.ErrNoRows) {
- return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The consent verifier has already been used, has not been granted, or is invalid."))
+ return nil, nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The consent verifier has already been used, has not been granted, or is invalid."))
} else if err != nil {
- return nil, err
+ return nil, nil, err
}
if session.RequestedAt.Add(s.c.ConsentRequestMaxAge(ctx)).Before(time.Now()) {
- return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The consent request has expired, please try again."))
+ return nil, nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The consent request has expired, please try again."))
}
if session.HasError() {
- session.Error.SetDefaults(consentRequestDeniedErrorName)
- return nil, errorsx.WithStack(session.Error.toRFCError())
+ session.Error.SetDefaults(flow.ConsentRequestDeniedErrorName)
+ return nil, nil, errorsx.WithStack(session.Error.ToRFCError())
}
if time.Time(session.ConsentRequest.AuthenticatedAt).IsZero() {
- return nil, errorsx.WithStack(fosite.ErrServerError.WithHint("The authenticatedAt value was not set."))
+ return nil, nil, errorsx.WithStack(fosite.ErrServerError.WithHint("The authenticatedAt value was not set."))
}
store, err := s.r.CookieStore(ctx)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameConsentCSRF(ctx), murmur3.Sum32(session.ConsentRequest.Client.ID.Bytes()))
+ clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%s", s.r.Config().CookieNameConsentCSRF(ctx), session.ConsentRequest.Client.CookieSuffix())
if err := validateCsrfSession(r, s.r.Config(), store, clientSpecificCookieNameConsentCSRF, session.ConsentRequest.CSRF); err != nil {
- return nil, err
+ return nil, nil, err
+ }
+
+ if err = flowctx.DeleteCookie(ctx, w, s.r, flowctx.FlowCookie(f.Client)); err != nil {
+ return nil, nil, err
}
if session.Session == nil {
- session.Session = NewConsentRequestSessionData()
+ session.Session = flow.NewConsentRequestSessionData()
}
if session.Session.AccessToken == nil {
@@ -615,7 +691,7 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, w http.ResponseWrit
}
session.AuthenticatedAt = session.ConsentRequest.AuthenticatedAt
- return session, nil
+ return session, f, nil
}
func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) {
@@ -711,7 +787,7 @@ func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.
return nil
}
-func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LogoutResult, error) {
+func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) {
// There are two types of log out flows:
//
// - RP initiated logout
@@ -758,7 +834,7 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
}
challenge := uuid.New()
- if err := s.r.ConsentManager().CreateLogoutRequest(r.Context(), &LogoutRequest{
+ if err := s.r.ConsentManager().CreateLogoutRequest(r.Context(), &flow.LogoutRequest{
RequestURL: r.URL.String(),
ID: challenge,
Subject: session.Subject,
@@ -869,7 +945,8 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
// We do not really want to verify if the user (from id token hint) has a session here because it doesn't really matter.
// Instead, we'll check this when we're actually revoking the cookie!
- session, err := s.r.ConsentManager().GetRememberedLoginSession(r.Context(), hintSid)
+ sessionFromCookie := s.loginSessionFromCookie(r)
+ session, err := s.r.ConsentManager().GetRememberedLoginSession(r.Context(), sessionFromCookie, hintSid)
if errors.Is(err, x.ErrNotFound) {
// Such a session does not exist - maybe it has already been revoked? In any case, we can't do much except
// leaning back and redirecting back.
@@ -880,7 +957,7 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
}
challenge := uuid.New()
- if err := s.r.ConsentManager().CreateLogoutRequest(r.Context(), &LogoutRequest{
+ if err := s.r.ConsentManager().CreateLogoutRequest(r.Context(), &flow.LogoutRequest{
RequestURL: r.URL.String(),
ID: challenge,
SessionID: hintSid,
@@ -899,7 +976,7 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
return nil, errorsx.WithStack(ErrAbortOAuth2Request)
}
-func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.Context, r *http.Request, subject string, sid string) error {
+func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(_ context.Context, r *http.Request, subject string, sid string) error {
if err := s.executeBackChannelLogout(r.Context(), r, subject, sid); err != nil {
return err
}
@@ -918,7 +995,7 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.C
return nil
}
-func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LogoutResult, error) {
+func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) {
verifier := r.URL.Query().Get("logout_verifier")
lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(r.Context(), verifier)
@@ -976,13 +1053,13 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri
WithField("subject", lr.Subject).
Info("User logout completed!")
- return &LogoutResult{
+ return &flow.LogoutResult{
RedirectTo: lr.PostLogoutRedirectURI,
FrontChannelLogoutURLs: urls,
}, nil
}
-func (s *DefaultStrategy) HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LogoutResult, error) {
+func (s *DefaultStrategy) HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) {
verifier := r.URL.Query().Get("logout_verifier")
if verifier == "" {
return s.issueLogoutVerifier(ctx, w, r)
@@ -991,8 +1068,9 @@ func (s *DefaultStrategy) HandleOpenIDConnectLogout(ctx context.Context, w http.
return s.completeLogout(ctx, w, r)
}
-func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, sid string) error {
- loginSession, lsErr := s.r.ConsentManager().GetRememberedLoginSession(ctx, sid)
+func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.ResponseWriter, r *http.Request, sid string) error {
+ sessionFromCookie := s.loginSessionFromCookie(r)
+ loginSession, lsErr := s.r.ConsentManager().GetRememberedLoginSession(ctx, sessionFromCookie, sid)
if errors.Is(lsErr, x.ErrNotFound) {
// This is ok (session probably already revoked), do nothing!
@@ -1016,28 +1094,33 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, w http.Respo
return nil
}
-func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*AcceptOAuth2ConsentRequest, error) {
- authenticationVerifier := strings.TrimSpace(req.GetRequestForm().Get("login_verifier"))
+func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+ req fosite.AuthorizeRequester,
+) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) {
+ loginVerifier := strings.TrimSpace(req.GetRequestForm().Get("login_verifier"))
consentVerifier := strings.TrimSpace(req.GetRequestForm().Get("consent_verifier"))
- if authenticationVerifier == "" && consentVerifier == "" {
+ if loginVerifier == "" && consentVerifier == "" {
// ok, we need to process this request and redirect to auth endpoint
- return nil, s.requestAuthentication(ctx, w, r, req)
- } else if authenticationVerifier != "" {
- authSession, err := s.verifyAuthentication(w, r, req, authenticationVerifier)
+ return nil, nil, s.requestAuthentication(ctx, w, r, req)
+ } else if loginVerifier != "" {
+ f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier)
if err != nil {
- return nil, err
+ return nil, nil, err
}
// ok, we need to process this request and redirect to auth endpoint
- return nil, s.requestConsent(ctx, w, r, req, authSession)
+ return nil, f, s.requestConsent(ctx, w, r, req, f)
}
- consentSession, err := s.verifyConsent(ctx, w, r, req, consentVerifier)
+ consentSession, f, err := s.verifyConsent(ctx, w, r, consentVerifier)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- return consentSession, nil
+ return consentSession, f, nil
}
func (s *DefaultStrategy) ObfuscateSubjectIdentifier(ctx context.Context, cl fosite.Client, subject, forcedIdentifier string) (string, error) {
@@ -1057,3 +1140,22 @@ func (s *DefaultStrategy) ObfuscateSubjectIdentifier(ctx context.Context, cl fos
}
return subject, nil
}
+
+func (s *DefaultStrategy) flowFromCookie(r *http.Request) (*flow.Flow, error) {
+ clientID := r.URL.Query().Get("client_id")
+ if clientID == "" {
+ return nil, errors.WithStack(fosite.ErrInvalidClient)
+ }
+
+ return flowctx.FromCookie[flow.Flow](r.Context(), r, s.r.FlowCipher(), flowctx.FlowCookie(flowctx.SuffixFromStatic(clientID)))
+}
+
+func (s *DefaultStrategy) loginSessionFromCookie(r *http.Request) *flow.LoginSession {
+ clientID := r.URL.Query().Get("client_id")
+ if clientID == "" {
+ return nil
+ }
+ ls, _ := flowctx.FromCookie[flow.LoginSession](r.Context(), r, s.r.FlowCipher(), flowctx.LoginSessionCookie(flowctx.SuffixFromStatic(clientID)))
+
+ return ls
+}
diff --git a/consent/strategy_oauth_test.go b/consent/strategy_oauth_test.go
index a44446121ee..d1ec766c61c 100644
--- a/consent/strategy_oauth_test.go
+++ b/consent/strategy_oauth_test.go
@@ -10,14 +10,20 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "net/http/cookiejar"
"net/url"
+ "regexp"
"testing"
"time"
+ "github.com/ory/hydra/v2/aead"
+ "github.com/ory/hydra/v2/consent"
+ "github.com/ory/hydra/v2/flow"
+ "github.com/ory/hydra/v2/oauth2/flowctx"
+ "github.com/ory/hydra/v2/x"
"github.com/ory/x/ioutilx"
- "github.com/twmb/murmur3"
-
+ "golang.org/x/exp/slices"
"golang.org/x/oauth2"
"github.com/ory/x/pointerx"
@@ -113,8 +119,12 @@ func TestStrategyLoginConsentNext(t *testing.T) {
t.Run("case=should fail because a login verifier was given that doesn't exist in the store", func(t *testing.T) {
testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t))
c := createDefaultClient(t)
+ hc := newHTTPClientWithFlowCookie(t, ctx, reg, c)
- makeRequestAndExpectError(t, nil, c, url.Values{"login_verifier": {"does-not-exist"}}, "The login verifier has already been used, has not been granted, or is invalid.")
+ makeRequestAndExpectError(
+ t, hc, c, url.Values{"login_verifier": {"does-not-exist"}},
+ "The login verifier has already been used, has not been granted, or is invalid.",
+ )
})
t.Run("case=should fail because a non-existing consent verifier was given", func(t *testing.T) {
@@ -123,7 +133,12 @@ func TestStrategyLoginConsentNext(t *testing.T) {
// - This should fail because a consent verifier was given but no login verifier
testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t))
c := createDefaultClient(t)
- makeRequestAndExpectError(t, nil, c, url.Values{"consent_verifier": {"does-not-exist"}}, "The consent verifier has already been used, has not been granted, or is invalid.")
+ hc := newHTTPClientWithFlowCookie(t, ctx, reg, c)
+
+ makeRequestAndExpectError(
+ t, hc, c, url.Values{"consent_verifier": {"does-not-exist"}},
+ "The consent verifier has already been used, has not been granted, or is invalid.",
+ )
})
t.Run("case=should fail because the request was redirected but the login endpoint doesn't do anything (like redirecting back)", func(t *testing.T) {
@@ -169,6 +184,7 @@ func TestStrategyLoginConsentNext(t *testing.T) {
testhelpers.HTTPServerNoExpectedCallHandler(t))
hc := new(http.Client)
+ hc.Jar = DropCookieJar(regexp.MustCompile("ory_hydra_.*_csrf_.*"))
makeRequestAndExpectError(t, hc, c, url.Values{}, "No CSRF value available in the session cookie.")
})
@@ -332,16 +348,13 @@ func TestStrategyLoginConsentNext(t *testing.T) {
loginChallengeRedirect, err := oauthRes.Location()
require.NoError(t, err)
defer oauthRes.Body.Close()
- setCookieHeader := oauthRes.Header.Get("set-cookie")
- assert.NotNil(t, setCookieHeader)
-
- t.Run("login cookie client specific suffix is set", func(t *testing.T) {
- assert.Regexp(t, fmt.Sprintf("ory_hydra_login_csrf_dev_%d=.*", murmur3.Sum32(c.ID.Bytes())), setCookieHeader)
- })
- t.Run("login cookie max age is set", func(t *testing.T) {
- assert.Regexp(t, fmt.Sprintf("ory_hydra_login_csrf_dev_%d=.*Max-Age=%.0f;.*", murmur3.Sum32(c.ID.Bytes()), consentRequestMaxAge), setCookieHeader)
+ foundLoginCookie := slices.ContainsFunc(oauthRes.Header.Values("set-cookie"), func(sc string) bool {
+ ok, err := regexp.MatchString(fmt.Sprintf("ory_hydra_login_csrf_dev_%s=.*Max-Age=%.0f;.*", c.CookieSuffix(), consentRequestMaxAge), sc)
+ require.NoError(t, err)
+ return ok
})
+ require.True(t, foundLoginCookie, "client-specific login cookie with max age set")
loginChallengeRes, err := hc.Get(loginChallengeRedirect.String())
require.NoError(t, err)
@@ -352,16 +365,13 @@ func TestStrategyLoginConsentNext(t *testing.T) {
loginVerifierRes, err := hc.Get(loginVerifierRedirect.String())
require.NoError(t, err)
defer loginVerifierRes.Body.Close()
- setCookieHeader = loginVerifierRes.Header.Values("set-cookie")[1]
- assert.NotNil(t, setCookieHeader)
- t.Run("consent cookie client specific suffix set", func(t *testing.T) {
- assert.Regexp(t, fmt.Sprintf("ory_hydra_consent_csrf_dev_%d=.*", murmur3.Sum32(c.ID.Bytes())), setCookieHeader)
- })
-
- t.Run("consent cookie max age is set", func(t *testing.T) {
- assert.Regexp(t, fmt.Sprintf("ory_hydra_consent_csrf_dev_%d=.*Max-Age=%.0f;.*", murmur3.Sum32(c.ID.Bytes()), consentRequestMaxAge), setCookieHeader)
+ foundConsentCookie := slices.ContainsFunc(loginVerifierRes.Header.Values("set-cookie"), func(sc string) bool {
+ ok, err := regexp.MatchString(fmt.Sprintf("ory_hydra_consent_csrf_dev_%s=.*Max-Age=%.0f;.*", c.CookieSuffix(), consentRequestMaxAge), sc)
+ require.NoError(t, err)
+ return ok
})
+ require.True(t, foundConsentCookie, "client-specific consent cookie with max age set")
})
t.Run("case=should pass if both login and consent are granted and check remember flows with refresh session cookie", func(t *testing.T) {
@@ -432,6 +442,7 @@ func TestStrategyLoginConsentNext(t *testing.T) {
require.NoError(t, err)
defer loginChallengeRes.Body.Close()
loginVerifierRedirect, err := loginChallengeRes.Location()
+ require.NoError(t, err)
loginVerifierRes, err := hc.Get(loginVerifierRedirect.String())
require.NoError(t, err)
@@ -580,9 +591,8 @@ func TestStrategyLoginConsentNext(t *testing.T) {
hc := testhelpers.NewEmptyJarClient(t)
- t.Run("set up initial session", func(t *testing.T) {
- makeRequestAndExpectCode(t, hc, c, url.Values{"redirect_uri": {c.RedirectURIs[0]}})
- })
+ // set up initial session
+ makeRequestAndExpectCode(t, hc, c, url.Values{"redirect_uri": {c.RedirectURIs[0]}})
// By not waiting here we ensure that there are no race conditions when it comes to authenticated_at and
// requested_at time comparisons:
@@ -1017,3 +1027,47 @@ func TestStrategyLoginConsentNext(t *testing.T) {
makeRequestAndExpectCode(t, hc, c, url.Values{"redirect_uri": {c.RedirectURIs[0]}})
})
}
+
+func DropCookieJar(drop *regexp.Regexp) http.CookieJar {
+ jar, _ := cookiejar.New(nil)
+ return &dropCSRFCookieJar{
+ jar: jar,
+ drop: drop,
+ }
+}
+
+type dropCSRFCookieJar struct {
+ jar *cookiejar.Jar
+ drop *regexp.Regexp
+}
+
+var _ http.CookieJar = (*dropCSRFCookieJar)(nil)
+
+func (d *dropCSRFCookieJar) SetCookies(u *url.URL, cookies []*http.Cookie) {
+ for _, c := range cookies {
+ if d.drop.MatchString(c.Name) {
+ continue
+ }
+ d.jar.SetCookies(u, []*http.Cookie{c})
+ }
+}
+
+func (d *dropCSRFCookieJar) Cookies(u *url.URL) []*http.Cookie {
+ return d.jar.Cookies(u)
+}
+
+func newHTTPClientWithFlowCookie(t *testing.T, ctx context.Context, reg interface {
+ ConsentManager() consent.Manager
+ Config() *config.DefaultProvider
+ FlowCipher() *aead.XChaCha20Poly1305
+}, c *client.Client) *http.Client {
+ f, err := reg.ConsentManager().CreateLoginRequest(ctx, &flow.LoginRequest{Client: c})
+ require.NoError(t, err)
+
+ hc := testhelpers.NewEmptyJarClient(t)
+ hc.Jar.SetCookies(reg.Config().OAuth2AuthURL(ctx), []*http.Cookie{
+ {Name: flowctx.FlowCookie(c), Value: x.Must(flowctx.Encode(ctx, reg.FlowCipher(), f))},
+ })
+
+ return hc
+}
diff --git a/docs/flow-cache-design-doc.md b/docs/flow-cache-design-doc.md
new file mode 100644
index 00000000000..22916348936
--- /dev/null
+++ b/docs/flow-cache-design-doc.md
@@ -0,0 +1,167 @@
+# Flow Cache Design Doc
+
+## Overview
+
+This design doc outlines the proposed solution for caching the flow object in
+the OAuth2 exchange between the Client, Ory Hydra, and the Consent and Login
+UIs. The flow object contains the state of the authorization request.
+
+## Problem Statement
+
+Currently, the flow object is stored in the database on the Ory Hydra server.
+This approach has several drawbacks:
+
+- Each step of the OAuth2 flow (initialization, consent, login, etc.) requires a
+ database query to retrieve the flow object, and another to update it.
+- Each part of the exchanges supplies different values (login challenge, consent
+ challenge, etc.) to identify the flow object. This means the database table
+ has multiple indices that slow down insertions.
+
+## Proposed Solution
+
+The proposed solution is to store the flow object in client cookies and URLs.
+This way, the flow object is written only once when the flow is completed and
+the final authorization code is generated.
+
+### Requirements
+
+- The flow object must be stored in client cookies and URLs.
+- The flow object must be secure and protect against unauthorized access.
+- The flow object must be persistent, so that the flow can be resumed if the
+ user navigates away from the page or closes the browser.
+- The flow object must be scalable and able to handle a large number of
+ concurrent requests.
+
+### Architecture
+
+The proposed architecture for the flow cache is as follows:
+
+- Store the flow object in an AEAD encrypted cookie.
+- Pass a partial flow around in the URL.
+- Use a secure connection to protect against unauthorized access.
+
+```mermaid
+sequenceDiagram
+ actor Client
+ participant Hydra
+ participant LoginUI as Login UI
+ participant ConsentUI as Consent UI
+ % participant Callback
+
+ autonumber
+
+ Client->>+Hydra: GET /oauth2/auth?client_id=CLIENT_ID&response_type=code&scope=SCOPES&state=STATE
+ Hydra->>-Client: Redirect to
http://login.local/?login_challenge=LOGIN_CHALLENGE
+
+ Client->>+LoginUI: GET /?login_challenge=LOGIN_CHALLENGE
+ LoginUI->>Hydra: GET /admin/oauth2/auth/requests/login
+ Hydra->>LoginUI: oAuth2LoginRequest
+ alt accept login
+ LoginUI->>Hydra: PUT /admin/oauth2/auth/requests/login/accept
+ else reject login
+ LoginUI->>Hydra: PUT /admin/oauth2/auth/requests/login/reject
+ end
+ Hydra->>LoginUI: oAuth2RedirectTo
+ LoginUI->>-Client: Redirect to
http://hydra.local/oauth2/auth?client_id=CLIENT_ID&login_verifier=LOGIN_VERIFIER&response_type=code&scope=SCOPES&state=STATE
+
+ Client->>+Hydra: GET /oauth2/auth?client_id=CLIENT_ID&login_verifier=LOGIN_VERIFIER&response_type=code&scope=SCOPES&state=STATE
+ Hydra->>-Client: Redirect to
http://consent.local/?consent_challenge=CONSENT_CHALLENGE
+
+ Client->>+ConsentUI: GET /?consent_challenge=CONSENT_CHALLENGE
+ ConsentUI->>Hydra: GET /admin/oauth2/auth/requests/consent
+ Hydra->>ConsentUI: oAuth2ConsentRequest
+ alt accept login
+ ConsentUI->>Hydra: PUT /admin/oauth2/auth/requests/consent/accept
+ else reject login
+ ConsentUI->>Hydra: PUT /admin/oauth2/auth/requests/consent/reject
+ end
+ Hydra->>ConsentUI: oAuth2RedirectTo
+ ConsentUI->>-Client: Redirect to
http://hydra.local/oauth2/auth?client_id=CLIENT_ID&consent_verifier=CONSENT_VERIFIER&response_type=code&scope=SCOPES&state=STATE
+
+ Client->>+Hydra: GET /oauth2/auth?client_id=CLIENT_ID&consent_verifier=CONSENT_VERIFIER&response_type=code&scope=SCOPES&state=STATE
+ Hydra->>-Client: Redirect to
http://callback.local/callback?code=AUTH_CODE&scope=SCOPES&state=STATE
+ Note over Hydra,Client: next, exchange code for token.
+
+
+ % Client->>+Callback: GET /callback?code=AUTH_CODE&scope=SCOPES&state=STATE
+ % Callback->>-Client: Return Authorization Code
+```
+
+Step 2:
+
+- Set the whole flow as an AEAD encrypted cookie on the client
+- The cookie is keyed by the `state`, so that multiple flows can run in parallel
+ from one cookie jar
+- Set the `LOGIN_CHALLENGE` to the AEAD-encrypted flow
+
+Step 5:
+
+- Decrypt the flow from the `LOGIN_CHALLENGE`, return the `oAuth2LoginRequest`
+
+Step 8:
+
+- Encode the flow into the redirect URL in `oAuth2RedirectTo` as the
+ `LOGIN_VERIFIER`
+
+Step 11
+
+- Check that the login challenge in the `LOGIN_VERIFIER` matches the challenge
+ in the flow cookie.
+- Update the flow based on the request from the `LOGIN_VERIFIER`
+- Update the cookie
+- Set the `CONSENT_CHALLENGE` to the AEAD-encrypted flow
+
+Step 14:
+
+- Decrypt the flow from the `CONSENT_CHALLENGE`
+
+Step 17:
+
+- Encode the flow into the redirect URL in `oAuth2RedirectTo` as the
+ `CONSENT_VERIFIER`
+
+Step 20
+
+- Check that the consent challenge in the `CONSENT_VERIFIER` matches the
+ challenge in the flow cookie.
+- Update the flow based on the request from the `CONSENT_VERIFIER`
+- Update the cookie
+- Write the flow to the database
+- Continue the flow as currently implemented (generate the authentication code,
+ return the code, etc.)
+
+### Client HTTP requests
+
+For reference, these HTTP requests are issued by the client:
+
+```
+GET http://hydra.local/oauth2/auth?client_id=CLIENT_ID&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE
+Redirect to http://login.local/?login_challenge=LOGIN_CHALLENGE
+GET http://login.local/?login_challenge=LOGIN_CHALLENGE
+Redirect to http://hydra.local/oauth2/auth?client_id=CLIENT_ID&login_verifier=LOGIN_VERIFIER&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE
+GET http://hydra.local/oauth2/auth?client_id=CLIENT_ID&login_verifier=LOGIN_VERIFIER&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE
+Redirect to http://consent.local/?consent_challenge=CONSENT_CHALLENGE
+GET http://consent.local/?consent_challenge=CONSENT_CHALLENGE
+Redirect to http://hydra.local/oauth2/auth?client_id=CLIENT_ID&consent_verifier=CONSENT_VERIFIER&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE
+GET http://hydra.local/oauth2/auth?client_id=CLIENT_ID&consent_verifier=CONSENT_VERIFIER&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE
+Redirect to http://callback.local/callback?code=AUTH_CODE&scope=SCOPES&state=STATE
+GET http://callback.local/callback?code=AUTH_CODE&scope=SCOPES&state=STATE
+```
+
+### Implementation
+
+The implementation of the flow cache will involve the following steps:
+
+1. Modify the Ory Hydra server to store the flow object in an AEAD encrypted
+ cookie.
+2. Modify the Consent and Login UIs to include the flow object in the URL.
+3. Use HTTPS to protect against unauthorized access.
+
+## Conclusion
+
+The proposed solution for caching the flow object in the OAuth2 exchange between
+the Client, Ory Hydra, and the Consent and Login UIs is to store the flow object
+in client cookies and URLs. This approach eliminates the need for a distributed
+cache and provides a scalable and secure solution. The flow object will be
+stored in an AEAD encrypted cookie and passed around in the URL. HTTPS will be
+used to protect against unauthorized access.
diff --git a/driver/registry.go b/driver/registry.go
index ccf80db1448..8fb5a5da1cd 100644
--- a/driver/registry.go
+++ b/driver/registry.go
@@ -6,8 +6,11 @@ package driver
import (
"context"
+ "go.opentelemetry.io/otel/trace"
+
"github.com/ory/x/httprouterx"
+ "github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/hsm"
"github.com/ory/x/contextx"
@@ -46,9 +49,11 @@ type Registry interface {
WithConfig(c *config.DefaultProvider) Registry
WithContextualizer(ctxer contextx.Contextualizer) Registry
WithLogger(l *logrusx.Logger) Registry
+ WithTracer(t trace.Tracer) Registry
x.HTTPClientProvider
GetJWKSFetcherStrategy() fosite.JWKSFetcherStrategy
+ contextx.Provider
config.Provider
persistence.Provider
x.RegistryLogger
@@ -61,6 +66,7 @@ type Registry interface {
oauth2.Registry
PrometheusManager() *prometheus.MetricsManager
x.TracingProvider
+ FlowCipher() *aead.XChaCha20Poly1305
RegisterRoutes(ctx context.Context, admin *httprouterx.RouterAdmin, public *httprouterx.RouterPublic)
ClientHandler() *client.Handler
@@ -109,6 +115,7 @@ func CallRegistry(ctx context.Context, r Registry) {
r.SubjectIdentifierAlgorithm(ctx)
r.KeyManager()
r.KeyCipher()
+ r.FlowCipher()
r.OAuth2Storage()
r.OAuth2Provider()
r.AudienceStrategy()
diff --git a/driver/registry_base.go b/driver/registry_base.go
index d2d458427a8..fe2a1a2a7c5 100644
--- a/driver/registry_base.go
+++ b/driver/registry_base.go
@@ -15,12 +15,14 @@ import (
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/cors"
+ "go.opentelemetry.io/otel/trace"
"github.com/ory/fosite"
"github.com/ory/fosite/compose"
foauth2 "github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/handler/openid"
"github.com/ory/herodot"
+ "github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/driver/config"
@@ -59,7 +61,8 @@ type RegistryBase struct {
ctxer contextx.Contextualizer
hh *healthx.Handler
migrationStatus *popx.MigrationStatuses
- kc *jwk.AEAD
+ kc *aead.AESGCM
+ flowc *aead.XChaCha20Poly1305
cos consent.Strategy
writer herodot.Writer
hsm hsm.Context
@@ -187,6 +190,11 @@ func (m *RegistryBase) WithLogger(l *logrusx.Logger) Registry {
return m.r
}
+func (m *RegistryBase) WithTracer(t trace.Tracer) Registry {
+ m.trc = new(otelx.Tracer).WithOTLP(t)
+ return m.r
+}
+
func (m *RegistryBase) Logger() *logrusx.Logger {
if m.l == nil {
m.l = logrusx.New("Ory Hydra", m.BuildVersion())
@@ -282,13 +290,20 @@ func (m *RegistryBase) ConsentStrategy() consent.Strategy {
return m.cos
}
-func (m *RegistryBase) KeyCipher() *jwk.AEAD {
+func (m *RegistryBase) KeyCipher() *aead.AESGCM {
if m.kc == nil {
- m.kc = jwk.NewAEAD(m.Config())
+ m.kc = aead.NewAESGCM(m.Config())
}
return m.kc
}
+func (m *RegistryBase) FlowCipher() *aead.XChaCha20Poly1305 {
+ if m.flowc == nil {
+ m.flowc = aead.NewXChaCha20Poly1305(m.Config())
+ }
+ return m.flowc
+}
+
func (m *RegistryBase) CookieStore(ctx context.Context) (sessions.Store, error) {
var keys [][]byte
secrets, err := m.conf.GetCookieSecrets(ctx)
diff --git a/consent/types.go b/flow/consent_types.go
similarity index 98%
rename from consent/types.go
rename to flow/consent_types.go
index 6a389e9d8bb..89e56ef8aa7 100644
--- a/consent/types.go
+++ b/flow/consent_types.go
@@ -1,7 +1,7 @@
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
-package consent
+package flow
import (
"database/sql"
@@ -23,8 +23,8 @@ import (
)
const (
- consentRequestDeniedErrorName = "consent request denied"
- loginRequestDeniedErrorName = "login request denied"
+ ConsentRequestDeniedErrorName = "consent request denied"
+ LoginRequestDeniedErrorName = "login request denied"
)
// OAuth 2.0 Redirect Browser To
@@ -49,7 +49,7 @@ type LoginSession struct {
Remember bool `db:"remember"`
}
-func (_ LoginSession) TableName() string {
+func (LoginSession) TableName() string {
return "hydra_oauth2_authentication_session"
}
@@ -77,11 +77,12 @@ type RequestDeniedError struct {
// to the public but only in the server logs.
Debug string `json:"error_debug"`
- valid bool
+ // swagger:ignore
+ Valid bool `json:"valid"`
}
func (e *RequestDeniedError) IsError() bool {
- return e != nil && e.valid
+ return e != nil && e.Valid
}
func (e *RequestDeniedError) SetDefaults(name string) {
@@ -94,7 +95,7 @@ func (e *RequestDeniedError) SetDefaults(name string) {
}
}
-func (e *RequestDeniedError) toRFCError() *fosite.RFC6749Error {
+func (e *RequestDeniedError) ToRFCError() *fosite.RFC6749Error {
if e.Name == "" {
e.Name = "request_denied"
}
@@ -112,7 +113,7 @@ func (e *RequestDeniedError) toRFCError() *fosite.RFC6749Error {
}
}
-func (e *RequestDeniedError) Scan(value interface{}) error {
+func (e *RequestDeniedError) Scan(value any) error {
v := fmt.Sprintf("%s", value)
if len(v) == 0 || v == "{}" {
return nil
@@ -122,7 +123,7 @@ func (e *RequestDeniedError) Scan(value interface{}) error {
return errorsx.WithStack(err)
}
- e.valid = true
+ e.Valid = true
return nil
}
@@ -188,6 +189,8 @@ func (r *AcceptOAuth2ConsentRequest) HasError() bool {
// List of OAuth 2.0 Consent Sessions
//
// swagger:model oAuth2ConsentSessions
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type oAuth2ConsentSessions []OAuth2ConsentSession
// OAuth 2.0 Consent Session
@@ -420,7 +423,7 @@ type LogoutRequest struct {
Client *client.Client `json:"client" db:"-"`
}
-func (_ LogoutRequest) TableName() string {
+func (LogoutRequest) TableName() string {
return "hydra_oauth2_logout_request"
}
diff --git a/consent/types_test.go b/flow/consent_types_test.go
similarity index 89%
rename from consent/types_test.go
rename to flow/consent_types_test.go
index 6366404d9e9..116b0f328bb 100644
--- a/consent/types_test.go
+++ b/flow/consent_types_test.go
@@ -1,7 +1,7 @@
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
-package consent
+package flow
import (
"fmt"
@@ -21,7 +21,7 @@ func TestToRFCError(t *testing.T) {
{
input: &RequestDeniedError{
Name: "not empty",
- valid: true,
+ Valid: true,
},
expect: &fosite.RFC6749Error{
ErrorField: "not empty",
@@ -34,7 +34,7 @@ func TestToRFCError(t *testing.T) {
input: &RequestDeniedError{
Name: "",
Description: "not empty",
- valid: true,
+ Valid: true,
},
expect: &fosite.RFC6749Error{
ErrorField: "request_denied",
@@ -44,7 +44,7 @@ func TestToRFCError(t *testing.T) {
},
},
{
- input: &RequestDeniedError{valid: true},
+ input: &RequestDeniedError{Valid: true},
expect: &fosite.RFC6749Error{
ErrorField: "request_denied",
DescriptionField: "",
@@ -55,7 +55,7 @@ func TestToRFCError(t *testing.T) {
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
- require.EqualValues(t, tc.input.toRFCError(), tc.expect)
+ require.EqualValues(t, tc.input.ToRFCError(), tc.expect)
})
}
}
diff --git a/flow/flow.go b/flow/flow.go
index bbf2e36fec9..0868e7f5f14 100644
--- a/flow/flow.go
+++ b/flow/flow.go
@@ -4,15 +4,16 @@
package flow
import (
+ "context"
"time"
+ "github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
- "github.com/gobuffalo/pop/v6"
-
+ "github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/client"
- "github.com/ory/hydra/v2/consent"
+ "github.com/ory/hydra/v2/oauth2/flowctx"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/sqlcon"
"github.com/ory/x/sqlxx"
@@ -113,7 +114,7 @@ type Flow struct {
// OpenIDConnectContext provides context for the (potential) OpenID Connect context. Implementation of these
// values in your app are optional but can be useful if you want to be fully compliant with the OpenID Connect spec.
- OpenIDConnectContext *consent.OAuth2ConsentRequestOpenIDConnectContext `db:"oidc_context"`
+ OpenIDConnectContext *OAuth2ConsentRequestOpenIDConnectContext `db:"oidc_context"`
// Client is the OAuth 2.0 Client that initiated the request.
//
@@ -195,8 +196,8 @@ type Flow struct {
// recommend redirecting the user to `request_url` to re-initiate the flow.
LoginWasUsed bool `db:"login_was_used"`
- LoginError *consent.RequestDeniedError `db:"login_error"`
- LoginAuthenticatedAt sqlxx.NullTime `db:"login_authenticated_at"`
+ LoginError *RequestDeniedError `db:"login_error"`
+ LoginAuthenticatedAt sqlxx.NullTime `db:"login_authenticated_at"`
// ConsentChallengeID is the identifier ("authorization challenge") of the consent authorization request. It is used to
// identify the session.
@@ -231,13 +232,13 @@ type Flow struct {
// ConsentWasHandled set to true means that the request was already handled.
// This can happen on form double-submit or other errors. If this is set we
// recommend redirecting the user to `request_url` to re-initiate the flow.
- ConsentWasHandled bool `db:"consent_was_used"`
- ConsentError *consent.RequestDeniedError `db:"consent_error"`
- SessionIDToken sqlxx.MapStringInterface `db:"session_id_token" faker:"-"`
- SessionAccessToken sqlxx.MapStringInterface `db:"session_access_token" faker:"-"`
+ ConsentWasHandled bool `db:"consent_was_used"`
+ ConsentError *RequestDeniedError `db:"consent_error"`
+ SessionIDToken sqlxx.MapStringInterface `db:"session_id_token" faker:"-"`
+ SessionAccessToken sqlxx.MapStringInterface `db:"session_access_token" faker:"-"`
}
-func NewFlow(r *consent.LoginRequest) *Flow {
+func NewFlow(r *LoginRequest) *Flow {
return &Flow{
ID: r.ID,
RequestedScope: r.RequestedScope,
@@ -259,7 +260,7 @@ func NewFlow(r *consent.LoginRequest) *Flow {
}
}
-func (f *Flow) HandleLoginRequest(h *consent.HandledLoginRequest) error {
+func (f *Flow) HandleLoginRequest(h *HandledLoginRequest) error {
if f.LoginWasUsed {
return errors.WithStack(x.ErrConflict.WithHint("The login request was already used and can no longer be changed."))
}
@@ -301,8 +302,8 @@ func (f *Flow) HandleLoginRequest(h *consent.HandledLoginRequest) error {
return nil
}
-func (f *Flow) GetHandledLoginRequest() consent.HandledLoginRequest {
- return consent.HandledLoginRequest{
+func (f *Flow) GetHandledLoginRequest() HandledLoginRequest {
+ return HandledLoginRequest{
ID: f.ID,
Remember: f.LoginRemember,
RememberFor: f.LoginRememberFor,
@@ -320,8 +321,8 @@ func (f *Flow) GetHandledLoginRequest() consent.HandledLoginRequest {
}
}
-func (f *Flow) GetLoginRequest() *consent.LoginRequest {
- return &consent.LoginRequest{
+func (f *Flow) GetLoginRequest() *LoginRequest {
+ return &LoginRequest{
ID: f.ID,
RequestedScope: f.RequestedScope,
RequestedAudience: f.RequestedAudience,
@@ -355,7 +356,7 @@ func (f *Flow) InvalidateLoginRequest() error {
return nil
}
-func (f *Flow) HandleConsentRequest(r *consent.AcceptOAuth2ConsentRequest) error {
+func (f *Flow) HandleConsentRequest(r *AcceptOAuth2ConsentRequest) error {
if time.Time(r.HandledAt).IsZero() {
return errors.New("refusing to handle a consent request with null HandledAt")
}
@@ -408,8 +409,8 @@ func (f *Flow) InvalidateConsentRequest() error {
return nil
}
-func (f *Flow) GetConsentRequest() *consent.OAuth2ConsentRequest {
- return &consent.OAuth2ConsentRequest{
+func (f *Flow) GetConsentRequest() *OAuth2ConsentRequest {
+ cs := OAuth2ConsentRequest{
ID: f.ConsentChallengeID.String(),
RequestedScope: f.RequestedScope,
RequestedAudience: f.RequestedAudience,
@@ -431,18 +432,22 @@ func (f *Flow) GetConsentRequest() *consent.OAuth2ConsentRequest {
AuthenticatedAt: f.LoginAuthenticatedAt,
RequestedAt: f.RequestedAt,
}
+ if cs.AMR == nil {
+ cs.AMR = []string{}
+ }
+ return &cs
}
-func (f *Flow) GetHandledConsentRequest() *consent.AcceptOAuth2ConsentRequest {
+func (f *Flow) GetHandledConsentRequest() *AcceptOAuth2ConsentRequest {
crf := 0
if f.ConsentRememberFor != nil {
crf = *f.ConsentRememberFor
}
- return &consent.AcceptOAuth2ConsentRequest{
+ return &AcceptOAuth2ConsentRequest{
ID: f.ConsentChallengeID.String(),
GrantedScope: f.GrantedScope,
GrantedAudience: f.GrantedAudience,
- Session: &consent.AcceptOAuth2ConsentRequestSession{AccessToken: f.SessionAccessToken, IDToken: f.SessionIDToken},
+ Session: &AcceptOAuth2ConsentRequestSession{AccessToken: f.SessionAccessToken, IDToken: f.SessionIDToken},
Remember: f.ConsentRemember,
RememberFor: crf,
HandledAt: f.ConsentHandledAt,
@@ -456,7 +461,7 @@ func (f *Flow) GetHandledConsentRequest() *consent.AcceptOAuth2ConsentRequest {
}
}
-func (_ Flow) TableName() string {
+func (Flow) TableName() string {
return "hydra_oauth2_flow"
}
@@ -470,15 +475,15 @@ func (f *Flow) BeforeSave(_ *pop.Connection) error {
return nil
}
-// TODO Populate the client field in FindInDB and FindByConsentChallengeID in
-// order to avoid accessing the database twice.
func (f *Flow) AfterFind(c *pop.Connection) error {
+ // TODO Populate the client field in FindInDB and FindByConsentChallengeID in
+ // order to avoid accessing the database twice.
f.AfterSave(c)
f.Client = &client.Client{}
return sqlcon.HandleError(c.Where("id = ? AND nid = ?", f.ClientID, f.NID).First(f.Client))
}
-func (f *Flow) AfterSave(c *pop.Connection) {
+func (f *Flow) AfterSave(_ *pop.Connection) {
if f.SessionAccessToken == nil {
f.SessionAccessToken = make(map[string]interface{})
}
@@ -486,3 +491,27 @@ func (f *Flow) AfterSave(c *pop.Connection) {
f.SessionIDToken = make(map[string]interface{})
}
}
+
+type CipherProvider interface {
+ FlowCipher() *aead.XChaCha20Poly1305
+}
+
+// ToLoginChallenge converts the flow into a login challenge.
+func (f *Flow) ToLoginChallenge(ctx context.Context, cipherProvider CipherProvider) (string, error) {
+ return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsLoginChallenge)
+}
+
+// ToLoginVerifier converts the flow into a login verifier.
+func (f *Flow) ToLoginVerifier(ctx context.Context, cipherProvider CipherProvider) (string, error) {
+ return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsLoginVerifier)
+}
+
+// ToConsentChallenge converts the flow into a consent challenge.
+func (f *Flow) ToConsentChallenge(ctx context.Context, cipherProvider CipherProvider) (string, error) {
+ return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsConsentChallenge)
+}
+
+// ToConsentVerifier converts the flow into a consent verifier.
+func (f *Flow) ToConsentVerifier(ctx context.Context, cipherProvider CipherProvider) (string, error) {
+ return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsConsentVerifier)
+}
diff --git a/flow/flow_test.go b/flow/flow_test.go
index c00e7524b2e..7876e9a1a63 100644
--- a/flow/flow_test.go
+++ b/flow/flow_test.go
@@ -7,17 +7,16 @@ import (
"testing"
"time"
- "github.com/instana/testify/require"
"github.com/mohae/deepcopy"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"github.com/bxcodec/faker/v3"
- "github.com/ory/hydra/v2/consent"
"github.com/ory/x/sqlxx"
)
-func (f *Flow) setLoginRequest(r *consent.LoginRequest) {
+func (f *Flow) setLoginRequest(r *LoginRequest) {
f.ID = r.ID
f.RequestedScope = r.RequestedScope
f.RequestedAudience = r.RequestedAudience
@@ -36,7 +35,7 @@ func (f *Flow) setLoginRequest(r *consent.LoginRequest) {
f.RequestedAt = r.RequestedAt
}
-func (f *Flow) setHandledLoginRequest(r *consent.HandledLoginRequest) {
+func (f *Flow) setHandledLoginRequest(r *HandledLoginRequest) {
f.ID = r.ID
f.LoginRemember = r.Remember
f.LoginRememberFor = r.RememberFor
@@ -52,7 +51,7 @@ func (f *Flow) setHandledLoginRequest(r *consent.HandledLoginRequest) {
f.LoginAuthenticatedAt = r.AuthenticatedAt
}
-func (f *Flow) setConsentRequest(r consent.OAuth2ConsentRequest) {
+func (f *Flow) setConsentRequest(r OAuth2ConsentRequest) {
f.ConsentChallengeID = sqlxx.NullString(r.ID)
f.RequestedScope = r.RequestedScope
f.RequestedAudience = r.RequestedAudience
@@ -75,7 +74,7 @@ func (f *Flow) setConsentRequest(r consent.OAuth2ConsentRequest) {
f.RequestedAt = r.RequestedAt
}
-func (f *Flow) setHandledConsentRequest(r consent.AcceptOAuth2ConsentRequest) {
+func (f *Flow) setHandledConsentRequest(r AcceptOAuth2ConsentRequest) {
f.ConsentChallengeID = sqlxx.NullString(r.ID)
f.GrantedScope = r.GrantedScope
f.GrantedAudience = r.GrantedAudience
@@ -93,7 +92,7 @@ func (f *Flow) setHandledConsentRequest(r consent.AcceptOAuth2ConsentRequest) {
func TestFlow_GetLoginRequest(t *testing.T) {
t.Run("GetLoginRequest should set all fields on its return value", func(t *testing.T) {
f := Flow{}
- expected := consent.LoginRequest{}
+ expected := LoginRequest{}
assert.NoError(t, faker.FakeData(&expected))
f.setLoginRequest(&expected)
actual := f.GetLoginRequest()
@@ -104,7 +103,7 @@ func TestFlow_GetLoginRequest(t *testing.T) {
func TestFlow_GetHandledLoginRequest(t *testing.T) {
t.Run("GetHandledLoginRequest should set all fields on its return value", func(t *testing.T) {
f := Flow{}
- expected := consent.HandledLoginRequest{}
+ expected := HandledLoginRequest{}
assert.NoError(t, faker.FakeData(&expected))
f.setHandledLoginRequest(&expected)
actual := f.GetHandledLoginRequest()
@@ -117,7 +116,7 @@ func TestFlow_GetHandledLoginRequest(t *testing.T) {
func TestFlow_NewFlow(t *testing.T) {
t.Run("NewFlow and GetLoginRequest should use all LoginRequest fields", func(t *testing.T) {
- expected := &consent.LoginRequest{}
+ expected := &LoginRequest{}
assert.NoError(t, faker.FakeData(expected))
actual := NewFlow(expected).GetLoginRequest()
assert.Equal(t, expected, actual)
@@ -132,7 +131,7 @@ func TestFlow_HandleLoginRequest(t *testing.T) {
assert.NoError(t, faker.FakeData(&f))
f.State = FlowStateLoginInitialized
- r := consent.HandledLoginRequest{}
+ r := HandledLoginRequest{}
assert.NoError(t, faker.FakeData(&r))
r.ID = f.ID
r.Subject = f.Subject
@@ -152,12 +151,12 @@ func TestFlow_HandleLoginRequest(t *testing.T) {
func TestFlow_InvalidateLoginRequest(t *testing.T) {
t.Run("InvalidateLoginRequest should transition the flow into FlowStateLoginUsed", func(t *testing.T) {
- f := NewFlow(&consent.LoginRequest{
+ f := NewFlow(&LoginRequest{
ID: "t3-id",
Subject: "t3-sub",
WasHandled: false,
})
- assert.NoError(t, f.HandleLoginRequest(&consent.HandledLoginRequest{
+ assert.NoError(t, f.HandleLoginRequest(&HandledLoginRequest{
ID: "t3-id",
Subject: "t3-sub",
WasHandled: false,
@@ -167,12 +166,12 @@ func TestFlow_InvalidateLoginRequest(t *testing.T) {
assert.Equal(t, true, f.LoginWasUsed)
})
t.Run("InvalidateLoginRequest should fail when flow.LoginWasUsed is true", func(t *testing.T) {
- f := NewFlow(&consent.LoginRequest{
+ f := NewFlow(&LoginRequest{
ID: "t3-id",
Subject: "t3-sub",
WasHandled: false,
})
- assert.NoError(t, f.HandleLoginRequest(&consent.HandledLoginRequest{
+ assert.NoError(t, f.HandleLoginRequest(&HandledLoginRequest{
ID: "t3-id",
Subject: "t3-sub",
WasHandled: true,
@@ -186,7 +185,7 @@ func TestFlow_InvalidateLoginRequest(t *testing.T) {
func TestFlow_GetConsentRequest(t *testing.T) {
t.Run("GetConsentRequest should set all fields on its return value", func(t *testing.T) {
f := Flow{}
- expected := consent.OAuth2ConsentRequest{}
+ expected := OAuth2ConsentRequest{}
assert.NoError(t, faker.FakeData(&expected))
f.setConsentRequest(expected)
actual := f.GetConsentRequest()
@@ -198,13 +197,13 @@ func TestFlow_HandleConsentRequest(t *testing.T) {
f := Flow{}
require.NoError(t, faker.FakeData(&f))
- expected := consent.AcceptOAuth2ConsentRequest{}
+ expected := AcceptOAuth2ConsentRequest{}
require.NoError(t, faker.FakeData(&expected))
expected.ID = string(f.ConsentChallengeID)
expected.HandledAt = sqlxx.NullTime(time.Now())
expected.RequestedAt = f.RequestedAt
- expected.Session = &consent.AcceptOAuth2ConsentRequestSession{
+ expected.Session = &AcceptOAuth2ConsentRequestSession{
IDToken: sqlxx.MapStringInterface{"claim1": "value1", "claim2": "value2"},
AccessToken: sqlxx.MapStringInterface{"claim3": "value3", "claim4": "value4"},
}
@@ -215,7 +214,7 @@ func TestFlow_HandleConsentRequest(t *testing.T) {
f.ConsentWasHandled = false
fGood := deepcopy.Copy(f).(Flow)
- eGood := deepcopy.Copy(expected).(consent.AcceptOAuth2ConsentRequest)
+ eGood := deepcopy.Copy(expected).(AcceptOAuth2ConsentRequest)
require.NoError(t, f.HandleConsentRequest(&expected))
t.Run("HandleConsentRequest should fail when already handled", func(t *testing.T) {
@@ -232,7 +231,7 @@ func TestFlow_HandleConsentRequest(t *testing.T) {
t.Run("HandleConsentRequest should fail when HandledAt in its argument is zero", func(t *testing.T) {
f := deepcopy.Copy(fGood).(Flow)
- eBad := deepcopy.Copy(eGood).(consent.AcceptOAuth2ConsentRequest)
+ eBad := deepcopy.Copy(eGood).(AcceptOAuth2ConsentRequest)
eBad.HandledAt = sqlxx.NullTime(time.Time{})
require.Error(t, f.HandleConsentRequest(&eBad))
})
@@ -249,11 +248,11 @@ func TestFlow_HandleConsentRequest(t *testing.T) {
func TestFlow_GetHandledConsentRequest(t *testing.T) {
t.Run("GetHandledConsentRequest should set all fields on its return value", func(t *testing.T) {
f := Flow{}
- expected := consent.AcceptOAuth2ConsentRequest{}
+ expected := AcceptOAuth2ConsentRequest{}
assert.NoError(t, faker.FakeData(&expected))
expected.ConsentRequest = nil
- expected.Session = &consent.AcceptOAuth2ConsentRequestSession{
+ expected.Session = &AcceptOAuth2ConsentRequestSession{
IDToken: sqlxx.MapStringInterface{"claim1": "value1", "claim2": "value2"},
AccessToken: sqlxx.MapStringInterface{"claim3": "value3", "claim4": "value4"},
}
diff --git a/fositex/token_strategy_test.go b/fositex/token_strategy_test.go
index 894b2470317..e308de58ef4 100644
--- a/fositex/token_strategy_test.go
+++ b/fositex/token_strategy_test.go
@@ -4,6 +4,7 @@
package fositex
import (
+ "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -14,6 +15,8 @@ import (
// Test that the generic signature function implements the same signature as the
// HMAC and JWT strategies.
func TestAccessTokenSignature(t *testing.T) {
+ ctx := context.Background()
+
t.Run("strategy=DefaultJWTStrategy", func(t *testing.T) {
strategy := new(oauth2.DefaultJWTStrategy)
for _, tc := range []struct{ token string }{
@@ -25,7 +28,7 @@ func TestAccessTokenSignature(t *testing.T) {
} {
t.Run("case="+tc.token, func(t *testing.T) {
assert.Equal(t,
- strategy.AccessTokenSignature(nil, tc.token),
+ strategy.AccessTokenSignature(ctx, tc.token),
genericSignature(tc.token))
})
}
@@ -41,7 +44,7 @@ func TestAccessTokenSignature(t *testing.T) {
} {
t.Run("case="+tc.token, func(t *testing.T) {
assert.Equal(t,
- strategy.AccessTokenSignature(nil, tc.token),
+ strategy.AccessTokenSignature(ctx, tc.token),
genericSignature(tc.token))
})
}
diff --git a/go.mod b/go.mod
index eea9903722e..d39f63a04a7 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/ory/hydra/v2
-go 1.19
+go 1.20
replace (
github.com/bradleyjkemp/cupaloy/v2 => github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3
@@ -28,9 +28,7 @@ require (
github.com/google/uuid v1.3.0
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.1
- github.com/gtank/cryptopasta v0.0.0-20170601214702-1f550f6f2f69
github.com/hashicorp/go-retryablehttp v0.7.2
- github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65
github.com/jackc/pgx/v4 v4.17.2
github.com/julienschmidt/httprouter v1.3.0
github.com/luna-duclos/instrumentedsql v1.1.3
@@ -61,10 +59,16 @@ require (
github.com/toqueteos/webbrowser v1.2.0
github.com/twmb/murmur3 v1.1.6
github.com/urfave/negroni v1.0.0
+ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4
go.opentelemetry.io/otel v1.11.1
+ go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.9.0
+ go.opentelemetry.io/otel/sdk v1.11.1
+ go.opentelemetry.io/otel/trace v1.11.1
go.uber.org/automaxprocs v1.3.0
+ golang.org/x/crypto v0.9.0
+ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
golang.org/x/oauth2 v0.5.0
- golang.org/x/tools v0.7.0
+ golang.org/x/tools v0.9.1
gopkg.in/square/go-jose.v2 v2.6.0
)
@@ -215,25 +219,20 @@ require (
github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c // indirect
go.mongodb.org/mongo-driver v1.10.3 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.36.4 // indirect
- go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 // indirect
go.opentelemetry.io/contrib/propagators/b3 v1.11.1 // indirect
go.opentelemetry.io/contrib/propagators/jaeger v1.11.1 // indirect
go.opentelemetry.io/contrib/samplers/jaegerremote v0.5.2 // indirect
go.opentelemetry.io/otel/exporters/jaeger v1.11.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.11.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.9.0 // indirect
- go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.9.0 // indirect
go.opentelemetry.io/otel/exporters/zipkin v1.11.1 // indirect
go.opentelemetry.io/otel/metric v0.33.0 // indirect
- go.opentelemetry.io/otel/sdk v1.11.1 // indirect
- go.opentelemetry.io/otel/trace v1.11.1 // indirect
go.opentelemetry.io/proto/otlp v0.18.0 // indirect
- golang.org/x/crypto v0.1.0 // indirect
golang.org/x/mod v0.10.0 // indirect
- golang.org/x/net v0.8.0 // indirect
- golang.org/x/sync v0.1.0 // indirect
- golang.org/x/sys v0.7.0 // indirect
- golang.org/x/text v0.8.0 // indirect
+ golang.org/x/net v0.10.0 // indirect
+ golang.org/x/sync v0.2.0 // indirect
+ golang.org/x/sys v0.8.0 // indirect
+ golang.org/x/text v0.9.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230403163135-c38d8f061ccd // indirect
diff --git a/go.sum b/go.sum
index 4e547b545ec..c893e4d9a09 100644
--- a/go.sum
+++ b/go.sum
@@ -452,8 +452,6 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFb
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0 h1:kr3j8iIMR4ywO/O0rvksXaJvauGGCMg2zAZIiNZ9uIQ=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0/go.mod h1:ummNFgdgLhhX7aIiy35vVmQNS0rWXknfPE0qe6fmFXg=
-github.com/gtank/cryptopasta v0.0.0-20170601214702-1f550f6f2f69 h1:7xsUJsB2NrdcttQPa7JLEaGzvdbk7KvfrjgHZXOQRo0=
-github.com/gtank/cryptopasta v0.0.0-20170601214702-1f550f6f2f69/go.mod h1:YLEMZOtU+AZ7dhN9T/IpGhXVGly2bvkJQ+zxj3WeVQo=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ=
@@ -480,8 +478,6 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf h1:FtEj8sfIcaaBfAKrE1Cwb61YDtYq9JxChK1c7AKce7s=
github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf/go.mod h1:yrqSXGoD/4EKfF26AOGzscPOgTTJcyAwM2rpixWT+t4=
-github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65 h1:T25FL3WEzgmKB0m6XCJNZ65nw09/QIp3T1yXr487D+A=
-github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65/go.mod h1:nYhEREG/B7HUY7P+LKOrqy53TpIqmJ9JyUShcaEKtGw=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
@@ -1018,8 +1014,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
-golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
+golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
+golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -1030,6 +1026,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
+golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
+golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@@ -1105,8 +1103,8 @@ golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
-golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
-golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
+golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
+golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -1136,8 +1134,9 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
+golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180816055513-1c9583448a9c/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -1219,8 +1218,8 @@ golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
-golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
+golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -1234,8 +1233,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
-golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
-golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
+golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
+golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1305,8 +1304,8 @@ golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
-golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4=
-golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s=
+golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
+golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
diff --git a/health/doc.go b/health/doc.go
index bad9c42139c..a0b2f45cbe8 100644
--- a/health/doc.go
+++ b/health/doc.go
@@ -24,6 +24,8 @@ package health
// Responses:
// 200: healthStatus
// 500: errorOAuth2
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
func swaggerPublicIsInstanceAlive() {}
// Alive returns an ok status if the instance is ready to handle HTTP requests.
@@ -47,6 +49,8 @@ func swaggerPublicIsInstanceAlive() {}
// Responses:
// 200: healthStatus
// 500: errorOAuth2
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
func swaggerAdminIsInstanceAlive() {}
// Ready returns an ok status if the instance is ready to handle HTTP requests and all ReadyCheckers are ok.
@@ -70,6 +74,8 @@ func swaggerAdminIsInstanceAlive() {}
// Responses:
// 200: healthStatus
// 503: healthNotReadyStatus
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
func swaggerAdminIsInstanceReady() {}
// Ready returns an ok status if the instance is ready to handle HTTP requests and all ReadyCheckers are ok.
@@ -93,6 +99,8 @@ func swaggerAdminIsInstanceReady() {}
// Responses:
// 200: healthStatus
// 503: healthNotReadyStatus
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
func swaggerPublicIsInstanceReady() {}
// Version returns this service's versions.
@@ -111,4 +119,6 @@ func swaggerPublicIsInstanceReady() {}
//
// Responses:
// 200: version
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
func swaggerGetVersion() {}
diff --git a/internal/driver.go b/internal/driver.go
index bd8de6ae547..a462f29ace5 100644
--- a/internal/driver.go
+++ b/internal/driver.go
@@ -5,10 +5,11 @@ package internal
import (
"context"
-
"sync"
"testing"
+ "gopkg.in/square/go-jose.v2"
+
"github.com/ory/x/configx"
"github.com/stretchr/testify/require"
@@ -45,19 +46,19 @@ func NewConfigurationWithDefaultsAndHTTPS() *config.DefaultProvider {
return p
}
-func NewRegistryMemory(t *testing.T, c *config.DefaultProvider, ctxer contextx.Contextualizer) driver.Registry {
+func NewRegistryMemory(t testing.TB, c *config.DefaultProvider, ctxer contextx.Contextualizer) driver.Registry {
return newRegistryDefault(t, "memory", c, true, ctxer)
}
-func NewMockedRegistry(t *testing.T, ctxer contextx.Contextualizer) driver.Registry {
+func NewMockedRegistry(t testing.TB, ctxer contextx.Contextualizer) driver.Registry {
return newRegistryDefault(t, "memory", NewConfigurationWithDefaults(), true, ctxer)
}
-func NewRegistrySQLFromURL(t *testing.T, url string, migrate bool, ctxer contextx.Contextualizer) driver.Registry {
+func NewRegistrySQLFromURL(t testing.TB, url string, migrate bool, ctxer contextx.Contextualizer) driver.Registry {
return newRegistryDefault(t, url, NewConfigurationWithDefaults(), migrate, ctxer)
}
-func newRegistryDefault(t *testing.T, url string, c *config.DefaultProvider, migrate bool, ctxer contextx.Contextualizer) driver.Registry {
+func newRegistryDefault(t testing.TB, url string, c *config.DefaultProvider, migrate bool, ctxer contextx.Contextualizer) driver.Registry {
ctx := context.Background()
c.MustSet(ctx, config.KeyLogLevel, "trace")
c.MustSet(ctx, config.KeyDSN, url)
@@ -77,15 +78,15 @@ func CleanAndMigrate(reg driver.Registry) func(*testing.T) {
}
}
-func ConnectToMySQL(t *testing.T) string {
- return dockertest.RunTestMySQLWithVersion(t, "11.8")
+func ConnectToMySQL(t testing.TB) string {
+ return dockertest.RunTestMySQLWithVersion(t, "8.0.26")
}
-func ConnectToPG(t *testing.T) string {
+func ConnectToPG(t testing.TB) string {
return dockertest.RunTestPostgreSQLWithVersion(t, "11.8")
}
-func ConnectToCRDB(t *testing.T) string {
+func ConnectToCRDB(t testing.TB) string {
return dockertest.RunTestCockroachDBWithVersion(t, "v22.1.2")
}
@@ -134,8 +135,8 @@ func ConnectDatabases(t *testing.T, migrate bool, ctxer contextx.Contextualizer)
return
}
-func MustEnsureRegistryKeys(r driver.Registry, key string) {
- if err := jwk.EnsureAsymmetricKeypairExists(context.Background(), r, "RS256", key); err != nil {
+func MustEnsureRegistryKeys(ctx context.Context, r driver.Registry, key string) {
+ if err := jwk.EnsureAsymmetricKeypairExists(ctx, r, string(jose.ES256), key); err != nil {
panic(err)
}
}
diff --git a/internal/testhelpers/janitor_test_helper.go b/internal/testhelpers/janitor_test_helper.go
index 9954d013667..334f490364f 100644
--- a/internal/testhelpers/janitor_test_helper.go
+++ b/internal/testhelpers/janitor_test_helper.go
@@ -20,6 +20,7 @@ import (
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/driver"
"github.com/ory/hydra/v2/driver/config"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/oauth2/trust"
@@ -32,8 +33,8 @@ import (
type JanitorConsentTestHelper struct {
uniqueName string
- flushLoginRequests []*consent.LoginRequest
- flushConsentRequests []*consent.OAuth2ConsentRequest
+ flushLoginRequests []*flow.LoginRequest
+ flushConsentRequests []*flow.OAuth2ConsentRequest
flushAccessRequests []*fosite.Request
flushRefreshRequests []*fosite.AccessRequest
flushGrants []*createGrantRequest
@@ -69,7 +70,7 @@ func NewConsentJanitorTestHelper(uniqueName string) *JanitorConsentTestHelper {
}
}
-func (j *JanitorConsentTestHelper) GetDSN(ctx context.Context) string {
+func (j *JanitorConsentTestHelper) GetDSN() string {
return j.conf.DSN()
}
@@ -149,7 +150,7 @@ func (j *JanitorConsentTestHelper) RefreshTokenNotAfterValidate(ctx context.Cont
}
}
-func (j *JanitorConsentTestHelper) GrantNotAfterSetup(ctx context.Context, cl client.Manager, gr trust.GrantManager) func(t *testing.T) {
+func (j *JanitorConsentTestHelper) GrantNotAfterSetup(ctx context.Context, gr trust.GrantManager) func(t *testing.T) {
return func(t *testing.T) {
for _, fg := range j.flushGrants {
require.NoError(t, gr.CreateGrant(ctx, fg.grant, fg.pk))
@@ -180,21 +181,29 @@ func (j *JanitorConsentTestHelper) GrantNotAfterValidate(ctx context.Context, no
}
}
-func (j *JanitorConsentTestHelper) LoginRejectionSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) {
- return func(t *testing.T) {
- var err error
+func (j *JanitorConsentTestHelper) LoginRejectionSetup(ctx context.Context, reg interface {
+ consent.ManagerProvider
+ client.ManagerProvider
+ flow.CipherProvider
+}) func(t *testing.T) {
+ cm := reg.ConsentManager()
+ cl := reg.ClientManager()
+ return func(t *testing.T) {
// Create login requests
for _, r := range j.flushLoginRequests {
require.NoError(t, cl.CreateClient(ctx, r.Client))
- require.NoError(t, cm.CreateLoginRequest(ctx, r))
- }
+ f, err := cm.CreateLoginRequest(ctx, r)
+ require.NoError(t, err)
- // Explicit rejection
- for _, r := range j.flushLoginRequests {
+ f.RequestedAt = time.Now() // we won't handle expired flows
+ f.LoginAuthenticatedAt = r.AuthenticatedAt
+ challenge := x.Must(f.ToLoginChallenge(ctx, reg))
+
+ // Explicit rejection
if r.ID == j.flushLoginRequests[0].ID {
// accept this one
- _, err = cm.HandleLoginRequest(ctx, r.ID, consent.NewHandledLoginRequest(
+ _, err = cm.HandleLoginRequest(ctx, f, challenge, consent.NewHandledLoginRequest(
r.ID, false, r.RequestedAt, r.AuthenticatedAt))
require.NoError(t, err)
@@ -202,7 +211,7 @@ func (j *JanitorConsentTestHelper) LoginRejectionSetup(ctx context.Context, cm c
}
// reject flush-login-2 and 3
- _, err = cm.HandleLoginRequest(ctx, r.ID, consent.NewHandledLoginRequest(
+ _, err = cm.HandleLoginRequest(ctx, f, challenge, consent.NewHandledLoginRequest(
r.ID, true, r.RequestedAt, r.AuthenticatedAt))
require.NoError(t, err)
}
@@ -215,28 +224,38 @@ func (j *JanitorConsentTestHelper) LoginRejectionValidate(ctx context.Context, c
for _, r := range j.flushLoginRequests {
t.Logf("check login: %s", r.ID)
_, err := cm.GetLoginRequest(ctx, r.ID)
- if r.ID == j.flushLoginRequests[0].ID {
- require.NoError(t, err)
- } else {
- require.Error(t, err)
- }
+ // Login requests should never be persisted.
+ require.Error(t, err)
}
}
}
-func (j *JanitorConsentTestHelper) LimitSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) {
+func (j *JanitorConsentTestHelper) LimitSetup(ctx context.Context, reg interface {
+ consent.ManagerProvider
+ client.ManagerProvider
+ flow.CipherProvider
+}) func(t *testing.T) {
+ cl := reg.ClientManager()
+ cm := reg.ConsentManager()
+
return func(t *testing.T) {
- var err error
+ var (
+ err error
+ f *flow.Flow
+ )
// Create login requests
for _, r := range j.flushLoginRequests {
require.NoError(t, cl.CreateClient(ctx, r.Client))
- require.NoError(t, cm.CreateLoginRequest(ctx, r))
- }
+ f, err = cm.CreateLoginRequest(ctx, r)
+ require.NoError(t, err)
- // Reject each request
- for _, r := range j.flushLoginRequests {
- _, err = cm.HandleLoginRequest(ctx, r.ID, consent.NewHandledLoginRequest(
+ // Reject each request
+ f.RequestedAt = time.Now() // we won't handle expired flows
+ f.LoginAuthenticatedAt = r.AuthenticatedAt
+ challenge := x.Must(f.ToLoginChallenge(ctx, reg))
+
+ _, err = cm.HandleLoginRequest(ctx, f, challenge, consent.NewHandledLoginRequest(
r.ID, true, r.RequestedAt, r.AuthenticatedAt))
require.NoError(t, err)
}
@@ -249,41 +268,50 @@ func (j *JanitorConsentTestHelper) LimitValidate(ctx context.Context, cm consent
for _, r := range j.flushLoginRequests {
t.Logf("check login: %s", r.ID)
_, err := cm.GetLoginRequest(ctx, r.ID)
- if r.ID == j.flushLoginRequests[0].ID {
- require.NoError(t, err)
- } else {
- require.Error(t, err)
- }
+ // No Requests should have been persisted.
+ require.Error(t, err)
}
}
}
-func (j *JanitorConsentTestHelper) ConsentRejectionSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) {
+func (j *JanitorConsentTestHelper) ConsentRejectionSetup(ctx context.Context, reg interface {
+ consent.ManagerProvider
+ client.ManagerProvider
+ flow.CipherProvider
+}) func(t *testing.T) {
+ cl := reg.ClientManager()
+ cm := reg.ConsentManager()
+
return func(t *testing.T) {
- var err error
+ var (
+ err error
+ f *flow.Flow
+ )
// Create login requests
- for _, r := range j.flushLoginRequests {
- require.NoError(t, cl.CreateClient(ctx, r.Client))
- require.NoError(t, cm.CreateLoginRequest(ctx, r))
- }
+ for i, loginRequest := range j.flushLoginRequests {
+ require.NoError(t, cl.CreateClient(ctx, loginRequest.Client))
+ f, err = cm.CreateLoginRequest(ctx, loginRequest)
+ require.NoError(t, err)
- // Create consent requests
- for _, r := range j.flushConsentRequests {
- require.NoError(t, cm.CreateConsentRequest(ctx, r))
- }
+ // Create consent requests
+ consentRequest := j.flushConsentRequests[i]
+ err = cm.CreateConsentRequest(ctx, f, consentRequest)
+ require.NoError(t, err)
- //Reject the consents
- for _, r := range j.flushConsentRequests {
- if r.ID == j.flushConsentRequests[0].ID {
+ f.RequestedAt = time.Now() // we won't handle expired flows
+ f.LoginAuthenticatedAt = consentRequest.AuthenticatedAt
+
+ // Reject the consents
+ if consentRequest.ID == j.flushConsentRequests[0].ID {
// accept this one
- _, err = cm.HandleConsentRequest(ctx, consent.NewHandledConsentRequest(
- r.ID, false, r.RequestedAt, r.AuthenticatedAt))
+ _, err = cm.HandleConsentRequest(ctx, f, consent.NewHandledConsentRequest(
+ consentRequest.ID, false, consentRequest.RequestedAt, consentRequest.AuthenticatedAt))
require.NoError(t, err)
continue
}
- _, err = cm.HandleConsentRequest(ctx, consent.NewHandledConsentRequest(
- r.ID, true, r.RequestedAt, r.AuthenticatedAt))
+ _, err = cm.HandleConsentRequest(ctx, f, consent.NewHandledConsentRequest(
+ consentRequest.ID, true, consentRequest.RequestedAt, consentRequest.AuthenticatedAt))
require.NoError(t, err)
}
}
@@ -295,32 +323,43 @@ func (j *JanitorConsentTestHelper) ConsentRejectionValidate(ctx context.Context,
for _, r := range j.flushConsentRequests {
t.Logf("check consent: %s", r.ID)
_, err = cm.GetConsentRequest(ctx, r.ID)
- if r.ID == j.flushConsentRequests[0].ID {
- require.NoError(t, err)
- } else {
- require.Error(t, err)
- }
+ // Consent requests should never be persisted.
+ require.Error(t, err)
}
}
}
-func (j *JanitorConsentTestHelper) LoginTimeoutSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) {
+func (j *JanitorConsentTestHelper) LoginTimeoutSetup(ctx context.Context, reg interface {
+ consent.ManagerProvider
+ client.ManagerProvider
+ flow.CipherProvider
+}) func(t *testing.T) {
+ cl := reg.ClientManager()
+ cm := reg.ConsentManager()
+
return func(t *testing.T) {
- var err error
+ var (
+ err error
+ f *flow.Flow
+ )
// Create login requests
- for _, r := range j.flushLoginRequests {
- require.NoError(t, cl.CreateClient(ctx, r.Client))
- require.NoError(t, cm.CreateLoginRequest(ctx, r))
- }
+ for i, loginRequest := range j.flushLoginRequests {
+ require.NoError(t, cl.CreateClient(ctx, loginRequest.Client))
+ f, err = cm.CreateLoginRequest(ctx, loginRequest)
+ require.NoError(t, err)
- // Creating at least 1 that has not timed out
- _, err = cm.HandleLoginRequest(ctx, j.flushLoginRequests[0].ID, &consent.HandledLoginRequest{
- ID: j.flushLoginRequests[0].ID,
- RequestedAt: j.flushLoginRequests[0].RequestedAt,
- AuthenticatedAt: j.flushLoginRequests[0].AuthenticatedAt,
- WasHandled: true,
- })
+ if i == 0 {
+ // Creating at least 1 that has not timed out
+ challenge := x.Must(f.ToLoginChallenge(ctx, reg))
+ _, err = cm.HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{
+ ID: loginRequest.ID,
+ RequestedAt: loginRequest.RequestedAt,
+ AuthenticatedAt: loginRequest.AuthenticatedAt,
+ WasHandled: true,
+ })
+ }
+ }
require.NoError(t, err)
}
@@ -328,51 +367,56 @@ func (j *JanitorConsentTestHelper) LoginTimeoutSetup(ctx context.Context, cm con
func (j *JanitorConsentTestHelper) LoginTimeoutValidate(ctx context.Context, cm consent.Manager) func(t *testing.T) {
return func(t *testing.T) {
- var err error
-
for _, r := range j.flushLoginRequests {
- _, err = cm.GetLoginRequest(ctx, r.ID)
- if r.ID == j.flushLoginRequests[0].ID {
- require.NoError(t, err)
- } else {
- require.Error(t, err)
- }
+ _, err := cm.GetLoginRequest(ctx, r.ID)
+ // Login requests should never be persisted.
+ require.Error(t, err)
}
-
}
}
-func (j *JanitorConsentTestHelper) ConsentTimeoutSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) {
- return func(t *testing.T) {
- var err error
+func (j *JanitorConsentTestHelper) ConsentTimeoutSetup(ctx context.Context, reg interface {
+ consent.ManagerProvider
+ client.ManagerProvider
+ flow.CipherProvider
+}) func(t *testing.T) {
+ cl := reg.ClientManager()
+ cm := reg.ConsentManager()
+ return func(t *testing.T) {
// Let's reset and accept all login requests to test the consent requests
- for _, r := range j.flushLoginRequests {
- require.NoError(t, cl.CreateClient(ctx, r.Client))
- require.NoError(t, cm.CreateLoginRequest(ctx, r))
- _, err = cm.HandleLoginRequest(ctx, r.ID, &consent.HandledLoginRequest{
- ID: r.ID,
- AuthenticatedAt: r.AuthenticatedAt,
- RequestedAt: r.RequestedAt,
+ for i, loginRequest := range j.flushLoginRequests {
+ require.NoError(t, cl.CreateClient(ctx, loginRequest.Client))
+ f, err := cm.CreateLoginRequest(ctx, loginRequest)
+ require.NoError(t, err)
+ f.RequestedAt = time.Now() // we won't handle expired flows
+ challenge := x.Must(f.ToLoginChallenge(ctx, reg))
+ _, err = cm.HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{
+ ID: loginRequest.ID,
+ AuthenticatedAt: loginRequest.AuthenticatedAt,
+ RequestedAt: loginRequest.RequestedAt,
WasHandled: true,
})
require.NoError(t, err)
- }
- // Create consent requests
- for _, r := range j.flushConsentRequests {
- require.NoError(t, cm.CreateConsentRequest(ctx, r))
+ // Create consent requests
+ consentRequest := j.flushConsentRequests[i]
+ err = cm.CreateConsentRequest(ctx, f, consentRequest)
+ require.NoError(t, err)
+
+ if i == 0 {
+ // Create at least 1 consent request that has been accepted
+ _, err = cm.HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{
+ ID: consentRequest.ID,
+ WasHandled: true,
+ HandledAt: sqlxx.NullTime(time.Now()),
+ RequestedAt: consentRequest.RequestedAt,
+ AuthenticatedAt: consentRequest.AuthenticatedAt,
+ })
+ require.NoError(t, err)
+ }
}
- // Create at least 1 consent request that has been accepted
- _, err = cm.HandleConsentRequest(ctx, &consent.AcceptOAuth2ConsentRequest{
- ID: j.flushConsentRequests[0].ID,
- WasHandled: true,
- HandledAt: sqlxx.NullTime(time.Now()),
- RequestedAt: j.flushConsentRequests[0].RequestedAt,
- AuthenticatedAt: j.flushConsentRequests[0].AuthenticatedAt,
- })
- require.NoError(t, err)
}
}
@@ -382,40 +426,58 @@ func (j *JanitorConsentTestHelper) ConsentTimeoutValidate(ctx context.Context, c
for _, r := range j.flushConsentRequests {
_, err = cm.GetConsentRequest(ctx, r.ID)
- if r.ID == j.flushConsentRequests[0].ID {
- require.NoError(t, err)
- } else {
- require.Error(t, err)
- }
+ require.Error(t, err, "Unverified consent requests are never pesisted")
}
-
}
}
func (j *JanitorConsentTestHelper) LoginConsentNotAfterSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) {
return func(t *testing.T) {
+ var (
+ f *flow.Flow
+ err error
+ )
for _, r := range j.flushLoginRequests {
require.NoError(t, cl.CreateClient(ctx, r.Client))
- require.NoError(t, cm.CreateLoginRequest(ctx, r))
+ f, err = cm.CreateLoginRequest(ctx, r)
+ require.NoError(t, err)
}
for _, r := range j.flushConsentRequests {
- require.NoError(t, cm.CreateConsentRequest(ctx, r))
+ f.ID = r.LoginChallenge.String()
+ err = cm.CreateConsentRequest(ctx, f, r)
+ require.NoError(t, err)
}
}
}
-func (j *JanitorConsentTestHelper) LoginConsentNotAfterValidate(ctx context.Context, notAfter time.Time, consentRequestLifespan time.Time, cm consent.Manager) func(t *testing.T) {
+func (j *JanitorConsentTestHelper) LoginConsentNotAfterValidate(
+ ctx context.Context,
+ notAfter time.Time,
+ consentRequestLifespan time.Time,
+ reg interface {
+ consent.ManagerProvider
+ flow.CipherProvider
+ },
+) func(t *testing.T) {
return func(t *testing.T) {
- var err error
+ var (
+ err error
+ f *flow.Flow
+ )
for _, r := range j.flushLoginRequests {
- t.Logf("login flush check:\nNotAfter: %s\nConsentRequest: %s\n%+v\n",
- notAfter.String(), consentRequestLifespan.String(), r)
- _, err = cm.GetLoginRequest(ctx, r.ID)
+ isExpired := r.RequestedAt.Before(consentRequestLifespan)
+ t.Logf("login flush check:\nNotAfter: %s\nLoginRequest: %s\nis expired: %v\n%+v\n",
+ notAfter.String(), consentRequestLifespan.String(), isExpired, r)
+
+ f = x.Must(reg.ConsentManager().CreateLoginRequest(ctx, r))
+ loginChallenge := x.Must(f.ToLoginChallenge(ctx, reg))
+
+ _, err = reg.ConsentManager().GetLoginRequest(ctx, loginChallenge)
// if the lowest between notAfter and consent-request-lifespan is greater than requested_at
// then the it should expect the value to be deleted.
- if j.notAfterCheck(notAfter, consentRequestLifespan, r.RequestedAt) {
+ if isExpired {
// value has been deleted here
require.Error(t, err)
} else {
@@ -425,12 +487,19 @@ func (j *JanitorConsentTestHelper) LoginConsentNotAfterValidate(ctx context.Cont
}
for _, r := range j.flushConsentRequests {
- t.Logf("consent flush check:\nNotAfter: %s\nConsentRequest: %s\n%+v\n",
- notAfter.String(), consentRequestLifespan.String(), r)
- _, err = cm.GetConsentRequest(ctx, r.ID)
+ isExpired := r.RequestedAt.Before(consentRequestLifespan)
+ t.Logf("consent flush check:\nNotAfter: %s\nConsentRequest: %s\nis expired: %v\n%+v\n",
+ notAfter.String(), consentRequestLifespan.String(), isExpired, r)
+
+ f.ID = r.LoginChallenge.String()
+ require.NoError(t, reg.ConsentManager().CreateConsentRequest(ctx, f, r))
+ f.RequestedAt = r.RequestedAt
+ consentChallenge := x.Must(f.ToConsentChallenge(ctx, reg))
+
+ _, err = reg.ConsentManager().GetConsentRequest(ctx, consentChallenge)
// if the lowest between notAfter and consent-request-lifespan is greater than requested_at
// then the it should expect the value to be deleted.
- if j.notAfterCheck(notAfter, consentRequestLifespan, r.RequestedAt) {
+ if isExpired {
// value has been deleted here
require.Error(t, err)
} else {
@@ -470,8 +539,22 @@ func (j *JanitorConsentTestHelper) notAfterCheck(notAfter time.Time, lifespan ti
return lesser.Unix() > requestedAt.Unix()
}
-func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, clientManager client.Manager, fositeManager x.FositeStorer, network string, parallel bool) func(t *testing.T) {
+func JanitorTests(
+ reg interface {
+ ConsentManager() consent.Manager
+ OAuth2Storage() x.FositeStorer
+ config.Provider
+ client.ManagerProvider
+ flow.CipherProvider
+ },
+ network string,
+ parallel bool,
+) func(t *testing.T) {
return func(t *testing.T) {
+ consentManager := reg.ConsentManager()
+ clientManager := reg.ClientManager()
+ fositeManager := reg.OAuth2Storage()
+
if parallel {
t.Parallel()
}
@@ -479,7 +562,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager,
jt := NewConsentJanitorTestHelper(network + t.Name())
- conf.MustSet(context.Background(), config.KeyConsentRequestMaxAge, jt.GetConsentRequestLifespan(ctx))
+ reg.Config().MustSet(context.Background(), config.KeyConsentRequestMaxAge, jt.GetConsentRequestLifespan(ctx))
t.Run("case=flush-consent-request-not-after", func(t *testing.T) {
@@ -500,7 +583,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager,
})
// validate test
- t.Run("step=validate", jt.LoginConsentNotAfterValidate(ctx, notAfter, consentRequestLifespan, consentManager))
+ t.Run("step=validate", jt.LoginConsentNotAfterValidate(ctx, notAfter, consentRequestLifespan, reg))
})
}
@@ -511,7 +594,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager,
t.Run("case=limit", func(t *testing.T) {
// setup
- t.Run("step=setup", jt.LimitSetup(ctx, consentManager, clientManager))
+ t.Run("step=setup", jt.LimitSetup(ctx, reg))
// cleanup
t.Run("step=cleanup", func(t *testing.T) {
@@ -528,7 +611,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager,
t.Run(fmt.Sprintf("case=%s", "loginRejection"), func(t *testing.T) {
// setup
- t.Run("step=setup", jt.LoginRejectionSetup(ctx, consentManager, clientManager))
+ t.Run("step=setup", jt.LoginRejectionSetup(ctx, reg))
// cleanup
t.Run("step=cleanup", func(t *testing.T) {
@@ -543,7 +626,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager,
t.Run(fmt.Sprintf("case=%s", "consentRejection"), func(t *testing.T) {
// setup
- t.Run("step=setup", jt.ConsentRejectionSetup(ctx, consentManager, clientManager))
+ t.Run("step=setup", jt.ConsentRejectionSetup(ctx, reg))
// cleanup
t.Run("step=cleanup", func(t *testing.T) {
@@ -562,7 +645,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager,
t.Run(fmt.Sprintf("case=%s", "login-timeout"), func(t *testing.T) {
// setup
- t.Run("step=setup", jt.LoginTimeoutSetup(ctx, consentManager, clientManager))
+ t.Run("step=setup", jt.LoginTimeoutSetup(ctx, reg))
// cleanup
t.Run("step=cleanup", func(t *testing.T) {
@@ -579,7 +662,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager,
t.Run(fmt.Sprintf("case=%s", "consent-timeout"), func(t *testing.T) {
// setup
- t.Run("step=setup", jt.ConsentTimeoutSetup(ctx, consentManager, clientManager))
+ t.Run("step=setup", jt.ConsentTimeoutSetup(ctx, reg))
// cleanup
t.Run("step=cleanup", func(t *testing.T) {
@@ -627,7 +710,7 @@ func getAccessRequests(uniqueName string, lifespan time.Duration) []*fosite.Requ
}
func getRefreshRequests(uniqueName string, lifespan time.Duration) []*fosite.AccessRequest {
- var tokenSignature = "4c7c7e8b3a77ad0c3ec846a21653c48b45dbfa31"
+ var tokenSignature = "4c7c7e8b3a77ad0c3ec846a21653c48b45dbfa31" //nolint:gosec
return []*fosite.AccessRequest{
{
GrantTypes: []string{
@@ -680,8 +763,8 @@ func getRefreshRequests(uniqueName string, lifespan time.Duration) []*fosite.Acc
}
}
-func genLoginRequests(uniqueName string, lifespan time.Duration) []*consent.LoginRequest {
- return []*consent.LoginRequest{
+func genLoginRequests(uniqueName string, lifespan time.Duration) []*flow.LoginRequest {
+ return []*flow.LoginRequest{
{
ID: fmt.Sprintf("%s_flush-login-1", uniqueName),
RequestedScope: []string{"foo", "bar"},
@@ -704,8 +787,8 @@ func genLoginRequests(uniqueName string, lifespan time.Duration) []*consent.Logi
RedirectURIs: []string{"http://redirect"},
},
RequestURL: "http://redirect",
- RequestedAt: time.Now().Round(time.Second).Add(-(lifespan + time.Minute)),
- AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Second).Add(-(lifespan + time.Minute))),
+ RequestedAt: time.Now().Round(time.Second).Add(-(lifespan + 10*time.Minute)),
+ AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Second).Add(-(lifespan + 10*time.Minute))),
Verifier: fmt.Sprintf("%s_flush-login-2", uniqueName),
},
{
@@ -724,8 +807,8 @@ func genLoginRequests(uniqueName string, lifespan time.Duration) []*consent.Logi
}
}
-func genConsentRequests(uniqueName string, lifespan time.Duration) []*consent.OAuth2ConsentRequest {
- return []*consent.OAuth2ConsentRequest{
+func genConsentRequests(uniqueName string, lifespan time.Duration) []*flow.OAuth2ConsentRequest {
+ return []*flow.OAuth2ConsentRequest{
{
ID: fmt.Sprintf("%s_flush-consent-1", uniqueName),
RequestedScope: []string{"foo", "bar"},
diff --git a/internal/testhelpers/oauth2.go b/internal/testhelpers/oauth2.go
index d637ee921ad..41f0ddaec8e 100644
--- a/internal/testhelpers/oauth2.go
+++ b/internal/testhelpers/oauth2.go
@@ -7,14 +7,17 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"net/http"
"net/http/cookiejar"
+ "net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
+ "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"github.com/ory/fosite/token/jwt"
@@ -26,8 +29,6 @@ import (
"github.com/ory/x/httpx"
"github.com/ory/x/ioutilx"
- "net/http/httptest"
-
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver"
"github.com/ory/hydra/v2/driver/config"
@@ -55,7 +56,7 @@ func NewIDTokenWithClaims(t *testing.T, reg driver.Registry, claims jwt.MapClaim
return token
}
-func NewOAuth2Server(ctx context.Context, t *testing.T, reg driver.Registry) (publicTS, adminTS *httptest.Server) {
+func NewOAuth2Server(ctx context.Context, t testing.TB, reg driver.Registry) (publicTS, adminTS *httptest.Server) {
// Lifespan is two seconds to avoid time synchronization issues with SQL.
reg.Config().MustSet(ctx, config.KeySubjectIdentifierAlgorithmSalt, "76d5d2bf-747f-4592-9fbd-d2b895a54b3a")
reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, time.Second*2)
@@ -66,19 +67,22 @@ func NewOAuth2Server(ctx context.Context, t *testing.T, reg driver.Registry) (pu
public, admin := x.NewRouterPublic(), x.NewRouterAdmin(reg.Config().AdminURL)
- publicTS = httptest.NewServer(public)
+ internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName)
+ internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName)
+
+ reg.RegisterRoutes(ctx, admin, public)
+
+ publicTS = httptest.NewServer(otelhttp.NewHandler(public, "public", otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
+ return r.URL.Path
+ })))
t.Cleanup(publicTS.Close)
- adminTS = httptest.NewServer(admin)
+ adminTS = httptest.NewServer(otelhttp.NewHandler(admin, "admin", otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
+ return r.URL.Path
+ })))
t.Cleanup(adminTS.Close)
reg.Config().MustSet(ctx, config.KeyIssuerURL, publicTS.URL)
- // SendDebugMessagesToClients: true,
-
- internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName)
- internal.MustEnsureRegistryKeys(reg, x.OAuth2JWTKeyName)
-
- reg.RegisterRoutes(ctx, admin, public)
return publicTS, adminTS
}
@@ -93,7 +97,7 @@ func DecodeIDToken(t *testing.T, token *oauth2.Token) gjson.Result {
return gjson.ParseBytes(body)
}
-func IntrospectToken(t *testing.T, conf *oauth2.Config, token string, adminTS *httptest.Server) gjson.Result {
+func IntrospectToken(t testing.TB, conf *oauth2.Config, token string, adminTS *httptest.Server) gjson.Result {
require.NotEmpty(t, token)
req := httpx.MustNewRequest("POST", adminTS.URL+"/admin/oauth2/introspect",
@@ -140,13 +144,13 @@ func HTTPServerNotImplementedHandler(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
-func HTTPServerNoExpectedCallHandler(t *testing.T) http.HandlerFunc {
+func HTTPServerNoExpectedCallHandler(t testing.TB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
t.Fatal("This should not have been called")
}
}
-func NewLoginConsentUI(t *testing.T, c *config.DefaultProvider, login, consent http.HandlerFunc) {
+func NewLoginConsentUI(t testing.TB, c *config.DefaultProvider, login, consent http.HandlerFunc) {
if login == nil {
login = HTTPServerNotImplementedHandler
}
@@ -165,7 +169,7 @@ func NewLoginConsentUI(t *testing.T, c *config.DefaultProvider, login, consent h
c.MustSet(context.Background(), config.KeyConsentURL, ct.URL)
}
-func NewCallbackURL(t *testing.T, prefix string, h http.HandlerFunc) string {
+func NewCallbackURL(t testing.TB, prefix string, h http.HandlerFunc) string {
if h == nil {
h = HTTPServerNotImplementedHandler
}
@@ -180,14 +184,35 @@ func NewCallbackURL(t *testing.T, prefix string, h http.HandlerFunc) string {
return ts.URL + "/" + prefix
}
-func NewEmptyCookieJar(t *testing.T) *cookiejar.Jar {
+func NewEmptyCookieJar(t testing.TB) *cookiejar.Jar {
c, err := cookiejar.New(&cookiejar.Options{})
require.NoError(t, err)
return c
}
-func NewEmptyJarClient(t *testing.T) *http.Client {
+func NewEmptyJarClient(t testing.TB) *http.Client {
return &http.Client{
- Jar: NewEmptyCookieJar(t),
+ Jar: NewEmptyCookieJar(t),
+ Transport: &loggingTransport{t},
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ //t.Logf("Redirect to %s", req.URL.String())
+
+ if len(via) >= 20 {
+ for k, v := range via {
+ t.Logf("Failed with redirect (%d): %s", k, v.URL.String())
+ }
+ return errors.New("stopped after 20 redirects")
+ }
+ return nil
+ },
}
}
+
+type loggingTransport struct{ t testing.TB }
+
+func (s *loggingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+ //s.t.Logf("%s %s", r.Method, r.URL.String())
+ //s.t.Logf("%s %s\nWith Cookies: %v", r.Method, r.URL.String(), r.Cookies())
+
+ return otelhttp.DefaultClient.Transport.RoundTrip(r)
+}
diff --git a/jwk/aead.go b/jwk/aead.go
deleted file mode 100644
index 081f5b87038..00000000000
--- a/jwk/aead.go
+++ /dev/null
@@ -1,101 +0,0 @@
-// Copyright © 2022 Ory Corp
-// SPDX-License-Identifier: Apache-2.0
-
-package jwk
-
-import (
- "context"
- "encoding/base64"
-
- "github.com/ory/x/errorsx"
-
- "github.com/ory/hydra/v2/driver/config"
-
- "github.com/gtank/cryptopasta"
- "github.com/pkg/errors"
-)
-
-type AEAD struct {
- c *config.DefaultProvider
-}
-
-func NewAEAD(c *config.DefaultProvider) *AEAD {
- return &AEAD{c: c}
-}
-
-func aeadKey(key []byte) *[32]byte {
- var result [32]byte
- copy(result[:], key[:32])
- return &result
-}
-
-func (c *AEAD) Encrypt(ctx context.Context, plaintext []byte) (string, error) {
- global, err := c.c.GetGlobalSecret(ctx)
- if err != nil {
- return "", err
- }
-
- rotated, err := c.c.GetRotatedGlobalSecrets(ctx)
- if err != nil {
- return "", err
- }
-
- keys := append([][]byte{global}, rotated...)
- if len(keys) == 0 {
- return "", errors.Errorf("at least one encryption key must be defined but none were")
- }
-
- if len(keys[0]) < 32 {
- return "", errors.Errorf("key must be exactly 32 long bytes, got %d bytes", len(keys[0]))
- }
-
- ciphertext, err := cryptopasta.Encrypt(plaintext, aeadKey(keys[0]))
- if err != nil {
- return "", errorsx.WithStack(err)
- }
-
- return base64.URLEncoding.EncodeToString(ciphertext), nil
-}
-
-func (c *AEAD) Decrypt(ctx context.Context, ciphertext string) (p []byte, err error) {
- global, err := c.c.GetGlobalSecret(ctx)
- if err != nil {
- return nil, err
- }
-
- rotated, err := c.c.GetRotatedGlobalSecrets(ctx)
- if err != nil {
- return nil, err
- }
-
- keys := append([][]byte{global}, rotated...)
- if len(keys) == 0 {
- return nil, errors.Errorf("at least one decryption key must be defined but none were")
- }
-
- for _, key := range keys {
- if p, err = c.decrypt(ciphertext, key); err == nil {
- return p, nil
- }
- }
-
- return nil, err
-}
-
-func (c *AEAD) decrypt(ciphertext string, key []byte) ([]byte, error) {
- if len(key) != 32 {
- return nil, errors.Errorf("key must be exactly 32 long bytes, got %d bytes", len(key))
- }
-
- raw, err := base64.URLEncoding.DecodeString(ciphertext)
- if err != nil {
- return nil, errorsx.WithStack(err)
- }
-
- plaintext, err := cryptopasta.Decrypt(raw, aeadKey(key))
- if err != nil {
- return nil, errorsx.WithStack(err)
- }
-
- return plaintext, nil
-}
diff --git a/jwk/aead_test.go b/jwk/aead_test.go
deleted file mode 100644
index 890918dde72..00000000000
--- a/jwk/aead_test.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright © 2022 Ory Corp
-// SPDX-License-Identifier: Apache-2.0
-
-package jwk_test
-
-import (
- "context"
- "crypto/rand"
- "fmt"
- "io"
- "testing"
-
- "github.com/ory/hydra/v2/driver/config"
- "github.com/ory/hydra/v2/internal"
- . "github.com/ory/hydra/v2/jwk"
-
- "github.com/pborman/uuid"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func secret(t *testing.T) string {
- bytes := make([]byte, 32)
- _, err := io.ReadFull(rand.Reader, bytes)
- require.NoError(t, err)
- return fmt.Sprintf("%X", bytes)
-}
-
-func TestAEAD(t *testing.T) {
- ctx := context.Background()
- c := internal.NewConfigurationWithDefaults()
- t.Run("case=without-rotation", func(t *testing.T) {
- c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
- a := NewAEAD(c)
-
- plain := []byte(uuid.New())
- ct, err := a.Encrypt(ctx, plain)
- assert.NoError(t, err)
-
- res, err := a.Decrypt(ctx, ct)
- assert.NoError(t, err)
- assert.Equal(t, plain, res)
- })
-
- t.Run("case=wrong-secret", func(t *testing.T) {
- c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
- a := NewAEAD(c)
-
- ct, err := a.Encrypt(ctx, []byte(uuid.New()))
- require.NoError(t, err)
-
- c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
- _, err = a.Decrypt(ctx, ct)
- require.Error(t, err)
- })
-
- t.Run("case=with-rotation", func(t *testing.T) {
- old := secret(t)
- c.MustSet(ctx, config.KeyGetSystemSecret, []string{old})
- a := NewAEAD(c)
-
- plain := []byte(uuid.New())
- ct, err := a.Encrypt(ctx, plain)
- require.NoError(t, err)
-
- // Sets the old secret as a rotated secret and creates a new one.
- c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t), old})
- res, err := a.Decrypt(ctx, ct)
- require.NoError(t, err)
- assert.Equal(t, plain, res)
-
- // THis should also work when we re-encrypt the same plain text.
- ct2, err := a.Encrypt(ctx, plain)
- require.NoError(t, err)
- assert.NotEqual(t, ct2, ct)
-
- res, err = a.Decrypt(ctx, ct)
- require.NoError(t, err)
- assert.Equal(t, plain, res)
- })
-
- t.Run("case=with-rotation-wrong-secret", func(t *testing.T) {
- c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)})
- a := NewAEAD(c)
-
- plain := []byte(uuid.New())
- ct, err := a.Encrypt(ctx, plain)
- require.NoError(t, err)
-
- // When the secrets do not match, an error should be thrown during decryption.
- c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t), secret(t)})
- _, err = a.Decrypt(ctx, ct)
- require.Error(t, err)
- })
-}
diff --git a/jwk/cast_test.go b/jwk/cast_test.go
index d55b81ba518..2e78283719d 100644
--- a/jwk/cast_test.go
+++ b/jwk/cast_test.go
@@ -14,6 +14,7 @@ import (
)
func TestMustRSAPrivate(t *testing.T) {
+ t.Parallel()
keys, err := GenerateJWK(context.Background(), jose.RS256, "foo", "sig")
require.NoError(t, err)
diff --git a/jwk/generate_test.go b/jwk/generate_test.go
index 01a47d4ec67..544a8de9d23 100644
--- a/jwk/generate_test.go
+++ b/jwk/generate_test.go
@@ -13,6 +13,7 @@ import (
)
func TestGenerateJWK(t *testing.T) {
+ t.Parallel()
jwks, err := GenerateJWK(context.Background(), jose.RS256, "", "")
require.NoError(t, err)
assert.NotEmpty(t, jwks.Keys[0].KeyID)
diff --git a/jwk/handler.go b/jwk/handler.go
index d5e87ea29fc..1bbfb9ecb81 100644
--- a/jwk/handler.go
+++ b/jwk/handler.go
@@ -36,6 +36,8 @@ type Handler struct {
// JSON Web Key Set
//
// swagger:model jsonWebKeySet
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type jsonWebKeySet struct {
// List of JSON Web Keys
//
@@ -114,6 +116,8 @@ func (h *Handler) discoverJsonWebKeys(w http.ResponseWriter, r *http.Request) {
// Get JSON Web Key Request
//
// swagger:parameters getJsonWebKey
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type getJsonWebKey struct {
// JSON Web Key Set ID
//
@@ -162,6 +166,8 @@ func (h *Handler) getJsonWebKey(w http.ResponseWriter, r *http.Request, ps httpr
// Get JSON Web Key Set Parameters
//
// swagger:parameters getJsonWebKeySet
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type getJsonWebKeySet struct {
// JSON Web Key Set ID
//
@@ -205,6 +211,8 @@ func (h *Handler) getJsonWebKeySet(w http.ResponseWriter, r *http.Request, ps ht
// Create JSON Web Key Set Request
//
// swagger:parameters createJsonWebKeySet
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type adminCreateJsonWebKeySet struct {
// The JSON Web Key Set ID
//
@@ -283,6 +291,8 @@ func (h *Handler) createJsonWebKeySet(w http.ResponseWriter, r *http.Request, ps
// Set JSON Web Key Set Request
//
// swagger:parameters setJsonWebKeySet
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type setJsonWebKeySet struct {
// The JSON Web Key Set ID
//
@@ -333,6 +343,8 @@ func (h *Handler) setJsonWebKeySet(w http.ResponseWriter, r *http.Request, ps ht
// Set JSON Web Key Request
//
// swagger:parameters setJsonWebKey
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type setJsonWebKey struct {
// The JSON Web Key Set ID
//
@@ -389,6 +401,8 @@ func (h *Handler) adminUpdateJsonWebKey(w http.ResponseWriter, r *http.Request,
// Delete JSON Web Key Set Parameters
//
// swagger:parameters deleteJsonWebKeySet
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type deleteJsonWebKeySet struct {
// The JSON Web Key Set
// in: path
@@ -429,6 +443,8 @@ func (h *Handler) adminDeleteJsonWebKeySet(w http.ResponseWriter, r *http.Reques
// Delete JSON Web Key Parameters
//
// swagger:parameters deleteJsonWebKey
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type deleteJsonWebKey struct {
// The JSON Web Key Set
// in: path
diff --git a/jwk/handler_test.go b/jwk/handler_test.go
index 09613ca42b9..c9040a37a4c 100644
--- a/jwk/handler_test.go
+++ b/jwk/handler_test.go
@@ -25,6 +25,8 @@ import (
)
func TestHandlerWellKnown(t *testing.T) {
+ t.Parallel()
+
conf := internal.NewConfigurationWithDefaults()
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
conf.MustSet(context.Background(), config.KeyWellKnownKeys, []string{x.OpenIDConnectKeyName, x.OpenIDConnectKeyName})
@@ -37,6 +39,7 @@ func TestHandlerWellKnown(t *testing.T) {
JWKPath := "/.well-known/jwks.json"
t.Run("Test_Handler_WellKnown/Run_public_key_With_public_prefix", func(t *testing.T) {
+ t.Parallel()
if conf.HSMEnabled() {
t.Skip("Skipping test. Not applicable when Hardware Security Module is enabled. Public/private keys on HSM are generated with equal key id's and are not using prefixes")
}
@@ -62,6 +65,7 @@ func TestHandlerWellKnown(t *testing.T) {
})
t.Run("Test_Handler_WellKnown/Run_public_key_Without_public_prefix", func(t *testing.T) {
+ t.Parallel()
var IDKS *jose.JSONWebKeySet
if conf.HSMEnabled() {
diff --git a/jwk/helper_test.go b/jwk/helper_test.go
index b724349d515..d0ce928c3ec 100644
--- a/jwk/helper_test.go
+++ b/jwk/helper_test.go
@@ -6,7 +6,7 @@ package jwk_test
import (
"context"
"crypto"
- "crypto/dsa"
+ "crypto/dsa" //lint:ignore SA1019 used for testing invalid key types
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
@@ -46,7 +46,10 @@ func (f *fakeSigner) Public() crypto.PublicKey {
}
func TestHandlerFindPublicKey(t *testing.T) {
+ t.Parallel()
+
t.Run("Test_Helper/Run_FindPublicKey_With_RSA", func(t *testing.T) {
+ t.Parallel()
RSIDKS, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
require.NoError(t, err)
keys, err := jwk.FindPublicKey(RSIDKS)
@@ -56,6 +59,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
})
t.Run("Test_Helper/Run_FindPublicKey_With_Opaque", func(t *testing.T) {
+ t.Parallel()
key, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
RSIDKS := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{
Algorithm: "RS256",
@@ -82,6 +86,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
})
t.Run("Test_Helper/Run_FindPublicKey_With_ECDSA", func(t *testing.T) {
+ t.Parallel()
ECDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.ES256, "test-id-2", "sig")
require.NoError(t, err)
keys, err := jwk.FindPublicKey(ECDSAIDKS)
@@ -91,6 +96,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
})
t.Run("Test_Helper/Run_FindPublicKey_With_EdDSA", func(t *testing.T) {
+ t.Parallel()
EdDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.EdDSA, "test-id-3", "sig")
require.NoError(t, err)
keys, err := jwk.FindPublicKey(EdDSAIDKS)
@@ -100,6 +106,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
})
t.Run("Test_Helper/Run_FindPublicKey_With_KeyNotFound", func(t *testing.T) {
+ t.Parallel()
keySet := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}}
_, err := jwk.FindPublicKey(keySet)
require.Error(t, err)
@@ -108,6 +115,7 @@ func TestHandlerFindPublicKey(t *testing.T) {
}
func TestHandlerFindPrivateKey(t *testing.T) {
+ t.Parallel()
t.Run("Test_Helper/Run_FindPrivateKey_With_RSA", func(t *testing.T) {
RSIDKS, _ := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
keys, err := jwk.FindPrivateKey(RSIDKS)
@@ -143,6 +151,7 @@ func TestHandlerFindPrivateKey(t *testing.T) {
}
func TestPEMBlockForKey(t *testing.T) {
+ t.Parallel()
t.Run("Test_Helper/Run_PEMBlockForKey_With_RSA", func(t *testing.T) {
RSIDKS, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
require.NoError(t, err)
@@ -185,6 +194,7 @@ func TestPEMBlockForKey(t *testing.T) {
}
func TestExcludeOpaquePrivateKeys(t *testing.T) {
+ t.Parallel()
opaqueKeys, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig")
assert.NoError(t, err)
require.Len(t, opaqueKeys.Keys, 1)
@@ -199,6 +209,7 @@ func TestExcludeOpaquePrivateKeys(t *testing.T) {
}
func TestGetOrGenerateKeys(t *testing.T) {
+ t.Parallel()
reg := internal.NewMockedRegistry(t, &contextx.Default{})
setId := uuid.NewUUID().String()
diff --git a/jwk/manager_strategy_test.go b/jwk/manager_strategy_test.go
index e138f30e072..6fb8db03bbb 100644
--- a/jwk/manager_strategy_test.go
+++ b/jwk/manager_strategy_test.go
@@ -17,6 +17,7 @@ import (
)
func TestKeyManagerStrategy(t *testing.T) {
+ t.Parallel()
ctrl := gomock.NewController(t)
softwareKeyManager := NewMockManager(ctrl)
hardwareKeyManager := NewMockManager(ctrl)
diff --git a/jwk/registry.go b/jwk/registry.go
index 1d7b4355f8c..b5c3ea8d811 100644
--- a/jwk/registry.go
+++ b/jwk/registry.go
@@ -4,6 +4,7 @@
package jwk
import (
+ "github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/x"
)
@@ -18,5 +19,5 @@ type Registry interface {
config.Provider
KeyManager() Manager
SoftwareKeyManager() Manager
- KeyCipher() *AEAD
+ KeyCipher() *aead.AESGCM
}
diff --git a/jwk/registry_mock_test.go b/jwk/registry_mock_test.go
index d6295f11b31..c305fd18167 100644
--- a/jwk/registry_mock_test.go
+++ b/jwk/registry_mock_test.go
@@ -13,6 +13,7 @@ import (
gomock "github.com/golang/mock/gomock"
herodot "github.com/ory/herodot"
+ "github.com/ory/hydra/v2/aead"
config "github.com/ory/hydra/v2/driver/config"
jwk "github.com/ory/hydra/v2/jwk"
logrusx "github.com/ory/x/logrusx"
@@ -70,10 +71,10 @@ func (mr *MockInternalRegistryMockRecorder) Config() *gomock.Call {
}
// KeyCipher mocks base method.
-func (m *MockInternalRegistry) KeyCipher() *jwk.AEAD {
+func (m *MockInternalRegistry) KeyCipher() *aead.AESGCM {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeyCipher")
- ret0, _ := ret[0].(*jwk.AEAD)
+ ret0, _ := ret[0].(*aead.AESGCM)
return ret0
}
@@ -177,10 +178,10 @@ func (mr *MockRegistryMockRecorder) Config() *gomock.Call {
}
// KeyCipher mocks base method.
-func (m *MockRegistry) KeyCipher() *jwk.AEAD {
+func (m *MockRegistry) KeyCipher() *aead.AESGCM {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeyCipher")
- ret0, _ := ret[0].(*jwk.AEAD)
+ ret0, _ := ret[0].(*aead.AESGCM)
return ret0
}
diff --git a/jwk/sdk_test.go b/jwk/sdk_test.go
index 571e8ce5b05..46d1cc81448 100644
--- a/jwk/sdk_test.go
+++ b/jwk/sdk_test.go
@@ -24,6 +24,7 @@ import (
)
func TestJWKSDK(t *testing.T) {
+ t.Parallel()
ctx := context.Background()
conf := internal.NewConfigurationWithDefaults()
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
@@ -41,6 +42,7 @@ func TestJWKSDK(t *testing.T) {
expectedKid := "key-bar"
t.Run("JSON Web Key", func(t *testing.T) {
+ t.Parallel()
t.Run("CreateJwkSetKey", func(t *testing.T) {
// Create a key called set-foo
resultKeys, _, err := sdk.JwkApi.CreateJsonWebKeySet(context.Background(), "set-foo").CreateJsonWebKeySet(hydra.CreateJsonWebKeySet{
@@ -93,6 +95,7 @@ func TestJWKSDK(t *testing.T) {
})
t.Run("JWK Set", func(t *testing.T) {
+ t.Parallel()
t.Run("CreateJwkSetKey", func(t *testing.T) {
resultKeys, _, err := sdk.JwkApi.CreateJsonWebKeySet(ctx, "set-foo2").CreateJsonWebKeySet(hydra.CreateJsonWebKeySet{
Alg: "RS256",
diff --git a/oauth2/flowctx/cookies.go b/oauth2/flowctx/cookies.go
new file mode 100644
index 00000000000..00ae91aeef0
--- /dev/null
+++ b/oauth2/flowctx/cookies.go
@@ -0,0 +1,38 @@
+// Copyright © 2023 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package flowctx
+
+import "github.com/ory/hydra/v2/client"
+
+type (
+ CookieSuffixer interface {
+ CookieSuffix() string
+ }
+
+ StaticSuffix string
+ clientID string
+)
+
+func (s StaticSuffix) CookieSuffix() string { return string(s) }
+func (s clientID) GetID() string { return string(s) }
+
+const (
+ flowCookie = "ory_hydra_flow"
+ loginSessionCookie = "ory_hydra_loginsession"
+)
+
+func FlowCookie(suffix CookieSuffixer) string {
+ return flowCookie + "_" + suffix.CookieSuffix()
+}
+func LoginSessionCookie(suffix CookieSuffixer) string {
+ return loginSessionCookie + "_" + suffix.CookieSuffix()
+}
+
+func SuffixForClient(c client.IDer) StaticSuffix {
+ return StaticSuffix(client.CookieSuffix(c))
+}
+
+func SuffixFromStatic(id string) StaticSuffix {
+ return SuffixForClient(clientID(id))
+}
diff --git a/oauth2/flowctx/encoding.go b/oauth2/flowctx/encoding.go
new file mode 100644
index 00000000000..5b01b5ec1cd
--- /dev/null
+++ b/oauth2/flowctx/encoding.go
@@ -0,0 +1,150 @@
+// Copyright © 2023 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package flowctx
+
+import (
+ "bytes"
+ "compress/gzip"
+ "context"
+ "encoding/json"
+ "net/http"
+
+ "github.com/pkg/errors"
+
+ "github.com/ory/fosite"
+ "github.com/ory/hydra/v2/aead"
+ "github.com/ory/hydra/v2/driver/config"
+)
+
+type (
+ data struct {
+ Purpose purpose `json:"p,omitempty"`
+ }
+ purpose int
+ CodecOption func(ad *data)
+)
+
+const (
+ loginChallenge purpose = iota
+ loginVerifier
+ consentChallenge
+ consentVerifier
+)
+
+func withPurpose(purpose purpose) CodecOption { return func(ad *data) { ad.Purpose = purpose } }
+
+var (
+ AsLoginChallenge = withPurpose(loginChallenge)
+ AsLoginVerifier = withPurpose(loginVerifier)
+ AsConsentChallenge = withPurpose(consentChallenge)
+ AsConsentVerifier = withPurpose(consentVerifier)
+)
+
+func additionalDataFromOpts(opts ...CodecOption) []byte {
+ if len(opts) == 0 {
+ return nil
+ }
+ ad := &data{}
+ for _, o := range opts {
+ o(ad)
+ }
+ b, err := json.Marshal(ad)
+ if err != nil {
+ // Panic is OK here because the struct and the parameters are all known.
+ panic("failed to marshal additional data: " + errors.WithStack(err).Error())
+ }
+
+ return b
+}
+
+// Decode decodes the given string to a value.
+func Decode[T any](ctx context.Context, cipher aead.Cipher, encoded string, opts ...CodecOption) (*T, error) {
+ plaintext, err := cipher.Decrypt(ctx, encoded, additionalDataFromOpts(opts...))
+ if err != nil {
+ return nil, err
+ }
+
+ rawBytes, err := gzip.NewReader(bytes.NewReader(plaintext))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rawBytes.Close() }()
+
+ var val T
+ if err = json.NewDecoder(rawBytes).Decode(&val); err != nil {
+ return nil, err
+ }
+
+ return &val, nil
+}
+
+// Encode encodes the given value to a string.
+func Encode(ctx context.Context, cipher aead.Cipher, val any, opts ...CodecOption) (s string, err error) {
+ // Steps:
+ // 1. Encode to JSON
+ // 2. GZIP
+ // 3. Encrypt with AEAD (AES-GCM) + Base64 URL-encode
+ var b bytes.Buffer
+
+ gz := gzip.NewWriter(&b)
+
+ if err = json.NewEncoder(gz).Encode(val); err != nil {
+ return "", err
+ }
+ if err = gz.Close(); err != nil {
+ return "", err
+ }
+
+ return cipher.Encrypt(ctx, b.Bytes(), additionalDataFromOpts(opts...))
+}
+
+// SetCookie encrypts the given value and sets it in a cookie.
+func SetCookie(ctx context.Context, w http.ResponseWriter, reg interface {
+ FlowCipher() *aead.XChaCha20Poly1305
+ config.Provider
+}, cookieName string, value any, opts ...CodecOption) error {
+ cipher := reg.FlowCipher()
+ cookie, err := Encode(ctx, cipher, value, opts...)
+ if err != nil {
+ return err
+ }
+
+ http.SetCookie(w, &http.Cookie{
+ Name: cookieName,
+ Value: cookie,
+ HttpOnly: true,
+ Domain: reg.Config().CookieDomain(ctx),
+ Secure: reg.Config().CookieSecure(ctx),
+ SameSite: reg.Config().CookieSameSiteMode(ctx),
+ })
+
+ return nil
+}
+
+// DeleteCookie deletes the flow cookie.
+func DeleteCookie(ctx context.Context, w http.ResponseWriter, reg interface {
+ config.Provider
+}, cookieName string) error {
+ http.SetCookie(w, &http.Cookie{
+ Name: cookieName,
+ Value: "",
+ MaxAge: -1,
+ HttpOnly: true,
+ Domain: reg.Config().CookieDomain(ctx),
+ Secure: reg.Config().CookieSecure(ctx),
+ SameSite: reg.Config().CookieSameSiteMode(ctx),
+ })
+
+ return nil
+}
+
+// FromCookie looks up the value stored in the cookie and decodes it.
+func FromCookie[T any](ctx context.Context, r *http.Request, cipher aead.Cipher, cookieName string, opts ...CodecOption) (*T, error) {
+ cookie, err := r.Cookie(cookieName)
+ if err != nil {
+ return nil, errors.WithStack(fosite.ErrInvalidClient.WithHint("No cookie found for this request. Please initiate a new flow and retry."))
+ }
+
+ return Decode[T](ctx, cipher, cookie.Value, opts...)
+}
diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go
index a885b73bb16..0a2c670fe03 100644
--- a/oauth2/fosite_store_helpers.go
+++ b/oauth2/fosite_store_helpers.go
@@ -11,6 +11,7 @@ import (
"testing"
"time"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/jwk"
"github.com/gobuffalo/pop/v6"
@@ -36,7 +37,6 @@ import (
"github.com/ory/x/sqlcon"
"github.com/ory/hydra/v2/client"
- "github.com/ory/hydra/v2/consent"
)
func signatureFromJTI(jti string) string {
@@ -121,9 +121,9 @@ var flushRequests = []*fosite.Request{
func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry, createClient bool) {
cl := &client.Client{LegacyClientID: "foobar"}
- cr := &consent.OAuth2ConsentRequest{
+ cr := &flow.OAuth2ConsentRequest{
Client: cl,
- OpenIDConnectContext: new(consent.OAuth2ConsentRequestOpenIDConnectContext),
+ OpenIDConnectContext: new(flow.OAuth2ConsentRequestOpenIDConnectContext),
LoginChallenge: sqlxx.NullString(id),
ID: id,
Verifier: id,
@@ -132,18 +132,36 @@ func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry, createCl
RequestedAt: time.Now(),
}
+ ctx := context.Background()
if createClient {
- require.NoError(t, x.ClientManager().CreateClient(context.Background(), cl))
+ require.NoError(t, x.ClientManager().CreateClient(ctx, cl))
}
- require.NoError(t, x.ConsentManager().CreateLoginRequest(context.Background(), &consent.LoginRequest{Client: cl, OpenIDConnectContext: new(consent.OAuth2ConsentRequestOpenIDConnectContext), ID: id, Verifier: id, AuthenticatedAt: sqlxx.NullTime(time.Now()), RequestedAt: time.Now()}))
- require.NoError(t, x.ConsentManager().CreateConsentRequest(context.Background(), cr))
- _, err := x.ConsentManager().HandleConsentRequest(context.Background(), &consent.AcceptOAuth2ConsentRequest{
- ConsentRequest: cr, Session: new(consent.AcceptOAuth2ConsentRequestSession), AuthenticatedAt: sqlxx.NullTime(time.Now()),
- ID: id,
- RequestedAt: time.Now(),
- HandledAt: sqlxx.NullTime(time.Now()),
+ f, err := x.ConsentManager().CreateLoginRequest(
+ ctx, &flow.LoginRequest{
+ Client: cl,
+ OpenIDConnectContext: new(flow.OAuth2ConsentRequestOpenIDConnectContext),
+ ID: id,
+ Verifier: id,
+ AuthenticatedAt: sqlxx.NullTime(time.Now()),
+ RequestedAt: time.Now(),
+ })
+ require.NoError(t, err)
+ err = x.ConsentManager().CreateConsentRequest(ctx, f, cr)
+ require.NoError(t, err)
+
+ encodedFlow, err := f.ToConsentVerifier(ctx, x)
+ require.NoError(t, err)
+
+ _, err = x.ConsentManager().HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{
+ ConsentRequest: cr,
+ Session: new(flow.AcceptOAuth2ConsentRequestSession),
+ AuthenticatedAt: sqlxx.NullTime(time.Now()),
+ ID: encodedFlow,
+ RequestedAt: time.Now(),
+ HandledAt: sqlxx.NullTime(time.Now()),
})
+
require.NoError(t, err)
}
@@ -270,10 +288,18 @@ func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) {
mockRequestForeignKey(t, reqIdOne, x, false)
mockRequestForeignKey(t, reqIdTwo, x, false)
- err = m.CreateRefreshTokenSession(ctx, "1111", &fosite.Request{ID: reqIdOne, Client: &client.Client{LegacyClientID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), Session: &Session{}})
+ err = m.CreateRefreshTokenSession(ctx, "1111", &fosite.Request{
+ ID: reqIdOne,
+ Client: &client.Client{LegacyClientID: "foobar"},
+ RequestedAt: time.Now().UTC().Round(time.Second),
+ Session: &Session{}})
require.NoError(t, err)
- err = m.CreateRefreshTokenSession(ctx, "1122", &fosite.Request{ID: reqIdTwo, Client: &client.Client{LegacyClientID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), Session: &Session{}})
+ err = m.CreateRefreshTokenSession(ctx, "1122", &fosite.Request{
+ ID: reqIdTwo,
+ Client: &client.Client{LegacyClientID: "foobar"},
+ RequestedAt: time.Now().UTC().Round(time.Second),
+ Session: &Session{}})
require.NoError(t, err)
_, err = m.GetRefreshTokenSession(ctx, "1111", &Session{})
diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go
index 70efbc3f4b2..767f33444d5 100644
--- a/oauth2/fosite_store_test.go
+++ b/oauth2/fosite_store_test.go
@@ -23,8 +23,8 @@ import (
func TestMain(m *testing.M) {
flag.Parse()
- runner := dockertest.Register()
- runner.Exit(m.Run())
+ defer dockertest.KillAllTestDatabases()
+ m.Run()
}
var registries = make(map[string]driver.Registry)
diff --git a/oauth2/handler.go b/oauth2/handler.go
index c9eaaaa0c62..6d473b4e1eb 100644
--- a/oauth2/handler.go
+++ b/oauth2/handler.go
@@ -58,7 +58,10 @@ type Handler struct {
}
func NewHandler(r InternalRegistry, c *config.DefaultProvider) *Handler {
- return &Handler{r: r, c: c}
+ return &Handler{
+ r: r,
+ c: c,
+ }
}
func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin, public *httprouterx.RouterPublic, corsMiddleware func(http.Handler) http.Handler) {
@@ -460,6 +463,8 @@ func (h *Handler) discoverOidcConfiguration(w http.ResponseWriter, r *http.Reque
// OpenID Connect Userinfo
//
// swagger:model oidcUserInfo
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type oidcUserInfo struct {
// Subject - Identifier for the End-User at the IssuerURL.
Subject string `json:"sub"`
@@ -623,6 +628,8 @@ func (h *Handler) getOidcUserInfo(w http.ResponseWriter, r *http.Request) {
// Revoke OAuth 2.0 Access or Refresh Token Request
//
// swagger:parameters revokeOAuth2Token
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type revokeOAuth2Token struct {
// in: formData
// required: true
@@ -668,6 +675,8 @@ func (h *Handler) revokeOAuth2Token(w http.ResponseWriter, r *http.Request) {
// Introspect OAuth 2.0 Access or Refresh Token Request
//
// swagger:parameters introspectOAuth2Token
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type introspectOAuth2Token struct {
// The string value of the token. For access tokens, this
// is the "access_token" value returned from the token endpoint
@@ -796,6 +805,8 @@ func (h *Handler) introspectOAuth2Token(w http.ResponseWriter, r *http.Request,
// OAuth 2.0 Token Exchange Parameters
//
// swagger:parameters oauth2TokenExchange
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type performOAuth2TokenFlow struct {
// in: formData
// required: true
@@ -817,6 +828,8 @@ type performOAuth2TokenFlow struct {
// OAuth2 Token Exchange Result
//
// swagger:model oAuth2TokenExchange
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type oAuth2TokenExchange struct {
// The lifetime in seconds of the access token. For
// example, the value "3600" denotes that the access token will
@@ -865,8 +878,8 @@ type oAuth2TokenExchange struct {
// 200: oAuth2TokenExchange
// default: errorOAuth2
func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
- var session = NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims(r.Context()))
- var ctx = r.Context()
+ session := NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims(r.Context()))
+ ctx := r.Context()
accessRequest, err := h.r.OAuth2Provider().NewAccessRequest(ctx, r, session)
if err != nil {
@@ -895,7 +908,7 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(r.Context()).String()
session.DefaultSession.Claims.IssuedAt = time.Now().UTC()
- var scopes = accessRequest.GetRequestedScopes()
+ scopes := accessRequest.GetRequestedScopes()
// Added for compatibility with MITREid
if h.c.GrantAllClientCredentialsScopesPerDefault(r.Context()) && len(scopes) == 0 {
@@ -962,7 +975,7 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http
return
}
- session, err := h.r.ConsentStrategy().HandleOAuth2AuthorizationRequest(ctx, w, r, authorizeRequest)
+ session, flow, err := h.r.ConsentStrategy().HandleOAuth2AuthorizationRequest(ctx, w, r, authorizeRequest)
if errors.Is(err, consent.ErrAbortOAuth2Request) {
x.LogAudit(r, nil, h.r.AuditLogger())
// do nothing
@@ -1049,6 +1062,7 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http
ConsentChallenge: session.ID,
ExcludeNotBeforeClaim: h.c.ExcludeNotBeforeClaim(ctx),
AllowedTopLevelClaims: h.c.AllowedTopLevelClaims(ctx),
+ Flow: flow,
})
if err != nil {
x.LogError(r, err, h.r.Logger())
@@ -1062,6 +1076,8 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http
// Delete OAuth 2.0 Access Token Parameters
//
// swagger:parameters deleteOAuth2Token
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type deleteOAuth2Token struct {
// OAuth 2.0 Client ID
//
diff --git a/oauth2/handler_test.go b/oauth2/handler_test.go
index 74a8751a5eb..cc10d429127 100644
--- a/oauth2/handler_test.go
+++ b/oauth2/handler_test.go
@@ -93,7 +93,7 @@ func TestUserinfo(t *testing.T) {
conf.MustSet(ctx, config.KeyAuthCodeLifespan, lifespan)
conf.MustSet(ctx, config.KeyIssuerURL, "http://hydra.localhost")
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
- internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName)
+ internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName)
ctrl := gomock.NewController(t)
op := NewMockOAuth2Provider(ctrl)
@@ -147,8 +147,8 @@ func TestUserinfo(t *testing.T) {
setup: func(t *testing.T) {
op.EXPECT().
IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()).
- DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
- session = &oauth2.Session{
+ DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
+ session := &oauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "alice",
@@ -180,8 +180,8 @@ func TestUserinfo(t *testing.T) {
setup: func(t *testing.T) {
op.EXPECT().
IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()).
- DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
- session = &oauth2.Session{
+ DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
+ session := &oauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "another-alice",
@@ -215,8 +215,8 @@ func TestUserinfo(t *testing.T) {
setup: func(t *testing.T) {
op.EXPECT().
IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()).
- DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
- session = &oauth2.Session{
+ DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
+ session := &oauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "alice",
@@ -250,8 +250,8 @@ func TestUserinfo(t *testing.T) {
setup: func(t *testing.T) {
op.EXPECT().
IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()).
- DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
- session = &oauth2.Session{
+ DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
+ session := &oauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "alice",
@@ -278,8 +278,8 @@ func TestUserinfo(t *testing.T) {
setup: func(t *testing.T) {
op.EXPECT().
IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()).
- DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
- session = &oauth2.Session{
+ DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) {
+ session := &oauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "alice",
diff --git a/oauth2/introspector_test.go b/oauth2/introspector_test.go
index 1905476b9f6..6511a77e33e 100644
--- a/oauth2/introspector_test.go
+++ b/oauth2/introspector_test.go
@@ -35,7 +35,7 @@ func TestIntrospectorSDK(t *testing.T) {
conf.MustSet(ctx, config.KeyIssuerURL, "https://foobariss")
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
- internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName)
+ internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName)
internal.AddFositeExamples(reg)
tokens := Tokens(reg.OAuth2ProviderConfig(), 4)
diff --git a/oauth2/oauth2_auth_code_bench_test.go b/oauth2/oauth2_auth_code_bench_test.go
new file mode 100644
index 00000000000..92b54f8ff81
--- /dev/null
+++ b/oauth2/oauth2_auth_code_bench_test.go
@@ -0,0 +1,305 @@
+// Copyright © 2022 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package oauth2_test
+
+import (
+ "context"
+ "flag"
+ "net/http"
+ "os"
+ "runtime"
+ "runtime/pprof"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/pborman/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
+ "go.opentelemetry.io/otel/propagation"
+ "go.opentelemetry.io/otel/sdk/resource"
+ "go.opentelemetry.io/otel/sdk/trace"
+ "go.opentelemetry.io/otel/sdk/trace/tracetest"
+ semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
+ "golang.org/x/oauth2"
+ "gopkg.in/square/go-jose.v2"
+
+ hydra "github.com/ory/hydra-client-go/v2"
+ hc "github.com/ory/hydra/v2/client"
+ "github.com/ory/hydra/v2/driver/config"
+ "github.com/ory/hydra/v2/internal"
+ "github.com/ory/hydra/v2/internal/testhelpers"
+ "github.com/ory/hydra/v2/jwk"
+ "github.com/ory/hydra/v2/x"
+ "github.com/ory/x/contextx"
+ "github.com/ory/x/pointerx"
+ "github.com/ory/x/stringsx"
+)
+
+var (
+ prof = flag.String("profile", "", "write a CPU profile to this filename")
+ conc = flag.Int("conc", 100, "dispatch this many requests concurrently")
+ tracing = flag.Bool("tracing", false, "send OpenTelemetry traces to localhost:4318")
+)
+
+func BenchmarkAuthCode(b *testing.B) {
+ flag.Parse()
+
+ ctx := context.Background()
+
+ spans := tracetest.NewSpanRecorder()
+ opts := []trace.TracerProviderOption{
+ trace.WithSpanProcessor(spans),
+ trace.WithResource(resource.NewWithAttributes(
+ semconv.SchemaURL, attribute.String(string(semconv.ServiceNameKey), "BenchmarkAuthCode"),
+ )),
+ }
+ if *tracing {
+ exporter, err := otlptracehttp.New(ctx, otlptracehttp.WithInsecure(), otlptracehttp.WithEndpoint("localhost:4318"))
+ require.NoError(b, err)
+ opts = append(opts, trace.WithSpanProcessor(trace.NewSimpleSpanProcessor(exporter)))
+ }
+ provider := trace.NewTracerProvider(opts...)
+
+ tracer := provider.Tracer("BenchmarkAuthCode")
+ otel.SetTextMapPropagator(propagation.TraceContext{})
+ otel.SetTracerProvider(provider)
+
+ ctx, span := tracer.Start(ctx, "BenchmarkAuthCode")
+ defer span.End()
+
+ ctx = context.WithValue(ctx, oauth2.HTTPClient, otelhttp.DefaultClient)
+
+ dsn := stringsx.Coalesce(os.Getenv("DSN"), "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable&max_conns=20&max_idle_conns=20")
+ // dsn := "mysql://root:secret@tcp(localhost:3444)/mysql?max_conns=16&max_idle_conns=16"
+ // dsn := "cockroach://root@localhost:3446/defaultdb?sslmode=disable&max_conns=16&max_idle_conns=16"
+ reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer)
+ reg.Config().MustSet(ctx, config.KeyLogLevel, "error")
+ reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
+ reg.Config().MustSet(ctx, config.KeyRefreshTokenHookURL, "")
+ oauth2Keys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OAuth2JWTKeyName, "sig")
+ require.NoError(b, err)
+ oidcKeys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OpenIDConnectKeyName, "sig")
+ require.NoError(b, err)
+ _, _ = oauth2Keys, oidcKeys
+ require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OAuth2JWTKeyName, oauth2Keys))
+ require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OpenIDConnectKeyName, oidcKeys))
+ _, adminTS := testhelpers.NewOAuth2Server(ctx, b, reg)
+ var (
+ authURL = reg.Config().OAuth2AuthURL(ctx).String()
+ tokenURL = reg.Config().OAuth2TokenURL(ctx).String()
+ nonce = uuid.New()
+ )
+
+ newOAuth2Client := func(b *testing.B, cb string) (*hc.Client, *oauth2.Config) {
+ secret := uuid.New()
+ c := &hc.Client{
+ Secret: secret,
+ RedirectURIs: []string{cb},
+ ResponseTypes: []string{"id_token", "code", "token"},
+ GrantTypes: []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"},
+ Scope: "hydra offline openid",
+ Audience: []string{"https://api.ory.sh/"},
+ }
+ require.NoError(b, reg.ClientManager().CreateClient(ctx, c))
+ return c, &oauth2.Config{
+ ClientID: c.GetID(),
+ ClientSecret: secret,
+ Endpoint: oauth2.Endpoint{
+ AuthURL: authURL,
+ TokenURL: tokenURL,
+ AuthStyle: oauth2.AuthStyleInHeader,
+ },
+ Scopes: strings.Split(c.Scope, " "),
+ }
+ }
+
+ cfg := hydra.NewConfiguration()
+ cfg.HTTPClient = otelhttp.DefaultClient
+ adminClient := hydra.NewAPIClient(cfg)
+ adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}}
+
+ getAuthorizeCode := func(ctx context.Context, b *testing.B, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) {
+ if c == nil {
+ c = testhelpers.NewEmptyJarClient(b)
+ }
+
+ state := uuid.New()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", conf.AuthCodeURL(state, params...), nil)
+ require.NoError(b, err)
+ resp, err := c.Do(req)
+ require.NoError(b, err)
+ defer resp.Body.Close()
+
+ q := resp.Request.URL.Query()
+ require.EqualValues(b, state, q.Get("state"))
+ return q.Get("code"), resp
+ }
+
+ acceptLoginHandler := func(b *testing.B, c *hc.Client, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc {
+ return otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ rr, _, err := adminClient.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute()
+ require.NoError(b, err)
+
+ assert.EqualValues(b, c.GetID(), pointerx.Deref(rr.Client.ClientId))
+ assert.Empty(b, pointerx.Deref(rr.Client.ClientSecret))
+ assert.EqualValues(b, c.GrantTypes, rr.Client.GrantTypes)
+ assert.EqualValues(b, c.LogoURI, pointerx.Deref(rr.Client.LogoUri))
+ assert.EqualValues(b, c.RedirectURIs, rr.Client.RedirectUris)
+ assert.EqualValues(b, r.URL.Query().Get("login_challenge"), rr.Challenge)
+ assert.EqualValues(b, []string{"hydra", "offline", "openid"}, rr.RequestedScope)
+ assert.Contains(b, rr.RequestUrl, authURL)
+
+ acceptBody := hydra.AcceptOAuth2LoginRequest{
+ Subject: uuid.New(),
+ Remember: pointerx.Ptr(!rr.Skip),
+ Acr: pointerx.Ptr("1"),
+ Amr: []string{"pwd"},
+ Context: map[string]interface{}{"context": "bar"},
+ }
+ if checkRequestPayload != nil {
+ if b := checkRequestPayload(rr); b != nil {
+ acceptBody = *b
+ }
+ }
+
+ v, _, err := adminClient.OAuth2Api.AcceptOAuth2LoginRequest(ctx).
+ LoginChallenge(r.URL.Query().Get("login_challenge")).
+ AcceptOAuth2LoginRequest(acceptBody).
+ Execute()
+ require.NoError(b, err)
+ require.NotEmpty(b, v.RedirectTo)
+ http.Redirect(w, r, v.RedirectTo, http.StatusFound)
+ }), "acceptLoginHandler").ServeHTTP
+ }
+
+ acceptConsentHandler := func(b *testing.B, c *hc.Client, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc {
+ return otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ rr, _, err := adminClient.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute()
+ require.NoError(b, err)
+
+ assert.EqualValues(b, c.GetID(), pointerx.Deref(rr.Client.ClientId))
+ assert.Empty(b, pointerx.Deref(rr.Client.ClientSecret))
+ assert.EqualValues(b, c.GrantTypes, rr.Client.GrantTypes)
+ assert.EqualValues(b, c.LogoURI, pointerx.Deref(rr.Client.LogoUri))
+ assert.EqualValues(b, c.RedirectURIs, rr.Client.RedirectUris)
+ // assert.EqualValues(b, subject, pointerx.Deref(rr.Subject))
+ assert.EqualValues(b, []string{"hydra", "offline", "openid"}, rr.RequestedScope)
+ assert.EqualValues(b, r.URL.Query().Get("consent_challenge"), rr.Challenge)
+ assert.Contains(b, *rr.RequestUrl, authURL)
+ if checkRequestPayload != nil {
+ checkRequestPayload(rr)
+ }
+
+ assert.Equal(b, map[string]interface{}{"context": "bar"}, rr.Context)
+ v, _, err := adminClient.OAuth2Api.AcceptOAuth2ConsentRequest(ctx).
+ ConsentChallenge(r.URL.Query().Get("consent_challenge")).
+ AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{
+ GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0),
+ GrantAccessTokenAudience: rr.RequestedAccessTokenAudience,
+ Session: &hydra.AcceptOAuth2ConsentRequestSession{
+ AccessToken: map[string]interface{}{"foo": "bar"},
+ IdToken: map[string]interface{}{"bar": "baz"},
+ },
+ }).
+ Execute()
+ require.NoError(b, err)
+ require.NotEmpty(b, v.RedirectTo)
+ http.Redirect(w, r, v.RedirectTo, http.StatusFound)
+ }), "acceptConsentHandler").ServeHTTP
+ }
+
+ run := func(b *testing.B, strategy string) func(*testing.B) {
+ reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
+ c, conf := newOAuth2Client(b, testhelpers.NewCallbackURL(b, "callback", testhelpers.HTTPServerNotImplementedHandler))
+ testhelpers.NewLoginConsentUI(b, reg.Config(),
+ acceptLoginHandler(b, c, nil),
+ acceptConsentHandler(b, c, nil),
+ )
+
+ return func(b *testing.B) {
+ //pop.Debug = true
+ code, _ := getAuthorizeCode(ctx, b, conf, nil, oauth2.SetAuthURLParam("nonce", nonce))
+ require.NotEmpty(b, code)
+
+ _, err := conf.Exchange(ctx, code)
+ //pop.Debug = false
+ require.NoError(b, err)
+ }
+ }
+
+ b.ResetTimer()
+
+ b.SetParallelism(*conc / runtime.GOMAXPROCS(0))
+
+ b.Run("strategy=jwt", func(b *testing.B) {
+ initialDBSpans := dbSpans(spans)
+ B := run(b, "jwt")
+
+ stop := profile(b)
+ defer stop()
+
+ var totalMS int64 = 0
+ b.RunParallel(func(p *testing.PB) {
+ defer func(t0 time.Time) {
+ atomic.AddInt64(&totalMS, int64(time.Since(t0).Milliseconds()))
+ }(time.Now())
+ for p.Next() {
+ B(b)
+ }
+ })
+
+ b.ReportMetric(0, "ns/op")
+ b.ReportMetric(float64(atomic.LoadInt64(&totalMS))/float64(b.N), "ms/op")
+ b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op")
+ b.ReportMetric(float64(b.N)/b.Elapsed().Seconds(), "ops/s")
+ })
+
+ b.Run("strategy=opaque", func(b *testing.B) {
+ initialDBSpans := dbSpans(spans)
+ B := run(b, "opaque")
+
+ stop := profile(b)
+ defer stop()
+
+ var totalMS int64 = 0
+ b.RunParallel(func(p *testing.PB) {
+ defer func(t0 time.Time) {
+ atomic.AddInt64(&totalMS, int64(time.Since(t0).Milliseconds()))
+ }(time.Now())
+ for p.Next() {
+ B(b)
+ }
+ })
+
+ b.ReportMetric(0, "ns/op")
+ b.ReportMetric(float64(atomic.LoadInt64(&totalMS))/float64(b.N), "ms/op")
+ b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op")
+ b.ReportMetric(float64(b.N)/b.Elapsed().Seconds(), "ops/s")
+ })
+
+}
+
+func profile(t testing.TB) (stop func()) {
+ t.Helper()
+ if *prof == "" {
+ return func() {} // noop
+ }
+ f, err := os.Create(*prof)
+ require.NoError(t, err)
+ require.NoError(t, pprof.StartCPUProfile(f))
+ return func() {
+ pprof.StopCPUProfile()
+ require.NoError(t, f.Close())
+ t.Log("Wrote profile to", f.Name())
+ }
+}
diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go
index b4811fead10..93349df73ea 100644
--- a/oauth2/oauth2_auth_code_test.go
+++ b/oauth2/oauth2_auth_code_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"time"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/x/ioutilx"
"github.com/ory/x/requirex"
@@ -30,7 +31,6 @@ import (
"github.com/pborman/uuid"
"github.com/tidwall/gjson"
- "github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/internal/testhelpers"
"github.com/ory/x/contextx"
@@ -50,7 +50,7 @@ import (
"github.com/ory/x/snapshotx"
)
-func noopHandler(t *testing.T) httprouter.Handle {
+func noopHandler(*testing.T) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
w.WriteHeader(http.StatusNotImplemented)
}
@@ -347,6 +347,108 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
})
})
+ t.Run("suite=invalid query params", func(t *testing.T) {
+ c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
+ otherClient, _ := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler))
+ testhelpers.NewLoginConsentUI(t, reg.Config(),
+ acceptLoginHandler(t, c, subject, nil),
+ acceptConsentHandler(t, c, subject, nil),
+ )
+
+ withWrongClientAfterLogin := &http.Client{
+ Jar: testhelpers.NewEmptyCookieJar(t),
+ CheckRedirect: func(req *http.Request, _ []*http.Request) error {
+ if req.URL.Path != "/oauth2/auth" {
+ return nil
+ }
+ q := req.URL.Query()
+ if !q.Has("login_verifier") {
+ return nil
+ }
+ q.Set("client_id", otherClient.ID.String())
+ req.URL.RawQuery = q.Encode()
+ return nil
+ },
+ }
+ withWrongClientAfterConsent := &http.Client{
+ Jar: testhelpers.NewEmptyCookieJar(t),
+ CheckRedirect: func(req *http.Request, _ []*http.Request) error {
+ if req.URL.Path != "/oauth2/auth" {
+ return nil
+ }
+ q := req.URL.Query()
+ if !q.Has("consent_verifier") {
+ return nil
+ }
+ q.Set("client_id", otherClient.ID.String())
+ req.URL.RawQuery = q.Encode()
+ return nil
+ },
+ }
+
+ withWrongScopeAfterLogin := &http.Client{
+ Jar: testhelpers.NewEmptyCookieJar(t),
+ CheckRedirect: func(req *http.Request, _ []*http.Request) error {
+ if req.URL.Path != "/oauth2/auth" {
+ return nil
+ }
+ q := req.URL.Query()
+ if !q.Has("login_verifier") {
+ return nil
+ }
+ q.Set("scope", "invalid scope")
+ req.URL.RawQuery = q.Encode()
+ return nil
+ },
+ }
+
+ withWrongScopeAfterConsent := &http.Client{
+ Jar: testhelpers.NewEmptyCookieJar(t),
+ CheckRedirect: func(req *http.Request, _ []*http.Request) error {
+ if req.URL.Path != "/oauth2/auth" {
+ return nil
+ }
+ q := req.URL.Query()
+ if !q.Has("consent_verifier") {
+ return nil
+ }
+ q.Set("scope", "invalid scope")
+ req.URL.RawQuery = q.Encode()
+ return nil
+ },
+ }
+
+ for _, tc := range []struct {
+ name string
+ client *http.Client
+ expectedResponse string
+ }{{
+ name: "fails with wrong client ID after login",
+ client: withWrongClientAfterLogin,
+ expectedResponse: "access_denied",
+ }, {
+ name: "fails with wrong client ID after consent",
+ client: withWrongClientAfterConsent,
+ expectedResponse: "invalid_client",
+ }, {
+ name: "fails with wrong scopes after login",
+ client: withWrongScopeAfterLogin,
+ expectedResponse: "invalid_scope",
+ }, {
+ name: "fails with wrong scopes after consent",
+ client: withWrongScopeAfterConsent,
+ expectedResponse: "invalid_scope",
+ }} {
+ t.Run("case="+tc.name, func(t *testing.T) {
+ state := uuid.New()
+ resp, err := tc.client.Get(conf.AuthCodeURL(state))
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedResponse, resp.Request.URL.Query().Get("error"), "%s", resp.Request.URL.RawQuery)
+ resp.Body.Close()
+ })
+ }
+ })
+
t.Run("case=checks if request fails when subject is empty", func(t *testing.T) {
testhelpers.NewLoginConsentUI(t, reg.Config(), func(w http.ResponseWriter, r *http.Request) {
_, res, err := adminClient.OAuth2Api.AcceptOAuth2LoginRequest(ctx).
@@ -702,7 +804,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
}
hookResp := hydraoauth2.TokenHookResponse{
- Session: consent.AcceptOAuth2ConsentRequestSession{
+ Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
@@ -894,8 +996,8 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY")
conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d)
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
- internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName)
- internal.MustEnsureRegistryKeys(reg, x.OAuth2JWTKeyName)
+ internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName)
+ internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName)
consentStrategy := &consentMock{}
router := x.NewRouterPublic()
@@ -1102,7 +1204,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
require.NotEmpty(t, code)
- token, err := oauthConfig.Exchange(oauth2.NoContext, code)
+ token, err := oauthConfig.Exchange(context.TODO(), code)
if tc.expectOAuthTokenError {
require.Error(t, err)
return
@@ -1263,7 +1365,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
}
hookResp := hydraoauth2.TokenHookResponse{
- Session: consent.AcceptOAuth2ConsentRequestSession{
+ Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
@@ -1446,7 +1548,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
})
t.Run("duplicate code exchange fails", func(t *testing.T) {
- token, err := oauthConfig.Exchange(oauth2.NoContext, code)
+ token, err := oauthConfig.Exchange(context.TODO(), code)
require.Error(t, err)
require.Nil(t, token)
})
diff --git a/oauth2/oauth2_client_credentials_bench_test.go b/oauth2/oauth2_client_credentials_bench_test.go
new file mode 100644
index 00000000000..310727f34cc
--- /dev/null
+++ b/oauth2/oauth2_client_credentials_bench_test.go
@@ -0,0 +1,162 @@
+// Copyright © 2022 Ory Corp
+// SPDX-License-Identifier: Apache-2.0
+
+package oauth2_test
+
+import (
+ "context"
+ "encoding/json"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+ "go.opentelemetry.io/otel/sdk/trace"
+ "go.opentelemetry.io/otel/sdk/trace/tracetest"
+ goauth2 "golang.org/x/oauth2"
+ "golang.org/x/oauth2/clientcredentials"
+
+ hc "github.com/ory/hydra/v2/client"
+ "github.com/ory/hydra/v2/driver/config"
+ "github.com/ory/hydra/v2/internal"
+ "github.com/ory/hydra/v2/internal/testhelpers"
+ "github.com/ory/hydra/v2/x"
+ "github.com/ory/x/contextx"
+ "github.com/ory/x/requirex"
+)
+
+func BenchmarkClientCredentials(b *testing.B) {
+ ctx := context.Background()
+
+ spans := tracetest.NewSpanRecorder()
+ tracer := trace.NewTracerProvider(trace.WithSpanProcessor(spans)).Tracer("")
+
+ dsn := "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable"
+ reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer)
+ reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
+ public, admin := testhelpers.NewOAuth2Server(ctx, b, reg)
+
+ var newCustomClient = func(b *testing.B, c *hc.Client) (*hc.Client, clientcredentials.Config) {
+ unhashedSecret := c.Secret
+ require.NoError(b, reg.ClientManager().CreateClient(ctx, c))
+ return c, clientcredentials.Config{
+ ClientID: c.GetID(),
+ ClientSecret: unhashedSecret,
+ TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
+ Scopes: strings.Split(c.Scope, " "),
+ EndpointParams: url.Values{"audience": c.Audience},
+ }
+ }
+
+ var newClient = func(b *testing.B) (*hc.Client, clientcredentials.Config) {
+ cc, config := newCustomClient(b, &hc.Client{
+ Secret: uuid.New().String(),
+ RedirectURIs: []string{public.URL + "/callback"},
+ ResponseTypes: []string{"token"},
+ GrantTypes: []string{"client_credentials"},
+ Scope: "foobar",
+ Audience: []string{"https://api.ory.sh/"},
+ })
+ return cc, config
+ }
+
+ var getToken = func(t *testing.B, conf clientcredentials.Config) (*goauth2.Token, error) {
+ conf.AuthStyle = goauth2.AuthStyleInHeader
+ return conf.Token(context.Background())
+ }
+
+ var encodeOr = func(b *testing.B, val interface{}, or string) string {
+ out, err := json.Marshal(val)
+ require.NoError(b, err)
+ if string(out) == "null" {
+ return or
+ }
+
+ return string(out)
+ }
+
+ var inspectToken = func(b *testing.B, token *goauth2.Token, cl *hc.Client, conf clientcredentials.Config, strategy string, expectedExp time.Time, checkExtraClaims bool) {
+ introspection := testhelpers.IntrospectToken(b, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: conf.ClientSecret}, token.AccessToken, admin)
+
+ check := func(res gjson.Result) {
+ assert.EqualValues(b, cl.GetID(), res.Get("client_id").String(), "%s", res.Raw)
+ assert.EqualValues(b, cl.GetID(), res.Get("sub").String(), "%s", res.Raw)
+ assert.EqualValues(b, reg.Config().IssuerURL(ctx).String(), res.Get("iss").String(), "%s", res.Raw)
+
+ assert.EqualValues(b, res.Get("nbf").Int(), res.Get("iat").Int(), "%s", res.Raw)
+ requirex.EqualTime(b, expectedExp, time.Unix(res.Get("exp").Int(), 0), time.Second)
+
+ assert.EqualValues(b, encodeOr(b, conf.EndpointParams["audience"], "[]"), res.Get("aud").Raw, "%s", res.Raw)
+
+ if checkExtraClaims {
+ require.True(b, res.Get("ext.hooked").Bool())
+ }
+ }
+
+ check(introspection)
+ assert.True(b, introspection.Get("active").Bool())
+ assert.EqualValues(b, "access_token", introspection.Get("token_use").String())
+ assert.EqualValues(b, "Bearer", introspection.Get("token_type").String())
+ assert.EqualValues(b, strings.Join(conf.Scopes, " "), introspection.Get("scope").String(), "%s", introspection.Raw)
+
+ if strategy != "jwt" {
+ return
+ }
+
+ body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1])
+ require.NoError(b, err)
+
+ jwtClaims := gjson.ParseBytes(body)
+ assert.NotEmpty(b, jwtClaims.Get("jti").String())
+ assert.EqualValues(b, encodeOr(b, conf.Scopes, "[]"), jwtClaims.Get("scp").Raw, "%s", introspection.Raw)
+ check(jwtClaims)
+ }
+
+ var getAndInspectToken = func(b *testing.B, cl *hc.Client, conf clientcredentials.Config, strategy string, expectedExp time.Time, checkExtraClaims bool) {
+ token, err := getToken(b, conf)
+ require.NoError(b, err)
+ inspectToken(b, token, cl, conf, strategy, expectedExp, checkExtraClaims)
+ }
+
+ run := func(strategy string) func(b *testing.B) {
+ return func(t *testing.B) {
+ reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
+
+ cl, conf := newClient(b)
+ getAndInspectToken(b, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), false)
+ }
+ }
+
+ b.Run("strategy=jwt", func(b *testing.B) {
+ initialDBSpans := dbSpans(spans)
+ for i := 0; i < b.N; i++ {
+ run("jwt")(b)
+ }
+ b.ReportMetric(0, "ns/op")
+ b.ReportMetric(float64(b.Elapsed().Milliseconds())/float64(b.N), "ms/op")
+ b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op")
+ })
+
+ b.Run("strategy=opaque", func(b *testing.B) {
+ initialDBSpans := dbSpans(spans)
+ for i := 0; i < b.N; i++ {
+ run("opaque")(b)
+ }
+ b.ReportMetric(0, "ns/op")
+ b.ReportMetric(float64(b.Elapsed().Milliseconds())/float64(b.N), "ms/op")
+ b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op")
+ })
+}
+
+func dbSpans(spans *tracetest.SpanRecorder) (count int) {
+ for _, s := range spans.Started() {
+ if strings.HasPrefix(s.Name(), "sql-") {
+ count++
+ }
+ }
+ return
+}
diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go
index 1703bda9c38..059036a2700 100644
--- a/oauth2/oauth2_client_credentials_test.go
+++ b/oauth2/oauth2_client_credentials_test.go
@@ -22,7 +22,7 @@ import (
goauth2 "golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
- "github.com/ory/hydra/v2/consent"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/internal/testhelpers"
hydraoauth2 "github.com/ory/hydra/v2/oauth2"
"github.com/ory/x/contextx"
@@ -276,7 +276,7 @@ func TestClientCredentials(t *testing.T) {
}
hookResp := hydraoauth2.TokenHookResponse{
- Session: consent.AcceptOAuth2ConsentRequestSession{
+ Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
diff --git a/oauth2/oauth2_helper_test.go b/oauth2/oauth2_helper_test.go
index ea679c24189..52a30e5975e 100644
--- a/oauth2/oauth2_helper_test.go
+++ b/oauth2/oauth2_helper_test.go
@@ -11,6 +11,7 @@ import (
"github.com/pkg/errors"
"github.com/ory/fosite"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/x/sqlxx"
"github.com/ory/hydra/v2/client"
@@ -25,27 +26,27 @@ type consentMock struct {
requestTime time.Time
}
-func (c *consentMock) HandleOAuth2AuthorizationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*consent.AcceptOAuth2ConsentRequest, error) {
+func (c *consentMock) HandleOAuth2AuthorizationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) {
if c.deny {
- return nil, fosite.ErrRequestForbidden
+ return nil, nil, fosite.ErrRequestForbidden
}
- return &consent.AcceptOAuth2ConsentRequest{
- ConsentRequest: &consent.OAuth2ConsentRequest{
+ return &flow.AcceptOAuth2ConsentRequest{
+ ConsentRequest: &flow.OAuth2ConsentRequest{
Subject: "foo",
ACR: "1",
},
AuthenticatedAt: sqlxx.NullTime(c.authTime),
GrantedScope: []string{"offline", "openid", "hydra.*"},
- Session: &consent.AcceptOAuth2ConsentRequestSession{
+ Session: &flow.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]interface{}{},
IDToken: map[string]interface{}{},
},
RequestedAt: c.requestTime,
- }, nil
+ }, nil, nil
}
-func (c *consentMock) HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*consent.LogoutResult, error) {
+func (c *consentMock) HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) {
panic("not implemented")
}
diff --git a/oauth2/oauth2_jwt_bearer_test.go b/oauth2/oauth2_jwt_bearer_test.go
index 1aa1f8179ff..b975af21c72 100644
--- a/oauth2/oauth2_jwt_bearer_test.go
+++ b/oauth2/oauth2_jwt_bearer_test.go
@@ -20,7 +20,7 @@ import (
"gopkg.in/square/go-jose.v2"
"github.com/ory/fosite/token/jwt"
- "github.com/ory/hydra/v2/consent"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/jwk"
hydraoauth2 "github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/oauth2/trust"
@@ -342,7 +342,7 @@ func TestJWTBearer(t *testing.T) {
}
hookResp := hydraoauth2.TokenHookResponse{
- Session: consent.AcceptOAuth2ConsentRequestSession{
+ Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
@@ -417,7 +417,7 @@ func TestJWTBearer(t *testing.T) {
}
hookResp := hydraoauth2.TokenHookResponse{
- Session: consent.AcceptOAuth2ConsentRequestSession{
+ Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
diff --git a/oauth2/registry.go b/oauth2/registry.go
index 38ac335bb11..52f9f7bb9bf 100644
--- a/oauth2/registry.go
+++ b/oauth2/registry.go
@@ -6,6 +6,7 @@ package oauth2
import (
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
+ "github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/jwk"
@@ -21,6 +22,7 @@ type InternalRegistry interface {
x.RegistryLogger
consent.Registry
Registry
+ FlowCipher() *aead.XChaCha20Poly1305
}
type Registry interface {
diff --git a/oauth2/revocator_test.go b/oauth2/revocator_test.go
index a2eb5f3d4b3..71b85e63ea2 100644
--- a/oauth2/revocator_test.go
+++ b/oauth2/revocator_test.go
@@ -63,7 +63,7 @@ func TestRevoke(t *testing.T) {
conf := internal.NewConfigurationWithDefaults()
reg := internal.NewRegistryMemory(t, conf, &contextx.Default{})
- internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName)
+ internal.MustEnsureRegistryKeys(context.Background(), reg, x.OpenIDConnectKeyName)
internal.AddFositeExamples(reg)
tokens := Tokens(reg.OAuth2ProviderConfig(), 4)
diff --git a/oauth2/session.go b/oauth2/session.go
index e543a1e123f..3032925fe20 100644
--- a/oauth2/session.go
+++ b/oauth2/session.go
@@ -16,6 +16,7 @@ import (
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/x/stringslice"
)
@@ -29,6 +30,8 @@ type Session struct {
ConsentChallenge string `json:"consent_challenge"`
ExcludeNotBeforeClaim bool `json:"exclude_not_before_claim"`
AllowedTopLevelClaims []string `json:"allowed_top_level_claims"`
+
+ Flow *flow.Flow `json:"-"`
}
func NewSession(subject string) *Session {
diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go
index f7ca4416a71..fc33fe3813d 100644
--- a/oauth2/token_hook.go
+++ b/oauth2/token_hook.go
@@ -12,10 +12,10 @@ import (
"github.com/hashicorp/go-retryablehttp"
+ "github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/x"
"github.com/ory/fosite"
- "github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/x/errorsx"
)
@@ -54,7 +54,7 @@ type TokenHookRequest struct {
// swagger:ignore
type TokenHookResponse struct {
// Session is the session data returned by the hook.
- Session consent.AcceptOAuth2ConsentRequestSession `json:"session"`
+ Session flow.AcceptOAuth2ConsentRequestSession `json:"session"`
}
func executeHookAndUpdateSession(ctx context.Context, reg x.HTTPClientProvider, hookURL *url.URL, reqBodyBytes []byte, session *Session) error {
diff --git a/oauth2/trust/doc.go b/oauth2/trust/doc.go
index c30e9521ac0..16de4977dd3 100644
--- a/oauth2/trust/doc.go
+++ b/oauth2/trust/doc.go
@@ -14,11 +14,15 @@ import (
// OAuth2 JWT Bearer Grant Type Issuer Trust Relationships
//
// swagger:model trustedOAuth2JwtGrantIssuers
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type trustedOAuth2JwtGrantIssuers []trustedOAuth2JwtGrantIssuer
// OAuth2 JWT Bearer Grant Type Issuer Trust Relationship
//
// swagger:model trustedOAuth2JwtGrantIssuer
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type trustedOAuth2JwtGrantIssuer struct {
// example: 9edc811f-4e28-453c-9b46-4de65f00217f
ID string `json:"id"`
@@ -51,6 +55,8 @@ type trustedOAuth2JwtGrantIssuer struct {
// OAuth2 JWT Bearer Grant Type Issuer Trusted JSON Web Key
//
// swagger:model trustedOAuth2JwtGrantJsonWebKey
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type trustedOAuth2JwtGrantJsonWebKey struct {
// The "set" is basically a name for a group(set) of keys. Will be the same as "issuer" in grant.
// example: https://jwt-idp.example.com
diff --git a/oauth2/trust/handler.go b/oauth2/trust/handler.go
index 7bc622e95c0..453ab376975 100644
--- a/oauth2/trust/handler.go
+++ b/oauth2/trust/handler.go
@@ -43,6 +43,8 @@ func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin) {
// Trust OAuth2 JWT Bearer Grant Type Issuer Request Body
//
// swagger:model trustOAuth2JwtGrantIssuer
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type trustOAuth2JwtGrantIssuerBody struct {
// The "issuer" identifies the principal that issued the JWT assertion (same as "iss" claim in JWT).
//
@@ -78,6 +80,8 @@ type trustOAuth2JwtGrantIssuerBody struct {
// Trust OAuth2 JWT Bearer Grant Type Issuer Request
//
// swagger:parameters trustOAuth2JwtGrantIssuer
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type trustOAuth2JwtGrantIssuer struct {
// in: body
Body trustOAuth2JwtGrantIssuerBody
@@ -140,6 +144,8 @@ func (h *Handler) trustOAuth2JwtGrantIssuer(w http.ResponseWriter, r *http.Reque
// Get Trusted OAuth2 JWT Bearer Grant Type Issuer Request
//
// swagger:parameters getTrustedOAuth2JwtGrantIssuer
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type getTrustedOAuth2JwtGrantIssuer struct {
// The id of the desired grant
//
@@ -181,6 +187,8 @@ func (h *Handler) getTrustedOAuth2JwtGrantIssuer(w http.ResponseWriter, r *http.
// Delete Trusted OAuth2 JWT Bearer Grant Type Issuer Request
//
// swagger:parameters deleteTrustedOAuth2JwtGrantIssuer
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type deleteTrustedOAuth2JwtGrantIssuer struct {
// The id of the desired grant
// in: path
@@ -223,6 +231,8 @@ func (h *Handler) deleteTrustedOAuth2JwtGrantIssuer(w http.ResponseWriter, r *ht
// List Trusted OAuth2 JWT Bearer Grant Type Issuers Request
//
// swagger:parameters listTrustedOAuth2JwtGrantIssuers
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type listTrustedOAuth2JwtGrantIssuers struct {
// If optional "issuer" is supplied, only jwt-bearer grants with this issuer will be returned.
//
diff --git a/persistence/sql/migratest/assertion_helpers.go b/persistence/sql/migratest/assertion_helpers.go
index 242f2460891..36f512a2cca 100644
--- a/persistence/sql/migratest/assertion_helpers.go
+++ b/persistence/sql/migratest/assertion_helpers.go
@@ -8,7 +8,7 @@ import (
"time"
"github.com/gofrs/uuid"
- "github.com/instana/testify/require"
+ "github.com/stretchr/testify/require"
"github.com/ory/hydra/v2/flow"
testhelpersuuid "github.com/ory/hydra/v2/internal/testhelpers/uuid"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0001.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0001.json
index 0d4b588349d..d89f26c7d42 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0001.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0001.json
@@ -32,7 +32,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0001",
@@ -52,7 +53,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0001": "0001"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0002.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0002.json
index da822717327..369ba83ba25 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0002.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0002.json
@@ -32,7 +32,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0002",
@@ -52,7 +53,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0002": "0002"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0003.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0003.json
index 0c8587a0383..66718c0ba27 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0003.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0003.json
@@ -32,7 +32,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0003",
@@ -52,7 +53,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0003": "0003"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0004.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0004.json
index 08fbbf88023..e707616aa87 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0004.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0004.json
@@ -34,7 +34,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0004",
@@ -56,7 +57,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0004": "0004"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0005.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0005.json
index 1bebff1778d..fcc4760db32 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0005.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0005.json
@@ -34,7 +34,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0005",
@@ -56,7 +57,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0005": "0005"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0006.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0006.json
index af35899c259..825ca5b9b00 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0006.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0006.json
@@ -34,7 +34,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0006",
@@ -56,7 +57,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0006": "0006"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0007.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0007.json
index 509653dbf89..1d20de4190f 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0007.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0007.json
@@ -34,7 +34,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0007",
@@ -56,7 +57,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0007": "0007"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0008.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0008.json
index 7da6b5b2c10..3ed3dad5245 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0008.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0008.json
@@ -36,7 +36,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0008",
@@ -58,7 +59,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0008": "0008"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0009.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0009.json
index f59ac706aaa..61f8bbabf0c 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0009.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0009.json
@@ -36,7 +36,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0009",
@@ -58,7 +59,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0009": "0009"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0010.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0010.json
index 99135f5f763..a886dd0aefe 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0010.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0010.json
@@ -36,7 +36,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0010",
@@ -58,7 +59,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0010": "0010"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0011.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0011.json
index ab8c93003b7..dda3212a8d7 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0011.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0011.json
@@ -36,7 +36,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0011",
@@ -58,7 +59,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0011": "0011"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0012.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0012.json
index 53c58242a1a..d6491837a10 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0012.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0012.json
@@ -36,7 +36,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0012",
@@ -58,7 +59,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0012": "0012"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0013.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0013.json
index b39ef9aca29..89ca9f7daf4 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0013.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0013.json
@@ -36,7 +36,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0013",
@@ -58,7 +59,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0013": "0013"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0014.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0014.json
index fff06cbd01d..d020259b581 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0014.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0014.json
@@ -36,7 +36,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0014",
@@ -58,7 +59,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0014": "0014"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0015.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0015.json
index 4a013571bed..78ee82f16d5 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0015.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0015.json
@@ -41,7 +41,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0015",
@@ -65,7 +66,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0015": "0015"
diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0016.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0016.json
index 803bab67ce6..e3bddee39a1 100644
--- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0016.json
+++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0016.json
@@ -41,7 +41,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"LoginAuthenticatedAt": null,
"ConsentChallengeID": "challenge-0016",
@@ -65,7 +66,8 @@
"error_description": "",
"error_hint": "",
"status_code": 0,
- "error_debug": ""
+ "error_debug": "",
+ "valid": false
},
"SessionIDToken": {
"session_id_token-0016": "0016"
diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go
index 460ffb910fd..7c4db0c81d2 100644
--- a/persistence/sql/migratest/migration_test.go
+++ b/persistence/sql/migratest/migration_test.go
@@ -18,8 +18,8 @@ import (
"github.com/bradleyjkemp/cupaloy/v2"
"github.com/fatih/structs"
"github.com/gofrs/uuid"
- "github.com/instana/testify/assert"
"github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
"github.com/gobuffalo/pop/v6"
@@ -143,7 +143,7 @@ func TestMigrations(t *testing.T) {
})
t.Run("case=hydra_oauth2_authentication_session", func(t *testing.T) {
- ss := []consent.LoginSession{}
+ ss := []flow.LoginSession{}
c.All(&ss)
require.Equal(t, 16, len(ss))
@@ -168,7 +168,7 @@ func TestMigrations(t *testing.T) {
})
t.Run("case=hydra_oauth2_logout_request", func(t *testing.T) {
- lrs := []consent.LogoutRequest{}
+ lrs := []flow.LogoutRequest{}
c.All(&lrs)
require.Equal(t, 6, len(lrs))
diff --git a/persistence/sql/migrations/20230606112801000001_remove_flow_indices.down.sql b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.down.sql
new file mode 100644
index 00000000000..a391920ba8f
--- /dev/null
+++ b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.down.sql
@@ -0,0 +1,12 @@
+CREATE UNIQUE INDEX hydra_oauth2_flow_login_verifier_idx ON hydra_oauth2_flow (login_verifier);
+CREATE UNIQUE INDEX hydra_oauth2_flow_consent_verifier_idx ON hydra_oauth2_flow (consent_verifier);
+
+CREATE INDEX hydra_oauth2_flow_multi_query_idx
+ ON hydra_oauth2_flow
+ (
+ consent_error ASC, state ASC, subject ASC,
+ client_id ASC, consent_skip ASC, consent_remember
+ ASC, nid ASC
+ );
+
+DROP INDEX hydra_oauth2_flow_previous_consents_idx;
diff --git a/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.down.sql b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.down.sql
new file mode 100644
index 00000000000..16d4e470dae
--- /dev/null
+++ b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.down.sql
@@ -0,0 +1,12 @@
+CREATE UNIQUE INDEX hydra_oauth2_flow_login_verifier_idx ON hydra_oauth2_flow (login_verifier);
+CREATE UNIQUE INDEX hydra_oauth2_flow_consent_verifier_idx ON hydra_oauth2_flow (consent_verifier);
+
+CREATE INDEX hydra_oauth2_flow_multi_query_idx
+ ON hydra_oauth2_flow
+ (
+ consent_error(2) ASC, state ASC, subject ASC,
+ client_id ASC, consent_skip ASC, consent_remember
+ ASC, nid ASC
+ );
+
+DROP INDEX hydra_oauth2_flow_previous_consents_idx ON hydra_oauth2_flow;
diff --git a/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.up.sql b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.up.sql
new file mode 100644
index 00000000000..d7f86b61f94
--- /dev/null
+++ b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.up.sql
@@ -0,0 +1,6 @@
+DROP INDEX hydra_oauth2_flow_login_verifier_idx ON hydra_oauth2_flow;
+DROP INDEX hydra_oauth2_flow_consent_verifier_idx ON hydra_oauth2_flow;
+DROP INDEX hydra_oauth2_flow_multi_query_idx ON hydra_oauth2_flow;
+
+CREATE INDEX hydra_oauth2_flow_previous_consents_idx
+ ON hydra_oauth2_flow (subject, client_id, nid, consent_skip, consent_error(2), consent_remember);
diff --git a/persistence/sql/migrations/20230606112801000001_remove_flow_indices.up.sql b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.up.sql
new file mode 100644
index 00000000000..d522d3482f5
--- /dev/null
+++ b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.up.sql
@@ -0,0 +1,6 @@
+DROP INDEX hydra_oauth2_flow_login_verifier_idx;
+DROP INDEX hydra_oauth2_flow_consent_verifier_idx;
+DROP INDEX hydra_oauth2_flow_multi_query_idx;
+
+CREATE INDEX IF NOT EXISTS hydra_oauth2_flow_previous_consents_idx
+ ON hydra_oauth2_flow (subject, client_id, nid, consent_skip, consent_error, consent_remember);
diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go
index c9f9678ab9a..82971db2a2c 100644
--- a/persistence/sql/persister.go
+++ b/persistence/sql/persister.go
@@ -16,14 +16,15 @@ import (
"github.com/ory/fosite"
"github.com/ory/fosite/storage"
+ "github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/driver/config"
- "github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/persistence"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
"github.com/ory/x/errorsx"
"github.com/ory/x/logrusx"
"github.com/ory/x/networkx"
+ "github.com/ory/x/otelx"
"github.com/ory/x/popx"
)
@@ -48,16 +49,17 @@ type (
}
Dependencies interface {
ClientHasher() fosite.Hasher
- KeyCipher() *jwk.AEAD
+ KeyCipher() *aead.AESGCM
+ FlowCipher() *aead.XChaCha20Poly1305
contextx.Provider
x.RegistryLogger
x.TracingProvider
}
)
-func (p *Persister) BeginTX(ctx context.Context) (context.Context, error) {
+func (p *Persister) BeginTX(ctx context.Context) (_ context.Context, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.BeginTX")
- defer span.End()
+ defer otelx.End(span, &err)
fallback := &pop.Connection{TX: &pop.Tx{}}
if popx.GetConnection(ctx, fallback).TX != fallback.TX {
@@ -77,9 +79,9 @@ func (p *Persister) BeginTX(ctx context.Context) (context.Context, error) {
return popx.WithTransaction(ctx, c), err
}
-func (p *Persister) Commit(ctx context.Context) error {
+func (p *Persister) Commit(ctx context.Context) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Commit")
- defer span.End()
+ defer otelx.End(span, &err)
fallback := &pop.Connection{TX: &pop.Tx{}}
tx := popx.GetConnection(ctx, fallback)
@@ -90,9 +92,9 @@ func (p *Persister) Commit(ctx context.Context) error {
return errorsx.WithStack(tx.TX.Commit())
}
-func (p *Persister) Rollback(ctx context.Context) error {
+func (p *Persister) Rollback(ctx context.Context) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Rollback")
- defer span.End()
+ defer otelx.End(span, &err)
fallback := &pop.Connection{TX: &pop.Tx{}}
tx := popx.GetConnection(ctx, fallback)
diff --git a/persistence/sql/persister_client.go b/persistence/sql/persister_client.go
index 8e45870dab5..482f7126a88 100644
--- a/persistence/sql/persister_client.go
+++ b/persistence/sql/persister_client.go
@@ -6,20 +6,20 @@ package sql
import (
"context"
- "github.com/gofrs/uuid"
-
"github.com/gobuffalo/pop/v6"
+ "github.com/gofrs/uuid"
"github.com/ory/x/errorsx"
+ "github.com/ory/x/otelx"
"github.com/ory/fosite"
"github.com/ory/hydra/v2/client"
"github.com/ory/x/sqlcon"
)
-func (p *Persister) GetConcreteClient(ctx context.Context, id string) (*client.Client, error) {
+func (p *Persister) GetConcreteClient(ctx context.Context, id string) (c *client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteClient")
- defer span.End()
+ defer otelx.End(span, &err)
var cl client.Client
if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&cl); err != nil {
@@ -32,9 +32,9 @@ func (p *Persister) GetClient(ctx context.Context, id string) (fosite.Client, er
return p.GetConcreteClient(ctx, id)
}
-func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) error {
+func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateClient")
- defer span.End()
+ defer otelx.End(span, &err)
return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
o, err := p.GetConcreteClient(ctx, cl.GetID())
@@ -71,9 +71,9 @@ func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) error {
})
}
-func (p *Persister) Authenticate(ctx context.Context, id string, secret []byte) (*client.Client, error) {
+func (p *Persister) Authenticate(ctx context.Context, id string, secret []byte) (_ *client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Authenticate")
- defer span.End()
+ defer otelx.End(span, &err)
c, err := p.GetConcreteClient(ctx, id)
if err != nil {
@@ -87,9 +87,9 @@ func (p *Persister) Authenticate(ctx context.Context, id string, secret []byte)
return c, nil
}
-func (p *Persister) CreateClient(ctx context.Context, c *client.Client) error {
+func (p *Persister) CreateClient(ctx context.Context, c *client.Client) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateClient")
- defer span.End()
+ defer otelx.End(span, &err)
h, err := p.r.ClientHasher().Hash(ctx, []byte(c.Secret))
if err != nil {
@@ -106,11 +106,11 @@ func (p *Persister) CreateClient(ctx context.Context, c *client.Client) error {
return sqlcon.HandleError(p.CreateWithNetwork(ctx, c))
}
-func (p *Persister) DeleteClient(ctx context.Context, id string) error {
+func (p *Persister) DeleteClient(ctx context.Context, id string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteClient")
- defer span.End()
+ defer otelx.End(span, &err)
- _, err := p.GetConcreteClient(ctx, id)
+ _, err = p.GetConcreteClient(ctx, id)
if err != nil {
return err
}
@@ -118,9 +118,9 @@ func (p *Persister) DeleteClient(ctx context.Context, id string) error {
return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("id = ?", id).Delete(&client.Client{}))
}
-func (p *Persister) GetClients(ctx context.Context, filters client.Filter) ([]client.Client, error) {
+func (p *Persister) GetClients(ctx context.Context, filters client.Filter) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetClients")
- defer span.End()
+ defer otelx.End(span, &err)
cs := make([]client.Client, 0)
@@ -141,10 +141,10 @@ func (p *Persister) GetClients(ctx context.Context, filters client.Filter) ([]cl
return cs, nil
}
-func (p *Persister) CountClients(ctx context.Context) (int, error) {
+func (p *Persister) CountClients(ctx context.Context) (n int, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountClients")
- defer span.End()
+ defer otelx.End(span, &err)
- n, err := p.QueryWithNetwork(ctx).Count(&client.Client{})
+ n, err = p.QueryWithNetwork(ctx).Count(&client.Client{})
return n, sqlcon.HandleError(err)
}
diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go
index 8f1fca3d490..bd401d56423 100644
--- a/persistence/sql/persister_consent.go
+++ b/persistence/sql/persister_consent.go
@@ -11,7 +11,9 @@ import (
"time"
"github.com/gobuffalo/pop/v6"
+ "github.com/gofrs/uuid"
+ "github.com/ory/hydra/v2/oauth2/flowctx"
"github.com/ory/x/sqlxx"
"github.com/ory/x/errorsx"
@@ -95,7 +97,7 @@ func (p *Persister) RevokeSubjectLoginSession(ctx context.Context, subject strin
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectLoginSession")
defer span.End()
- err := p.QueryWithNetwork(ctx).Where("subject = ?", subject).Delete(&consent.LoginSession{})
+ err := p.QueryWithNetwork(ctx).Where("subject = ?", subject).Delete(&flow.LoginSession{})
if err != nil {
return sqlcon.HandleError(err)
}
@@ -158,34 +160,22 @@ func (p *Persister) GetForcedObfuscatedLoginSession(ctx context.Context, client,
// CreateConsentRequest configures fields that are introduced or changed in the
// consent request. It doesn't touch fields that would be copied from the login
// request.
-func (p *Persister) CreateConsentRequest(ctx context.Context, req *consent.OAuth2ConsentRequest) error {
+func (p *Persister) CreateConsentRequest(ctx context.Context, f *flow.Flow, req *flow.OAuth2ConsentRequest) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateConsentRequest")
defer span.End()
- c, err := p.Connection(ctx).RawQuery(`
-UPDATE hydra_oauth2_flow
-SET
- state = ?,
- consent_challenge_id = ?,
- consent_skip = ?,
- consent_verifier = ?,
- consent_csrf = ?
-WHERE login_challenge = ? AND nid = ?;
-`,
- flow.FlowStateConsentInitialized,
- sqlxx.NullString(req.ID),
- req.Skip,
- req.Verifier,
- req.CSRF,
- req.LoginChallenge.String(),
- p.NetworkID(ctx),
- ).ExecWithCount()
- if err != nil {
- return sqlcon.HandleError(err)
+ if f == nil {
+ return errorsx.WithStack(x.ErrNotFound.WithDebug("Flow is nil"))
}
- if c != 1 {
+ if f.ID != req.LoginChallenge.String() || f.NID != p.NetworkID(ctx) {
return errorsx.WithStack(x.ErrNotFound)
}
+ f.State = flow.FlowStateConsentInitialized
+ f.ConsentChallengeID = sqlxx.NullString(req.ID)
+ f.ConsentSkip = req.Skip
+ f.ConsentVerifier = sqlxx.NullString(req.Verifier)
+ f.ConsentCSRF = sqlxx.NullString(req.CSRF)
+
return nil
}
@@ -193,16 +183,22 @@ func (p *Persister) GetFlowByConsentChallenge(ctx context.Context, challenge str
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetFlowByConsentChallenge")
defer span.End()
- f := &flow.Flow{}
-
- if err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("consent_challenge_id = ?", challenge).First(f)); err != nil {
- return nil, err
+ // challenge contains the flow.
+ f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), challenge, flowctx.AsConsentChallenge)
+ if err != nil {
+ return nil, errorsx.WithStack(x.ErrNotFound)
+ }
+ if f.NID != p.NetworkID(ctx) {
+ return nil, errorsx.WithStack(x.ErrNotFound)
+ }
+ if f.RequestedAt.Add(p.config.ConsentRequestMaxAge(ctx)).Before(time.Now()) {
+ return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The consent request has expired, please try again."))
}
return f, nil
}
-func (p *Persister) GetConsentRequest(ctx context.Context, challenge string) (*consent.OAuth2ConsentRequest, error) {
+func (p *Persister) GetConsentRequest(ctx context.Context, challenge string) (*flow.OAuth2ConsentRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConsentRequest")
defer span.End()
@@ -214,15 +210,24 @@ func (p *Persister) GetConsentRequest(ctx context.Context, challenge string) (*c
return nil, err
}
+ // We need to overwrite the ID with the encoded flow (challenge) so that the client is not confused.
+ f.ConsentChallengeID = sqlxx.NullString(challenge)
+
return f.GetConsentRequest(), nil
}
-func (p *Persister) CreateLoginRequest(ctx context.Context, req *consent.LoginRequest) error {
+func (p *Persister) CreateLoginRequest(ctx context.Context, req *flow.LoginRequest) (*flow.Flow, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginRequest")
defer span.End()
f := flow.NewFlow(req)
- return sqlcon.HandleError(p.CreateWithNetwork(ctx, f))
+ nid := p.NetworkID(ctx)
+ if nid == uuid.Nil {
+ return nil, errorsx.WithStack(x.ErrNotFound)
+ }
+ f.NID = nid
+
+ return f, nil
}
func (p *Persister) GetFlow(ctx context.Context, loginChallenge string) (*flow.Flow, error) {
@@ -230,130 +235,166 @@ func (p *Persister) GetFlow(ctx context.Context, loginChallenge string) (*flow.F
defer span.End()
var f flow.Flow
- return &f, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- if err := p.QueryWithNetwork(ctx).Where("login_challenge = ?", loginChallenge).First(&f); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return errorsx.WithStack(x.ErrNotFound)
- }
- return sqlcon.HandleError(err)
+ if err := p.QueryWithNetwork(ctx).Where("login_challenge = ?", loginChallenge).First(&f); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errorsx.WithStack(x.ErrNotFound)
}
-
- return nil
- })
+ return nil, sqlcon.HandleError(err)
+ }
+ return &f, nil
}
-func (p *Persister) GetLoginRequest(ctx context.Context, loginChallenge string) (*consent.LoginRequest, error) {
+func (p *Persister) GetLoginRequest(ctx context.Context, loginChallenge string) (*flow.LoginRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetLoginRequest")
defer span.End()
- var lr *consent.LoginRequest
- return lr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- var f flow.Flow
- if err := p.QueryWithNetwork(ctx).Where("login_challenge = ?", loginChallenge).First(&f); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return errorsx.WithStack(x.ErrNotFound)
- }
- return sqlcon.HandleError(err)
- }
- lr = f.GetLoginRequest()
+ f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), loginChallenge, flowctx.AsLoginChallenge)
+ if err != nil {
+ return nil, errorsx.WithStack(x.ErrNotFound.WithWrap(err))
+ }
+ if f.NID != p.NetworkID(ctx) {
+ return nil, errorsx.WithStack(x.ErrNotFound)
+ }
+ if f.RequestedAt.Add(p.config.ConsentRequestMaxAge(ctx)).Before(time.Now()) {
+ return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The login request has expired, please try again."))
+ }
+ lr := f.GetLoginRequest()
+ // Restore the short challenge ID, which was previously sent to the encoded flow,
+ // to make sure that the challenge ID in the returned flow matches the param.
+ lr.ID = loginChallenge
- return nil
- })
+ return lr, nil
}
-func (p *Persister) HandleConsentRequest(ctx context.Context, r *consent.AcceptOAuth2ConsentRequest) (*consent.OAuth2ConsentRequest, error) {
+func (p *Persister) HandleConsentRequest(ctx context.Context, f *flow.Flow, r *flow.AcceptOAuth2ConsentRequest) (*flow.OAuth2ConsentRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HandleConsentRequest")
defer span.End()
- f := &flow.Flow{}
-
- if err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("consent_challenge_id = ?", r.ID).First(f)); errors.Is(err, sqlcon.ErrNoRows) {
- return nil, err
+ if f == nil {
+ return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil"))
}
-
+ if f.NID != p.NetworkID(ctx) {
+ return nil, errorsx.WithStack(x.ErrNotFound)
+ }
+ // Restore the short challenge ID, which was previously sent to the encoded flow,
+ // to make sure that the challenge ID in the returned flow matches the param.
+ r.ID = f.ConsentChallengeID.String()
if err := f.HandleConsentRequest(r); err != nil {
return nil, errorsx.WithStack(err)
}
- _, err := p.UpdateWithNetwork(ctx, f)
- if err != nil {
- return nil, sqlcon.HandleError(err)
- }
-
- return p.GetConsentRequest(ctx, r.ID)
+ return f.GetConsentRequest(), nil
}
-func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*consent.AcceptOAuth2ConsentRequest, error) {
+func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, f *flow.Flow, verifier string) (*flow.AcceptOAuth2ConsentRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateConsentRequest")
defer span.End()
- var r consent.AcceptOAuth2ConsentRequest
- return &r, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- var f flow.Flow
- if err := p.QueryWithNetwork(ctx).Where("consent_verifier = ?", verifier).First(&f); err != nil {
- return sqlcon.HandleError(err)
- }
+ if f == nil {
+ return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil"))
+ }
+ if f.NID != p.NetworkID(ctx) {
+ return nil, errorsx.WithStack(sqlcon.ErrNoRows)
+ }
- if err := f.InvalidateConsentRequest(); err != nil {
- return errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))
- }
+ updatedFlow, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), verifier, flowctx.AsConsentVerifier)
+ if err != nil {
+ return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The consent verifier has already been used, has not been granted, or is invalid."))
+ }
+ if updatedFlow.ID != f.ID {
+ return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Consent verifier does not match login request."))
+ }
+ if updatedFlow.NID != p.NetworkID(ctx) {
+ return nil, errorsx.WithStack(sqlcon.ErrNoRows)
+ }
- r = *f.GetHandledConsentRequest()
- _, err := p.UpdateWithNetwork(ctx, &f)
- return err
- })
+ // Update flow from login request, but keep requested at.
+ updatedFlow.NID = f.NID
+ updatedFlow.ConsentCSRF = f.ConsentCSRF
+ updatedFlow.ConsentVerifier = f.ConsentVerifier
+ *f = *updatedFlow
+
+ if err = f.InvalidateConsentRequest(); err != nil {
+ return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))
+ }
+
+ // We set the consent challenge ID to a new UUID that we can use as a foreign key in the database
+ // without encoding the whole flow.
+ f.ConsentChallengeID = sqlxx.NullString(uuid.Must(uuid.NewV4()).String())
+
+ if err = p.Connection(ctx).Create(f); err != nil {
+ return nil, sqlcon.HandleError(err)
+ }
+
+ return f.GetHandledConsentRequest(), nil
}
-func (p *Persister) HandleLoginRequest(ctx context.Context, challenge string, r *consent.HandledLoginRequest) (lr *consent.LoginRequest, err error) {
+func (p *Persister) HandleLoginRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledLoginRequest) (lr *flow.LoginRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HandleLoginRequest")
defer span.End()
- return lr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- f, err := p.GetFlow(ctx, challenge)
- if err != nil {
- return sqlcon.HandleError(err)
- }
- err = f.HandleLoginRequest(r)
- if err != nil {
- return err
- }
-
- _, err = p.UpdateWithNetwork(ctx, f)
- if err != nil {
- return sqlcon.HandleError(err)
- }
+ if f == nil {
+ return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil"))
+ }
+ if f.NID != p.NetworkID(ctx) {
+ return nil, errorsx.WithStack(x.ErrNotFound)
+ }
+ r.ID = f.ID
+ err = f.HandleLoginRequest(r)
+ if err != nil {
+ return nil, err
+ }
- lr, err = p.GetLoginRequest(ctx, challenge)
- return sqlcon.HandleError(err)
- })
+ return p.GetLoginRequest(ctx, challenge)
}
-func (p *Persister) VerifyAndInvalidateLoginRequest(ctx context.Context, verifier string) (*consent.HandledLoginRequest, error) {
+func (p *Persister) VerifyAndInvalidateLoginRequest(ctx context.Context, f *flow.Flow, verifier string) (*flow.HandledLoginRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateLoginRequest")
defer span.End()
- var d consent.HandledLoginRequest
- return &d, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- var f flow.Flow
- if err := p.QueryWithNetwork(ctx).Where("login_verifier = ?", verifier).First(&f); err != nil {
- return sqlcon.HandleError(err)
- }
+ if f == nil {
+ return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil"))
+ }
+ if f.NID != p.NetworkID(ctx) {
+ return nil, errorsx.WithStack(sqlcon.ErrNoRows)
+ }
- if err := f.InvalidateLoginRequest(); err != nil {
- return errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))
- }
+ updatedFlow, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), verifier, flowctx.AsLoginVerifier)
+ if err != nil {
+ return nil, errorsx.WithStack(sqlcon.ErrNoRows)
+ }
+ if f.NID != updatedFlow.NID {
+ return nil, errorsx.WithStack(sqlcon.ErrNoRows)
+ }
- d = f.GetHandledLoginRequest()
- _, err := p.UpdateWithNetwork(ctx, &f)
- return sqlcon.HandleError(err)
- })
+ if updatedFlow.ID != f.ID {
+ return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Login verifier does not match login request."))
+ }
+
+ // Update flow from login request, but keep requested at.
+ updatedFlow.NID = f.NID
+ updatedFlow.RequestedAt = f.RequestedAt
+ updatedFlow.LoginCSRF = f.LoginCSRF
+ updatedFlow.LoginVerifier = f.LoginVerifier
+ *f = *updatedFlow
+
+ if err := f.InvalidateLoginRequest(); err != nil {
+ return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))
+ }
+ d := f.GetHandledLoginRequest()
+
+ return &d, nil
}
-func (p *Persister) GetRememberedLoginSession(ctx context.Context, id string) (*consent.LoginSession, error) {
+func (p *Persister) GetRememberedLoginSession(ctx context.Context, loginSessionFromCookie *flow.LoginSession, id string) (*flow.LoginSession, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRememberedLoginSession")
defer span.End()
- var s consent.LoginSession
+ if s := loginSessionFromCookie; s != nil && s.NID == p.NetworkID(ctx) && s.ID == id && s.Remember {
+ return s, nil
+ }
+
+ var s flow.LoginSession
if err := p.QueryWithNetwork(ctx).Where("remember = TRUE").Find(&s, id); errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(x.ErrNotFound)
@@ -364,30 +405,56 @@ func (p *Persister) GetRememberedLoginSession(ctx context.Context, id string) (*
return &s, nil
}
-func (p *Persister) ConfirmLoginSession(ctx context.Context, id string, authenticatedAt time.Time, subject string, remember bool) error {
+func (p *Persister) ConfirmLoginSession(ctx context.Context, session *flow.LoginSession, id string, authenticatedAt time.Time, subject string, remember bool) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ConfirmLoginSession")
defer span.End()
- _, err := p.Connection(ctx).Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).UpdateQuery(&consent.LoginSession{
+ // Since we previously cached the login session, we now need to persist it to db.
+ if session != nil {
+ if session.NID != p.NetworkID(ctx) || session.ID != id {
+ return errorsx.WithStack(x.ErrNotFound)
+ }
+ session.AuthenticatedAt = sqlxx.NullTime(authenticatedAt.Truncate(time.Second))
+ session.Subject = subject
+ session.Remember = remember
+
+ return p.CreateWithNetwork(ctx, session)
+ }
+
+ // In some unit tests, we still confirm the login session without data from the cookie. We can remove this case
+ // once all tests are fixed.
+ n, err := p.Connection(ctx).Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).UpdateQuery(&flow.LoginSession{
AuthenticatedAt: sqlxx.NullTime(authenticatedAt),
Subject: subject,
Remember: remember,
}, "authenticated_at", "subject", "remember")
- return sqlcon.HandleError(err)
+ if err != nil {
+ return sqlcon.HandleError(err)
+ }
+ if n == 0 {
+ return errorsx.WithStack(x.ErrNotFound)
+ }
+ return nil
}
-func (p *Persister) CreateLoginSession(ctx context.Context, session *consent.LoginSession) error {
+func (p *Persister) CreateLoginSession(ctx context.Context, session *flow.LoginSession) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginSession")
defer span.End()
- return sqlcon.HandleError(p.CreateWithNetwork(ctx, session))
+ nid := p.NetworkID(ctx)
+ if nid == uuid.Nil {
+ return errorsx.WithStack(x.ErrNotFound)
+ }
+ session.NID = nid
+
+ return nil
}
func (p *Persister) DeleteLoginSession(ctx context.Context, id string) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginSession")
defer span.End()
- count, err := p.Connection(ctx).RawQuery("DELETE FROM hydra_oauth2_authentication_session WHERE id=? AND nid = ?", id, p.NetworkID(ctx)).ExecWithCount()
+ count, err := p.Connection(ctx).RawQuery("DELETE FROM hydra_oauth2_authentication_session WHERE id=? AND nid=?", id, p.NetworkID(ctx)).ExecWithCount()
if count == 0 {
return errorsx.WithStack(x.ErrNotFound)
} else {
@@ -395,18 +462,14 @@ func (p *Persister) DeleteLoginSession(ctx context.Context, id string) error {
}
}
-func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) ([]consent.AcceptOAuth2ConsentRequest, error) {
+func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) (rs []flow.AcceptOAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequests")
defer span.End()
- rs := make([]consent.AcceptOAuth2ConsentRequest, 0)
-
- return rs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- f := &flow.Flow{}
-
- if err := c.
- Where(
- strings.TrimSpace(fmt.Sprintf(`
+ var f flow.Flow
+ if err = p.Connection(ctx).
+ Where(
+ strings.TrimSpace(fmt.Sprintf(`
(state = %d OR state = %d) AND
subject = ? AND
client_id = ? AND
@@ -414,24 +477,21 @@ consent_skip=FALSE AND
consent_error='{}' AND
consent_remember=TRUE AND
nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
- )),
- subject, client, p.NetworkID(ctx)).
- Order("requested_at DESC").
- Limit(1).
- First(f); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return errorsx.WithStack(consent.ErrNoPreviousConsentFound)
- }
- return sqlcon.HandleError(err)
+ )),
+ subject, client, p.NetworkID(ctx)).
+ Order("requested_at DESC").
+ Limit(1).
+ First(&f); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errorsx.WithStack(consent.ErrNoPreviousConsentFound)
}
+ return nil, sqlcon.HandleError(err)
+ }
- var err error
- rs, err = p.filterExpiredConsentRequests(ctx, []consent.AcceptOAuth2ConsentRequest{*f.GetHandledConsentRequest()})
- return err
- })
+ return p.filterExpiredConsentRequests(ctx, []flow.AcceptOAuth2ConsentRequest{*f.GetHandledConsentRequest()})
}
-func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, limit, offset int) ([]consent.AcceptOAuth2ConsentRequest, error) {
+func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, limit, offset int) ([]flow.AcceptOAuth2ConsentRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsGrantedConsentRequests")
defer span.End()
@@ -457,7 +517,7 @@ nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
return nil, sqlcon.HandleError(err)
}
- var rs []consent.AcceptOAuth2ConsentRequest
+ var rs []flow.AcceptOAuth2ConsentRequest
for _, f := range fs {
rs = append(rs, *f.GetHandledConsentRequest())
}
@@ -465,7 +525,7 @@ nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
return p.filterExpiredConsentRequests(ctx, rs)
}
-func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, limit, offset int) ([]consent.AcceptOAuth2ConsentRequest, error) {
+func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, limit, offset int) ([]flow.AcceptOAuth2ConsentRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsSessionGrantedConsentRequests")
defer span.End()
@@ -492,7 +552,7 @@ nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
return nil, sqlcon.HandleError(err)
}
- var rs []consent.AcceptOAuth2ConsentRequest
+ var rs []flow.AcceptOAuth2ConsentRequest
for _, f := range fs {
rs = append(rs, *f.GetHandledConsentRequest())
}
@@ -518,11 +578,11 @@ nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
return n, sqlcon.HandleError(err)
}
-func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests []consent.AcceptOAuth2ConsentRequest) ([]consent.AcceptOAuth2ConsentRequest, error) {
+func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests []flow.AcceptOAuth2ConsentRequest) ([]flow.AcceptOAuth2ConsentRequest, error) {
_, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.filterExpiredConsentRequests")
defer span.End()
- var result []consent.AcceptOAuth2ConsentRequest
+ var result []flow.AcceptOAuth2ConsentRequest
for _, v := range requests {
if v.RememberFor > 0 && v.RequestedAt.Add(time.Duration(v.RememberFor)*time.Second).Before(time.Now().UTC()) {
continue
@@ -553,10 +613,9 @@ func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, s
defer span.End()
var cs []client.Client
- return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- if err := c.RawQuery(
- /* #nosec G201 - channel can either be "front" or "back" */
- fmt.Sprintf(`
+ if err := p.Connection(ctx).RawQuery(
+ /* #nosec G201 - channel can either be "front" or "back" */
+ fmt.Sprintf(`
SELECT DISTINCT c.* FROM hydra_client as c
JOIN hydra_oauth2_flow as f ON (c.id = f.client_id)
WHERE
@@ -566,29 +625,28 @@ WHERE
f.login_session_id = ? AND
f.nid = ? AND
c.nid = ?`,
- channel,
- channel,
- ),
- subject,
- sid,
- p.NetworkID(ctx),
- p.NetworkID(ctx),
- ).All(&cs); err != nil {
- return sqlcon.HandleError(err)
- }
+ channel,
+ channel,
+ ),
+ subject,
+ sid,
+ p.NetworkID(ctx),
+ p.NetworkID(ctx),
+ ).All(&cs); err != nil {
+ return nil, sqlcon.HandleError(err)
+ }
- return nil
- })
+ return cs, nil
}
-func (p *Persister) CreateLogoutRequest(ctx context.Context, request *consent.LogoutRequest) error {
+func (p *Persister) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLogoutRequest")
defer span.End()
return errorsx.WithStack(p.CreateWithNetwork(ctx, request))
}
-func (p *Persister) AcceptLogoutRequest(ctx context.Context, challenge string) (*consent.LogoutRequest, error) {
+func (p *Persister) AcceptLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AcceptLogoutRequest")
defer span.End()
@@ -613,37 +671,35 @@ func (p *Persister) RejectLogoutRequest(ctx context.Context, challenge string) e
}
}
-func (p *Persister) GetLogoutRequest(ctx context.Context, challenge string) (*consent.LogoutRequest, error) {
+func (p *Persister) GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetLogoutRequest")
defer span.End()
- var lr consent.LogoutRequest
+ var lr flow.LogoutRequest
return &lr, sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("challenge = ? AND rejected = FALSE", challenge).First(&lr))
}
-func (p *Persister) VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*consent.LogoutRequest, error) {
+func (p *Persister) VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*flow.LogoutRequest, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateLogoutRequest")
defer span.End()
- var lr consent.LogoutRequest
- return &lr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- if count, err := c.RawQuery(
- "UPDATE hydra_oauth2_logout_request SET was_used=TRUE WHERE nid = ? AND verifier=? AND was_used=FALSE AND accepted=TRUE AND rejected=FALSE",
- p.NetworkID(ctx),
- verifier,
- ).ExecWithCount(); count == 0 && err == nil {
- return errorsx.WithStack(x.ErrNotFound)
- } else if err != nil {
- return sqlcon.HandleError(err)
- }
+ var lr flow.LogoutRequest
+ if count, err := p.Connection(ctx).RawQuery(
+ "UPDATE hydra_oauth2_logout_request SET was_used=TRUE WHERE nid = ? AND verifier=? AND was_used=FALSE AND accepted=TRUE AND rejected=FALSE",
+ p.NetworkID(ctx),
+ verifier,
+ ).ExecWithCount(); count == 0 && err == nil {
+ return nil, errorsx.WithStack(x.ErrNotFound)
+ } else if err != nil {
+ return nil, sqlcon.HandleError(err)
+ }
- err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier=?", verifier).First(&lr))
- if err != nil {
- return err
- }
+ err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier=?", verifier).First(&lr))
+ if err != nil {
+ return nil, err
+ }
- return nil
- })
+ return &lr, nil
}
func (p *Persister) FlushInactiveLoginConsentRequests(ctx context.Context, notAfter time.Time, limit int, batchSize int) error {
diff --git a/persistence/sql/persister_jwk.go b/persistence/sql/persister_jwk.go
index fe8041326b8..92eb3cf9cea 100644
--- a/persistence/sql/persister_jwk.go
+++ b/persistence/sql/persister_jwk.go
@@ -47,7 +47,7 @@ func (p *Persister) AddKey(ctx context.Context, set string, key *jose.JSONWebKey
return errorsx.WithStack(err)
}
- encrypted, err := p.r.KeyCipher().Encrypt(ctx, out)
+ encrypted, err := p.r.KeyCipher().Encrypt(ctx, out, nil)
if err != nil {
return errorsx.WithStack(err)
}
@@ -71,7 +71,7 @@ func (p *Persister) AddKeySet(ctx context.Context, set string, keys *jose.JSONWe
return errorsx.WithStack(err)
}
- encrypted, err := p.r.KeyCipher().Encrypt(ctx, out)
+ encrypted, err := p.r.KeyCipher().Encrypt(ctx, out, nil)
if err != nil {
return err
}
@@ -133,7 +133,7 @@ func (p *Persister) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebK
return nil, sqlcon.HandleError(err)
}
- key, err := p.r.KeyCipher().Decrypt(ctx, j.Key)
+ key, err := p.r.KeyCipher().Decrypt(ctx, j.Key, nil)
if err != nil {
return nil, errorsx.WithStack(err)
}
@@ -148,7 +148,7 @@ func (p *Persister) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebK
}, nil
}
-func (p *Persister) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) {
+func (p *Persister) GetKeySet(ctx context.Context, set string) (keys *jose.JSONWebKeySet, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetKeySet")
defer span.End()
@@ -164,9 +164,9 @@ func (p *Persister) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKey
return nil, errors.Wrap(x.ErrNotFound, "")
}
- keys := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}}
+ keys = &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}}
for _, d := range js {
- key, err := p.r.KeyCipher().Decrypt(ctx, d.Key)
+ key, err := p.r.KeyCipher().Decrypt(ctx, d.Key, nil)
if err != nil {
return nil, errorsx.WithStack(err)
}
diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go
index 29f4e613f72..cf36ba61c45 100644
--- a/persistence/sql/persister_nid_test.go
+++ b/persistence/sql/persister_nid_test.go
@@ -6,15 +6,17 @@ package sql_test
import (
"context"
"database/sql"
+ "encoding/json"
"testing"
"time"
+ "github.com/ory/hydra/v2/persistence"
"github.com/ory/x/uuidx"
"github.com/ory/x/assertx"
"github.com/gofrs/uuid"
- "github.com/instana/testify/require"
+ "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"gopkg.in/square/go-jose.v2"
@@ -88,7 +90,7 @@ func (s *PersisterTestSuite) TestAcceptLogoutRequest() {
lrAccepted, err := r.ConsentManager().AcceptLogoutRequest(s.t2, lr.ID)
require.Error(t, err)
- require.Equal(t, &consent.LogoutRequest{}, lrAccepted)
+ require.Equal(t, &flow.LogoutRequest{}, lrAccepted)
actual, err := r.ConsentManager().GetLogoutRequest(s.t1, lr.ID)
require.NoError(t, err)
@@ -182,17 +184,20 @@ func (s *PersisterTestSuite) TestConfirmLoginSession() {
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
require.NoError(t, r.Persister().CreateLoginSession(s.t1, ls))
- expected := &consent.LoginSession{}
- require.NoError(t, r.Persister().Connection(context.Background()).Find(expected, ls.ID))
- require.NoError(t, r.Persister().ConfirmLoginSession(s.t2, expected.ID, time.Now(), expected.Subject, !expected.Remember))
- actual := &consent.LoginSession{}
+ // Expects the login session to be confirmed in the correct context.
+ require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember))
+ actual := &flow.LoginSession{}
require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID))
- require.Equal(t, expected, actual)
+ exp, _ := json.Marshal(ls)
+ act, _ := json.Marshal(actual)
+ require.JSONEq(t, string(exp), string(act))
- require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, expected.ID, time.Now(), expected.Subject, !expected.Remember))
- require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID))
- require.NotEqual(t, expected, actual)
+ // Can't find the login session in the wrong context.
+ require.ErrorIs(t,
+ r.Persister().ConfirmLoginSession(s.t2, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember),
+ x.ErrNotFound,
+ )
})
}
}
@@ -202,8 +207,8 @@ func (s *PersisterTestSuite) TestCreateSession() {
ls := newLoginSession()
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, ls))
- actual := &consent.LoginSession{}
+ persistLoginSession(s.t1, t, r.Persister(), ls)
+ actual := &flow.LoginSession{}
require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID))
require.Equal(t, s.t1NID, actual.NID)
ls.NID = actual.NID
@@ -280,12 +285,12 @@ func (s *PersisterTestSuite) TestCountSubjectsGrantedConsentRequests() {
require.Equal(t, 0, count)
sessionID := uuid.Must(uuid.NewV4()).String()
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
client := &client.Client{LegacyClientID: "client-id"}
require.NoError(t, r.Persister().CreateClient(s.t1, client))
f := newFlow(s.t1NID, client.LegacyClientID, sub, sqlxx.NullString(sessionID))
f.ConsentSkip = false
- f.ConsentError = &consent.RequestDeniedError{}
+ f.ConsentError = &flow.RequestDeniedError{}
f.State = flow.FlowStateConsentUnused
require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
@@ -359,18 +364,18 @@ func (s *PersisterTestSuite) TestCreateConsentRequest() {
sessionID := uuid.Must(uuid.NewV4()).String()
client := &client.Client{LegacyClientID: "client-id"}
f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID))
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, client))
require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
- req := &consent.OAuth2ConsentRequest{
+ req := &flow.OAuth2ConsentRequest{
ID: "consent-request-id",
LoginChallenge: sqlxx.NullString(f.ID),
Skip: false,
Verifier: "verifier",
CSRF: "csrf",
}
- require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req))
actual := flow.Flow{}
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, f.ID))
@@ -418,12 +423,11 @@ func (s *PersisterTestSuite) TestCreateLoginRequest() {
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
client := &client.Client{LegacyClientID: "client-id"}
- lr := consent.LoginRequest{ID: "lr-id", ClientID: client.LegacyClientID, RequestedAt: time.Now()}
+ lr := flow.LoginRequest{ID: "lr-id", ClientID: client.LegacyClientID, RequestedAt: time.Now()}
require.NoError(t, r.Persister().CreateClient(s.t1, client))
- require.NoError(t, r.ConsentManager().CreateLoginRequest(s.t1, &lr))
- f := flow.Flow{}
- require.NoError(t, r.Persister().Connection(context.Background()).Find(&f, lr.ID))
+ f, err := r.ConsentManager().CreateLoginRequest(s.t1, &lr)
+ require.NoError(t, err)
require.Equal(t, s.t1NID, f.NID)
})
}
@@ -433,9 +437,9 @@ func (s *PersisterTestSuite) TestCreateLoginSession() {
t := s.T()
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
- ls := consent.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true}
+ ls := flow.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true}
require.NoError(t, r.Persister().CreateLoginSession(s.t1, &ls))
- actual, err := r.Persister().GetRememberedLoginSession(s.t1, ls.ID)
+ actual, err := r.Persister().GetRememberedLoginSession(s.t1, &ls, ls.ID)
require.NoError(t, err)
require.Equal(t, s.t1NID, actual.NID)
})
@@ -447,7 +451,7 @@ func (s *PersisterTestSuite) TestCreateLogoutRequest() {
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
client := &client.Client{LegacyClientID: "client-id"}
- lr := consent.LogoutRequest{
+ lr := flow.LogoutRequest{
// TODO there is not FK for SessionID so we don't need it here; TODO make sure the missing FK is intentional
ID: uuid.Must(uuid.NewV4()).String(),
ClientID: sql.NullString{Valid: true, String: client.LegacyClientID},
@@ -626,15 +630,15 @@ func (s *PersisterTestSuite) TestDeleteLoginSession() {
t := s.T()
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
- ls := consent.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true}
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &ls))
+ ls := flow.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true}
+ persistLoginSession(s.t1, t, r.Persister(), &ls)
require.Error(t, r.Persister().DeleteLoginSession(s.t2, ls.ID))
- _, err := r.Persister().GetRememberedLoginSession(s.t1, ls.ID)
+ _, err := r.Persister().GetRememberedLoginSession(s.t1, nil, ls.ID)
require.NoError(t, err)
require.NoError(t, r.Persister().DeleteLoginSession(s.t1, ls.ID))
- _, err = r.Persister().GetRememberedLoginSession(s.t1, ls.ID)
+ _, err = r.Persister().GetRememberedLoginSession(s.t1, nil, ls.ID)
require.Error(t, err)
})
}
@@ -734,11 +738,10 @@ func (s *PersisterTestSuite) TestFindGrantedAndRememberedConsentRequests() {
sessionID := uuid.Must(uuid.NewV4()).String()
client := &client.Client{LegacyClientID: "client-id"}
f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID))
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, client))
- require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
- req := &consent.OAuth2ConsentRequest{
+ req := &flow.OAuth2ConsentRequest{
ID: "consent-request-id",
LoginChallenge: sqlxx.NullString(f.ID),
Skip: false,
@@ -746,14 +749,15 @@ func (s *PersisterTestSuite) TestFindGrantedAndRememberedConsentRequests() {
CSRF: "csrf",
}
- hcr := &consent.AcceptOAuth2ConsentRequest{
+ hcr := &flow.AcceptOAuth2ConsentRequest{
ID: req.ID,
HandledAt: sqlxx.NullTime(time.Now()),
Remember: true,
}
- require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req))
- _, err := r.Persister().HandleConsentRequest(s.t1, hcr)
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req))
+ _, err := r.Persister().HandleConsentRequest(s.t1, f, hcr)
require.NoError(t, err)
+ require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
actual, err := r.Persister().FindGrantedAndRememberedConsentRequests(s.t2, client.LegacyClientID, f.Subject)
require.Error(t, err)
@@ -773,11 +777,11 @@ func (s *PersisterTestSuite) TestFindSubjectsGrantedConsentRequests() {
sessionID := uuid.Must(uuid.NewV4()).String()
client := &client.Client{LegacyClientID: "client-id"}
f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID))
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, client))
require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
- req := &consent.OAuth2ConsentRequest{
+ req := &flow.OAuth2ConsentRequest{
ID: "consent-request-id",
LoginChallenge: sqlxx.NullString(f.ID),
Skip: false,
@@ -785,13 +789,13 @@ func (s *PersisterTestSuite) TestFindSubjectsGrantedConsentRequests() {
CSRF: "csrf",
}
- hcr := &consent.AcceptOAuth2ConsentRequest{
+ hcr := &flow.AcceptOAuth2ConsentRequest{
ID: req.ID,
HandledAt: sqlxx.NullTime(time.Now()),
Remember: true,
}
- require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req))
- _, err := r.Persister().HandleConsentRequest(s.t1, hcr)
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req))
+ _, err := r.Persister().HandleConsentRequest(s.t1, f, hcr)
require.NoError(t, err)
actual, err := r.Persister().FindSubjectsGrantedConsentRequests(s.t2, f.Subject, 100, 0)
@@ -876,7 +880,7 @@ func (s *PersisterTestSuite) TestFlushInactiveLoginConsentRequests() {
client := &client.Client{LegacyClientID: "client-id"}
f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID))
f.RequestedAt = time.Now().Add(-24 * time.Hour)
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, client))
require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
@@ -1056,18 +1060,18 @@ func (s *PersisterTestSuite) TestGetConsentRequest() {
sessionID := uuid.Must(uuid.NewV4()).String()
client := &client.Client{LegacyClientID: "client-id"}
f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID))
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, client))
require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
- req := &consent.OAuth2ConsentRequest{
- ID: "consent-request-id",
+ req := &flow.OAuth2ConsentRequest{
+ ID: x.Must(f.ToConsentChallenge(s.t1, r)),
LoginChallenge: sqlxx.NullString(f.ID),
Skip: false,
Verifier: "verifier",
CSRF: "csrf",
}
- require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req))
actual, err := r.Persister().GetConsentRequest(s.t2, req.ID)
require.Error(t, err)
@@ -1087,7 +1091,7 @@ func (s *PersisterTestSuite) TestGetFlow() {
sessionID := uuid.Must(uuid.NewV4()).String()
client := &client.Client{LegacyClientID: "client-id"}
f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID))
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, client))
require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
@@ -1112,19 +1116,20 @@ func (s *PersisterTestSuite) TestGetFlowByConsentChallenge() {
sessionID := uuid.Must(uuid.NewV4()).String()
client := &client.Client{LegacyClientID: "client-id"}
f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID))
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ require.NoError(t, r.Persister().CreateLoginSession(s.t1, &flow.LoginSession{ID: sessionID}))
require.NoError(t, r.Persister().CreateClient(s.t1, client))
- require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
store, ok := r.Persister().(*persistencesql.Persister)
if !ok {
t.Fatal("type assertion failed")
}
- _, err := store.GetFlowByConsentChallenge(s.t2, f.ConsentChallengeID.String())
+ challenge := x.Must(f.ToConsentChallenge(s.t1, r))
+
+ _, err := store.GetFlowByConsentChallenge(s.t2, challenge)
require.Error(t, err)
- _, err = store.GetFlowByConsentChallenge(s.t1, f.ConsentChallengeID.String())
+ _, err = store.GetFlowByConsentChallenge(s.t1, challenge)
require.NoError(t, err)
})
}
@@ -1179,19 +1184,20 @@ func (s *PersisterTestSuite) TestGetLoginRequest() {
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
client := &client.Client{LegacyClientID: "client-id"}
- lr := consent.LoginRequest{ID: "lr-id", ClientID: client.LegacyClientID, RequestedAt: time.Now()}
+ lr := flow.LoginRequest{ID: "lr-id", ClientID: client.LegacyClientID, RequestedAt: time.Now()}
require.NoError(t, r.Persister().CreateClient(s.t1, client))
- require.NoError(t, r.ConsentManager().CreateLoginRequest(s.t1, &lr))
- f := flow.Flow{}
- require.NoError(t, r.Persister().Connection(context.Background()).Find(&f, lr.ID))
+ f, err := r.ConsentManager().CreateLoginRequest(s.t1, &lr)
+ require.NoError(t, err)
require.Equal(t, s.t1NID, f.NID)
- actual, err := r.Persister().GetLoginRequest(s.t2, lr.ID)
+ challenge := x.Must(f.ToLoginChallenge(s.t1, r))
+
+ actual, err := r.Persister().GetLoginRequest(s.t2, challenge)
require.Error(t, err)
require.Nil(t, actual)
- actual, err = r.Persister().GetLoginRequest(s.t1, lr.ID)
+ actual, err = r.Persister().GetLoginRequest(s.t1, challenge)
require.NoError(t, err)
require.NotNil(t, actual)
})
@@ -1203,7 +1209,7 @@ func (s *PersisterTestSuite) TestGetLogoutRequest() {
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
client := &client.Client{LegacyClientID: "client-id"}
- lr := consent.LogoutRequest{
+ lr := flow.LogoutRequest{
ID: uuid.Must(uuid.NewV4()).String(),
ClientID: sql.NullString{Valid: true, String: client.LegacyClientID},
}
@@ -1213,11 +1219,11 @@ func (s *PersisterTestSuite) TestGetLogoutRequest() {
actual, err := r.Persister().GetLogoutRequest(s.t2, lr.ID)
require.Error(t, err)
- require.Equal(t, &consent.LogoutRequest{}, actual)
+ require.Equal(t, &flow.LogoutRequest{}, actual)
actual, err = r.Persister().GetLogoutRequest(s.t1, lr.ID)
require.NoError(t, err)
- require.NotEqual(t, &consent.LogoutRequest{}, actual)
+ require.NotEqual(t, &flow.LogoutRequest{}, actual)
})
}
}
@@ -1368,14 +1374,14 @@ func (s *PersisterTestSuite) TestGetRememberedLoginSession() {
t := s.T()
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
- ls := consent.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true}
+ ls := flow.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true}
require.NoError(t, r.Persister().CreateLoginSession(s.t1, &ls))
- actual, err := r.Persister().GetRememberedLoginSession(s.t2, ls.ID)
+ actual, err := r.Persister().GetRememberedLoginSession(s.t2, &ls, ls.ID)
require.Error(t, err)
require.Nil(t, actual)
- actual, err = r.Persister().GetRememberedLoginSession(s.t1, ls.ID)
+ actual, err = r.Persister().GetRememberedLoginSession(s.t1, &ls, ls.ID)
require.NoError(t, err)
require.NotNil(t, actual)
})
@@ -1389,13 +1395,12 @@ func (s *PersisterTestSuite) TestHandleConsentRequest() {
sessionID := uuid.Must(uuid.NewV4()).String()
c1 := &client.Client{LegacyClientID: uuidx.NewV4().String()}
f := newFlow(s.t1NID, c1.LegacyClientID, "sub", sqlxx.NullString(sessionID))
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, c1))
c1.ID = uuid.Nil
require.NoError(t, r.Persister().CreateClient(s.t2, c1))
- require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
- req := &consent.OAuth2ConsentRequest{
+ req := &flow.OAuth2ConsentRequest{
ID: "consent-request-id",
LoginChallenge: sqlxx.NullString(f.ID),
Skip: false,
@@ -1403,23 +1408,24 @@ func (s *PersisterTestSuite) TestHandleConsentRequest() {
CSRF: "csrf",
}
- hcr := &consent.AcceptOAuth2ConsentRequest{
+ hcr := &flow.AcceptOAuth2ConsentRequest{
ID: req.ID,
HandledAt: sqlxx.NullTime(time.Now()),
Remember: true,
}
- require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req))
- actualCR, err := r.Persister().HandleConsentRequest(s.t2, hcr)
+ actualCR, err := r.Persister().HandleConsentRequest(s.t2, f, hcr)
require.Error(t, err)
require.Nil(t, actualCR)
actual, err := r.Persister().FindGrantedAndRememberedConsentRequests(s.t1, c1.LegacyClientID, f.Subject)
require.Error(t, err)
require.Equal(t, 0, len(actual))
- actualCR, err = r.Persister().HandleConsentRequest(s.t1, hcr)
+ actualCR, err = r.Persister().HandleConsentRequest(s.t1, f, hcr)
require.NoError(t, err)
require.NotNil(t, actualCR)
+ require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
actual, err = r.Persister().FindGrantedAndRememberedConsentRequests(s.t1, c1.LegacyClientID, f.Subject)
require.NoError(t, err)
require.Equal(t, 1, len(actual))
@@ -1497,21 +1503,51 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo
t2f2.LoginVerifier = "t2f2-login-verifier"
t2f2.ConsentVerifier = "t2f2-consent-verifier"
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: t1f1.SessionID.String()}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()})
require.NoError(t, r.Persister().Connection(context.Background()).Create(t1f1))
require.NoError(t, r.Persister().Connection(context.Background()).Create(t2f1))
require.NoError(t, r.Persister().Connection(context.Background()).Create(t2f2))
- require.NoError(t, r.Persister().CreateConsentRequest(s.t1, &consent.OAuth2ConsentRequest{ID: t1f1.ID, LoginChallenge: sqlxx.NullString(t1f1.ID), Skip: false, Verifier: t1f1.ConsentVerifier.String(), CSRF: "csrf"}))
- require.NoError(t, r.Persister().CreateConsentRequest(s.t2, &consent.OAuth2ConsentRequest{ID: t2f1.ID, LoginChallenge: sqlxx.NullString(t2f1.ID), Skip: false, Verifier: t2f1.ConsentVerifier.String(), CSRF: "csrf"}))
- require.NoError(t, r.Persister().CreateConsentRequest(s.t2, &consent.OAuth2ConsentRequest{ID: t2f2.ID, LoginChallenge: sqlxx.NullString(t2f2.ID), Skip: false, Verifier: t2f2.ConsentVerifier.String(), CSRF: "csrf"}))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t1, t1f1, &flow.OAuth2ConsentRequest{
+ ID: t1f1.ID,
+ LoginChallenge: sqlxx.NullString(t1f1.ID),
+ Skip: false,
+ Verifier: t1f1.ConsentVerifier.String(),
+ CSRF: "csrf",
+ }))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t2, t2f1, &flow.OAuth2ConsentRequest{
+ ID: t2f1.ID,
+ LoginChallenge: sqlxx.NullString(t2f1.ID),
+ Skip: false,
+ Verifier: t2f1.ConsentVerifier.String(),
+ CSRF: "csrf",
+ }))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t2, t2f2, &flow.OAuth2ConsentRequest{
+ ID: t2f2.ID,
+ LoginChallenge: sqlxx.NullString(t2f2.ID),
+ Skip: false,
+ Verifier: t2f2.ConsentVerifier.String(),
+ CSRF: "csrf",
+ }))
- _, err := r.Persister().HandleConsentRequest(s.t1, &consent.AcceptOAuth2ConsentRequest{ID: t1f1.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true})
+ _, err := r.Persister().HandleConsentRequest(s.t1, t1f1, &flow.AcceptOAuth2ConsentRequest{
+ ID: t1f1.ID,
+ HandledAt: sqlxx.NullTime(time.Now()),
+ Remember: true,
+ })
require.NoError(t, err)
- _, err = r.Persister().HandleConsentRequest(s.t2, &consent.AcceptOAuth2ConsentRequest{ID: t2f1.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true})
+ _, err = r.Persister().HandleConsentRequest(s.t2, t2f1, &flow.AcceptOAuth2ConsentRequest{
+ ID: t2f1.ID,
+ HandledAt: sqlxx.NullTime(time.Now()),
+ Remember: true,
+ })
require.NoError(t, err)
- _, err = r.Persister().HandleConsentRequest(s.t2, &consent.AcceptOAuth2ConsentRequest{ID: t2f2.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true})
+ _, err = r.Persister().HandleConsentRequest(s.t2, t2f2, &flow.AcceptOAuth2ConsentRequest{
+ ID: t2f2.ID,
+ HandledAt: sqlxx.NullTime(time.Now()),
+ Remember: true,
+ })
require.NoError(t, err)
cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String())
@@ -1551,21 +1587,51 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog
t2f2.LoginVerifier = "t2f2-login-verifier"
t2f2.ConsentVerifier = "t2f2-consent-verifier"
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: t1f1.SessionID.String()}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()})
require.NoError(t, r.Persister().Connection(context.Background()).Create(t1f1))
require.NoError(t, r.Persister().Connection(context.Background()).Create(t2f1))
require.NoError(t, r.Persister().Connection(context.Background()).Create(t2f2))
- require.NoError(t, r.Persister().CreateConsentRequest(s.t1, &consent.OAuth2ConsentRequest{ID: t1f1.ID, LoginChallenge: sqlxx.NullString(t1f1.ID), Skip: false, Verifier: t1f1.ConsentVerifier.String(), CSRF: "csrf"}))
- require.NoError(t, r.Persister().CreateConsentRequest(s.t2, &consent.OAuth2ConsentRequest{ID: t2f1.ID, LoginChallenge: sqlxx.NullString(t2f1.ID), Skip: false, Verifier: t2f1.ConsentVerifier.String(), CSRF: "csrf"}))
- require.NoError(t, r.Persister().CreateConsentRequest(s.t2, &consent.OAuth2ConsentRequest{ID: t2f2.ID, LoginChallenge: sqlxx.NullString(t2f2.ID), Skip: false, Verifier: t2f2.ConsentVerifier.String(), CSRF: "csrf"}))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t1, t1f1, &flow.OAuth2ConsentRequest{
+ ID: t1f1.ID,
+ LoginChallenge: sqlxx.NullString(t1f1.ID),
+ Skip: false,
+ Verifier: t1f1.ConsentVerifier.String(),
+ CSRF: "csrf",
+ }))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t2, t2f1, &flow.OAuth2ConsentRequest{
+ ID: t2f1.ID,
+ LoginChallenge: sqlxx.NullString(t2f1.ID),
+ Skip: false,
+ Verifier: t2f1.ConsentVerifier.String(),
+ CSRF: "csrf",
+ }))
+ require.NoError(t, r.Persister().CreateConsentRequest(s.t2, t2f2, &flow.OAuth2ConsentRequest{
+ ID: t2f2.ID,
+ LoginChallenge: sqlxx.NullString(t2f2.ID),
+ Skip: false,
+ Verifier: t2f2.ConsentVerifier.String(),
+ CSRF: "csrf",
+ }))
- _, err := r.Persister().HandleConsentRequest(s.t1, &consent.AcceptOAuth2ConsentRequest{ID: t1f1.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true})
+ _, err := r.Persister().HandleConsentRequest(s.t1, t1f1, &flow.AcceptOAuth2ConsentRequest{
+ ID: t1f1.ID,
+ HandledAt: sqlxx.NullTime(time.Now()),
+ Remember: true,
+ })
require.NoError(t, err)
- _, err = r.Persister().HandleConsentRequest(s.t2, &consent.AcceptOAuth2ConsentRequest{ID: t2f1.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true})
+ _, err = r.Persister().HandleConsentRequest(s.t2, t2f1, &flow.AcceptOAuth2ConsentRequest{
+ ID: t2f1.ID,
+ HandledAt: sqlxx.NullTime(time.Now()),
+ Remember: true,
+ })
require.NoError(t, err)
- _, err = r.Persister().HandleConsentRequest(s.t2, &consent.AcceptOAuth2ConsentRequest{ID: t2f2.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true})
+ _, err = r.Persister().HandleConsentRequest(s.t2, t2f2, &flow.AcceptOAuth2ConsentRequest{
+ ID: t2f2.ID,
+ HandledAt: sqlxx.NullTime(time.Now()),
+ Remember: true,
+ })
require.NoError(t, err)
cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String())
@@ -1639,7 +1705,7 @@ func (s *PersisterTestSuite) TestRejectLogoutRequest() {
require.NoError(t, r.ConsentManager().RejectLogoutRequest(s.t1, lr.ID))
actual, err = r.ConsentManager().GetLogoutRequest(s.t1, lr.ID)
require.Error(t, err)
- require.Equal(t, &consent.LogoutRequest{}, actual)
+ require.Equal(t, &flow.LogoutRequest{}, actual)
})
}
}
@@ -1729,7 +1795,7 @@ func (s *PersisterTestSuite) TestRevokeSubjectClientConsentSession() {
client := &client.Client{LegacyClientID: "client-id"}
f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID))
f.RequestedAt = time.Now().Add(-24 * time.Hour)
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
require.NoError(t, r.Persister().CreateClient(s.t1, client))
require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
@@ -1900,7 +1966,7 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateConsentRequest() {
t.Run(k, func(t *testing.T) {
sub := uuid.Must(uuid.NewV4()).String()
sessionID := uuid.Must(uuid.NewV4()).String()
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
client := &client.Client{LegacyClientID: "client-id"}
require.NoError(t, r.Persister().CreateClient(s.t1, client))
f := newFlow(s.t1NID, client.LegacyClientID, sub, sqlxx.NullString(sessionID))
@@ -1909,24 +1975,22 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateConsentRequest() {
f.ConsentRemember = false
crf := 86400
f.ConsentRememberFor = &crf
- f.ConsentError = &consent.RequestDeniedError{}
+ f.ConsentError = &flow.RequestDeniedError{}
f.SessionAccessToken = map[string]interface{}{}
f.SessionIDToken = map[string]interface{}{}
f.ConsentWasHandled = false
f.State = flow.FlowStateConsentUnused
- require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
- actual := &flow.Flow{}
- _, err := r.ConsentManager().VerifyAndInvalidateConsentRequest(s.t2, f.ConsentVerifier.String())
+ consentVerifier := x.Must(f.ToConsentVerifier(s.t1, r))
+
+ _, err := r.ConsentManager().VerifyAndInvalidateConsentRequest(s.t2, f, consentVerifier)
require.Error(t, err)
- require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, f.ID))
- require.Equal(t, flow.FlowStateConsentUnused, actual.State)
- require.Equal(t, false, actual.ConsentWasHandled)
- _, err = r.ConsentManager().VerifyAndInvalidateConsentRequest(s.t1, f.ConsentVerifier.String())
+ require.Equal(t, flow.FlowStateConsentUnused, f.State)
+ require.Equal(t, false, f.ConsentWasHandled)
+ _, err = r.ConsentManager().VerifyAndInvalidateConsentRequest(s.t1, f, consentVerifier)
require.NoError(t, err)
- require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, f.ID))
- require.Equal(t, flow.FlowStateConsentUsed, actual.State)
- require.Equal(t, true, actual.ConsentWasHandled)
+ require.Equal(t, flow.FlowStateConsentUsed, f.State)
+ require.Equal(t, true, f.ConsentWasHandled)
})
}
}
@@ -1937,24 +2001,21 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateLoginRequest() {
t.Run(k, func(t *testing.T) {
sub := uuid.Must(uuid.NewV4()).String()
sessionID := uuid.Must(uuid.NewV4()).String()
- require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID}))
+ persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID})
client := &client.Client{LegacyClientID: "client-id"}
require.NoError(t, r.Persister().CreateClient(s.t1, client))
f := newFlow(s.t1NID, client.LegacyClientID, sub, sqlxx.NullString(sessionID))
f.State = flow.FlowStateLoginUnused
- require.NoError(t, r.Persister().Connection(context.Background()).Create(f))
- actual := &flow.Flow{}
- _, err := r.ConsentManager().VerifyAndInvalidateLoginRequest(s.t2, f.LoginVerifier)
+ loginVerifier := x.Must(f.ToLoginVerifier(s.t1, r))
+ _, err := r.ConsentManager().VerifyAndInvalidateLoginRequest(s.t2, f, loginVerifier)
require.Error(t, err)
- require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, f.ID))
- require.Equal(t, flow.FlowStateLoginUnused, actual.State)
- require.Equal(t, false, actual.LoginWasUsed)
- _, err = r.ConsentManager().VerifyAndInvalidateLoginRequest(s.t1, f.LoginVerifier)
+ require.Equal(t, flow.FlowStateLoginUnused, f.State)
+ require.Equal(t, false, f.LoginWasUsed)
+ _, err = r.ConsentManager().VerifyAndInvalidateLoginRequest(s.t1, f, loginVerifier)
require.NoError(t, err)
- require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, f.ID))
- require.Equal(t, flow.FlowStateLoginUsed, actual.State)
- require.Equal(t, true, actual.LoginWasUsed)
+ require.Equal(t, flow.FlowStateLoginUsed, f.State)
+ require.Equal(t, true, f.LoginWasUsed)
})
}
}
@@ -1974,8 +2035,8 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateLogoutRequest() {
lrInvalidated, err := r.ConsentManager().VerifyAndInvalidateLogoutRequest(s.t2, lr.Verifier)
require.Error(t, err)
- require.Equal(t, &consent.LogoutRequest{}, lrInvalidated)
- actual := &consent.LogoutRequest{}
+ require.Nil(t, lrInvalidated)
+ actual := &flow.LogoutRequest{}
require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, lr.ID))
require.Equal(t, expected, actual)
@@ -2026,9 +2087,9 @@ func newFlow(nid uuid.UUID, clientID string, subject string, sessionID sqlxx.Nul
ID: uuid.Must(uuid.NewV4()).String(),
ClientID: clientID,
Subject: subject,
- ConsentError: &consent.RequestDeniedError{},
+ ConsentError: &flow.RequestDeniedError{},
State: flow.FlowStateConsentUnused,
- LoginError: &consent.RequestDeniedError{},
+ LoginError: &flow.RequestDeniedError{},
Context: sqlxx.JSONRawMessage{},
AMR: sqlxx.StringSliceJSONFormat{},
ConsentChallengeID: sqlxx.NullString("not-null"),
@@ -2050,8 +2111,8 @@ func newGrant(keySet string, keyID string) trust.Grant {
}
}
-func newLogoutRequest() *consent.LogoutRequest {
- return &consent.LogoutRequest{
+func newLogoutRequest() *flow.LogoutRequest {
+ return &flow.LogoutRequest{
ID: uuid.Must(uuid.NewV4()).String(),
}
}
@@ -2065,15 +2126,11 @@ func newKey(ksID string, use string) jose.JSONWebKey {
}
func newKeySet(id string, use string) *jose.JSONWebKeySet {
- ks, err := jwk.GenerateJWK(context.Background(), jose.RS256, id, use)
- if err != nil {
- panic(err)
- }
- return ks
+ return x.Must(jwk.GenerateJWK(context.Background(), jose.RS256, id, use))
}
-func newLoginSession() *consent.LoginSession {
- return &consent.LoginSession{
+func newLoginSession() *flow.LoginSession {
+ return &flow.LoginSession{
ID: uuid.Must(uuid.NewV4()).String(),
AuthenticatedAt: sqlxx.NullTime(time.Time{}),
Subject: uuid.Must(uuid.NewV4()).String(),
@@ -2084,3 +2141,9 @@ func newLoginSession() *consent.LoginSession {
func requireKeySetEqual(t *testing.T, expected *jose.JSONWebKeySet, actual *jose.JSONWebKeySet) {
assertx.EqualAsJSON(t, expected, actual)
}
+
+func persistLoginSession(ctx context.Context, t *testing.T, p persistence.Persister, session *flow.LoginSession) {
+ t.Helper()
+ require.NoError(t, p.CreateLoginSession(ctx, session))
+ require.NoError(t, p.Connection(ctx).Create(session))
+}
diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go
index e2ae3887516..e240e89d50b 100644
--- a/persistence/sql/persister_oauth2.go
+++ b/persistence/sql/persister_oauth2.go
@@ -13,21 +13,17 @@ import (
"strings"
"time"
- "github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
-
- "github.com/ory/x/errorsx"
-
- "github.com/ory/fosite/storage"
-
"github.com/pkg/errors"
"github.com/tidwall/gjson"
"github.com/ory/fosite"
+ "github.com/ory/fosite/storage"
+ "github.com/ory/hydra/v2/oauth2"
+ "github.com/ory/x/errorsx"
+ "github.com/ory/x/otelx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/stringsx"
-
- "github.com/ory/hydra/v2/oauth2"
)
var _ oauth2.AssertionJWTReader = &Persister{}
@@ -80,7 +76,7 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin
}
if p.config.EncryptSessionData(ctx) {
- ciphertext, err := p.r.KeyCipher().Encrypt(ctx, session)
+ ciphertext, err := p.r.KeyCipher().Encrypt(ctx, session, nil)
if err != nil {
return nil, errorsx.WithStack(err)
}
@@ -115,14 +111,14 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin
}, nil
}
-func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (*fosite.Request, error) {
+func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (_ *fosite.Request, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.toRequest")
- defer span.End()
+ defer otelx.End(span, &err)
sess := r.Session
if !gjson.ValidBytes(sess) {
var err error
- sess, err = p.r.KeyCipher().Decrypt(ctx, string(sess))
+ sess, err = p.r.KeyCipher().Decrypt(ctx, string(sess), nil)
if err != nil {
return nil, errorsx.WithStack(err)
}
@@ -173,9 +169,9 @@ func (p *Persister) hashSignature(_ context.Context, signature string, table tab
return signature
}
-func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) error {
+func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid")
- defer span.End()
+ defer otelx.End(span, &err)
j, err := p.GetClientAssertionJWT(ctx, jti)
if errors.Is(err, sqlcon.ErrNoRows) {
@@ -192,9 +188,9 @@ func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) err
return nil
}
-func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) error {
+func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetClientAssertionJWT")
- defer span.End()
+ defer otelx.End(span, &err)
// delete expired; this cleanup spares us the need for a background worker
if err := p.QueryWithNetwork(ctx).Where("expires_at < CURRENT_TIMESTAMP").Delete(&oauth2.BlacklistedJTI{}); err != nil {
@@ -212,31 +208,31 @@ func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp t
return nil
}
-func (p *Persister) GetClientAssertionJWT(ctx context.Context, j string) (*oauth2.BlacklistedJTI, error) {
+func (p *Persister) GetClientAssertionJWT(ctx context.Context, j string) (_ *oauth2.BlacklistedJTI, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetClientAssertionJWT")
- defer span.End()
+ defer otelx.End(span, &err)
jti := oauth2.NewBlacklistedJTI(j, time.Time{})
return jti, sqlcon.HandleError(p.QueryWithNetwork(ctx).Find(jti, jti.ID))
}
-func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.BlacklistedJTI) error {
+func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.BlacklistedJTI) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetClientAssertionJWTRaw")
- defer span.End()
+ defer otelx.End(span, &err)
return sqlcon.HandleError(p.CreateWithNetwork(ctx, jti))
}
-func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) error {
+func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createSession")
- defer span.End()
+ defer otelx.End(span, &err)
req, err := p.sqlSchemaFromRequest(ctx, signature, requester, table)
if err != nil {
return err
}
- if err := sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) {
+ if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
} else if err != nil {
return err
@@ -244,44 +240,39 @@ func (p *Persister) createSession(ctx context.Context, signature string, request
return nil
}
-func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (fosite.Requester, error) {
+func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature")
- defer span.End()
+ defer otelx.End(span, &err)
r := OAuth2RequestSQL{Table: table}
- var fr fosite.Requester
-
- return fr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
- // We look for the signature as well as the hash of the signature here.
- // This is because we now always store the hash of the signature in the database,
- // regardless of the type of the signature. In previous versions, we only stored
- // the hash of the signature for JWT tokens.
- //
- // This code will be removed in a future version.
- err := p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r)
- if errors.Is(err, sql.ErrNoRows) {
- return errorsx.WithStack(fosite.ErrNotFound)
- } else if err != nil {
- return sqlcon.HandleError(err)
- } else if !r.Active {
- fr, err = r.toRequest(ctx, session, p)
- if err != nil {
- return err
- } else if table == sqlTableCode {
- return errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode)
- }
-
- return errorsx.WithStack(fosite.ErrInactiveToken)
+
+ // We look for the signature as well as the hash of the signature here.
+ // This is because we now always store the hash of the signature in the database,
+ // regardless of the type of the signature. In previous versions, we only stored
+ // the hash of the signature for JWT tokens.
+ //
+ // This code will be removed in a future version.
+ err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r)
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errorsx.WithStack(fosite.ErrNotFound)
+ } else if err != nil {
+ return nil, sqlcon.HandleError(err)
+ } else if !r.Active {
+ fr, err := r.toRequest(ctx, session, p)
+ if err != nil {
+ return nil, err
+ } else if table == sqlTableCode {
+ return fr, errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode)
}
+ return fr, errorsx.WithStack(fosite.ErrInactiveToken)
+ }
- fr, err = r.toRequest(ctx, session, p)
- return err
- })
+ return r.toRequest(ctx, session, p)
}
-func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error {
+func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature")
- defer span.End()
+ defer otelx.End(span, &err)
signature = p.hashSignature(ctx, signature, table)
@@ -291,7 +282,7 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri
// the hash of the signature for JWT tokens.
//
// This code will be removed in a future version.
- err := sqlcon.HandleError(
+ err = sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature IN (?, ?)", signature, SignatureHash(signature)).
Delete(&OAuth2RequestSQL{Table: table}))
@@ -306,9 +297,9 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri
return nil
}
-func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) error {
+func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionByRequestID")
- defer span.End()
+ defer otelx.End(span, &err)
/* #nosec G201 table is static */
if err := p.QueryWithNetwork(ctx).
@@ -326,9 +317,9 @@ func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, tab
return nil
}
-func (p *Persister) deactivateSessionByRequestID(ctx context.Context, id string, table tableName) error {
+func (p *Persister) deactivateSessionByRequestID(ctx context.Context, id string, table tableName) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deactivateSessionByRequestID")
- defer span.End()
+ defer otelx.End(span, &err)
/* #nosec G201 table is static */
return sqlcon.HandleError(
@@ -342,23 +333,22 @@ func (p *Persister) deactivateSessionByRequestID(ctx context.Context, id string,
)
}
-func (p *Persister) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) (err error) {
- ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateAuthorizeCodeSession")
- defer span.End()
-
- return p.createSession(ctx, signature, requester, sqlTableCode)
+func (p *Persister) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error {
+ return otelx.WithSpan(ctx, "persistence.sql.CreateAuthorizeCodeSession", func(ctx context.Context) error {
+ return p.createSession(ctx, signature, requester, sqlTableCode)
+ })
}
func (p *Persister) GetAuthorizeCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAuthorizeCodeSession")
- defer span.End()
+ defer otelx.End(span, &err)
return p.findSessionBySignature(ctx, signature, session, sqlTableCode)
}
func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateAuthorizeCodeSession")
- defer span.End()
+ defer otelx.End(span, &err)
/* #nosec G201 table is static */
return sqlcon.HandleError(
@@ -372,67 +362,97 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur
)
}
-func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) {
- return p.createSession(ctx, signature, requester, sqlTableAccess)
+func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) error {
+ return otelx.WithSpan(ctx, "persistence.sql.CreateAccessTokenSession", func(ctx context.Context) error {
+ return p.createSession(ctx, signature, requester, sqlTableAccess)
+ })
}
func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
+ ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession")
+ defer otelx.End(span, &err)
return p.findSessionBySignature(ctx, signature, session, sqlTableAccess)
}
-func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) {
- return p.deleteSessionBySignature(ctx, signature, sqlTableAccess)
+func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) error {
+ return otelx.WithSpan(ctx, "persistence.sql.DeleteAccessTokenSession", func(ctx context.Context) error {
+ return p.deleteSessionBySignature(ctx, signature, sqlTableAccess)
+ })
}
-func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) {
- return p.createSession(ctx, signature, requester, sqlTableRefresh)
+func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, requester fosite.Requester) error {
+ return otelx.WithSpan(ctx, "persistence.sql.CreateRefreshTokenSession", func(ctx context.Context) error {
+ return p.createSession(ctx, signature, requester, sqlTableRefresh)
+ })
}
func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
+ ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession")
+ defer otelx.End(span, &err)
return p.findSessionBySignature(ctx, signature, session, sqlTableRefresh)
}
-func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) {
- return p.deleteSessionBySignature(ctx, signature, sqlTableRefresh)
+func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) error {
+ return otelx.WithSpan(ctx, "persistence.sql.DeleteRefreshTokenSession", func(ctx context.Context) error {
+ return p.deleteSessionBySignature(ctx, signature, sqlTableRefresh)
+ })
}
func (p *Persister) CreateOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) error {
- return p.createSession(ctx, signature, requester, sqlTableOpenID)
+ return otelx.WithSpan(ctx, "persistence.sql.CreateOpenIDConnectSession", func(ctx context.Context) error {
+ return p.createSession(ctx, signature, requester, sqlTableOpenID)
+ })
}
-func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (fosite.Requester, error) {
+func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (_ fosite.Requester, err error) {
+ ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetOpenIDConnectSession")
+ defer otelx.End(span, &err)
return p.findSessionBySignature(ctx, signature, requester.GetSession(), sqlTableOpenID)
}
func (p *Persister) DeleteOpenIDConnectSession(ctx context.Context, signature string) error {
- return p.deleteSessionBySignature(ctx, signature, sqlTableOpenID)
+ return otelx.WithSpan(ctx, "persistence.sql.DeleteOpenIDConnectSession", func(ctx context.Context) error {
+ return p.deleteSessionBySignature(ctx, signature, sqlTableOpenID)
+ })
}
-func (p *Persister) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) {
+func (p *Persister) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.Requester, err error) {
+ ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPKCERequestSession")
+ defer otelx.End(span, &err)
return p.findSessionBySignature(ctx, signature, session, sqlTablePKCE)
}
func (p *Persister) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error {
- return p.createSession(ctx, signature, requester, sqlTablePKCE)
+ return otelx.WithSpan(ctx, "persistence.sql.CreatePKCERequestSession", func(ctx context.Context) error {
+ return p.createSession(ctx, signature, requester, sqlTablePKCE)
+ })
}
func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature string) error {
- return p.deleteSessionBySignature(ctx, signature, sqlTablePKCE)
+ return otelx.WithSpan(ctx, "persistence.sql.DeletePKCERequestSession", func(ctx context.Context) error {
+ return p.deleteSessionBySignature(ctx, signature, sqlTablePKCE)
+ })
}
func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) error {
- return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh)
+ return otelx.WithSpan(ctx, "persistence.sql.RevokeRefreshToken", func(ctx context.Context) error {
+ return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh)
+ })
}
func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) error {
- return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh)
+ return otelx.WithSpan(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod", func(ctx context.Context) error {
+ return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh)
+ })
}
func (p *Persister) RevokeAccessToken(ctx context.Context, id string) error {
- return p.deleteSessionByRequestID(ctx, id, sqlTableAccess)
+ return otelx.WithSpan(ctx, "persistence.sql.RevokeAccessToken", func(ctx context.Context) error {
+ return p.deleteSessionByRequestID(ctx, id, sqlTableAccess)
+ })
}
-func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) error {
+func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) {
/* #nosec G201 table is static */
// The value of notAfter should be the minimum between input parameter and token max expire based on its configured age
requestMaxExpire := time.Now().Add(-lifespan)
@@ -440,8 +460,6 @@ func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time,
notAfter = requestMaxExpire
}
- var err error
-
totalDeletedCount := 0
for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; {
d := batchSize
@@ -469,16 +487,22 @@ func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time,
}
func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error {
- return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx))
+ return otelx.WithSpan(ctx, "persistence.sql.FlushInactiveAccessTokens", func(ctx context.Context) error {
+ return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx))
+ })
}
func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error {
- return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx))
+ return otelx.WithSpan(ctx, "persistence.sql.FlushInactiveRefreshTokens", func(ctx context.Context) error {
+ return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx))
+ })
}
func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) error {
- /* #nosec G201 table is static */
- return sqlcon.HandleError(
- p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}),
- )
+ return otelx.WithSpan(ctx, "persistence.sql.DeleteAccessTokens", func(ctx context.Context) error {
+ /* #nosec G201 table is static */
+ return sqlcon.HandleError(
+ p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}),
+ )
+ })
}
diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go
index 475d32b88a8..ad71b374909 100644
--- a/persistence/sql/persister_test.go
+++ b/persistence/sql/persister_test.go
@@ -12,9 +12,9 @@ import (
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
- "github.com/instana/testify/assert"
- "github.com/instana/testify/require"
"github.com/pkg/errors"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/consent"
@@ -52,12 +52,12 @@ func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registr
parallel = false
}
- t.Run("package=consent/manager="+k, consent.ManagerTests(t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel))
- t.Run("package=consent/manager="+k, consent.ManagerTests(t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), "t2", parallel))
+ t.Run("package=consent/manager="+k, consent.ManagerTests(t1, t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel))
+ t.Run("package=consent/manager="+k, consent.ManagerTests(t2, t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), "t2", parallel))
t.Run("parallel-boundary", func(t *testing.T) {
- t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t1.Config(), t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel))
- t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t2.Config(), t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), "t2", parallel))
+ t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t1, "t1", parallel))
+ t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t2, "t2", parallel))
})
t.Run("package=jwk/manager="+k, func(t *testing.T) {
@@ -186,7 +186,7 @@ func TestManagers(t *testing.T) {
)
}
t.Run("package=consent/manager="+k+"/case=nid",
- consent.TestHelperNID(t1.ClientManager(), t1.ConsentManager(), t2.ConsentManager()),
+ consent.TestHelperNID(t1, t1.ConsentManager(), t2.ConsentManager()),
)
}
}
diff --git a/test/conformance/hydra/Dockerfile b/test/conformance/hydra/Dockerfile
index df86aefa45b..58eb6d8155a 100644
--- a/test/conformance/hydra/Dockerfile
+++ b/test/conformance/hydra/Dockerfile
@@ -1,4 +1,4 @@
-FROM golang:1.19-buster AS builder
+FROM golang:1.20-buster AS builder
RUN apt-get update && \
apt-get install --no-install-recommends -y \
diff --git a/x/doc_swagger.go b/x/doc_swagger.go
index af7944ca09d..5c8fb350e8b 100644
--- a/x/doc_swagger.go
+++ b/x/doc_swagger.go
@@ -7,11 +7,15 @@ package x
// typically 201.
//
// swagger:response emptyResponse
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type emptyResponse struct{}
// Error
//
// swagger:model errorOAuth2
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type errorOAuth2 struct {
// Error
Name string `json:"error"`
@@ -40,6 +44,8 @@ type errorOAuth2 struct {
// Default Error Response
//
// swagger:response errorOAuth2Default
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type errorOAuth2Default struct {
// in: body
Body errorOAuth2
@@ -48,6 +54,8 @@ type errorOAuth2Default struct {
// Bad Request Error Response
//
// swagger:response errorOAuth2BadRequest
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type errorOAuth2BadRequest struct {
// in: body
Body errorOAuth2
@@ -56,6 +64,8 @@ type errorOAuth2BadRequest struct {
// Not Found Error Response
//
// swagger:response errorOAuth2NotFound
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type errorOAuth2NotFound struct {
// in: body
Body errorOAuth2
diff --git a/x/errors.go b/x/errors.go
index 229884a5d51..f90802bf50a 100644
--- a/x/errors.go
+++ b/x/errors.go
@@ -31,3 +31,10 @@ func LogError(r *http.Request, err error, logger *logrusx.Logger) {
logger.WithRequest(r).
WithError(err).Errorln("An error occurred")
}
+
+func Must[T any](t T, err error) T {
+ if err != nil {
+ panic(err)
+ }
+ return t
+}
diff --git a/x/sqlx.go b/x/sqlx.go
index 7ca0e5a727d..0b90b923665 100644
--- a/x/sqlx.go
+++ b/x/sqlx.go
@@ -71,6 +71,8 @@ func (ns *Duration) UnmarshalJSON(data []byte) error {
}
// swagger:model NullDuration
+//
+//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions
type swaggerNullDuration string
// NullDuration represents a nullable JSON and SQL compatible time.Duration.