diff --git a/driver/config/config.go b/driver/config/config.go index 0d755a11ba63..2b0380677dd8 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -162,6 +162,7 @@ const ( ViperKeySelfServiceVerificationNotifyUnknownRecipients = "selfservice.flows.verification.notify_unknown_recipients" ViperKeyDefaultIdentitySchemaID = "identity.default_schema_id" ViperKeyIdentitySchemas = "identity.schemas" + ViperKeyIdentitySchemaFallbackTemplate = "identity.schema_fallback_template" ViperKeyHasherAlgorithm = "hashers.algorithm" ViperKeyHasherArgon2ConfigMemory = "hashers.argon2.memory" ViperKeyHasherArgon2ConfigIterations = "hashers.argon2.iterations" @@ -356,6 +357,8 @@ func HookStrategyKey(key, strategy string) string { } } +var ErrSchemaNotFound = errors.New("schema not found") + func (s Schemas) FindSchemaByID(id string) (*Schema, error) { for _, sc := range s { if sc.ID == id { @@ -363,7 +366,7 @@ func (s Schemas) FindSchemaByID(id string) (*Schema, error) { } } - return nil, errors.Errorf("unable to find identity schema with id: %s", id) + return nil, errors.Wrapf(ErrSchemaNotFound, "schema with id %s not found", id) } func MustNew(t testing.TB, l *logrusx.Logger, stdOutOrErr io.Writer, opts ...configx.OptionModifier) *Config { @@ -467,7 +470,7 @@ func (p *Config) validateIdentitySchemas(ctx context.Context) error { return err } - ss, err := p.IdentityTraitsSchemas(ctx) + ss, err := p.ConfiguredSchemas(ctx) if err != nil { return err } @@ -571,7 +574,7 @@ func (p *Config) listenOn(ctx context.Context, key string) string { } func (p *Config) DefaultIdentityTraitsSchemaURL(ctx context.Context) (*url.URL, error) { - ss, err := p.IdentityTraitsSchemas(ctx) + ss, err := p.ConfiguredSchemas(ctx) if err != nil { return nil, err } @@ -597,7 +600,10 @@ func (p *Config) OIDCRedirectURIBase(ctx context.Context) *url.URL { return p.GetProvider(ctx).URIF(ViperKeyOIDCBaseRedirectURL, p.SelfPublicURL(ctx)) } -func (p *Config) IdentityTraitsSchemas(ctx context.Context) (ss Schemas, err error) { +// ConfiguredSchemas returns the identity traits schemas. +// +// Please use driver.Registry.IdentityTraitsSchemas() instead. +func (p *Config) ConfiguredSchemas(ctx context.Context) (ss Schemas, err error) { if err = p.GetProvider(ctx).Koanf.Unmarshal(ViperKeyIdentitySchemas, &ss); err != nil { return ss, nil } @@ -605,6 +611,10 @@ func (p *Config) IdentityTraitsSchemas(ctx context.Context) (ss Schemas, err err return ss, nil } +func (p *Config) IdentityTraitsSchemaFallback(ctx context.Context) (fallback string) { + return p.GetProvider(ctx).StringF(ViperKeyIdentitySchemaFallbackTemplate, "") +} + func (p *Config) AdminListenOn(ctx context.Context) string { return p.listenOn(ctx, "admin") } diff --git a/driver/config/config_test.go b/driver/config/config_test.go index 6cb37f100850..c72d4cd9c3b8 100644 --- a/driver/config/config_test.go +++ b/driver/config/config_test.go @@ -168,7 +168,7 @@ func TestViperProvider(t *testing.T) { require.NoError(t, err) assert.Equal(t, "http://test.kratos.ory.sh/default-identity.schema.json", ds.String()) - ss, err := c.IdentityTraitsSchemas(ctx) + ss, err := c.ConfiguredSchemas(ctx) require.NoError(t, err) assert.Equal(t, 2, len(ss)) diff --git a/driver/registry_default_schemas.go b/driver/registry_default_schemas.go index 9f68fbd2d86e..45ea6e028d0d 100644 --- a/driver/registry_default_schemas.go +++ b/driver/registry_default_schemas.go @@ -13,7 +13,7 @@ import ( ) func (m *RegistryDefault) IdentityTraitsSchemas(ctx context.Context) (schema.Schemas, error) { - ms, err := m.Config().IdentityTraitsSchemas(ctx) + ms, err := m.Config().ConfiguredSchemas(ctx) if err != nil { return nil, err } diff --git a/embedx/config.schema.json b/embedx/config.schema.json index 0a94c54dbb23..4c418e0f91fb 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -2502,6 +2502,10 @@ }, "required": ["id", "url"] } + }, + "schema_fallback_template": { + "title": "Identity Schema Fallback Template", + "type": "string" } }, "required": ["schemas"], diff --git a/identity/validator.go b/identity/validator.go index 3bd8f9476fa5..99bd26208348 100644 --- a/identity/validator.go +++ b/identity/validator.go @@ -46,7 +46,7 @@ func (v *Validator) ValidateWithRunner(ctx context.Context, i *Identity, runners return err } - s, err := ss.GetByID(i.SchemaID) + s, err := ss.GetByID(i.SchemaID, v.d.Config().IdentityTraitsSchemaFallback(ctx)) if err != nil { return err } diff --git a/internal/testhelpers/config.go b/internal/testhelpers/config.go index 2a24709c0745..4efae159b712 100644 --- a/internal/testhelpers/config.go +++ b/internal/testhelpers/config.go @@ -42,7 +42,7 @@ func SetDefaultIdentitySchema(conf *config.Config, url string) func() { // It also registeres a test cleanup function, to reset the schemas to the original values, after the test finishes func UseIdentitySchema(t *testing.T, conf *config.Config, url string) (id string) { id = randx.MustString(16, randx.Alpha) - schemas, err := conf.IdentityTraitsSchemas(context.Background()) + schemas, err := conf.ConfiguredSchemas(context.Background()) require.NoError(t, err) conf.MustSet(context.Background(), config.ViperKeyIdentitySchemas, append(schemas, config.Schema{ ID: id, diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 52daaf2fc71b..9b8c12c15378 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -922,17 +922,11 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. return nil, nil, err } - schemaCache := map[string]string{} for k := range is { i := &is[k] - if u, ok := schemaCache[i.SchemaID]; ok { - i.SchemaURL = u - } else { - if err := p.InjectTraitsSchemaURL(ctx, i); err != nil { - return nil, nil, err - } - schemaCache[i.SchemaID] = i.SchemaURL + if err := p.InjectTraitsSchemaURL(ctx, i); err != nil { + return nil, nil, err } if err := i.Validate(); err != nil { @@ -1133,20 +1127,9 @@ func (p *IdentityPersister) validateIdentity(ctx context.Context, i *identity.Id return nil } +// InjectTraitsSchemaURL sets the schema URL on the identity. The schema URL is the one hosted by Ory Kratos, and not the actual +// schema URL, which might very well be a base64 encoded string, or an internal URL. The Schema URL is not used internally, but only exposed when the identity is sent over the REST API. func (p *IdentityPersister) InjectTraitsSchemaURL(ctx context.Context, i *identity.Identity) (err error) { - // This trace is more noisy than it's worth in diagnostic power. - // ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InjectTraitsSchemaURL") - // defer otelx.End(span, &err) - - ss, err := p.r.IdentityTraitsSchemas(ctx) - if err != nil { - return err - } - s, err := ss.GetByID(i.SchemaID) - if err != nil { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf( - `The JSON Schema "%s" for this identity's traits could not be found.`, i.SchemaID)) - } - i.SchemaURL = s.SchemaURL(p.r.Config().SelfPublicURL(ctx)).String() + i.SchemaURL = schema.IDToURL(p.r.Config().SelfPublicURL(ctx), i.SchemaID).String() return nil } diff --git a/schema/handler.go b/schema/handler.go index e154ae75d699..4abfd78c0d83 100644 --- a/schema/handler.go +++ b/schema/handler.go @@ -115,14 +115,16 @@ func (h *Handler) getIdentitySchema(w http.ResponseWriter, r *http.Request, ps h } id := ps.ByName("id") - s, err := ss.GetByID(id) + + // The first attempt should use no fallback template. + s, err := ss.GetByID(id, "") if err != nil { - // Maybe it is a base64 encoded ID? + // schema.SchemaURL base64 encodes the schema ID in order for it to work with httprouter if the schema ID contains slashes. if dec, err := base64.RawURLEncoding.DecodeString(id); err == nil { id = string(dec) } - s, err = ss.GetByID(id) + s, err = ss.GetByID(id, h.r.Config().IdentityTraitsSchemaFallback(ctx)) if err != nil { h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrNotFound.WithReasonf("Identity schema `%s` could not be found.", id))) return diff --git a/schema/handler_test.go b/schema/handler_test.go index a5bdf06a5a83..d2a625eda32c 100644 --- a/schema/handler_test.go +++ b/schema/handler_test.go @@ -76,7 +76,7 @@ func TestHandler(t *testing.T) { } getSchemaById := func(id string) *schema.Schema { - s, err := schemas.GetByID(id) + s, err := schemas.GetByID(id, "") require.NoError(t, err) return s } diff --git a/schema/schema.go b/schema/schema.go index 69b6bbca7332..7c84e88f0a9b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -6,6 +6,7 @@ package schema import ( "context" "encoding/base64" + "fmt" "io" "net/url" "strings" @@ -26,7 +27,7 @@ type IdentityTraitsProvider interface { IdentityTraitsSchemas(ctx context.Context) (Schemas, error) } -func (s Schemas) GetByID(id string) (*Schema, error) { +func (s Schemas) GetByID(id string, fallbackTemplate string) (*Schema, error) { if id == "" { id = config.DefaultIdentityTraitsSchemaID } @@ -37,6 +38,21 @@ func (s Schemas) GetByID(id string) (*Schema, error) { } } + if fallbackTemplate != "" { + source := fmt.Sprintf(fallbackTemplate, id) + + parsedURL, err := url.Parse(source) + if err != nil { + return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to parse identity schema fallback tempalte. Please contact the site's administrator.")) + } + + return &Schema{ + ID: id, + URL: parsedURL, + RawURL: fmt.Sprintf(fallbackTemplate, id), + }, nil + } + return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Unable to find JSON Schema ID: %s", id)) } @@ -98,11 +114,16 @@ func GetKeysInOrder(ctx context.Context, schemaRef string) ([]string, error) { } type Schema struct { - ID string `json:"id"` - URL *url.URL `json:"-"` - RawURL string `json:"url"` + ID string `json:"id"` + URL *url.URL `json:"-"` + // RawURL contains the raw URL value as it was passed in the configuration. URL parsing can break base64 encoded URLs. + RawURL string `json:"url"` } func (s *Schema) SchemaURL(host *url.URL) *url.URL { - return urlx.AppendPaths(host, SchemasPath, base64.RawURLEncoding.EncodeToString([]byte(s.ID))) + return IDToURL(host, s.ID) +} + +func IDToURL(host *url.URL, id string) *url.URL { + return urlx.AppendPaths(host, SchemasPath, base64.RawURLEncoding.EncodeToString([]byte(id))) } diff --git a/schema/schema_test.go b/schema/schema_test.go index bb77fdd55a2e..71432676bb31 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -40,37 +40,47 @@ func TestSchemas_GetByID(t *testing.T) { } t.Run("case=get first schema", func(t *testing.T) { - s, err := ss.GetByID("foo") + s, err := ss.GetByID("foo", "") require.NoError(t, err) assert.Equal(t, &ss[0], s) }) t.Run("case=get second schema", func(t *testing.T) { - s, err := ss.GetByID("bar") + s, err := ss.GetByID("bar", "") require.NoError(t, err) assert.Equal(t, &ss[1], s) }) t.Run("case=get third schema", func(t *testing.T) { - s, err := ss.GetByID("foobar") + s, err := ss.GetByID("foobar", "") require.NoError(t, err) assert.Equal(t, &ss[2], s) }) t.Run("case=get default schema", func(t *testing.T) { - s1, err := ss.GetByID("") + s1, err := ss.GetByID("", "") require.NoError(t, err) - s2, err := ss.GetByID(config.DefaultIdentityTraitsSchemaID) + s2, err := ss.GetByID(config.DefaultIdentityTraitsSchemaID, "") require.NoError(t, err) assert.Equal(t, &ss[3], s1) assert.Equal(t, &ss[3], s2) }) t.Run("case=should return error on not existing id", func(t *testing.T) { - s, err := ss.GetByID("not existing id") + s, err := ss.GetByID("not existing id", "") require.Error(t, err) assert.Equal(t, (*Schema)(nil), s) }) + + t.Run("case=should return fallback", func(t *testing.T) { + s, err := ss.GetByID("fallback", "https://%s/") + require.NoError(t, err) + assert.Equal(t, &Schema{ + ID: "fallback", + URL: urlx.ParseOrPanic("https://fallback/"), + RawURL: "https://fallback/", + }, s) + }) } func TestSchemas_List(t *testing.T) { diff --git a/selfservice/flow/settings/error.go b/selfservice/flow/settings/error.go index 2fd190c888e9..96112c1068ba 100644 --- a/selfservice/flow/settings/error.go +++ b/selfservice/flow/settings/error.go @@ -221,7 +221,7 @@ func (s *ErrorHandler) WriteFlowError( return } - schema, err := schemas.GetByID(id.SchemaID) + schema, err := schemas.GetByID(id.SchemaID, s.d.Config().IdentityTraitsSchemaFallback(r.Context())) if err != nil { s.forward(w, r, f, err) return diff --git a/selfservice/strategy/code/strategy.go b/selfservice/strategy/code/strategy.go index fd8993447744..ffba87edf57d 100644 --- a/selfservice/strategy/code/strategy.go +++ b/selfservice/strategy/code/strategy.go @@ -213,7 +213,7 @@ func (s *Strategy) populateChooseMethodFlow(r *http.Request, f flow.Flow) error if err != nil { return err } - iSchema, err := allSchemas.GetByID(sess.Identity.SchemaID) + iSchema, err := allSchemas.GetByID(sess.Identity.SchemaID, s.deps.Config().IdentityTraitsSchemaFallback(r.Context())) if err != nil { return err } diff --git a/selfservice/strategy/profile/strategy.go b/selfservice/strategy/profile/strategy.go index 0942d738840e..095c29d4da84 100644 --- a/selfservice/strategy/profile/strategy.go +++ b/selfservice/strategy/profile/strategy.go @@ -81,19 +81,19 @@ func (s *Strategy) SettingsStrategyID() string { func (s *Strategy) RegisterSettingsRoutes(public *x.RouterPublic) {} func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity, f *settings.Flow) error { - schemas, err := s.d.Config().IdentityTraitsSchemas(r.Context()) + schemas, err := s.d.IdentityTraitsSchemas(r.Context()) if err != nil { return err } - traitsSchema, err := schemas.FindSchemaByID(id.SchemaID) + traitsSchema, err := schemas.GetByID(id.SchemaID, s.d.Config().IdentityTraitsSchemaFallback(r.Context())) if err != nil { return err } // use a schema compiler that disables identifiers schemaCompiler := jsonschema.NewCompiler() - nodes, err := container.NodesFromJSONSchema(r.Context(), node.ProfileGroup, traitsSchema.URL, "", schemaCompiler) + nodes, err := container.NodesFromJSONSchema(r.Context(), node.ProfileGroup, traitsSchema.URL.String(), "", schemaCompiler) if err != nil { return err } @@ -270,7 +270,7 @@ func (s *Strategy) newSettingsProfileDecoder(ctx context.Context, i *identity.Id if err != nil { return nil, err } - ss, err := schemas.GetByID(i.SchemaID) + ss, err := schemas.GetByID(i.SchemaID, s.d.Config().IdentityTraitsSchemaFallback(ctx)) if err != nil { return nil, err }