diff --git a/go.mod b/go.mod index 40e47ac..c4e6b68 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( github.com/fxamacker/cbor/v2 v2.5.0 - go.arcalot.io/assert v1.7.0 + go.arcalot.io/assert v1.8.0 go.arcalot.io/log/v2 v2.1.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index b261836..28ca8e4 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADi github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -go.arcalot.io/assert v1.7.0 h1:PTLyeisNMUKpM9wXRDxResanBhuGOYO1xFK3v5b3FSw= -go.arcalot.io/assert v1.7.0/go.mod h1:nNmWPoNUHFyrPkNrD2aASm5yPuAfiWdB/4X7Lw3ykHk= +go.arcalot.io/assert v1.8.0 h1:hGcHMPncQXwQvjj7MbyOu2gg8VIBB00crUJZpeQOjxs= +go.arcalot.io/assert v1.8.0/go.mod h1:nNmWPoNUHFyrPkNrD2aASm5yPuAfiWdB/4X7Lw3ykHk= go.arcalot.io/log/v2 v2.1.0 h1:lNO931hJ82LgS6WcCFCxpLWXQXPFhOkz6PyAJ/augq4= go.arcalot.io/log/v2 v2.1.0/go.mod h1:PNWOSkkPmgS2OMlWTIlB/WqOw0yaBvDYd8ENAP80H4k= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/schema/oneof.go b/schema/oneof.go index 9a71d66..24c201f 100644 --- a/schema/oneof.go +++ b/schema/oneof.go @@ -2,6 +2,7 @@ package schema import ( "fmt" + "maps" "reflect" "strings" ) @@ -18,6 +19,8 @@ type OneOfSchema[KeyType int64 | string] struct { interfaceType reflect.Type TypesValue map[KeyType]Object `json:"types"` DiscriminatorFieldNameValue string `json:"discriminator_field_name"` + // whether or not the discriminator is inlined in the underlying objects' schema + DiscriminatorInlined bool `json:"discriminator_inlined"` } func (o OneOfSchema[KeyType]) TypeID() TypeID { @@ -44,6 +47,11 @@ func (o OneOfSchema[KeyType]) ApplyScope(scope Scope) { for _, t := range o.TypesValue { t.ApplyScope(scope) } + // scope must be applied before we can access the subtypes' properties + err := o.validateSubtypeDiscriminatorInlineFields() + if err != nil { + panic(err) + } } func (o OneOfSchema[KeyType]) ReflectedType() reflect.Type { @@ -56,11 +64,14 @@ func (o OneOfSchema[KeyType]) ReflectedType() reflect.Type { //nolint:funlen func (o OneOfSchema[KeyType]) UnserializeType(data any) (result any, err error) { + if data == nil { + return nil, fmt.Errorf("bug: data is nil in OneOfSchema UnserializeType") + } reflectedValue := reflect.ValueOf(data) if reflectedValue.Kind() != reflect.Map { return result, &ConstraintError{ Message: fmt.Sprintf( - "Invalid type for one-of type: '%s'. Expected map.", + "Invalid type for one-of type: %q. Expected map.", reflect.TypeOf(data).Name(), ), } @@ -77,6 +88,7 @@ func (o OneOfSchema[KeyType]) UnserializeType(data any) (result any, err error) if err != nil { return result, err } + typedData := make(map[string]any, reflectedValue.Len()) for _, k := range reflectedValue.MapKeys() { v := reflectedValue.MapIndex(k) @@ -102,25 +114,24 @@ func (o OneOfSchema[KeyType]) UnserializeType(data any) (result any, err error) } return result, &ConstraintError{ Message: fmt.Sprintf( - "Invalid value for '%s', expected one of: %s", + "Invalid value for %q, expected one of: %s", o.DiscriminatorFieldNameValue, strings.Join(validDiscriminators, ", "), ), } } - if _, ok := selectedType.Properties()[o.DiscriminatorFieldNameValue]; !ok { - delete(typedData, o.DiscriminatorFieldNameValue) - } - - unserializedData, err := selectedType.Unserialize(typedData) + cloneData := o.deleteDiscriminator(typedData) + unserializedData, err := selectedType.Unserialize(cloneData) if err != nil { return result, err } - if o.interfaceType == nil { - return unserializedData, nil + unserializedMap, ok := unserializedData.(map[string]any) + if ok { + unserializedMap[o.DiscriminatorFieldNameValue] = discriminator + return unserializedMap, nil } - return saveConvertTo(unserializedData, o.interfaceType) + return saveConvertTo(unserializedData, o.ReflectedType()) } func (o OneOfSchema[KeyType]) ValidateType(data any) error { @@ -128,6 +139,10 @@ func (o OneOfSchema[KeyType]) ValidateType(data any) error { if err != nil { return err } + dataMap, ok := data.(map[string]any) + if ok { + data = o.deleteDiscriminator(dataMap) + } if err := underlyingType.Validate(data); err != nil { return ConstraintErrorAddPathSegment(err, fmt.Sprintf("{oneof[%v]}", discriminatorValue)) } @@ -139,6 +154,10 @@ func (o OneOfSchema[KeyType]) SerializeType(data any) (any, error) { if err != nil { return nil, err } + dataMap, ok := data.(map[string]any) + if ok { + data = o.deleteDiscriminator(dataMap) + } serializedData, err := underlyingType.Serialize(data) if err != nil { return nil, err @@ -162,7 +181,8 @@ func (o OneOfSchema[KeyType]) ValidateCompatibility(typeOrData any) error { // If not, verify it as data. inputAsMap, ok := typeOrData.(map[string]any) if ok { - return o.validateMap(inputAsMap) + _, _, err := o.validateMap(inputAsMap) + return err } value := reflect.ValueOf(typeOrData) if reflect.Indirect(value).Kind() != reflect.Struct { @@ -217,13 +237,14 @@ func (o OneOfSchema[KeyType]) validateSchema(otherSchema OneOfSchema[KeyType]) e return nil } -func (o OneOfSchema[KeyType]) validateMap(data map[string]any) error { +func (o OneOfSchema[KeyType]) validateMap(data map[string]any) (KeyType, Object, error) { + var nilKey KeyType // Validate that it has the discriminator field. // If it doesn't, fail // If it does, pass the non-discriminator fields into the ValidateCompatibility method for the object selectedTypeID := data[o.DiscriminatorFieldNameValue] if selectedTypeID == nil { - return &ConstraintError{ + return nilKey, nil, &ConstraintError{ Message: fmt.Sprintf( "validation failed for OneOfSchema. Discriminator field '%s' missing", o.DiscriminatorFieldNameValue), } @@ -231,7 +252,7 @@ func (o OneOfSchema[KeyType]) validateMap(data map[string]any) error { // Ensure it's the correct type selectedTypeIDAsserted, ok := selectedTypeID.(KeyType) if !ok { - return &ConstraintError{ + return nilKey, nil, &ConstraintError{ Message: fmt.Sprintf( "validation failed for OneOfSchema. Discriminator field '%v' has invalid type '%T'. Expected %T", o.DiscriminatorFieldNameValue, selectedTypeID, selectedTypeIDAsserted), @@ -240,24 +261,22 @@ func (o OneOfSchema[KeyType]) validateMap(data map[string]any) error { // Find the object that's associated with the selected type selectedSchema := o.TypesValue[selectedTypeIDAsserted] if selectedSchema == nil { - return &ConstraintError{ + return nilKey, nil, &ConstraintError{ Message: fmt.Sprintf( "validation failed for OneOfSchema. Discriminator value '%v' is invalid. Expected one of: %v", selectedTypeIDAsserted, o.getTypeValues()), } } - if selectedSchema.Properties()[o.DiscriminatorFieldNameValue] == nil { // Check to see if the discriminator is part of the sub-object. - delete(data, o.DiscriminatorFieldNameValue) // The discriminator isn't part of the object. - } - err := selectedSchema.ValidateCompatibility(data) + cloneData := o.deleteDiscriminator(data) + err := selectedSchema.ValidateCompatibility(cloneData) if err != nil { - return &ConstraintError{ + return nilKey, nil, &ConstraintError{ Message: fmt.Sprintf( "validation failed for OneOfSchema. Failed to validate as selected schema type '%T' from discriminator value '%v' (%s)", selectedSchema, selectedTypeIDAsserted, err), } } - return nil + return selectedTypeIDAsserted, selectedSchema, nil } func (o OneOfSchema[KeyType]) getTypeValues() []KeyType { @@ -271,10 +290,7 @@ func (o OneOfSchema[KeyType]) getTypeValues() []KeyType { } func (o OneOfSchema[KeyType]) Validate(data any) error { - if o.interfaceType == nil { - return o.ValidateType(data) - } - d, err := saveConvertTo(data, o.interfaceType) + d, err := saveConvertTo(data, o.ReflectedType()) if err != nil { return err } @@ -282,10 +298,7 @@ func (o OneOfSchema[KeyType]) Validate(data any) error { } func (o OneOfSchema[KeyType]) Serialize(data any) (result any, err error) { - if o.interfaceType == nil { - return nil, o.ValidateType(data) - } - d, err := saveConvertTo(data, o.interfaceType) + d, err := saveConvertTo(data, o.ReflectedType()) if err != nil { return nil, err } @@ -328,20 +341,29 @@ func (o OneOfSchema[KeyType]) getTypedDiscriminator(discriminator any) (KeyType, } func (o OneOfSchema[KeyType]) findUnderlyingType(data any) (KeyType, Object, error) { + var nilKey KeyType + reflectedType := reflect.TypeOf(data) if reflectedType.Kind() != reflect.Struct && reflectedType.Kind() != reflect.Map && (reflectedType.Kind() != reflect.Pointer || reflectedType.Elem().Kind() != reflect.Struct) { - var defaultValue KeyType - return defaultValue, nil, &ConstraintError{ + + return nilKey, nil, &ConstraintError{ Message: fmt.Sprintf( - "Invalid type for one-of type: '%s' expected struct or map.", + "Invalid type for one-of type: %q expected struct or map.", reflect.TypeOf(data).Name(), ), } } var foundKey *KeyType + if reflectedType.Kind() == reflect.Map { + myKey, mySchemaObj, err := o.validateMap(data.(map[string]any)) + if err != nil { + return nilKey, nil, err + } + return myKey, mySchemaObj, nil + } for key, ref := range o.TypesValue { underlyingReflectedType := ref.ReflectedType() if underlyingReflectedType == reflectedType { @@ -350,7 +372,6 @@ func (o OneOfSchema[KeyType]) findUnderlyingType(data any) (KeyType, Object, err } } if foundKey == nil { - var defaultValue KeyType dataType := reflect.TypeOf(data) values := make([]string, len(o.TypesValue)) i := 0 @@ -361,7 +382,7 @@ func (o OneOfSchema[KeyType]) findUnderlyingType(data any) (KeyType, Object, err } i++ } - return defaultValue, nil, &ConstraintError{ + return nilKey, nil, &ConstraintError{ Message: fmt.Sprintf( "Invalid type for one-of schema: '%s' (valid types are: %s)", dataType.String(), @@ -371,3 +392,38 @@ func (o OneOfSchema[KeyType]) findUnderlyingType(data any) (KeyType, Object, err } return *foundKey, o.TypesValue[*foundKey], nil } + +// validateSubtypeDiscriminatorInlineFields checks to see if a subtype's +// discriminator field has been written in accordance with the OneOfSchema's +// declaration. +func (o OneOfSchema[KeyType]) validateSubtypeDiscriminatorInlineFields() error { + for key, typeValue := range o.TypesValue { + typeValueDiscriminatorValue, hasDiscriminator := typeValue.Properties()[o.DiscriminatorFieldNameValue] + switch { + case !o.DiscriminatorInlined && hasDiscriminator: + return fmt.Errorf( + "object id %q has conflicting field %q; either remove that field or set inline to true for %T[%T]", + typeValue.ID(), o.DiscriminatorFieldNameValue, o, key) + case o.DiscriminatorInlined && !hasDiscriminator: + return fmt.Errorf( + "object id %q needs discriminator field %q; either add that field or set inline to false for %T[%T]", + typeValue.ID(), o.DiscriminatorFieldNameValue, o, key) + case o.DiscriminatorInlined && hasDiscriminator && + (typeValueDiscriminatorValue.ReflectedType().Kind() != reflect.TypeOf(key).Kind()): + return fmt.Errorf( + "the type of object id %v's discriminator field %q does not match OneOfSchema discriminator type; expected %v got %T", + typeValue.ID(), o.DiscriminatorFieldNameValue, typeValueDiscriminatorValue.TypeID(), key) + } + } + return nil +} + +func (o OneOfSchema[KeyType]) deleteDiscriminator(mymap map[string]any) map[string]any { + // the discriminator is not a property of the subtype + if !o.DiscriminatorInlined { + cloneData := maps.Clone(mymap) + delete(cloneData, o.DiscriminatorFieldNameValue) + return cloneData + } + return mymap +} diff --git a/schema/oneof_int.go b/schema/oneof_int.go index cc8c72c..fa29336 100644 --- a/schema/oneof_int.go +++ b/schema/oneof_int.go @@ -13,11 +13,13 @@ type OneOfInt interface { func NewOneOfIntSchema[ItemsInterface any]( types map[int64]Object, discriminatorFieldName string, + discriminatorInlined bool, ) *OneOfSchema[int64] { var defaultValue ItemsInterface return &OneOfSchema[int64]{ reflect.TypeOf(&defaultValue).Elem(), types, discriminatorFieldName, + discriminatorInlined, } } diff --git a/schema/oneof_int_test.go b/schema/oneof_int_test.go index 04bb1a2..b9e29c3 100644 --- a/schema/oneof_int_test.go +++ b/schema/oneof_int_test.go @@ -19,6 +19,7 @@ var oneOfIntTestObjectAProperties = map[string]*schema.PropertySchema{ 2: schema.NewRefSchema("C", nil), }, "_type", + false, ), nil, true, @@ -39,6 +40,7 @@ var oneOfIntTestObjectAbProperties = map[string]*schema.PropertySchema{ 2: schema.NewRefSchema("C", nil), }, "_difftype", + false, ), nil, true, @@ -59,6 +61,7 @@ var oneOfIntTestObjectAcProperties = map[string]*schema.PropertySchema{ 3: schema.NewRefSchema("C", nil), }, "_type", + false, ), nil, true, @@ -79,6 +82,7 @@ var oneOfIntTestObjectAdProperties = map[string]*schema.PropertySchema{ 2: schema.NewRefSchema("D", nil), }, "_type", + false, ), nil, true, diff --git a/schema/oneof_string.go b/schema/oneof_string.go index fffb80a..52ce02e 100644 --- a/schema/oneof_string.go +++ b/schema/oneof_string.go @@ -13,11 +13,13 @@ type OneOfString interface { func NewOneOfStringSchema[ItemsInterface any]( types map[string]Object, discriminatorFieldName string, + discriminatorInlined bool, ) *OneOfSchema[string] { var defaultValue ItemsInterface return &OneOfSchema[string]{ reflect.TypeOf(&defaultValue).Elem(), types, discriminatorFieldName, + discriminatorInlined, } } diff --git a/schema/oneof_string_test.go b/schema/oneof_string_test.go index 019f74a..9ecc875 100644 --- a/schema/oneof_string_test.go +++ b/schema/oneof_string_test.go @@ -5,12 +5,16 @@ package schema_test import ( "encoding/json" + "fmt" "go.arcalot.io/assert" + "reflect" "testing" "go.flow.arcalot.io/pluginsdk/schema" ) +const discriminatorFieldName = "d_type" + var oneOfStringTestObjectAProperties = map[string]*schema.PropertySchema{ "s": schema.NewPropertySchema( schema.NewOneOfStringSchema[any]( @@ -19,6 +23,7 @@ var oneOfStringTestObjectAProperties = map[string]*schema.PropertySchema{ "C": schema.NewRefSchema("C", nil), }, "_type", + false, ), nil, true, @@ -39,6 +44,7 @@ var oneOfStringTestObjectAbProperties = map[string]*schema.PropertySchema{ "C": schema.NewRefSchema("C", nil), }, "_difftype", + false, ), nil, true, @@ -59,6 +65,7 @@ var oneOfStringTestObjectAcProperties = map[string]*schema.PropertySchema{ "D": schema.NewRefSchema("C", nil), }, "_type", + false, ), nil, true, @@ -79,6 +86,7 @@ var oneOfStringTestObjectAdProperties = map[string]*schema.PropertySchema{ "C": schema.NewRefSchema("D", nil), }, "_type", + false, ), nil, true, @@ -138,16 +146,32 @@ var oneOfStringTestObjectAType = schema.NewScopeSchema( func TestOneOfStringUnserialization(t *testing.T) { data := `{ - "s": { - "_type": "B", - "message": "Hello world!" - } -}` + "s": { + "_type": "B", + "message": "Hello world!" + } + }` var input any assert.NoError(t, json.Unmarshal([]byte(data), &input)) unserializedData, err := oneOfStringTestObjectAType.Unserialize(input) assert.NoError(t, err) assert.Equals(t, unserializedData.(oneOfTestObjectA).S.(oneOfTestObjectB).Message, "Hello world!") + serialized, err := oneOfStringTestObjectAType.Serialize(unserializedData) + assert.NoError(t, err) + unserialized2, err := oneOfStringTestObjectAType.Unserialize(serialized) + assert.NoError(t, err) + assert.Equals(t, unserialized2, unserializedData) + + // Not explicitly using a struct mapped object, but the type is inferred + // by the compiler when the oneOfTestBMappedSchema is in the test suite. + unserializedData, err = oneOfStringTestObjectASchema.Unserialize(input) + assert.NoError(t, err) + assert.Equals(t, unserializedData.(map[string]any)["s"].(oneOfTestObjectB).Message, "Hello world!") + serialized, err = oneOfStringTestObjectASchema.Serialize(unserializedData) + assert.NoError(t, err) + unserialized2, err = oneOfStringTestObjectASchema.Unserialize(serialized) + assert.NoError(t, err) + assert.Equals(t, unserialized2, unserializedData) } func TestOneOfStringCompatibilityValidation(t *testing.T) { @@ -222,3 +246,310 @@ func TestOneOfStringCompatibilityMapValidation(t *testing.T) { assert.NoError(t, oneOfStringTestObjectASchema.ValidateCompatibility(combinedMapAndSchema)) assert.Error(t, oneOfStringTestObjectASchema.ValidateCompatibility(combinedMapAndInvalidSchema)) } + +type inlinedTestObjectA struct { + DType string `json:"d_type"` + OtherFieldA string `json:"other_field_a"` +} + +type inlinedTestObjectB struct { + DType string `json:"d_type"` + OtherFieldB string `json:"other_field_b"` +} + +type nonInlinedTestObjectA struct { + OtherFieldA string `json:"other_field_a"` +} + +type nonInlinedTestObjectB struct { + OtherFieldB string `json:"other_field_b"` +} + +var inlinedTestObjectAProperties = map[string]*schema.PropertySchema{ + discriminatorFieldName: schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), + "other_field_a": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), +} + +var inlinedTestObjectBProperties = map[string]*schema.PropertySchema{ + discriminatorFieldName: schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), + "other_field_b": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), +} + +var nonInlinedTestObjectAProperties = map[string]*schema.PropertySchema{ + "other_field_a": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), +} + +var nonInlinedTestObjectBProperties = map[string]*schema.PropertySchema{ + "other_field_b": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), +} + +var inlinedTestObjectAMappedSchema = schema.NewStructMappedObjectSchema[inlinedTestObjectA]( + "inlined_A", + inlinedTestObjectAProperties, +) + +var inlinedTestObjectBMappedSchema = schema.NewStructMappedObjectSchema[inlinedTestObjectB]( + "inlined_B", + inlinedTestObjectBProperties, +) + +var nonInlinedTestObjectAMappedSchema = schema.NewStructMappedObjectSchema[nonInlinedTestObjectA]( + "non_inlined_A", + nonInlinedTestObjectAProperties, +) + +var nonInlinedTestObjectBMappedSchema = schema.NewStructMappedObjectSchema[nonInlinedTestObjectB]( + "non_inlined_B", + nonInlinedTestObjectBProperties, +) + +var inlinedTestObjectASchema = schema.NewObjectSchema( + "inlined_A", + inlinedTestObjectAProperties, +) + +var inlinedTestObjectBSchema = schema.NewObjectSchema( + "inlined_B", + inlinedTestObjectBProperties, +) + +var nonInlinedTestObjectASchema = schema.NewObjectSchema( + "non_inlined_A", + nonInlinedTestObjectAProperties, +) + +var nonInlinedTestObjectBSchema = schema.NewObjectSchema( + "non_inlined_B", + nonInlinedTestObjectBProperties, +) + +func TestOneOf_InlinedStructMapped(t *testing.T) { + oneofSchema := schema.NewOneOfStringSchema[any](map[string]schema.Object{ + "A": inlinedTestObjectAMappedSchema, + "B": inlinedTestObjectBMappedSchema, + }, discriminatorFieldName, true) + serializedData := map[string]any{ + discriminatorFieldName: "A", + "other_field_a": "test", + } + // Since this is struct-mapped, unserializedData is a struct. + unserializedData := assert.NoErrorR[any](t)(oneofSchema.Unserialize(serializedData)) + reserializedData := assert.NoErrorR[any](t)(oneofSchema.Serialize(unserializedData)) + assert.Equals[any](t, reserializedData, serializedData) +} + +func TestOneOf_NonInlinedStructMapped(t *testing.T) { + oneofSchema := schema.NewOneOfStringSchema[any](map[string]schema.Object{ + "A": nonInlinedTestObjectAMappedSchema, + "B": nonInlinedTestObjectBMappedSchema, + }, discriminatorFieldName, false) + serializedData := map[string]any{ + discriminatorFieldName: "A", + "other_field_a": "test", + } + // Since this is struct-mapped, unserializedData is a struct. + unserializedData := assert.NoErrorR[any](t)(oneofSchema.Unserialize(serializedData)) + reserializedData := assert.NoErrorR[any](t)(oneofSchema.Serialize(unserializedData)) + assert.Equals[any](t, reserializedData, serializedData) +} + +func TestOneOf_InlinedNonStructMapped(t *testing.T) { + oneofSchema := schema.NewOneOfStringSchema[any](map[string]schema.Object{ + "A": inlinedTestObjectASchema, + "B": inlinedTestObjectBSchema, + }, discriminatorFieldName, true) + serializedData := map[string]any{ + discriminatorFieldName: "A", + "other_field_a": "test", + } + // Since this is not struct-mapped, unserializedData is a map. + unserializedData := assert.NoErrorR[any](t)(oneofSchema.Unserialize(serializedData)) + reserializedData := assert.NoErrorR[any](t)(oneofSchema.Serialize(unserializedData)) + assert.Equals[any](t, reserializedData, serializedData) +} + +func TestOneOf_NonInlinedNonStructMapped(t *testing.T) { + oneofSchema := schema.NewOneOfStringSchema[any](map[string]schema.Object{ + "A": nonInlinedTestObjectASchema, + "B": nonInlinedTestObjectBSchema, + }, discriminatorFieldName, false) + serializedData := map[string]any{ + discriminatorFieldName: "A", + "other_field_a": "test", + } + // Since this is not struct-mapped, unserializedData is a map. + unserializedData := assert.NoErrorR[any](t)(oneofSchema.Unserialize(serializedData)) + reserializedData := assert.NoErrorR[any](t)(oneofSchema.Serialize(unserializedData)) + assert.Equals[any](t, reserializedData, serializedData) + + var input_mismatched_type any = struct{}{} + err := oneofSchema.Validate(input_mismatched_type) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Invalid type for one-of schema") + + var input_invalid_type any = true + error_msg := fmt.Sprintf("Invalid type for one-of type: %q. Expected map.", reflect.TypeOf(input_invalid_type).Kind()) + _, err = oneofSchema.Unserialize(input_invalid_type) + assert.Error(t, err) + assert.Contains(t, err.Error(), error_msg) + error_msg = fmt.Sprintf("Invalid type for one-of type: %q expected struct or map.", reflect.TypeOf(input_invalid_type).Kind()) + err = oneofSchema.Validate(input_invalid_type) + assert.Error(t, err) + assert.Contains(t, err.Error(), error_msg) + + var input_nil any = nil + _, err = oneofSchema.Unserialize(input_nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "bug: data is nil") +} + +type inlinedTestIntDiscriminatorA struct { + DType int `json:"d_type"` + OtherFieldA string `json:"other_field_a"` +} + +var inlinedTestIntDiscriminatorAProperties = map[string]*schema.PropertySchema{ + discriminatorFieldName: schema.NewPropertySchema( + schema.NewIntSchema(nil, nil, nil), + nil, true, nil, nil, nil, + nil, nil, + ), + "other_field_a": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, true, nil, nil, nil, + nil, nil, + ), +} + +var inlinedTestNoDiscriminatorBProperties = map[string]*schema.PropertySchema{ + "other_field_b": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, true, nil, nil, nil, + nil, nil, + ), +} + +var inlinedTestIntDiscriminatorAMappedSchema = schema.NewStructMappedObjectSchema[inlinedTestIntDiscriminatorA]( + "inlined_int_A", + inlinedTestIntDiscriminatorAProperties, +) + +var inlinedTestIntDiscriminatorASchema = schema.NewObjectSchema( + "inlined_int_A", + inlinedTestIntDiscriminatorAProperties, +) + +var inlinedTestIntDiscriminatorBSchema = schema.NewObjectSchema( + "inlined_int_B", + inlinedTestNoDiscriminatorBProperties, +) + +func TestOneOf_Error_SubtypeHasInvalidDiscriminatorType(t *testing.T) { + testSchema := schema.NewOneOfStringSchema[any](map[string]schema.Object{ + "A": inlinedTestIntDiscriminatorAMappedSchema, + "B": inlinedTestObjectBMappedSchema, + }, discriminatorFieldName, true) + expMsg := fmt.Sprintf( + "discriminator field %q does not match OneOfSchema discriminator type", + discriminatorFieldName) + + assert.PanicsContains(t, func() { + schema.NewScopeSchema(schema.NewObjectSchema("test", + map[string]*schema.PropertySchema{ + "test": schema.NewPropertySchema( + testSchema, + nil, true, nil, nil, + nil, nil, nil), + })) + }, expMsg) +} + +func TestOneOf_Error_InlineSubtypeMissingDiscriminator(t *testing.T) { + testSchema := schema.NewOneOfIntSchema[any](map[int64]schema.Object{ + 1: inlinedTestIntDiscriminatorASchema, + 2: inlinedTestIntDiscriminatorBSchema, + }, discriminatorFieldName, true) + expMsg := "needs discriminator field" + + assert.PanicsContains(t, func() { + schema.NewScopeSchema(schema.NewObjectSchema("test", + map[string]*schema.PropertySchema{ + "test": schema.NewPropertySchema( + testSchema, + nil, true, nil, nil, + nil, nil, nil), + })) + }, expMsg) +} + +func TestOneOf_Error_SubtypeHasDiscriminator(t *testing.T) { + testSchema := schema.NewOneOfStringSchema[any](map[string]schema.Object{ + "A": inlinedTestIntDiscriminatorASchema, + "B": nonInlinedTestObjectBSchema, + }, discriminatorFieldName, false) + expMsg := "has conflicting field" + + assert.PanicsContains(t, func() { + schema.NewScopeSchema(schema.NewObjectSchema("test", + map[string]*schema.PropertySchema{ + "test": schema.NewPropertySchema( + testSchema, + nil, true, nil, nil, + nil, nil, nil), + })) + }, expMsg) +} diff --git a/schema/oneof_test.go b/schema/oneof_test.go index 00ba273..4ff8ab7 100644 --- a/schema/oneof_test.go +++ b/schema/oneof_test.go @@ -1,9 +1,10 @@ package schema_test import ( - "go.arcalot.io/assert" + "fmt" "testing" + "go.arcalot.io/assert" "go.flow.arcalot.io/pluginsdk/schema" ) @@ -125,3 +126,95 @@ var oneOfTestCMappedSchema = schema.NewStructMappedObjectSchema[oneOfTestObjectC "C", oneOfTestObjectCProperties, ) + +// Test_OneOf_ConstructorBypass tests the behavior of a OneOf object created +// by the Scope ScopeSchema, a scope that contains the schema of a scope +// and an object, through unserialization of data without using a +// New* constructor function, like NewOneOfStringSchema or NewOneOfIntSchema, +// behaves as one would expect from a OneOf object created from a constructor. +func Test_OneOf_ConstructorBypass(t *testing.T) { //nolint:funlen + discriminator_field := "_type" + input_schema := map[string]any{ + "root": "InputParams", + "objects": map[string]any{ + "InputParams": map[string]any{ + "id": "InputParams", + "properties": map[string]any{ + "name": map[string]any{ + "required": true, + "type": map[string]any{ + "discriminator_field_name": discriminator_field, + "type_id": "one_of_string", + "types": map[string]any{ + "fullname": map[string]any{ + "id": "FullName", + "type_id": "ref", + }, + "nick": map[string]any{ + "id": "Nickname", + "type_id": "ref", + }, + }, + }, + }, + }, + }, + "FullName": map[string]any{ + "id": "FullName", + "properties": map[string]any{ + "first_name": map[string]any{ + "required": true, + "type": map[string]any{ + "type_id": "string", + }, + }, + "last_name": map[string]any{ + "required": true, + "type": map[string]any{ + "type_id": "string", + }, + }, + }, + }, + "Nickname": map[string]any{ + "id": "Nickname", + "properties": map[string]any{ + "nick": map[string]any{ + "required": true, + "type": map[string]any{ + "type_id": "string", + }, + }, + }, + }, + }, + } + var input_data_fullname any = map[string]any{ + "name": map[string]any{ + discriminator_field: "fullname", + "first_name": "Arca", + "last_name": "Lot", + }, + } + + scopeAny := assert.NoErrorR[any](t)(schema.DescribeScope().Unserialize(input_schema)) + scopeSchemaTyped := scopeAny.(*schema.ScopeSchema) + scopeSchemaTyped.ApplyScope(scopeSchemaTyped) + assert.NoError(t, scopeSchemaTyped.Validate(input_data_fullname)) + unserialized := assert.NoErrorR[any](t)(scopeSchemaTyped.Unserialize(input_data_fullname)) + serialized := assert.NoErrorR[any](t)(scopeSchemaTyped.Serialize(unserialized)) + unserialized2 := assert.NoErrorR[any](t)(scopeSchemaTyped.Unserialize(serialized)) + assert.Equals(t, unserialized2, unserialized) + + var input_invalid_discriminator_value any = map[string]any{ + "name": map[string]any{ + discriminator_field: "robotname", + "first_name": "Arca", + "last_name": "Lot", + }, + } + error_msg := fmt.Sprintf("Invalid value for %q", discriminator_field) + _, err := scopeSchemaTyped.Unserialize(input_invalid_discriminator_value) + assert.Error(t, err) + assert.Contains(t, err.Error(), error_msg) +} diff --git a/schema/schema_schema.go b/schema/schema_schema.go index fa0bb65..f54dd16 100644 --- a/schema/schema_schema.go +++ b/schema/schema_schema.go @@ -50,6 +50,7 @@ var mapKeyType = NewOneOfStringSchema[any]( ), }, "type_id", + false, ) var displayType = NewDisplayValue( PointerTo("Display"), @@ -194,6 +195,7 @@ var valueType = NewOneOfStringSchema[any]( ), }, "type_id", + false, ) var scopeObject = NewStructMappedObjectSchema[*ScopeSchema]( "Scope", @@ -530,6 +532,20 @@ var basicObjects = []*ObjectSchema{ NewStructMappedObjectSchema[*OneOfSchema[int64]]( "OneOfIntSchema", map[string]*PropertySchema{ + "discriminator_inlined": NewPropertySchema( + NewBoolSchema(), + NewDisplayValue( + PointerTo("Discriminator Inlined"), + PointerTo("whether or not the discriminator is inlined in the underlying objects' schema."), + nil, + ), + true, + nil, + nil, + nil, + PointerTo("false"), + nil, + ), "discriminator_field_name": NewPropertySchema( NewStringSchema(nil, nil, nil), NewDisplayValue( @@ -538,7 +554,7 @@ var basicObjects = []*ObjectSchema{ "field is present on any of the component objects it must also be an int."), nil, ), - false, + true, nil, nil, nil, @@ -555,6 +571,7 @@ var basicObjects = []*ObjectSchema{ string(TypeIDObject): NewRefSchema("Object", nil), }, "type_id", + false, ), nil, nil, @@ -576,6 +593,20 @@ var basicObjects = []*ObjectSchema{ NewStructMappedObjectSchema[*OneOfSchema[string]]( "OneOfStringSchema", map[string]*PropertySchema{ + "discriminator_inlined": NewPropertySchema( + NewBoolSchema(), + NewDisplayValue( + PointerTo("Discriminator Inlined"), + PointerTo("whether or not the discriminator is inlined in the underlying objects' schema."), + nil, + ), + true, + nil, + nil, + nil, + PointerTo("false"), + nil, + ), "discriminator_field_name": NewPropertySchema( NewStringSchema(nil, nil, nil), NewDisplayValue( @@ -584,7 +615,7 @@ var basicObjects = []*ObjectSchema{ "field is present on any of the component objects it must also be an int."), nil, ), - false, + true, nil, nil, nil, @@ -601,6 +632,7 @@ var basicObjects = []*ObjectSchema{ string(TypeIDObject): NewRefSchema("Object", nil), }, "type_id", + false, ), nil, nil,