Skip to content

Commit

Permalink
fix: identity schema fallback URL
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Apr 30, 2024
1 parent 3ecdf2b commit 6d32f5e
Show file tree
Hide file tree
Showing 14 changed files with 81 additions and 51 deletions.
18 changes: 14 additions & 4 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ const (
ViperKeySelfServiceVerificationNotifyUnknownRecipients = "selfservice.flows.verification.notify_unknown_recipients"
ViperKeyDefaultIdentitySchemaID = "identity.default_schema_id"
ViperKeyIdentitySchemas = "identity.schemas"
ViperKeyIdentitySchemaFallbackTemplate = "identity.schema_fallback_template"
ViperKeyHasherAlgorithm = "hashers.algorithm"
ViperKeyHasherArgon2ConfigMemory = "hashers.argon2.memory"
ViperKeyHasherArgon2ConfigIterations = "hashers.argon2.iterations"
Expand Down Expand Up @@ -356,14 +357,16 @@ func HookStrategyKey(key, strategy string) string {
}
}

var ErrSchemaNotFound = errors.New("schema not found")

func (s Schemas) FindSchemaByID(id string) (*Schema, error) {
for _, sc := range s {
if sc.ID == id {
return &sc, nil
}
}

return nil, errors.Errorf("unable to find identity schema with id: %s", id)
return nil, errors.Wrapf(ErrSchemaNotFound, "schema with id %s not found", id)
}

func MustNew(t testing.TB, l *logrusx.Logger, stdOutOrErr io.Writer, opts ...configx.OptionModifier) *Config {
Expand Down Expand Up @@ -467,7 +470,7 @@ func (p *Config) validateIdentitySchemas(ctx context.Context) error {
return err
}

ss, err := p.IdentityTraitsSchemas(ctx)
ss, err := p.ConfiguredSchemas(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -571,7 +574,7 @@ func (p *Config) listenOn(ctx context.Context, key string) string {
}

func (p *Config) DefaultIdentityTraitsSchemaURL(ctx context.Context) (*url.URL, error) {
ss, err := p.IdentityTraitsSchemas(ctx)
ss, err := p.ConfiguredSchemas(ctx)
if err != nil {
return nil, err
}
Expand All @@ -597,14 +600,21 @@ func (p *Config) OIDCRedirectURIBase(ctx context.Context) *url.URL {
return p.GetProvider(ctx).URIF(ViperKeyOIDCBaseRedirectURL, p.SelfPublicURL(ctx))
}

func (p *Config) IdentityTraitsSchemas(ctx context.Context) (ss Schemas, err error) {
// ConfiguredSchemas returns the identity traits schemas.
//
// Please use driver.Registry.IdentityTraitsSchemas() instead.
func (p *Config) ConfiguredSchemas(ctx context.Context) (ss Schemas, err error) {
if err = p.GetProvider(ctx).Koanf.Unmarshal(ViperKeyIdentitySchemas, &ss); err != nil {
return ss, nil
}

return ss, nil
}

func (p *Config) IdentityTraitsSchemaFallback(ctx context.Context) (fallback string) {
return p.GetProvider(ctx).StringF(ViperKeyIdentitySchemaFallbackTemplate, "")
}

func (p *Config) AdminListenOn(ctx context.Context) string {
return p.listenOn(ctx, "admin")
}
Expand Down
2 changes: 1 addition & 1 deletion driver/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func TestViperProvider(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "http://test.kratos.ory.sh/default-identity.schema.json", ds.String())

ss, err := c.IdentityTraitsSchemas(ctx)
ss, err := c.ConfiguredSchemas(ctx)
require.NoError(t, err)
assert.Equal(t, 2, len(ss))

Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default_schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func (m *RegistryDefault) IdentityTraitsSchemas(ctx context.Context) (schema.Schemas, error) {
ms, err := m.Config().IdentityTraitsSchemas(ctx)
ms, err := m.Config().ConfiguredSchemas(ctx)
if err != nil {
return nil, err
}
Expand Down
4 changes: 4 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2502,6 +2502,10 @@
},
"required": ["id", "url"]
}
},
"schema_fallback_template": {
"title": "Identity Schema Fallback Template",
"type": "string"
}
},
"required": ["schemas"],
Expand Down
2 changes: 1 addition & 1 deletion identity/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (v *Validator) ValidateWithRunner(ctx context.Context, i *Identity, runners
return err
}

s, err := ss.GetByID(i.SchemaID)
s, err := ss.GetByID(i.SchemaID, v.d.Config().IdentityTraitsSchemaFallback(ctx))
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/testhelpers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func SetDefaultIdentitySchema(conf *config.Config, url string) func() {
// It also registeres a test cleanup function, to reset the schemas to the original values, after the test finishes
func UseIdentitySchema(t *testing.T, conf *config.Config, url string) (id string) {
id = randx.MustString(16, randx.Alpha)
schemas, err := conf.IdentityTraitsSchemas(context.Background())
schemas, err := conf.ConfiguredSchemas(context.Background())
require.NoError(t, err)
conf.MustSet(context.Background(), config.ViperKeyIdentitySchemas, append(schemas, config.Schema{
ID: id,
Expand Down
27 changes: 5 additions & 22 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -922,17 +922,11 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.
return nil, nil, err
}

schemaCache := map[string]string{}
for k := range is {
i := &is[k]

if u, ok := schemaCache[i.SchemaID]; ok {
i.SchemaURL = u
} else {
if err := p.InjectTraitsSchemaURL(ctx, i); err != nil {
return nil, nil, err
}
schemaCache[i.SchemaID] = i.SchemaURL
if err := p.InjectTraitsSchemaURL(ctx, i); err != nil {
return nil, nil, err

Check warning on line 929 in persistence/sql/identity/persister_identity.go

View check run for this annotation

Codecov / codecov/patch

persistence/sql/identity/persister_identity.go#L929

Added line #L929 was not covered by tests
}

if err := i.Validate(); err != nil {
Expand Down Expand Up @@ -1133,20 +1127,9 @@ func (p *IdentityPersister) validateIdentity(ctx context.Context, i *identity.Id
return nil
}

// InjectTraitsSchemaURL sets the schema URL on the identity. The schema URL is the one hosted by Ory Kratos, and not the actual
// schema URL, which might very well be a base64 encoded string, or an internal URL. The Schema URL is not used internally, but only exposed when the identity is sent over the REST API.
func (p *IdentityPersister) InjectTraitsSchemaURL(ctx context.Context, i *identity.Identity) (err error) {
// This trace is more noisy than it's worth in diagnostic power.
// ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InjectTraitsSchemaURL")
// defer otelx.End(span, &err)

ss, err := p.r.IdentityTraitsSchemas(ctx)
if err != nil {
return err
}
s, err := ss.GetByID(i.SchemaID)
if err != nil {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf(
`The JSON Schema "%s" for this identity's traits could not be found.`, i.SchemaID))
}
i.SchemaURL = s.SchemaURL(p.r.Config().SelfPublicURL(ctx)).String()
i.SchemaURL = schema.IDToURL(p.r.Config().SelfPublicURL(ctx), i.SchemaID).String()
return nil
}
8 changes: 5 additions & 3 deletions schema/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,16 @@ func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request, ps h
}

id := ps.ByName("id")
s, err := ss.GetByID(id)

// The first attempt should use no fallback template.
s, err := ss.GetByID(id, "")
if err != nil {
// Maybe it is a base64 encoded ID?
// schema.SchemaURL base64 encodes the schema ID in order for it to work with httprouter if the schema ID contains slashes.
if dec, err := base64.RawURLEncoding.DecodeString(id); err == nil {
id = string(dec)
}

s, err = ss.GetByID(id)
s, err = ss.GetByID(id, h.r.Config().IdentityTraitsSchemaFallback(ctx))
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("Identity schema `%s` could not be found.", id)))
return
Expand Down
2 changes: 1 addition & 1 deletion schema/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestHandler(t *testing.T) {
}

getSchemaById := func(id string) *schema.Schema {
s, err := schemas.GetByID(id)
s, err := schemas.GetByID(id, "")
require.NoError(t, err)
return s
}
Expand Down
31 changes: 26 additions & 5 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package schema
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/url"
"strings"
Expand All @@ -26,7 +27,7 @@ type IdentityTraitsProvider interface {
IdentityTraitsSchemas(ctx context.Context) (Schemas, error)
}

func (s Schemas) GetByID(id string) (*Schema, error) {
func (s Schemas) GetByID(id string, fallbackTemplate string) (*Schema, error) {
if id == "" {
id = config.DefaultIdentityTraitsSchemaID
}
Expand All @@ -37,6 +38,21 @@ func (s Schemas) GetByID(id string) (*Schema, error) {
}
}

if fallbackTemplate != "" {
source := fmt.Sprintf(fallbackTemplate, id)

parsedURL, err := url.Parse(source)
if err != nil {
return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to parse identity schema fallback tempalte. Please contact the site's administrator."))

Check warning on line 46 in schema/schema.go

View check run for this annotation

Codecov / codecov/patch

schema/schema.go#L46

Added line #L46 was not covered by tests
}

return &Schema{
ID: id,
URL: parsedURL,
RawURL: fmt.Sprintf(fallbackTemplate, id),
}, nil
}

return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to find JSON Schema ID: %s", id))
}

Expand Down Expand Up @@ -98,11 +114,16 @@ func GetKeysInOrder(ctx context.Context, schemaRef string) ([]string, error) {
}

type Schema struct {
ID string `json:"id"`
URL *url.URL `json:"-"`
RawURL string `json:"url"`
ID string `json:"id"`
URL *url.URL `json:"-"`
// RawURL contains the raw URL value as it was passed in the configuration. URL parsing can break base64 encoded URLs.
RawURL string `json:"url"`
}

func (s *Schema) SchemaURL(host *url.URL) *url.URL {
return urlx.AppendPaths(host, SchemasPath, base64.RawURLEncoding.EncodeToString([]byte(s.ID)))
return IDToURL(host, s.ID)
}

func IDToURL(host *url.URL, id string) *url.URL {
return urlx.AppendPaths(host, SchemasPath, base64.RawURLEncoding.EncodeToString([]byte(id)))
}
22 changes: 16 additions & 6 deletions schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,47 @@ func TestSchemas_GetByID(t *testing.T) {
}

t.Run("case=get first schema", func(t *testing.T) {
s, err := ss.GetByID("foo")
s, err := ss.GetByID("foo", "")
require.NoError(t, err)
assert.Equal(t, &ss[0], s)
})

t.Run("case=get second schema", func(t *testing.T) {
s, err := ss.GetByID("bar")
s, err := ss.GetByID("bar", "")
require.NoError(t, err)
assert.Equal(t, &ss[1], s)
})

t.Run("case=get third schema", func(t *testing.T) {
s, err := ss.GetByID("foobar")
s, err := ss.GetByID("foobar", "")
require.NoError(t, err)
assert.Equal(t, &ss[2], s)
})

t.Run("case=get default schema", func(t *testing.T) {
s1, err := ss.GetByID("")
s1, err := ss.GetByID("", "")
require.NoError(t, err)
s2, err := ss.GetByID(config.DefaultIdentityTraitsSchemaID)
s2, err := ss.GetByID(config.DefaultIdentityTraitsSchemaID, "")
require.NoError(t, err)
assert.Equal(t, &ss[3], s1)
assert.Equal(t, &ss[3], s2)
})

t.Run("case=should return error on not existing id", func(t *testing.T) {
s, err := ss.GetByID("not existing id")
s, err := ss.GetByID("not existing id", "")
require.Error(t, err)
assert.Equal(t, (*Schema)(nil), s)
})

t.Run("case=should return fallback", func(t *testing.T) {
s, err := ss.GetByID("fallback", "https://%s/")
require.NoError(t, err)
assert.Equal(t, &Schema{
ID: "fallback",
URL: urlx.ParseOrPanic("https://fallback/"),
RawURL: "https://fallback/",
}, s)
})
}

func TestSchemas_List(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/settings/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (s *ErrorHandler) WriteFlowError(
return
}

schema, err := schemas.GetByID(id.SchemaID)
schema, err := schemas.GetByID(id.SchemaID, s.d.Config().IdentityTraitsSchemaFallback(r.Context()))
if err != nil {
s.forward(w, r, f, err)
return
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/code/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func (s *Strategy) populateChooseMethodFlow(r *http.Request, f flow.Flow) error
if err != nil {
return err
}
iSchema, err := allSchemas.GetByID(sess.Identity.SchemaID)
iSchema, err := allSchemas.GetByID(sess.Identity.SchemaID, s.deps.Config().IdentityTraitsSchemaFallback(r.Context()))
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions selfservice/strategy/profile/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ func (s *Strategy) SettingsStrategyID() string {
func (s *Strategy) RegisterSettingsRoutes(public *x.RouterPublic) {}

func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity, f *settings.Flow) error {
schemas, err := s.d.Config().IdentityTraitsSchemas(r.Context())
schemas, err := s.d.IdentityTraitsSchemas(r.Context())
if err != nil {
return err
}

traitsSchema, err := schemas.FindSchemaByID(id.SchemaID)
traitsSchema, err := schemas.GetByID(id.SchemaID, s.d.Config().IdentityTraitsSchemaFallback(r.Context()))
if err != nil {
return err
}

// use a schema compiler that disables identifiers
schemaCompiler := jsonschema.NewCompiler()
nodes, err := container.NodesFromJSONSchema(r.Context(), node.ProfileGroup, traitsSchema.URL, "", schemaCompiler)
nodes, err := container.NodesFromJSONSchema(r.Context(), node.ProfileGroup, traitsSchema.URL.String(), "", schemaCompiler)
if err != nil {
return err
}
Expand Down Expand Up @@ -270,7 +270,7 @@ func (s *Strategy) newSettingsProfileDecoder(ctx context.Context, i *identity.Id
if err != nil {
return nil, err
}
ss, err := schemas.GetByID(i.SchemaID)
ss, err := schemas.GetByID(i.SchemaID, s.d.Config().IdentityTraitsSchemaFallback(ctx))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 6d32f5e

Please sign in to comment.