Skip to content

Commit d561d0a

Browse files
authored
Allow configurable tag name overrides in native types (#1009)
Allowing tags other than "cel" enables more flexibility, particularly when the types in question aren't owned by those configuring the CEL environment. The main use-case for this is enabling an experience in CEL that matches the shape of an object in JSON (or YAML). This includes: - Property casing: 'Property' in Go would likely be 'property' in JSON/YAML. - Embedded structures: metav1.ObjectType `json:"metadata,omitempty"` is common on Kuberentes resources, allowing users to request CEL use the JSON tags enables access via the ".metadata" property even though there is no such property on the actual Native type.
1 parent 1fa4f15 commit d561d0a

File tree

2 files changed

+215
-63
lines changed

2 files changed

+215
-63
lines changed

ext/native.go

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,31 @@ func NativeTypes(args ...any) cel.EnvOption {
129129
type NativeTypesOption func(*nativeTypeOptions) error
130130

131131
type nativeTypeOptions struct {
132-
// parseStructTags controls if CEL should support struct field renames, by parsing
133-
// struct field tags.
134-
parseStructTags bool
132+
// structTagToParse controls if CEL should support struct field renames, by parsing
133+
// struct field tags. This must be set to the tag to parse, such as "cel" or "json".
134+
structTagToParse string
135135
}
136136

137137
// ParseStructTags configures if native types field names should be overridable by CEL struct tags.
138+
// This is equivalent to ParseStructTag("cel")
138139
func ParseStructTags(enabled bool) NativeTypesOption {
139140
return func(ntp *nativeTypeOptions) error {
140-
ntp.parseStructTags = true
141+
if enabled {
142+
ntp.structTagToParse = "cel"
143+
} else {
144+
ntp.structTagToParse = ""
145+
}
146+
return nil
147+
}
148+
}
149+
150+
// ParseStructTag configures the struct tag to parse. The 0th item in the tag is used as the name of the CEL field.
151+
// For example:
152+
// If the tag to parse is "cel" and the struct field has tag cel:"foo", the CEL struct field will be "foo".
153+
// If the tag to parse is "json" and the struct field has tag json:"foo,omitempty", the CEL struct field will be "foo".
154+
func ParseStructTag(tag string) NativeTypesOption {
155+
return func(ntp *nativeTypeOptions) error {
156+
ntp.structTagToParse = tag
141157
return nil
142158
}
143159
}
@@ -147,15 +163,15 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p
147163
for _, refType := range refTypes {
148164
switch rt := refType.(type) {
149165
case reflect.Type:
150-
result, err := newNativeTypes(tpOptions.parseStructTags, rt)
166+
result, err := newNativeTypes(tpOptions.structTagToParse, rt)
151167
if err != nil {
152168
return nil, err
153169
}
154170
for idx := range result {
155171
nativeTypes[result[idx].TypeName()] = result[idx]
156172
}
157173
case reflect.Value:
158-
result, err := newNativeTypes(tpOptions.parseStructTags, rt.Type())
174+
result, err := newNativeTypes(tpOptions.structTagToParse, rt.Type())
159175
if err != nil {
160176
return nil, err
161177
}
@@ -208,13 +224,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
208224
return tp.baseProvider.FindStructType(typeName)
209225
}
210226

211-
func toFieldName(parseStructTag bool, f reflect.StructField) string {
212-
if !parseStructTag {
227+
func toFieldName(structTagToParse string, f reflect.StructField) string {
228+
if structTagToParse == "" {
213229
return f.Name
214230
}
215231

216-
if name, found := f.Tag.Lookup("cel"); found {
217-
return name
232+
tag, found := f.Tag.Lookup(structTagToParse)
233+
if found {
234+
splits := strings.Split(tag, ",")
235+
if len(splits) > 0 {
236+
// We make the assumption that the leftmost entry in the tag is the name.
237+
// This seems to be true for most tags that have the concept of a name/key, such as:
238+
// https://pkg.go.dev/encoding/xml#Marshal
239+
// https://pkg.go.dev/encoding/json#Marshal
240+
// https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs
241+
// https://pkg.go.dev/gopkg.in/yaml.v2#Marshal
242+
name := splits[0]
243+
return name
244+
}
218245
}
219246

220247
return f.Name
@@ -228,7 +255,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
228255
fieldCount := t.refType.NumField()
229256
fields := make([]string, fieldCount)
230257
for i := 0; i < fieldCount; i++ {
231-
fields[i] = toFieldName(tp.options.parseStructTags, t.refType.Field(i))
258+
fields[i] = toFieldName(tp.options.structTagToParse, t.refType.Field(i))
232259
}
233260
return fields, true
234261
}
@@ -238,22 +265,6 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
238265
return tp.baseProvider.FindStructFieldNames(typeName)
239266
}
240267

241-
// valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by
242-
// searching for a matching field tag value or field name.
243-
func valueFieldByName(parseStructTags bool, target reflect.Value, fieldName string) reflect.Value {
244-
if !parseStructTags {
245-
return target.FieldByName(fieldName)
246-
}
247-
248-
for i := 0; i < target.Type().NumField(); i++ {
249-
f := target.Type().Field(i)
250-
if toFieldName(parseStructTags, f) == fieldName {
251-
return target.FieldByIndex(f.Index)
252-
}
253-
}
254-
return reflect.Value{}
255-
}
256-
257268
// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
258269
// type then proxies to the composed types.Provider
259270
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
@@ -273,12 +284,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*
273284
Type: celType,
274285
IsSet: func(obj any) bool {
275286
refVal := reflect.Indirect(reflect.ValueOf(obj))
276-
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
287+
refField := refVal.FieldByName(refField.Name)
277288
return !refField.IsZero()
278289
},
279290
GetFrom: func(obj any) (any, error) {
280291
refVal := reflect.Indirect(reflect.ValueOf(obj))
281-
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
292+
refField := refVal.FieldByName(refField.Name)
282293
return getFieldValue(refField), nil
283294
},
284295
}, true
@@ -404,7 +415,7 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
404415
}
405416

406417
func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val {
407-
valType, err := newNativeType(tp.options.parseStructTags, refValue.Type())
418+
valType, err := newNativeType(tp.options.structTagToParse, refValue.Type())
408419
if err != nil {
409420
return types.NewErr(err.Error())
410421
}
@@ -456,7 +467,7 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
456467
if !fieldValue.IsValid() || fieldValue.IsZero() {
457468
continue
458469
}
459-
fieldName := toFieldName(o.valType.parseStructTags, fieldType)
470+
fieldName := toFieldName(o.valType.structTagToParse, fieldType)
460471
fieldCELVal := o.NativeToValue(fieldValue.Interface())
461472
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
462473
if err != nil {
@@ -554,8 +565,8 @@ func (o *nativeObj) Value() any {
554565
return o.val
555566
}
556567

557-
func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType, error) {
558-
nt, err := newNativeType(parseStructTags, rawType)
568+
func newNativeTypes(structTagToParse string, rawType reflect.Type) ([]*nativeType, error) {
569+
nt, err := newNativeType(structTagToParse, rawType)
559570
if err != nil {
560571
return nil, err
561572
}
@@ -574,7 +585,7 @@ func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType,
574585
return
575586
}
576587
alreadySeen[t.String()] = struct{}{}
577-
nt, ntErr := newNativeType(parseStructTags, t)
588+
nt, ntErr := newNativeType(structTagToParse, t)
578589
if ntErr != nil {
579590
err = ntErr
580591
return
@@ -594,7 +605,7 @@ var (
594605
errDuplicatedFieldName = errors.New("field name already exists in struct")
595606
)
596607

597-
func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, error) {
608+
func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType, error) {
598609
refType := rawType
599610
if refType.Kind() == reflect.Pointer {
600611
refType = refType.Elem()
@@ -604,12 +615,12 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
604615
}
605616

606617
// Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled.
607-
if parseStructTags {
618+
if structTagToParse != "" {
608619
fieldNames := make(map[string]struct{})
609620

610621
for idx := 0; idx < refType.NumField(); idx++ {
611622
field := refType.Field(idx)
612-
fieldName := toFieldName(parseStructTags, field)
623+
fieldName := toFieldName(structTagToParse, field)
613624

614625
if _, found := fieldNames[fieldName]; found {
615626
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
@@ -620,16 +631,16 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
620631
}
621632

622633
return &nativeType{
623-
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
624-
refType: refType,
625-
parseStructTags: parseStructTags,
634+
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
635+
refType: refType,
636+
structTagToParse: structTagToParse,
626637
}, nil
627638
}
628639

629640
type nativeType struct {
630-
typeName string
631-
refType reflect.Type
632-
parseStructTags bool
641+
typeName string
642+
refType reflect.Type
643+
structTagToParse string
633644
}
634645

635646
// ConvertToNative implements ref.Val.ConvertToNative.
@@ -680,13 +691,13 @@ func (t *nativeType) Value() any {
680691
// fieldByName returns the corresponding reflect.StructField for the give name either by matching
681692
// field tag or field name.
682693
func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) {
683-
if !t.parseStructTags {
694+
if t.structTagToParse == "" {
684695
return t.refType.FieldByName(fieldName)
685696
}
686697

687698
for i := 0; i < t.refType.NumField(); i++ {
688699
f := t.refType.Field(i)
689-
if toFieldName(t.parseStructTags, f) == fieldName {
700+
if toFieldName(t.structTagToParse, f) == fieldName {
690701
return f, true
691702
}
692703
}

0 commit comments

Comments
 (0)