diff --git a/ext/native.go b/ext/native.go index 7b6f0ba5..36ab4a7a 100644 --- a/ext/native.go +++ b/ext/native.go @@ -128,10 +128,36 @@ func NativeTypes(args ...any) cel.EnvOption { // NativeTypesOption is a functional interface for configuring handling of native types. type NativeTypesOption func(*nativeTypeOptions) error +// NativeTypesFieldNameHandler is a handler for mapping a reflect.StructField to a CEL field name. +// This can be used to override the default Go struct field to CEL field name mapping. +type NativeTypesFieldNameHandler = func(field reflect.StructField) string + +func fieldNameByTag(structTagToParse string) func(field reflect.StructField) string { + return func(field reflect.StructField) string { + tag, found := field.Tag.Lookup(structTagToParse) + if found { + splits := strings.Split(tag, ",") + if len(splits) > 0 { + // We make the assumption that the leftmost entry in the tag is the name. + // This seems to be true for most tags that have the concept of a name/key, such as: + // https://pkg.go.dev/encoding/xml#Marshal + // https://pkg.go.dev/encoding/json#Marshal + // https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs + // https://pkg.go.dev/gopkg.in/yaml.v2#Marshal + name := splits[0] + return name + } + } + + return field.Name + } +} + type nativeTypeOptions struct { - // structTagToParse controls if CEL should support struct field renames, by parsing - // struct field tags. This must be set to the tag to parse, such as "cel" or "json". - structTagToParse string + // fieldNameHandler controls how CEL should perform struct field renames. + // This is most commonly used for switching to parsing based off the struct field tag, + // such as "cel" or "json". + fieldNameHandler NativeTypesFieldNameHandler } // ParseStructTags configures if native types field names should be overridable by CEL struct tags. @@ -139,9 +165,9 @@ type nativeTypeOptions struct { func ParseStructTags(enabled bool) NativeTypesOption { return func(ntp *nativeTypeOptions) error { if enabled { - ntp.structTagToParse = "cel" + ntp.fieldNameHandler = fieldNameByTag("cel") } else { - ntp.structTagToParse = "" + ntp.fieldNameHandler = nil } return nil } @@ -153,7 +179,15 @@ func ParseStructTags(enabled bool) NativeTypesOption { // If the tag to parse is "json" and the struct field has tag json:"foo,omitempty", the CEL struct field will be "foo". func ParseStructTag(tag string) NativeTypesOption { return func(ntp *nativeTypeOptions) error { - ntp.structTagToParse = tag + ntp.fieldNameHandler = fieldNameByTag(tag) + return nil + } +} + +// ParseStructField configures how to parse Go struct fields. It can be used to customize struct field parsing. +func ParseStructField(handler NativeTypesFieldNameHandler) NativeTypesOption { + return func(ntp *nativeTypeOptions) error { + ntp.fieldNameHandler = handler return nil } } @@ -163,7 +197,7 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p for _, refType := range refTypes { switch rt := refType.(type) { case reflect.Type: - result, err := newNativeTypes(tpOptions.structTagToParse, rt) + result, err := newNativeTypes(tpOptions.fieldNameHandler, rt) if err != nil { return nil, err } @@ -171,7 +205,7 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p nativeTypes[result[idx].TypeName()] = result[idx] } case reflect.Value: - result, err := newNativeTypes(tpOptions.structTagToParse, rt.Type()) + result, err := newNativeTypes(tpOptions.fieldNameHandler, rt.Type()) if err != nil { return nil, err } @@ -224,27 +258,12 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool return tp.baseProvider.FindStructType(typeName) } -func toFieldName(structTagToParse string, f reflect.StructField) string { - if structTagToParse == "" { +func toFieldName(fieldNameHandler NativeTypesFieldNameHandler, f reflect.StructField) string { + if fieldNameHandler == nil { return f.Name } - tag, found := f.Tag.Lookup(structTagToParse) - if found { - splits := strings.Split(tag, ",") - if len(splits) > 0 { - // We make the assumption that the leftmost entry in the tag is the name. - // This seems to be true for most tags that have the concept of a name/key, such as: - // https://pkg.go.dev/encoding/xml#Marshal - // https://pkg.go.dev/encoding/json#Marshal - // https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs - // https://pkg.go.dev/gopkg.in/yaml.v2#Marshal - name := splits[0] - return name - } - } - - return f.Name + return fieldNameHandler(f) } // FindStructFieldNames looks up the type definition first from the native types, then from @@ -255,7 +274,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b fieldCount := t.refType.NumField() fields := make([]string, fieldCount) for i := 0; i < fieldCount; i++ { - fields[i] = toFieldName(tp.options.structTagToParse, t.refType.Field(i)) + fields[i] = toFieldName(tp.options.fieldNameHandler, t.refType.Field(i)) } return fields, true } @@ -415,7 +434,7 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) { } func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val { - valType, err := newNativeType(tp.options.structTagToParse, refValue.Type()) + valType, err := newNativeType(tp.options.fieldNameHandler, refValue.Type()) if err != nil { return types.NewErr(err.Error()) } @@ -467,7 +486,7 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) { if !fieldValue.IsValid() || fieldValue.IsZero() { continue } - fieldName := toFieldName(o.valType.structTagToParse, fieldType) + fieldName := toFieldName(o.valType.fieldNameHandler, fieldType) fieldCELVal := o.NativeToValue(fieldValue.Interface()) fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType) if err != nil { @@ -565,8 +584,8 @@ func (o *nativeObj) Value() any { return o.val } -func newNativeTypes(structTagToParse string, rawType reflect.Type) ([]*nativeType, error) { - nt, err := newNativeType(structTagToParse, rawType) +func newNativeTypes(fieldNameHandler NativeTypesFieldNameHandler, rawType reflect.Type) ([]*nativeType, error) { + nt, err := newNativeType(fieldNameHandler, rawType) if err != nil { return nil, err } @@ -585,7 +604,7 @@ func newNativeTypes(structTagToParse string, rawType reflect.Type) ([]*nativeTyp return } alreadySeen[t.String()] = struct{}{} - nt, ntErr := newNativeType(structTagToParse, t) + nt, ntErr := newNativeType(fieldNameHandler, t) if ntErr != nil { err = ntErr return @@ -605,7 +624,7 @@ var ( errDuplicatedFieldName = errors.New("field name already exists in struct") ) -func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType, error) { +func newNativeType(fieldNameHandler NativeTypesFieldNameHandler, rawType reflect.Type) (*nativeType, error) { refType := rawType if refType.Kind() == reflect.Pointer { refType = refType.Elem() @@ -615,12 +634,12 @@ func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType, } // Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled. - if structTagToParse != "" { + if fieldNameHandler != nil { fieldNames := make(map[string]struct{}) for idx := 0; idx < refType.NumField(); idx++ { field := refType.Field(idx) - fieldName := toFieldName(structTagToParse, field) + fieldName := toFieldName(fieldNameHandler, field) if _, found := fieldNames[fieldName]; found { return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName) @@ -633,14 +652,14 @@ func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType, return &nativeType{ typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()), refType: refType, - structTagToParse: structTagToParse, + fieldNameHandler: fieldNameHandler, }, nil } type nativeType struct { typeName string refType reflect.Type - structTagToParse string + fieldNameHandler NativeTypesFieldNameHandler } // ConvertToNative implements ref.Val.ConvertToNative. @@ -691,13 +710,13 @@ func (t *nativeType) Value() any { // fieldByName returns the corresponding reflect.StructField for the give name either by matching // field tag or field name. func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) { - if t.structTagToParse == "" { + if t.fieldNameHandler == nil { return t.refType.FieldByName(fieldName) } for i := 0; i < t.refType.NumField(); i++ { f := t.refType.Field(i) - if toFieldName(t.structTagToParse, f) == fieldName { + if toFieldName(t.fieldNameHandler, f) == fieldName { return f, true } } diff --git a/ext/native_test.go b/ext/native_test.go index de309137..4a62ec04 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -827,7 +827,8 @@ func TestNativeTypeConvertToType(t *testing.T) { for i, tst := range nativeTests { tc := tst t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { - nt, err := newNativeType(tc.tag, reflect.TypeOf(&TestAllTypes{})) + handler := fieldNameByTag(tc.tag) + nt, err := newNativeType(handler, reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -842,7 +843,7 @@ func TestNativeTypeConvertToType(t *testing.T) { } func TestNativeTypeConvertToNative(t *testing.T) { - nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType(fieldNameByTag("cel"), reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -853,7 +854,7 @@ func TestNativeTypeConvertToNative(t *testing.T) { } func TestNativeTypeHasTrait(t *testing.T) { - nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType(fieldNameByTag("cel"), reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -863,7 +864,7 @@ func TestNativeTypeHasTrait(t *testing.T) { } func TestNativeTypeValue(t *testing.T) { - nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType(fieldNameByTag("cel"), reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -873,7 +874,7 @@ func TestNativeTypeValue(t *testing.T) { } func TestNativeStructWithMultipleSameFieldNames(t *testing.T) { - _, err := newNativeType("cel", reflect.TypeOf(TestStructWithMultipleSameNames{})) + _, err := newNativeType(fieldNameByTag("cel"), reflect.TypeOf(TestStructWithMultipleSameNames{})) if err == nil { t.Fatal("newNativeType() did not fail as expected") }