diff --git a/ext/native.go b/ext/native.go index 35b0f1e3..7b6f0ba5 100644 --- a/ext/native.go +++ b/ext/native.go @@ -129,15 +129,31 @@ func NativeTypes(args ...any) cel.EnvOption { type NativeTypesOption func(*nativeTypeOptions) error type nativeTypeOptions struct { - // parseStructTags controls if CEL should support struct field renames, by parsing - // struct field tags. - parseStructTags bool + // structTagToParse controls if CEL should support struct field renames, by parsing + // struct field tags. This must be set to the tag to parse, such as "cel" or "json". + structTagToParse string } // ParseStructTags configures if native types field names should be overridable by CEL struct tags. +// This is equivalent to ParseStructTag("cel") func ParseStructTags(enabled bool) NativeTypesOption { return func(ntp *nativeTypeOptions) error { - ntp.parseStructTags = true + if enabled { + ntp.structTagToParse = "cel" + } else { + ntp.structTagToParse = "" + } + return nil + } +} + +// ParseStructTag configures the struct tag to parse. The 0th item in the tag is used as the name of the CEL field. +// For example: +// If the tag to parse is "cel" and the struct field has tag cel:"foo", the CEL struct field will be "foo". +// If the tag to parse is "json" and the struct field has tag json:"foo,omitempty", the CEL struct field will be "foo". +func ParseStructTag(tag string) NativeTypesOption { + return func(ntp *nativeTypeOptions) error { + ntp.structTagToParse = tag return nil } } @@ -147,7 +163,7 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p for _, refType := range refTypes { switch rt := refType.(type) { case reflect.Type: - result, err := newNativeTypes(tpOptions.parseStructTags, rt) + result, err := newNativeTypes(tpOptions.structTagToParse, rt) if err != nil { return nil, err } @@ -155,7 +171,7 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p nativeTypes[result[idx].TypeName()] = result[idx] } case reflect.Value: - result, err := newNativeTypes(tpOptions.parseStructTags, rt.Type()) + result, err := newNativeTypes(tpOptions.structTagToParse, rt.Type()) if err != nil { return nil, err } @@ -208,13 +224,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool return tp.baseProvider.FindStructType(typeName) } -func toFieldName(parseStructTag bool, f reflect.StructField) string { - if !parseStructTag { +func toFieldName(structTagToParse string, f reflect.StructField) string { + if structTagToParse == "" { return f.Name } - if name, found := f.Tag.Lookup("cel"); found { - return name + tag, found := f.Tag.Lookup(structTagToParse) + if found { + splits := strings.Split(tag, ",") + if len(splits) > 0 { + // We make the assumption that the leftmost entry in the tag is the name. + // This seems to be true for most tags that have the concept of a name/key, such as: + // https://pkg.go.dev/encoding/xml#Marshal + // https://pkg.go.dev/encoding/json#Marshal + // https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs + // https://pkg.go.dev/gopkg.in/yaml.v2#Marshal + name := splits[0] + return name + } } return f.Name @@ -228,7 +255,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b fieldCount := t.refType.NumField() fields := make([]string, fieldCount) for i := 0; i < fieldCount; i++ { - fields[i] = toFieldName(tp.options.parseStructTags, t.refType.Field(i)) + fields[i] = toFieldName(tp.options.structTagToParse, t.refType.Field(i)) } return fields, true } @@ -238,22 +265,6 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b return tp.baseProvider.FindStructFieldNames(typeName) } -// valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by -// searching for a matching field tag value or field name. -func valueFieldByName(parseStructTags bool, target reflect.Value, fieldName string) reflect.Value { - if !parseStructTags { - return target.FieldByName(fieldName) - } - - for i := 0; i < target.Type().NumField(); i++ { - f := target.Type().Field(i) - if toFieldName(parseStructTags, f) == fieldName { - return target.FieldByIndex(f.Index) - } - } - return reflect.Value{} -} - // FindStructFieldType looks up a native type's field definition, and if the type name is not a native // type then proxies to the composed types.Provider func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) { @@ -273,12 +284,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (* Type: celType, IsSet: func(obj any) bool { refVal := reflect.Indirect(reflect.ValueOf(obj)) - refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName) + refField := refVal.FieldByName(refField.Name) return !refField.IsZero() }, GetFrom: func(obj any) (any, error) { refVal := reflect.Indirect(reflect.ValueOf(obj)) - refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName) + refField := refVal.FieldByName(refField.Name) return getFieldValue(refField), nil }, }, true @@ -404,7 +415,7 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) { } func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val { - valType, err := newNativeType(tp.options.parseStructTags, refValue.Type()) + valType, err := newNativeType(tp.options.structTagToParse, refValue.Type()) if err != nil { return types.NewErr(err.Error()) } @@ -456,7 +467,7 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) { if !fieldValue.IsValid() || fieldValue.IsZero() { continue } - fieldName := toFieldName(o.valType.parseStructTags, fieldType) + fieldName := toFieldName(o.valType.structTagToParse, fieldType) fieldCELVal := o.NativeToValue(fieldValue.Interface()) fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType) if err != nil { @@ -554,8 +565,8 @@ func (o *nativeObj) Value() any { return o.val } -func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType, error) { - nt, err := newNativeType(parseStructTags, rawType) +func newNativeTypes(structTagToParse string, rawType reflect.Type) ([]*nativeType, error) { + nt, err := newNativeType(structTagToParse, rawType) if err != nil { return nil, err } @@ -574,7 +585,7 @@ func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType, return } alreadySeen[t.String()] = struct{}{} - nt, ntErr := newNativeType(parseStructTags, t) + nt, ntErr := newNativeType(structTagToParse, t) if ntErr != nil { err = ntErr return @@ -594,7 +605,7 @@ var ( errDuplicatedFieldName = errors.New("field name already exists in struct") ) -func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, error) { +func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType, error) { refType := rawType if refType.Kind() == reflect.Pointer { refType = refType.Elem() @@ -604,12 +615,12 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err } // Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled. - if parseStructTags { + if structTagToParse != "" { fieldNames := make(map[string]struct{}) for idx := 0; idx < refType.NumField(); idx++ { field := refType.Field(idx) - fieldName := toFieldName(parseStructTags, field) + fieldName := toFieldName(structTagToParse, field) if _, found := fieldNames[fieldName]; found { return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName) @@ -620,16 +631,16 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err } return &nativeType{ - typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()), - refType: refType, - parseStructTags: parseStructTags, + typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()), + refType: refType, + structTagToParse: structTagToParse, }, nil } type nativeType struct { - typeName string - refType reflect.Type - parseStructTags bool + typeName string + refType reflect.Type + structTagToParse string } // ConvertToNative implements ref.Val.ConvertToNative. @@ -680,13 +691,13 @@ func (t *nativeType) Value() any { // fieldByName returns the corresponding reflect.StructField for the give name either by matching // field tag or field name. func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) { - if !t.parseStructTags { + if t.structTagToParse == "" { return t.refType.FieldByName(fieldName) } for i := 0; i < t.refType.NumField(); i++ { f := t.refType.Field(i) - if toFieldName(t.parseStructTags, f) == fieldName { + if toFieldName(t.structTagToParse, f) == fieldName { return f, true } } diff --git a/ext/native_test.go b/ext/native_test.go index 55e5aa04..de309137 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -33,8 +33,9 @@ import ( "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/test" - proto3pb "github.com/google/cel-go/test/proto3pb" structpb "google.golang.org/protobuf/types/known/structpb" + + proto3pb "github.com/google/cel-go/test/proto3pb" ) func TestNativeTypes(t *testing.T) { @@ -109,6 +110,72 @@ func TestNativeTypes(t *testing.T) { }, envOpts: []any{ParseStructTags(true)}, }, + + { + expr: `ext.TestAllTypes{ + nestedVal: ext.TestNestedType{NestedMapVal: {1: false}}, + boolVal: true, + BytesVal: b'hello', + DurationVal: duration('5s'), + DoubleVal: 1.5, + FloatVal: 2.5, + Int32Val: 10, + Int64Val: 20, + StringVal: 'hello world', + TimestampVal: timestamp('2011-08-06T01:23:45Z'), + Uint32Val: 100u, + Uint64Val: 200u, + ListVal: [ + ext.TestNestedType{ + NestedListVal:['goodbye', 'cruel', 'world'], + NestedMapVal: {42: true}, + custom_name: 'name', + }, + ], + ArrayVal: [ + ext.TestNestedType{ + NestedListVal:['goodbye', 'cruel', 'world'], + NestedMapVal: {42: true}, + custom_name: 'name', + }, + ], + MapVal: {'map-key': ext.TestAllTypes{boolVal: true}}, + CustomSliceVal: [ext.TestNestedSliceType{Value: 'none'}], + CustomMapVal: {'even': ext.TestMapVal{Value: 'more'}}, + CustomName: 'name', + }`, + out: &TestAllTypes{ + NestedVal: &TestNestedType{NestedMapVal: map[int64]bool{1: false}}, + BoolVal: true, + BytesVal: []byte("hello"), + DurationVal: time.Second * 5, + DoubleVal: 1.5, + FloatVal: 2.5, + Int32Val: 10, + Int64Val: 20, + StringVal: "hello world", + TimestampVal: mustParseTime(t, "2011-08-06T01:23:45Z"), + Uint32Val: uint32(100), + Uint64Val: uint64(200), + ListVal: []*TestNestedType{ + { + NestedListVal: []string{"goodbye", "cruel", "world"}, + NestedMapVal: map[int64]bool{42: true}, + NestedCustomName: "name", + }, + }, + ArrayVal: [1]*TestNestedType{{ + NestedListVal: []string{"goodbye", "cruel", "world"}, + NestedMapVal: map[int64]bool{42: true}, + NestedCustomName: "name", + }}, + MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}}, + CustomSliceVal: []TestNestedSliceType{{Value: "none"}}, + CustomMapVal: map[string]TestMapVal{"even": {Value: "more"}}, + CustomName: "name", + }, + envOpts: []any{ParseStructTag("json")}, + }, { expr: `ext.TestAllTypes{ NestedVal: ext.TestNestedType{NestedMapVal: {1: false}}, @@ -750,20 +817,32 @@ func TestNativeTypesWithOptional(t *testing.T) { } func TestNativeTypeConvertToType(t *testing.T) { - nt, err := newNativeType(true, reflect.TypeOf(&TestAllTypes{})) - if err != nil { - t.Fatalf("newNativeType() failed: %v", err) - } - if nt.ConvertToType(types.TypeType) != types.TypeType { - t.Error("ConvertToType(Type) failed") + var nativeTests = []struct { + tag string + }{ + {tag: "cel"}, + {tag: "json"}, } - if !types.IsError(nt.ConvertToType(types.StringType)) { - t.Errorf("ConvertToType(String) got %v, wanted error", nt.ConvertToType(types.StringType)) + + for i, tst := range nativeTests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + nt, err := newNativeType(tc.tag, reflect.TypeOf(&TestAllTypes{})) + if err != nil { + t.Fatalf("newNativeType() failed: %v", err) + } + if nt.ConvertToType(types.TypeType) != types.TypeType { + t.Error("ConvertToType(Type) failed") + } + if !types.IsError(nt.ConvertToType(types.StringType)) { + t.Errorf("ConvertToType(String) got %v, wanted error", nt.ConvertToType(types.StringType)) + } + }) } } func TestNativeTypeConvertToNative(t *testing.T) { - nt, err := newNativeType(true, reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -774,7 +853,7 @@ func TestNativeTypeConvertToNative(t *testing.T) { } func TestNativeTypeHasTrait(t *testing.T) { - nt, err := newNativeType(true, reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -784,7 +863,7 @@ func TestNativeTypeHasTrait(t *testing.T) { } func TestNativeTypeValue(t *testing.T) { - nt, err := newNativeType(true, reflect.TypeOf(&TestAllTypes{})) + nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{})) if err != nil { t.Fatalf("newNativeType() failed: %v", err) } @@ -793,8 +872,8 @@ func TestNativeTypeValue(t *testing.T) { } } -func TestNativeStructWithMultileSameFieldNames(t *testing.T) { - _, err := newNativeType(true, reflect.TypeOf(TestStructWithMultipleSameNames{})) +func TestNativeStructWithMultipleSameFieldNames(t *testing.T) { + _, err := newNativeType("cel", reflect.TypeOf(TestStructWithMultipleSameNames{})) if err == nil { t.Fatal("newNativeType() did not fail as expected") } @@ -803,6 +882,64 @@ func TestNativeStructWithMultileSameFieldNames(t *testing.T) { } } +func TestNativeStructEmbedded(t *testing.T) { + var nativeTests = []struct { + expr string + in any + }{ + { + expr: `test.embedded.custom_name == "name"`, + in: map[string]any{ + "test": &TestEmbeddedTypes{TestNestedType{NestedCustomName: "name"}}, + }, + }, + } + + envOpts := []cel.EnvOption{ + NativeTypes( + reflect.TypeOf(&TestEmbeddedTypes{}), + reflect.TypeOf(&TestNestedType{}), + ParseStructTag("json"), + ), + cel.Variable("test", cel.ObjectType("ext.TestEmbeddedTypes")), + } + + env, err := cel.NewEnv(envOpts...) + if err != nil { + t.Fatalf("cel.NewEnv(NativeTypes()) failed: %v", err) + } + + for i, tst := range nativeTests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + var asts []*cel.Ast + pAst, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, pAst) + cAst, iss := env.Check(pAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, cAst) + for _, ast := range asts { + prg, err := env.Program(ast) + if err != nil { + t.Fatal(err) + } + out, _, err := prg.Eval(tc.in) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(out.Value(), true) { + t.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr) + } + } + }) + } +} + // testEnv initializes the test environment common to all tests. func testNativeEnv(t *testing.T, opts ...any) *cel.Env { t.Helper() @@ -855,13 +992,13 @@ type TestStructWithMultipleSameNames struct { type TestNestedType struct { NestedListVal []string NestedMapVal map[int64]bool - NestedCustomName string `cel:"custom_name"` + NestedCustomName string `cel:"custom_name" json:"custom_name"` } type TestAllTypes struct { - NestedVal *TestNestedType - NestedStructVal TestNestedType - BoolVal bool + NestedVal *TestNestedType `json:"nestedVal,omitempty"` + NestedStructVal TestNestedType `json:"nestedStructVal"` + BoolVal bool `json:"boolVal"` BytesVal []byte DurationVal time.Duration DoubleVal float64 @@ -897,3 +1034,7 @@ type TestNestedSliceType struct { type TestMapVal struct { Value string } + +type TestEmbeddedTypes struct { + TestNestedType `json:"embedded,omitempty"` +}