From 8f5192fbb74c4b952029a6856284de8d59027770 Mon Sep 17 00:00:00 2001 From: Jonas Hungershausen Date: Thu, 22 Feb 2024 11:28:43 +0100 Subject: [PATCH] fix: ignore decrypt errors in WithDeclassifiedCredentials (#3731) --- cipher/chacha20.go | 2 +- driver/registry_default.go | 2 +- ...Credentials-case=oidc-credential=oidc.json | 22 ++++++++ ...entials-case=oidc-credential=password.json | 10 ++++ ...entials-case=oidc-credential=webauthn.json | 10 ++++ identity/handler_test.go | 48 +++++++++++++---- identity/identity.go | 51 +++++++++---------- identity/identity_test.go | 27 ++++++++-- 8 files changed, 131 insertions(+), 41 deletions(-) create mode 100644 identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=oidc.json create mode 100644 identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=password.json create mode 100644 identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=webauthn.json diff --git a/cipher/chacha20.go b/cipher/chacha20.go index 6ee73ea055f8..46cf1efc85d9 100644 --- a/cipher/chacha20.go +++ b/cipher/chacha20.go @@ -72,7 +72,7 @@ func (c *XChaCha20Poly1305) Decrypt(ctx context.Context, ciphertext string) ([]b for i := range secrets { aead, err := chacha20poly1305.NewX(secrets[i][:]) if err != nil { - return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReason("Unable to instanciate chacha20")) + return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReason("Unable to instantiate chacha20")) } if len(ciphertext) < aead.NonceSize() { diff --git a/driver/registry_default.go b/driver/registry_default.go index 9317846d81f0..f622d4c20336 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -473,7 +473,7 @@ func (m *RegistryDefault) Cipher(ctx context.Context) cipher.Cipher { m.crypter = cipher.NewCryptAES(m) default: m.crypter = cipher.NewNoop(m) - m.l.Logger.Warning("No encryption configuration found. Default algorithm (noop) will be use that mean sensitive data will be recorded in plaintext") + m.l.Logger.Warning("No encryption configuration found. The default algorithm (noop) will be used, resulting in sensitive data being stored in plaintext") } } return m.crypter diff --git a/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=oidc.json b/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=oidc.json new file mode 100644 index 000000000000..a967e155d02a --- /dev/null +++ b/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=oidc.json @@ -0,0 +1,22 @@ +{ + "type": "oidc", + "identifiers": [ + "bar", + "baz" + ], + "config": { + "providers": [ + { + "initial_id_token": "foo", + "initial_access_token": "", + "initial_refresh_token": "", + "subject": "", + "provider": "", + "organization": "" + } + ] + }, + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=password.json b/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=password.json new file mode 100644 index 000000000000..1939a8fe4f71 --- /dev/null +++ b/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=password.json @@ -0,0 +1,10 @@ +{ + "type": "password", + "identifiers": [ + "zab", + "bar" + ], + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=webauthn.json b/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=webauthn.json new file mode 100644 index 000000000000..1b7dcd8f6204 --- /dev/null +++ b/identity/.snapshots/TestWithDeclassifiedCredentials-case=oidc-credential=webauthn.json @@ -0,0 +1,10 @@ +{ + "type": "webauthn", + "identifiers": [ + "foo", + "bar" + ], + "version": 0, + "created_at": "0001-01-01T00:00:00Z", + "updated_at": "0001-01-01T00:00:00Z" +} diff --git a/identity/handler_test.go b/identity/handler_test.go index c28d67638266..1de5f6d0864b 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -364,17 +364,19 @@ func TestHandler(t *testing.T) { identities := res.Array() require.Equal(t, len(identities), listAmount) }) - }) t.Run("suite=create and update", func(t *testing.T) { var i identity.Identity createOidcIdentity := func(t *testing.T, identifier, accessToken, refreshToken, idToken string, encrypt bool) string { - transform := func(token string) string { + transform := func(token, suffix string) string { if !encrypt { return token } - c, err := reg.Cipher(ctx).Encrypt(context.Background(), []byte(token)) + if token == "" { + return "" + } + c, err := reg.Cipher(ctx).Encrypt(context.Background(), []byte(token+suffix)) require.NoError(t, err) return c } @@ -396,16 +398,16 @@ func TestHandler(t *testing.T) { { Subject: "foo", Provider: "bar", - InitialAccessToken: transform(accessToken + "0"), - InitialRefreshToken: transform(refreshToken + "0"), - InitialIDToken: transform(idToken + "0"), + InitialAccessToken: transform(accessToken, "0"), + InitialRefreshToken: transform(refreshToken, "0"), + InitialIDToken: transform(idToken, "0"), }, { Subject: "baz", Provider: "zab", - InitialAccessToken: transform(accessToken + "1"), - InitialRefreshToken: transform(refreshToken + "1"), - InitialIDToken: transform(idToken + "1"), + InitialAccessToken: transform(accessToken, "1"), + InitialRefreshToken: transform(refreshToken, "1"), + InitialIDToken: transform(idToken, "1"), }, }}), }, @@ -537,6 +539,34 @@ func TestHandler(t *testing.T) { } }) + t.Run("case=should not fail on empty tokens", func(t *testing.T) { + id := createOidcIdentity(t, "foo.oidc.empty-tokens@bar.com", "", "", "", true) + for name, ts := range map[string]*httptest.Server{"public": publicTS, "admin": adminTS} { + t.Run("endpoint="+name, func(t *testing.T) { + res := get(t, ts, "/identities/"+id, http.StatusOK) + assert.False(t, res.Get("credentials.oidc.config").Exists(), "credentials config should be omitted: %s", res.Raw) + assert.False(t, res.Get("credentials.password.config").Exists(), "credentials config should be omitted: %s", res.Raw) + + res = get(t, ts, "/identities/"+id+"?include_credential=oidc", http.StatusOK) + assert.True(t, res.Get("credentials").Exists(), "credentials should be included: %s", res.Raw) + assert.True(t, res.Get("credentials.password").Exists(), "password meta should be included: %s", res.Raw) + assert.False(t, res.Get("credentials.password.false").Exists(), "password credentials should not be included: %s", res.Raw) + assert.True(t, res.Get("credentials.oidc.config").Exists(), "oidc credentials should be included: %s", res.Raw) + + assert.EqualValues(t, "foo", res.Get("credentials.oidc.config.providers.0.subject").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "bar", res.Get("credentials.oidc.config.providers.0.provider").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "access_token0", res.Get("credentials.oidc.config.providers.0.initial_access_token").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "refresh_token0", res.Get("credentials.oidc.config.providers.0.initial_refresh_token").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "id_token0", res.Get("credentials.oidc.config.providers.0.initial_id_token").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "baz", res.Get("credentials.oidc.config.providers.1.subject").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "zab", res.Get("credentials.oidc.config.providers.1.provider").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "access_token1", res.Get("credentials.oidc.config.providers.1.initial_access_token").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "refresh_token1", res.Get("credentials.oidc.config.providers.1.initial_refresh_token").String(), "credentials should be included: %s", res.Raw) + assert.EqualValues(t, "id_token1", res.Get("credentials.oidc.config.providers.1.initial_id_token").String(), "credentials should be included: %s", res.Raw) + }) + } + }) + t.Run("case=should get identity with credentials", func(t *testing.T) { i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) credentials := map[identity.CredentialsType]identity.Credentials{ diff --git a/identity/identity.go b/identity/identity.go index d763e688a8e3..63839a8fad05 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -441,49 +441,48 @@ func (i *Identity) WithDeclassifiedCredentials(ctx context.Context, c cipher.Pro toPublish := original toPublish.Config = []byte{} - for _, token := range []string{"initial_id_token", "initial_access_token", "initial_refresh_token"} { - var i int - var err error - gjson.GetBytes(original.Config, "providers").ForEach(func(_, v gjson.Result) bool { + var i int + var err error + gjson.GetBytes(original.Config, "providers").ForEach(func(_, v gjson.Result) bool { + for _, token := range []string{"initial_id_token", "initial_access_token", "initial_refresh_token"} { key := fmt.Sprintf("%d.%s", i, token) ciphertext := v.Get(token).String() var plaintext []byte - plaintext, err = c.Cipher(ctx).Decrypt(ctx, ciphertext) + plaintext, err := c.Cipher(ctx).Decrypt(ctx, ciphertext) if err != nil { - return false + plaintext = []byte("") } - toPublish.Config, err = sjson.SetBytes(toPublish.Config, "providers."+key, string(plaintext)) if err != nil { return false } + } - toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.subject", i), v.Get("subject").String()) - if err != nil { - return false - } - - toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.provider", i), v.Get("provider").String()) - if err != nil { - return false - } - - toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.organization", i), v.Get("organization").String()) - if err != nil { - return false - } + toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.subject", i), v.Get("subject").String()) + if err != nil { + return false + } - i++ - return true - }) + toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.provider", i), v.Get("provider").String()) + if err != nil { + return false + } + toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.organization", i), v.Get("organization").String()) if err != nil { - return nil, err + return false } - credsToPublish[ct] = toPublish + i++ + return true + }) + + if err != nil { + return nil, err } + + credsToPublish[ct] = toPublish default: credsToPublish[ct] = original } diff --git a/identity/identity_test.go b/identity/identity_test.go index b20ee23e1652..726011fd00eb 100644 --- a/identity/identity_test.go +++ b/identity/identity_test.go @@ -5,12 +5,14 @@ package identity import ( "bytes" + "context" "encoding/json" "fmt" "testing" "github.com/ory/x/snapshotx" + "github.com/ory/kratos/cipher" "github.com/ory/kratos/x" "github.com/stretchr/testify/require" @@ -314,6 +316,12 @@ func TestVerifiableAddresses(t *testing.T) { assert.Equal(t, addresses, CollectVerifiableAddresses([]*Identity{id1, id2, id3})) } +type cipherProvider struct{} + +func (c *cipherProvider) Cipher(ctx context.Context) cipher.Cipher { + return cipher.NewNoop(nil) +} + func TestWithDeclassifiedCredentials(t *testing.T) { i := NewIdentity(config.DefaultIdentityTraitsSchemaID) credentials := map[CredentialsType]Credentials{ @@ -325,7 +333,7 @@ func TestWithDeclassifiedCredentials(t *testing.T) { CredentialsTypeOIDC: { Type: CredentialsTypeOIDC, Identifiers: []string{"bar", "baz"}, - Config: sqlxx.JSONRawMessage("{\"some\" : \"secret\"}"), + Config: sqlxx.JSONRawMessage(`{"providers": [{"initial_id_token": "666f6f"}]}`), }, CredentialsTypeWebAuthn: { Type: CredentialsTypeWebAuthn, @@ -336,7 +344,7 @@ func TestWithDeclassifiedCredentials(t *testing.T) { i.Credentials = credentials t.Run("case=no-include", func(t *testing.T) { - actualIdentity, err := i.WithDeclassifiedCredentials(ctx, nil, nil) + actualIdentity, err := i.WithDeclassifiedCredentials(ctx, &cipherProvider{}, nil) require.NoError(t, err) for ct, actual := range actualIdentity.Credentials { @@ -347,7 +355,7 @@ func TestWithDeclassifiedCredentials(t *testing.T) { }) t.Run("case=include-webauthn", func(t *testing.T) { - actualIdentity, err := i.WithDeclassifiedCredentials(ctx, nil, []CredentialsType{CredentialsTypeWebAuthn}) + actualIdentity, err := i.WithDeclassifiedCredentials(ctx, &cipherProvider{}, []CredentialsType{CredentialsTypeWebAuthn}) require.NoError(t, err) for ct, actual := range actualIdentity.Credentials { @@ -358,7 +366,18 @@ func TestWithDeclassifiedCredentials(t *testing.T) { }) t.Run("case=include-multi", func(t *testing.T) { - actualIdentity, err := i.WithDeclassifiedCredentials(ctx, nil, []CredentialsType{CredentialsTypeWebAuthn, CredentialsTypePassword}) + actualIdentity, err := i.WithDeclassifiedCredentials(ctx, &cipherProvider{}, []CredentialsType{CredentialsTypeWebAuthn, CredentialsTypePassword}) + require.NoError(t, err) + + for ct, actual := range actualIdentity.Credentials { + t.Run("credential="+string(ct), func(t *testing.T) { + snapshotx.SnapshotT(t, actual) + }) + } + }) + + t.Run("case=oidc", func(t *testing.T) { + actualIdentity, err := i.WithDeclassifiedCredentials(ctx, &cipherProvider{}, []CredentialsType{CredentialsTypeOIDC}) require.NoError(t, err) for ct, actual := range actualIdentity.Credentials {