From 4b8b15b877c9ad67ae7b150477611112aaae03c2 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 30 Sep 2024 11:49:36 -0700 Subject: [PATCH] Interop foldable maps and lists with map mutation helper (#1029) --- common/types/list.go | 27 ++++++++ common/types/list_test.go | 67 +++++++++++++++++--- common/types/map.go | 58 ++++++++++++++++- common/types/map_test.go | 113 +++++++++++++++++++++++++++++++--- common/types/traits/lister.go | 3 + common/types/traits/mapper.go | 9 ++- 6 files changed, 253 insertions(+), 24 deletions(-) diff --git a/common/types/list.go b/common/types/list.go index 3e71e33b..ca47d39f 100644 --- a/common/types/list.go +++ b/common/types/list.go @@ -545,3 +545,30 @@ func IndexOrError(index ref.Val) (int, error) { return -1, fmt.Errorf("unsupported index type '%s' in list", index.Type()) } } + +// ToFoldableList will create a Foldable version of a list suitable for key-value pair iteration. +// +// For values which are already Foldable, this call is a no-op. For all other values, the fold is +// driven via the Size() and Get() calls which means that the folding will function, but take a +// performance hit. +func ToFoldableList(l traits.Lister) traits.Foldable { + if f, ok := l.(traits.Foldable); ok { + return f + } + return interopFoldableList{Lister: l} +} + +type interopFoldableList struct { + traits.Lister +} + +// Fold implements the traits.Foldable interface method and performs an iteration over the +// range of elements of the list. +func (l interopFoldableList) Fold(f traits.Folder) { + sz := l.Size().(Int) + for i := Int(0); i < sz; i++ { + if !f.FoldEntry(i, l.Get(i)) { + break + } + } +} diff --git a/common/types/list_test.go b/common/types/list_test.go index ea92f23e..ba6c498f 100644 --- a/common/types/list_test.go +++ b/common/types/list_test.go @@ -787,14 +787,20 @@ func TestListFold(t *testing.T) { reg := NewEmptyRegistry() for i, tst := range tests { tc := tst - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - f := &testListFolder{foldLimit: tc.foldLimit} - l := reg.NativeToValue(tc.l).(traits.Foldable) - l.Fold(f) - if f.folds != tc.folds { - t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds) - } - }) + l := reg.NativeToValue(tc.l).(traits.Lister) + foldKinds := map[string]traits.Foldable{ + "modern": ToFoldableList(l), + "legacy": ToFoldableList(proxyLegacyList{proxy: l}), + } + for foldKind, foldable := range foldKinds { + t.Run(fmt.Sprintf("[%d]%s", i, foldKind), func(t *testing.T) { + f := &testListFolder{foldLimit: tc.foldLimit} + foldable.Fold(f) + if f.folds != tc.folds { + t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds) + } + }) + } } } @@ -813,6 +819,51 @@ func (f *testListFolder) FoldEntry(k, v any) bool { return true } +// proxyLegacyList omits the foldable interfaces associated with all core Lister implementations +type proxyLegacyList struct { + proxy traits.Lister +} + +func (m proxyLegacyList) ConvertToNative(typeDesc reflect.Type) (any, error) { + return m.proxy.ConvertToNative(typeDesc) +} + +func (m proxyLegacyList) ConvertToType(typeValue ref.Type) ref.Val { + return m.proxy.ConvertToType(typeValue) +} + +func (m proxyLegacyList) Equal(other ref.Val) ref.Val { + return m.proxy.Equal(other) +} + +func (m proxyLegacyList) Type() ref.Type { + return m.proxy.Type() +} + +func (m proxyLegacyList) Value() any { + return m.proxy.Value() +} + +func (m proxyLegacyList) Add(other ref.Val) ref.Val { + return m.proxy.Add(other) +} + +func (m proxyLegacyList) Contains(value ref.Val) ref.Val { + return m.proxy.Contains(value) +} + +func (m proxyLegacyList) Get(index ref.Val) ref.Val { + return m.proxy.Get(index) +} + +func (m proxyLegacyList) Iterator() traits.Iterator { + return m.proxy.Iterator() +} + +func (m proxyLegacyList) Size() ref.Val { + return m.proxy.Size() +} + func getElem(t *testing.T, list traits.Indexer, index ref.Val) any { t.Helper() val := list.Get(index) diff --git a/common/types/map.go b/common/types/map.go index bc20239f..89b33f90 100644 --- a/common/types/map.go +++ b/common/types/map.go @@ -336,12 +336,12 @@ type mutableMap struct { // Insert implements the traits.MutableMapper interface method, returning true if the key insertion // succeeds. -func (m *mutableMap) Insert(k, v ref.Val) bool { +func (m *mutableMap) Insert(k, v ref.Val) ref.Val { if _, found := m.mutableValues[k]; found { - return false + return NewErr("insert failed: key %v already exists", k) } m.mutableValues[k] = v - return true + return m } // ToImmutableMap implements the traits.MutableMapper interface method, converting a mutable map @@ -948,3 +948,55 @@ func (it *stringKeyIterator) Next() ref.Val { } return nil } + +// ToFoldableMap will create a Foldable version of a map suitable for key-value pair iteration. +// +// For values which are already Foldable, this call is a no-op. For all other values, the fold +// is driven via the Iterator HasNext() and Next() calls as well as the map's Get() method +// which means that the folding will function, but take a performance hit. +func ToFoldableMap(m traits.Mapper) traits.Foldable { + if f, ok := m.(traits.Foldable); ok { + return f + } + return interopFoldableMap{Mapper: m} +} + +type interopFoldableMap struct { + traits.Mapper +} + +func (m interopFoldableMap) Fold(f traits.Folder) { + it := m.Iterator() + for it.HasNext() == True { + k := it.Next() + if !f.FoldEntry(k, m.Get(k)) { + break + } + } +} + +// InsertMapKeyValue inserts a key, value pair into the target map if the target map does not +// already contain the given key. +// +// If the map is mutable, it is modified in-place per the MutableMapper contract. +// If the map is not mutable, a copy containing the new key, value pair is made. +func InsertMapKeyValue(m traits.Mapper, k, v ref.Val) ref.Val { + if mutable, ok := m.(traits.MutableMapper); ok { + return mutable.Insert(k, v) + } + + // Otherwise perform the slow version of the insertion which makes a copy of the incoming map. + if _, found := m.Find(k); !found { + size := m.Size().(Int) + copy := make(map[ref.Val]ref.Val, size+1) + copy[k] = v + it := m.Iterator() + for it.HasNext() == True { + nextK := it.Next() + nextV := m.Get(nextK) + copy[nextK] = nextV + } + return DefaultTypeAdapter.NativeToValue(copy) + } + return NewErr("insert failed: key %v already exists", k) +} diff --git a/common/types/map_test.go b/common/types/map_test.go index c96939b5..b16422c7 100644 --- a/common/types/map_test.go +++ b/common/types/map_test.go @@ -976,8 +976,8 @@ func TestMutableMap(t *testing.T) { if im.Size() != Int(2) { t.Errorf("m.ToImmutableMap() had size %d, wanted 2", im.Size()) } - if m.Insert(String("goodbye"), String("happy world")) { - t.Error("m.Insert('goodbye', 'happy world') got true, wanted false") + if !IsError(m.Insert(String("goodbye"), String("happy world"))) { + t.Error("m.Insert('goodbye', 'happy world') suceeded, wanted error") } m.Insert(String("well"), String("well")) if im.Size() != Int(2) { @@ -1090,14 +1090,62 @@ func TestMapFold(t *testing.T) { reg := NewEmptyRegistry() for i, tst := range tests { tc := tst - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - f := &testMapFolder{foldLimit: tc.foldLimit} - m := reg.NativeToValue(tc.m).(traits.Foldable) - m.Fold(f) - if f.folds != tc.folds { - t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds) - } - }) + m := reg.NativeToValue(tc.m).(traits.Mapper) + foldKinds := map[string]traits.Foldable{ + "modern": ToFoldableMap(m), + "legacy": ToFoldableMap(proxyLegacyMap{proxy: m}), + } + for foldKind, foldable := range foldKinds { + t.Run(fmt.Sprintf("[%d]%s", i, foldKind), func(t *testing.T) { + f := &testMapFolder{foldLimit: tc.foldLimit} + foldable.Fold(f) + if f.folds != tc.folds { + t.Errorf("m.Fold(f) got %d, wanted %d folds", f.folds, tc.folds) + } + }) + } + } +} + +func TestInsertMapKeyValue_MutableMapper(t *testing.T) { + m := NewMutableMap(DefaultTypeAdapter, map[ref.Val]ref.Val{String("first"): Int(1)}) + modified := InsertMapKeyValue(m, String("second"), Int(2)) + if IsError(modified) { + t.Fatalf("InsertMapKeyValue() got error: %v, wanted insertion", modified) + } + if modified != m { + t.Fatalf("InsertMapKeyValue() created a new map for a mutable input: %v", modified) + } + im := m.ToImmutableMap() + if _, found := im.Find(String("first")); !found { + t.Errorf("InsertMapKeyValue() did not preserve entry 'first': %v", im) + } + if _, found := im.Find(String("second")); !found { + t.Errorf("InsertMapKeyValue() did not insert entry 'second': %v", im) + } + if !IsError(InsertMapKeyValue(m, String("second"), Int(3))) { + t.Errorf("InsertMapKeyValue('second', 3) modified the map instead of erroring: %v", m) + } +} + +func TestInsertMapKeyValue_Mapper(t *testing.T) { + m := NewRefValMap(DefaultTypeAdapter, map[ref.Val]ref.Val{String("first"): Int(1)}) + modified := InsertMapKeyValue(m, String("second"), Int(2)) + if IsError(modified) { + t.Fatalf("InsertMapKeyValue() got error: %v, wanted insertion", modified) + } + if modified == m { + t.Fatalf("InsertMapKeyValue() modified an immutable input: %v", modified) + } + im := modified.(traits.Mapper) + if _, found := im.Find(String("first")); !found { + t.Errorf("InsertMapKeyValue() did not preserve entry 'first': %v", im) + } + if _, found := im.Find(String("second")); !found { + t.Errorf("InsertMapKeyValue() did not insert entry 'second': %v", im) + } + if !IsError(InsertMapKeyValue(im, String("second"), Int(3))) { + t.Errorf("InsertMapKeyValue('second', 3) modified the map instead of erroring: %v", m) } } @@ -1124,3 +1172,48 @@ func testCreateStruct(t *testing.T, m map[string]any) *structpb.Struct { } return v } + +// proxyLegacyMap omits the foldable interfaces associated with all core Mapper implementations +type proxyLegacyMap struct { + proxy traits.Mapper +} + +func (m proxyLegacyMap) ConvertToNative(typeDesc reflect.Type) (any, error) { + return m.proxy.ConvertToNative(typeDesc) +} + +func (m proxyLegacyMap) ConvertToType(typeValue ref.Type) ref.Val { + return m.proxy.ConvertToType(typeValue) +} + +func (m proxyLegacyMap) Equal(other ref.Val) ref.Val { + return m.proxy.Equal(other) +} + +func (m proxyLegacyMap) Type() ref.Type { + return m.proxy.Type() +} + +func (m proxyLegacyMap) Value() any { + return m.proxy.Value() +} + +func (m proxyLegacyMap) Contains(value ref.Val) ref.Val { + return m.proxy.Contains(value) +} + +func (m proxyLegacyMap) Find(key ref.Val) (ref.Val, bool) { + return m.proxy.Find(key) +} + +func (m proxyLegacyMap) Get(index ref.Val) ref.Val { + return m.proxy.Get(index) +} + +func (m proxyLegacyMap) Iterator() traits.Iterator { + return m.proxy.Iterator() +} + +func (m proxyLegacyMap) Size() ref.Val { + return m.proxy.Size() +} diff --git a/common/types/traits/lister.go b/common/types/traits/lister.go index 5cf2593f..e54781a6 100644 --- a/common/types/traits/lister.go +++ b/common/types/traits/lister.go @@ -27,6 +27,9 @@ type Lister interface { } // MutableLister interface which emits an immutable result after an intermediate computation. +// +// Note, this interface is intended only to be used within Comprehensions where the mutable +// value is not directly observable within the user-authored CEL expression. type MutableLister interface { Lister ToImmutableList() Lister diff --git a/common/types/traits/mapper.go b/common/types/traits/mapper.go index 5f1a66b9..d13333f3 100644 --- a/common/types/traits/mapper.go +++ b/common/types/traits/mapper.go @@ -33,12 +33,15 @@ type Mapper interface { } // MutableMapper interface which emits an immutable result after an intermediate computation. +// +// Note, this interface is intended only to be used within Comprehensions where the mutable +// value is not directly observable within the user-authored CEL expression. type MutableMapper interface { Mapper - // Insert a key, value pair into the map, returning true if key does not already exist in the map - // to indicate the insert is successful. - Insert(k, v ref.Val) bool + // Insert a key, value pair into the map, returning the map if the insert is successful + // and an error if key already exists in the mutable map. + Insert(k, v ref.Val) ref.Val // ToImmutableMap converts a mutable map into an immutable map. ToImmutableMap() Mapper