From 58833799d841169e0dd6f460c53441292552924f Mon Sep 17 00:00:00 2001 From: Alvaro Aleman Date: Thu, 1 Feb 2024 18:55:08 -0500 Subject: [PATCH] ext.NativeTypes: Recursively add sub-types (#892) This change extends the `NativeTypes` provider to not only add the passed-in type but also all of its sub-types in order to simplify using it in the context of nested structs. --- ext/native.go | 48 ++++++++++++++++++++++++++++++++++++++++++---- ext/native_test.go | 17 ++++++++++++++-- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/ext/native.go b/ext/native.go index 0c2cd52f..d1b78777 100644 --- a/ext/native.go +++ b/ext/native.go @@ -96,17 +96,21 @@ func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTy for _, refType := range refTypes { switch rt := refType.(type) { case reflect.Type: - t, err := newNativeType(rt) + result, err := newNativeTypes(rt) if err != nil { return nil, err } - nativeTypes[t.TypeName()] = t + for idx := range result { + nativeTypes[result[idx].TypeName()] = result[idx] + } case reflect.Value: - t, err := newNativeType(rt.Type()) + result, err := newNativeTypes(rt.Type()) if err != nil { return nil, err } - nativeTypes[t.TypeName()] = t + for idx := range result { + nativeTypes[result[idx].TypeName()] = result[idx] + } default: return nil, fmt.Errorf("unsupported native type: %v (%T) must be reflect.Type or reflect.Value", rt, rt) } @@ -465,6 +469,42 @@ func (o *nativeObj) Value() any { return o.val } +func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) { + nt, err := newNativeType(rawType) + if err != nil { + return nil, err + } + result := []*nativeType{nt} + + alreadySeen := make(map[string]struct{}) + var iterateStructMembers func(reflect.Type) + iterateStructMembers = func(t reflect.Type) { + if k := t.Kind(); k == reflect.Pointer || k == reflect.Slice || k == reflect.Array || k == reflect.Map { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return + } + if _, seen := alreadySeen[t.String()]; seen { + return + } + alreadySeen[t.String()] = struct{}{} + nt, ntErr := newNativeType(t) + if ntErr != nil { + err = ntErr + return + } + result = append(result, nt) + + for idx := 0; idx < t.NumField(); idx++ { + iterateStructMembers(t.Field(idx).Type) + } + } + iterateStructMembers(rawType) + + return result, err +} + func newNativeType(rawType reflect.Type) (*nativeType, error) { refType := rawType if refType.Kind() == reflect.Pointer { diff --git a/ext/native_test.go b/ext/native_test.go index ead7bc1c..2641ad72 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -63,6 +63,8 @@ func TestNativeTypes(t *testing.T) { }, ], MapVal: {'map-key': ext.TestAllTypes{BoolVal: true}}, + CustomSliceVal: [ext.TestNestedSliceType{Value: 'none'}], + CustomMapVal: {'even': ext.TestMapVal{Value: 'more'}}, }`, out: &TestAllTypes{ NestedVal: &TestNestedType{NestedMapVal: map[int64]bool{1: false}}, @@ -83,7 +85,9 @@ func TestNativeTypes(t *testing.T) { NestedMapVal: map[int64]bool{42: true}, }, }, - MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}}, + MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}}, + CustomSliceVal: []TestNestedSliceType{{Value: "none"}}, + CustomMapVal: map[string]TestMapVal{"even": {Value: "more"}}, }, }, { @@ -645,7 +649,6 @@ func testNativeEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { envOpts = append(envOpts, opts...) envOpts = append(envOpts, NativeTypes( - reflect.TypeOf(&TestNestedType{}), reflect.ValueOf(&TestAllTypes{}), ), ) @@ -687,6 +690,8 @@ type TestAllTypes struct { ListVal []*TestNestedType MapVal map[string]TestAllTypes PbVal *proto3pb.TestAllTypes + CustomSliceVal []TestNestedSliceType + CustomMapVal map[string]TestMapVal // channel types are not supported UnsupportedVal chan string @@ -696,3 +701,11 @@ type TestAllTypes struct { // unexported types can be found but not set or accessed privateVal map[string]string } + +type TestNestedSliceType struct { + Value string +} + +type TestMapVal struct { + Value string +}