@@ -129,15 +129,31 @@ func NativeTypes(args ...any) cel.EnvOption {
129
129
type NativeTypesOption func (* nativeTypeOptions ) error
130
130
131
131
type nativeTypeOptions struct {
132
- // parseStructTags controls if CEL should support struct field renames, by parsing
133
- // struct field tags.
134
- parseStructTags bool
132
+ // structTagToParse controls if CEL should support struct field renames, by parsing
133
+ // struct field tags. This must be set to the tag to parse, such as "cel" or "json".
134
+ structTagToParse string
135
135
}
136
136
137
137
// ParseStructTags configures if native types field names should be overridable by CEL struct tags.
138
+ // This is equivalent to ParseStructTag("cel")
138
139
func ParseStructTags (enabled bool ) NativeTypesOption {
139
140
return func (ntp * nativeTypeOptions ) error {
140
- ntp .parseStructTags = true
141
+ if enabled {
142
+ ntp .structTagToParse = "cel"
143
+ } else {
144
+ ntp .structTagToParse = ""
145
+ }
146
+ return nil
147
+ }
148
+ }
149
+
150
+ // ParseStructTag configures the struct tag to parse. The 0th item in the tag is used as the name of the CEL field.
151
+ // For example:
152
+ // If the tag to parse is "cel" and the struct field has tag cel:"foo", the CEL struct field will be "foo".
153
+ // If the tag to parse is "json" and the struct field has tag json:"foo,omitempty", the CEL struct field will be "foo".
154
+ func ParseStructTag (tag string ) NativeTypesOption {
155
+ return func (ntp * nativeTypeOptions ) error {
156
+ ntp .structTagToParse = tag
141
157
return nil
142
158
}
143
159
}
@@ -147,15 +163,15 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p
147
163
for _ , refType := range refTypes {
148
164
switch rt := refType .(type ) {
149
165
case reflect.Type :
150
- result , err := newNativeTypes (tpOptions .parseStructTags , rt )
166
+ result , err := newNativeTypes (tpOptions .structTagToParse , rt )
151
167
if err != nil {
152
168
return nil , err
153
169
}
154
170
for idx := range result {
155
171
nativeTypes [result [idx ].TypeName ()] = result [idx ]
156
172
}
157
173
case reflect.Value :
158
- result , err := newNativeTypes (tpOptions .parseStructTags , rt .Type ())
174
+ result , err := newNativeTypes (tpOptions .structTagToParse , rt .Type ())
159
175
if err != nil {
160
176
return nil , err
161
177
}
@@ -208,13 +224,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
208
224
return tp .baseProvider .FindStructType (typeName )
209
225
}
210
226
211
- func toFieldName (parseStructTag bool , f reflect.StructField ) string {
212
- if ! parseStructTag {
227
+ func toFieldName (structTagToParse string , f reflect.StructField ) string {
228
+ if structTagToParse == "" {
213
229
return f .Name
214
230
}
215
231
216
- if name , found := f .Tag .Lookup ("cel" ); found {
217
- return name
232
+ tag , found := f .Tag .Lookup (structTagToParse )
233
+ if found {
234
+ splits := strings .Split (tag , "," )
235
+ if len (splits ) > 0 {
236
+ // We make the assumption that the leftmost entry in the tag is the name.
237
+ // This seems to be true for most tags that have the concept of a name/key, such as:
238
+ // https://pkg.go.dev/encoding/xml#Marshal
239
+ // https://pkg.go.dev/encoding/json#Marshal
240
+ // https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs
241
+ // https://pkg.go.dev/gopkg.in/yaml.v2#Marshal
242
+ name := splits [0 ]
243
+ return name
244
+ }
218
245
}
219
246
220
247
return f .Name
@@ -228,7 +255,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
228
255
fieldCount := t .refType .NumField ()
229
256
fields := make ([]string , fieldCount )
230
257
for i := 0 ; i < fieldCount ; i ++ {
231
- fields [i ] = toFieldName (tp .options .parseStructTags , t .refType .Field (i ))
258
+ fields [i ] = toFieldName (tp .options .structTagToParse , t .refType .Field (i ))
232
259
}
233
260
return fields , true
234
261
}
@@ -238,22 +265,6 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
238
265
return tp .baseProvider .FindStructFieldNames (typeName )
239
266
}
240
267
241
- // valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by
242
- // searching for a matching field tag value or field name.
243
- func valueFieldByName (parseStructTags bool , target reflect.Value , fieldName string ) reflect.Value {
244
- if ! parseStructTags {
245
- return target .FieldByName (fieldName )
246
- }
247
-
248
- for i := 0 ; i < target .Type ().NumField (); i ++ {
249
- f := target .Type ().Field (i )
250
- if toFieldName (parseStructTags , f ) == fieldName {
251
- return target .FieldByIndex (f .Index )
252
- }
253
- }
254
- return reflect.Value {}
255
- }
256
-
257
268
// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
258
269
// type then proxies to the composed types.Provider
259
270
func (tp * nativeTypeProvider ) FindStructFieldType (typeName , fieldName string ) (* types.FieldType , bool ) {
@@ -273,12 +284,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*
273
284
Type : celType ,
274
285
IsSet : func (obj any ) bool {
275
286
refVal := reflect .Indirect (reflect .ValueOf (obj ))
276
- refField := valueFieldByName ( tp . options . parseStructTags , refVal , fieldName )
287
+ refField := refVal . FieldByName ( refField . Name )
277
288
return ! refField .IsZero ()
278
289
},
279
290
GetFrom : func (obj any ) (any , error ) {
280
291
refVal := reflect .Indirect (reflect .ValueOf (obj ))
281
- refField := valueFieldByName ( tp . options . parseStructTags , refVal , fieldName )
292
+ refField := refVal . FieldByName ( refField . Name )
282
293
return getFieldValue (refField ), nil
283
294
},
284
295
}, true
@@ -404,7 +415,7 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
404
415
}
405
416
406
417
func (tp * nativeTypeProvider ) newNativeObject (val any , refValue reflect.Value ) ref.Val {
407
- valType , err := newNativeType (tp .options .parseStructTags , refValue .Type ())
418
+ valType , err := newNativeType (tp .options .structTagToParse , refValue .Type ())
408
419
if err != nil {
409
420
return types .NewErr (err .Error ())
410
421
}
@@ -456,7 +467,7 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
456
467
if ! fieldValue .IsValid () || fieldValue .IsZero () {
457
468
continue
458
469
}
459
- fieldName := toFieldName (o .valType .parseStructTags , fieldType )
470
+ fieldName := toFieldName (o .valType .structTagToParse , fieldType )
460
471
fieldCELVal := o .NativeToValue (fieldValue .Interface ())
461
472
fieldJSONVal , err := fieldCELVal .ConvertToNative (jsonValueType )
462
473
if err != nil {
@@ -554,8 +565,8 @@ func (o *nativeObj) Value() any {
554
565
return o .val
555
566
}
556
567
557
- func newNativeTypes (parseStructTags bool , rawType reflect.Type ) ([]* nativeType , error ) {
558
- nt , err := newNativeType (parseStructTags , rawType )
568
+ func newNativeTypes (structTagToParse string , rawType reflect.Type ) ([]* nativeType , error ) {
569
+ nt , err := newNativeType (structTagToParse , rawType )
559
570
if err != nil {
560
571
return nil , err
561
572
}
@@ -574,7 +585,7 @@ func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType,
574
585
return
575
586
}
576
587
alreadySeen [t .String ()] = struct {}{}
577
- nt , ntErr := newNativeType (parseStructTags , t )
588
+ nt , ntErr := newNativeType (structTagToParse , t )
578
589
if ntErr != nil {
579
590
err = ntErr
580
591
return
@@ -594,7 +605,7 @@ var (
594
605
errDuplicatedFieldName = errors .New ("field name already exists in struct" )
595
606
)
596
607
597
- func newNativeType (parseStructTags bool , rawType reflect.Type ) (* nativeType , error ) {
608
+ func newNativeType (structTagToParse string , rawType reflect.Type ) (* nativeType , error ) {
598
609
refType := rawType
599
610
if refType .Kind () == reflect .Pointer {
600
611
refType = refType .Elem ()
@@ -604,12 +615,12 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
604
615
}
605
616
606
617
// Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled.
607
- if parseStructTags {
618
+ if structTagToParse != "" {
608
619
fieldNames := make (map [string ]struct {})
609
620
610
621
for idx := 0 ; idx < refType .NumField (); idx ++ {
611
622
field := refType .Field (idx )
612
- fieldName := toFieldName (parseStructTags , field )
623
+ fieldName := toFieldName (structTagToParse , field )
613
624
614
625
if _ , found := fieldNames [fieldName ]; found {
615
626
return nil , fmt .Errorf ("invalid field name `%s` in struct `%s`: %w" , fieldName , refType .Name (), errDuplicatedFieldName )
@@ -620,16 +631,16 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
620
631
}
621
632
622
633
return & nativeType {
623
- typeName : fmt .Sprintf ("%s.%s" , simplePkgAlias (refType .PkgPath ()), refType .Name ()),
624
- refType : refType ,
625
- parseStructTags : parseStructTags ,
634
+ typeName : fmt .Sprintf ("%s.%s" , simplePkgAlias (refType .PkgPath ()), refType .Name ()),
635
+ refType : refType ,
636
+ structTagToParse : structTagToParse ,
626
637
}, nil
627
638
}
628
639
629
640
type nativeType struct {
630
- typeName string
631
- refType reflect.Type
632
- parseStructTags bool
641
+ typeName string
642
+ refType reflect.Type
643
+ structTagToParse string
633
644
}
634
645
635
646
// ConvertToNative implements ref.Val.ConvertToNative.
@@ -680,13 +691,13 @@ func (t *nativeType) Value() any {
680
691
// fieldByName returns the corresponding reflect.StructField for the give name either by matching
681
692
// field tag or field name.
682
693
func (t * nativeType ) fieldByName (fieldName string ) (reflect.StructField , bool ) {
683
- if ! t . parseStructTags {
694
+ if t . structTagToParse == "" {
684
695
return t .refType .FieldByName (fieldName )
685
696
}
686
697
687
698
for i := 0 ; i < t .refType .NumField (); i ++ {
688
699
f := t .refType .Field (i )
689
- if toFieldName (t .parseStructTags , f ) == fieldName {
700
+ if toFieldName (t .structTagToParse , f ) == fieldName {
690
701
return f , true
691
702
}
692
703
}
0 commit comments