Skip to content

Commit

Permalink
Update tag-based parsing to use lambda for additional customization (#…
Browse files Browse the repository at this point in the history
…1039)

This allows callers to override a bit more about struct field parsing.
In particular, it's useful when parsing tags to perform escaping of tags
that aren't supported CEL field names.
This allows implementation of something similar to what Kuberentes does
with CEL field names:
https://kubernetes.io/docs/reference/using-api/cel/#escaping.
  • Loading branch information
matthchr authored Oct 15, 2024
1 parent c936b8b commit 2a010f9
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 44 deletions.
97 changes: 58 additions & 39 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,46 @@ func NativeTypes(args ...any) cel.EnvOption {
// NativeTypesOption is a functional interface for configuring handling of native types.
type NativeTypesOption func(*nativeTypeOptions) error

// NativeTypesFieldNameHandler is a handler for mapping a reflect.StructField to a CEL field name.
// This can be used to override the default Go struct field to CEL field name mapping.
type NativeTypesFieldNameHandler = func(field reflect.StructField) string

func fieldNameByTag(structTagToParse string) func(field reflect.StructField) string {
return func(field reflect.StructField) string {
tag, found := field.Tag.Lookup(structTagToParse)
if found {
splits := strings.Split(tag, ",")
if len(splits) > 0 {
// We make the assumption that the leftmost entry in the tag is the name.
// This seems to be true for most tags that have the concept of a name/key, such as:
// https://pkg.go.dev/encoding/xml#Marshal
// https://pkg.go.dev/encoding/json#Marshal
// https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs
// https://pkg.go.dev/gopkg.in/yaml.v2#Marshal
name := splits[0]
return name
}
}

return field.Name
}
}

type nativeTypeOptions struct {
// structTagToParse controls if CEL should support struct field renames, by parsing
// struct field tags. This must be set to the tag to parse, such as "cel" or "json".
structTagToParse string
// fieldNameHandler controls how CEL should perform struct field renames.
// This is most commonly used for switching to parsing based off the struct field tag,
// such as "cel" or "json".
fieldNameHandler NativeTypesFieldNameHandler
}

// ParseStructTags configures if native types field names should be overridable by CEL struct tags.
// This is equivalent to ParseStructTag("cel")
func ParseStructTags(enabled bool) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
if enabled {
ntp.structTagToParse = "cel"
ntp.fieldNameHandler = fieldNameByTag("cel")
} else {
ntp.structTagToParse = ""
ntp.fieldNameHandler = nil
}
return nil
}
Expand All @@ -153,7 +179,15 @@ func ParseStructTags(enabled bool) NativeTypesOption {
// If the tag to parse is "json" and the struct field has tag json:"foo,omitempty", the CEL struct field will be "foo".
func ParseStructTag(tag string) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.structTagToParse = tag
ntp.fieldNameHandler = fieldNameByTag(tag)
return nil
}
}

// ParseStructField configures how to parse Go struct fields. It can be used to customize struct field parsing.
func ParseStructField(handler NativeTypesFieldNameHandler) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.fieldNameHandler = handler
return nil
}
}
Expand All @@ -163,15 +197,15 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
result, err := newNativeTypes(tpOptions.structTagToParse, rt)
result, err := newNativeTypes(tpOptions.fieldNameHandler, rt)
if err != nil {
return nil, err
}
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
case reflect.Value:
result, err := newNativeTypes(tpOptions.structTagToParse, rt.Type())
result, err := newNativeTypes(tpOptions.fieldNameHandler, rt.Type())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -224,27 +258,12 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}

func toFieldName(structTagToParse string, f reflect.StructField) string {
if structTagToParse == "" {
func toFieldName(fieldNameHandler NativeTypesFieldNameHandler, f reflect.StructField) string {
if fieldNameHandler == nil {
return f.Name
}

tag, found := f.Tag.Lookup(structTagToParse)
if found {
splits := strings.Split(tag, ",")
if len(splits) > 0 {
// We make the assumption that the leftmost entry in the tag is the name.
// This seems to be true for most tags that have the concept of a name/key, such as:
// https://pkg.go.dev/encoding/xml#Marshal
// https://pkg.go.dev/encoding/json#Marshal
// https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs
// https://pkg.go.dev/gopkg.in/yaml.v2#Marshal
name := splits[0]
return name
}
}

return f.Name
return fieldNameHandler(f)
}

// FindStructFieldNames looks up the type definition first from the native types, then from
Expand All @@ -255,7 +274,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = toFieldName(tp.options.structTagToParse, t.refType.Field(i))
fields[i] = toFieldName(tp.options.fieldNameHandler, t.refType.Field(i))
}
return fields, true
}
Expand Down Expand Up @@ -415,7 +434,7 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
}

func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(tp.options.structTagToParse, refValue.Type())
valType, err := newNativeType(tp.options.fieldNameHandler, refValue.Type())
if err != nil {
return types.NewErr(err.Error())
}
Expand Down Expand Up @@ -467,7 +486,7 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldName := toFieldName(o.valType.structTagToParse, fieldType)
fieldName := toFieldName(o.valType.fieldNameHandler, fieldType)
fieldCELVal := o.NativeToValue(fieldValue.Interface())
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
if err != nil {
Expand Down Expand Up @@ -565,8 +584,8 @@ func (o *nativeObj) Value() any {
return o.val
}

func newNativeTypes(structTagToParse string, rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(structTagToParse, rawType)
func newNativeTypes(fieldNameHandler NativeTypesFieldNameHandler, rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(fieldNameHandler, rawType)
if err != nil {
return nil, err
}
Expand All @@ -585,7 +604,7 @@ func newNativeTypes(structTagToParse string, rawType reflect.Type) ([]*nativeTyp
return
}
alreadySeen[t.String()] = struct{}{}
nt, ntErr := newNativeType(structTagToParse, t)
nt, ntErr := newNativeType(fieldNameHandler, t)
if ntErr != nil {
err = ntErr
return
Expand All @@ -605,7 +624,7 @@ var (
errDuplicatedFieldName = errors.New("field name already exists in struct")
)

func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType, error) {
func newNativeType(fieldNameHandler NativeTypesFieldNameHandler, rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
refType = refType.Elem()
Expand All @@ -615,12 +634,12 @@ func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType,
}

// Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled.
if structTagToParse != "" {
if fieldNameHandler != nil {
fieldNames := make(map[string]struct{})

for idx := 0; idx < refType.NumField(); idx++ {
field := refType.Field(idx)
fieldName := toFieldName(structTagToParse, field)
fieldName := toFieldName(fieldNameHandler, field)

if _, found := fieldNames[fieldName]; found {
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
Expand All @@ -633,14 +652,14 @@ func newNativeType(structTagToParse string, rawType reflect.Type) (*nativeType,
return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
structTagToParse: structTagToParse,
fieldNameHandler: fieldNameHandler,
}, nil
}

type nativeType struct {
typeName string
refType reflect.Type
structTagToParse string
fieldNameHandler NativeTypesFieldNameHandler
}

// ConvertToNative implements ref.Val.ConvertToNative.
Expand Down Expand Up @@ -691,13 +710,13 @@ func (t *nativeType) Value() any {
// fieldByName returns the corresponding reflect.StructField for the give name either by matching
// field tag or field name.
func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) {
if t.structTagToParse == "" {
if t.fieldNameHandler == nil {
return t.refType.FieldByName(fieldName)
}

for i := 0; i < t.refType.NumField(); i++ {
f := t.refType.Field(i)
if toFieldName(t.structTagToParse, f) == fieldName {
if toFieldName(t.fieldNameHandler, f) == fieldName {
return f, true
}
}
Expand Down
11 changes: 6 additions & 5 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,8 @@ func TestNativeTypeConvertToType(t *testing.T) {
for i, tst := range nativeTests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
nt, err := newNativeType(tc.tag, reflect.TypeOf(&TestAllTypes{}))
handler := fieldNameByTag(tc.tag)
nt, err := newNativeType(handler, reflect.TypeOf(&TestAllTypes{}))
if err != nil {
t.Fatalf("newNativeType() failed: %v", err)
}
Expand All @@ -842,7 +843,7 @@ func TestNativeTypeConvertToType(t *testing.T) {
}

func TestNativeTypeConvertToNative(t *testing.T) {
nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{}))
nt, err := newNativeType(fieldNameByTag("cel"), reflect.TypeOf(&TestAllTypes{}))
if err != nil {
t.Fatalf("newNativeType() failed: %v", err)
}
Expand All @@ -853,7 +854,7 @@ func TestNativeTypeConvertToNative(t *testing.T) {
}

func TestNativeTypeHasTrait(t *testing.T) {
nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{}))
nt, err := newNativeType(fieldNameByTag("cel"), reflect.TypeOf(&TestAllTypes{}))
if err != nil {
t.Fatalf("newNativeType() failed: %v", err)
}
Expand All @@ -863,7 +864,7 @@ func TestNativeTypeHasTrait(t *testing.T) {
}

func TestNativeTypeValue(t *testing.T) {
nt, err := newNativeType("cel", reflect.TypeOf(&TestAllTypes{}))
nt, err := newNativeType(fieldNameByTag("cel"), reflect.TypeOf(&TestAllTypes{}))
if err != nil {
t.Fatalf("newNativeType() failed: %v", err)
}
Expand All @@ -873,7 +874,7 @@ func TestNativeTypeValue(t *testing.T) {
}

func TestNativeStructWithMultipleSameFieldNames(t *testing.T) {
_, err := newNativeType("cel", reflect.TypeOf(TestStructWithMultipleSameNames{}))
_, err := newNativeType(fieldNameByTag("cel"), reflect.TypeOf(TestStructWithMultipleSameNames{}))
if err == nil {
t.Fatal("newNativeType() did not fail as expected")
}
Expand Down

0 comments on commit 2a010f9

Please sign in to comment.