Skip to content

Commit

Permalink
Allow configurable tag name overrides in native types (#1009)
Browse files Browse the repository at this point in the history
Allowing tags other than "cel" enables more flexibility, particularly
when the types in question aren't owned by those configuring the CEL
environment.

The main use-case for this is enabling an experience in CEL that matches
the shape of an object in JSON (or YAML). This includes:
 - Property casing: 'Property' in Go would likely be 'property' in
   JSON/YAML.
 - Embedded structures: metav1.ObjectType `json:"metadata,omitempty"`
   is common on Kuberentes resources, allowing users to request CEL use
   the JSON tags enables access via the ".metadata" property even though
   there is no such property on the actual Native type.
  • Loading branch information
matthchr authored Sep 9, 2024
1 parent 1fa4f15 commit d561d0a
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 63 deletions.
101 changes: 56 additions & 45 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,31 @@ func NativeTypes(args ...any) cel.EnvOption {
type NativeTypesOption func(*nativeTypeOptions) error

type nativeTypeOptions struct {
// parseStructTags controls if CEL should support struct field renames, by parsing
// struct field tags.
parseStructTags bool
// structTagToParse controls if CEL should support struct field renames, by parsing
// struct field tags. This must be set to the tag to parse, such as "cel" or "json".
structTagToParse string
}

// ParseStructTags configures if native types field names should be overridable by CEL struct tags.
// This is equivalent to ParseStructTag("cel")
func ParseStructTags(enabled bool) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.parseStructTags = true
if enabled {
ntp.structTagToParse = "cel"
} else {
ntp.structTagToParse = ""
}
return nil
}
}

// ParseStructTag configures the struct tag to parse. The 0th item in the tag is used as the name of the CEL field.
// For example:
// If the tag to parse is "cel" and the struct field has tag cel:"foo", the CEL struct field will be "foo".
// If the tag to parse is "json" and the struct field has tag json:"foo,omitempty", the CEL struct field will be "foo".
func ParseStructTag(tag string) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.structTagToParse = tag
return nil
}
}
Expand All @@ -147,15 +163,15 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
result, err := newNativeTypes(tpOptions.parseStructTags, rt)
result, err := newNativeTypes(tpOptions.structTagToParse, rt)
if err != nil {
return nil, err
}
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
case reflect.Value:
result, err := newNativeTypes(tpOptions.parseStructTags, rt.Type())
result, err := newNativeTypes(tpOptions.structTagToParse, rt.Type())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -208,13 +224,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}

func toFieldName(parseStructTag bool, f reflect.StructField) string {
if !parseStructTag {
func toFieldName(structTagToParse string, f reflect.StructField) string {
if structTagToParse == "" {
return f.Name
}

if name, found := f.Tag.Lookup("cel"); found {
return name
tag, found := f.Tag.Lookup(structTagToParse)
if found {
splits := strings.Split(tag, ",")
if len(splits) > 0 {
// We make the assumption that the leftmost entry in the tag is the name.
// This seems to be true for most tags that have the concept of a name/key, such as:
// https://pkg.go.dev/encoding/xml#Marshal
// https://pkg.go.dev/encoding/json#Marshal
// https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs
// https://pkg.go.dev/gopkg.in/yaml.v2#Marshal
name := splits[0]
return name
}
}

return f.Name
Expand All @@ -228,7 +255,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = toFieldName(tp.options.parseStructTags, t.refType.Field(i))
fields[i] = toFieldName(tp.options.structTagToParse, t.refType.Field(i))
}
return fields, true
}
Expand All @@ -238,22 +265,6 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
return tp.baseProvider.FindStructFieldNames(typeName)
}

// valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by
// searching for a matching field tag value or field name.
func valueFieldByName(parseStructTags bool, target reflect.Value, fieldName string) reflect.Value {
if !parseStructTags {
return target.FieldByName(fieldName)
}

for i := 0; i < target.Type().NumField(); i++ {
f := target.Type().Field(i)
if toFieldName(parseStructTags, f) == fieldName {
return target.FieldByIndex(f.Index)
}
}
return reflect.Value{}
}

// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
Expand All @@ -273,12 +284,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*
Type: celType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
refField := refVal.FieldByName(refField.Name)
return !refField.IsZero()
},
GetFrom: func(obj any) (any, error) {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
refField := refVal.FieldByName(refField.Name)
return getFieldValue(refField), nil
},
}, true
Expand Down Expand Up @@ -404,7 +415,7 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
}

func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(tp.options.parseStructTags, refValue.Type())
valType, err := newNativeType(tp.options.structTagToParse, refValue.Type())
if err != nil {
return types.NewErr(err.Error())
}
Expand Down Expand Up @@ -456,7 +467,7 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldName := toFieldName(o.valType.parseStructTags, fieldType)
fieldName := toFieldName(o.valType.structTagToParse, fieldType)
fieldCELVal := o.NativeToValue(fieldValue.Interface())
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
if err != nil {
Expand Down Expand Up @@ -554,8 +565,8 @@ func (o *nativeObj) Value() any {
return o.val
}

func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(parseStructTags, rawType)
func newNativeTypes(structTagToParse string, rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(structTagToParse, rawType)
if err != nil {
return nil, err
}
Expand All @@ -574,7 +585,7 @@ func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType,
return
}
alreadySeen[t.String()] = struct{}{}
nt, ntErr := newNativeType(parseStructTags, t)
nt, ntErr := newNativeType(structTagToParse, t)
if ntErr != nil {
err = ntErr
return
Expand All @@ -594,7 +605,7 @@ var (
errDuplicatedFieldName = errors.New("field name already exists in struct")
)

func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, error) {
func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
refType = refType.Elem()
Expand All @@ -604,12 +615,12 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
}

// Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled.
if parseStructTags {
if structTagToParse != "" {
fieldNames := make(map[string]struct{})

for idx := 0; idx < refType.NumField(); idx++ {
field := refType.Field(idx)
fieldName := toFieldName(parseStructTags, field)
fieldName := toFieldName(structTagToParse, field)

if _, found := fieldNames[fieldName]; found {
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
Expand All @@ -620,16 +631,16 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
}

return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
parseStructTags: parseStructTags,
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
structTagToParse: structTagToParse,
}, nil
}

type nativeType struct {
typeName string
refType reflect.Type
parseStructTags bool
typeName string
refType reflect.Type
structTagToParse string
}

// ConvertToNative implements ref.Val.ConvertToNative.
Expand Down Expand Up @@ -680,13 +691,13 @@ func (t *nativeType) Value() any {
// fieldByName returns the corresponding reflect.StructField for the give name either by matching
// field tag or field name.
func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) {
if !t.parseStructTags {
if t.structTagToParse == "" {
return t.refType.FieldByName(fieldName)
}

for i := 0; i < t.refType.NumField(); i++ {
f := t.refType.Field(i)
if toFieldName(t.parseStructTags, f) == fieldName {
if toFieldName(t.structTagToParse, f) == fieldName {
return f, true
}
}
Expand Down
Loading

0 comments on commit d561d0a

Please sign in to comment.