Skip to content

Commit

Permalink
Fix list and map equality to preserve the logical ANDing of elements …
Browse files Browse the repository at this point in the history
…and to consider typing before erroring (#449)
  • Loading branch information
TristonianJones authored Sep 18, 2021
1 parent e24e354 commit e7c178e
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 182 deletions.
22 changes: 18 additions & 4 deletions common/types/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,20 @@ func (l *baseList) Equal(other ref.Val) ref.Val {
if l.Size() != otherList.Size() {
return False
}
var maybeErr ref.Val
for i := IntZero; i < l.Size().(Int); i++ {
thisElem := l.Get(i)
otherElem := otherList.Get(i)
elemEq := thisElem.Equal(otherElem)
if elemEq != True {
return elemEq
if elemEq == False {
return False
}
if maybeErr == nil && IsUnknownOrError(elemEq) {
maybeErr = elemEq
}
}
if maybeErr != nil {
return maybeErr
}
return True
}
Expand Down Expand Up @@ -347,13 +354,20 @@ func (l *concatList) Equal(other ref.Val) ref.Val {
if l.Size() != otherList.Size() {
return False
}
var maybeErr ref.Val
for i := IntZero; i < l.Size().(Int); i++ {
thisElem := l.Get(i)
otherElem := otherList.Get(i)
elemEq := thisElem.Equal(otherElem)
if elemEq != True {
return elemEq
if elemEq == False {
return False
}
if maybeErr == nil && IsUnknownOrError(elemEq) {
maybeErr = elemEq
}
}
if maybeErr != nil {
return maybeErr
}
return True
}
Expand Down
52 changes: 35 additions & 17 deletions common/types/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestBaseListContains_NonBool(t *testing.T) {
t.Error("List contains succeeded with wrong type")
}
if !reflect.DeepEqual(list.Contains(Unknown{1}), Unknown{1}) {
t.Error("List ")
t.Error("list.Contains(unknown) did not return unknown input")
}
}

Expand Down Expand Up @@ -132,10 +132,25 @@ func TestBaseListConvertToType(t *testing.T) {
}

func TestBaseListEqual(t *testing.T) {
listA := NewDynamicList(newTestRegistry(t), []string{"h", "e", "l", "l", "o"})
listB := NewDynamicList(newTestRegistry(t), []string{"h", "e", "l", "p", "!"})
reg := newTestRegistry(t)
listA := NewDynamicList(reg, []string{"h", "e", "l", "l", "o"})
if listA.Equal(listA) != True {
t.Error("listA.Equal(listA) did not return true.")
}
listB := NewDynamicList(reg, []string{"h", "e", "l", "p", "!"})
if listA.Equal(listB) != False {
t.Error("Lists with different contents returned equal.")
t.Error("listA.Equal(listB) did not return false.")
}
listC := reg.NativeToValue([]interface{}{"h", "e", "l", "l", String("o")})
if listA.Equal(listC) != True {
t.Error("listA.Equal(listC) did not return true.")
}
listD := reg.NativeToValue([]interface{}{"h", "e", 1, "p", "!"})
if listA.Equal(listD) != False {
t.Error("listA.Equal(listD) did not return true")
}
if !IsError(listB.Equal(listD)) {
t.Error("listA.Equal(listD) did not error on single element type difference")
}
}

Expand Down Expand Up @@ -320,27 +335,30 @@ func TestConcatListContains_NonBool(t *testing.T) {
}
}

func TestConcatListValue_Equal(t *testing.T) {
func TestConcatListEqual(t *testing.T) {
reg := newTestRegistry(t)
listA := NewDynamicList(reg, []float32{1.0, 2.0})
listB := NewDynamicList(reg, []float64{3.0})
list := listA.Add(listB)
// Note the internal type of list raw and concat list are slightly different.
listRaw := NewDynamicList(reg, []interface{}{
float32(1.0), float64(2.0), float64(3.0)})
if listRaw.Equal(list) != True ||
list.Equal(listRaw) != True {
t.Errorf("Concat list and raw list were not equal, got '%v', expected '%v'",
list.Value(),
listRaw.Value())
listRaw := NewDynamicList(reg, []interface{}{float32(1.0), float64(2.0), float64(3.0)})
if listRaw.Equal(list) != True || list.Equal(listRaw) != True {
t.Errorf("listRaw.Equal(list) not true, got '%v', expected '%v'", list.Value(), listRaw.Value())
}
if list.Equal(listA) == True || listRaw.Equal(listA) == True {
t.Error("lists of unequal length considered equal")
}
listC := reg.NativeToValue([]interface{}{1.0, 3.0, 2.0})
if list.Equal(listC) != False {
t.Errorf("list.Equal(listC) got %v, wanted false", list.Equal(listC))
}
if list.Equal(listA) == True ||
listRaw.Equal(listA) == True {
t.Errorf("Lists of unequal length considered equal")
listD := reg.NativeToValue([]interface{}{1, 2.0, 3.0})
if !IsError(list.Equal(listD)) {
t.Errorf("list.Equal(listD) got %v, wanted error", list.Equal(listD))
}
}

func TestConcatListValue_Get(t *testing.T) {
func TestConcatListGet(t *testing.T) {
reg := newTestRegistry(t)
listA := NewDynamicList(reg, []float32{1.0, 2.0})
listB := NewDynamicList(reg, []float64{3.0})
Expand All @@ -358,7 +376,7 @@ func TestConcatListValue_Get(t *testing.T) {
}
}

func TestConcatListValue_Iterator(t *testing.T) {
func TestConcatListIterator(t *testing.T) {
reg := newTestRegistry(t)
listA := NewDynamicList(reg, []float32{1.0, 2.0})
listB := NewDynamicList(reg, []float64{3.0})
Expand Down
77 changes: 50 additions & 27 deletions common/types/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ func (m *baseMap) Equal(other ref.Val) ref.Val {
return False
}
it := m.Iterator()
var maybeErr ref.Val
for it.HasNext() == True {
key := it.Next()
thisVal, _ := m.Find(key)
Expand All @@ -265,12 +266,21 @@ func (m *baseMap) Equal(other ref.Val) ref.Val {
if otherVal == nil {
return False
}
return MaybeNoSuchOverloadErr(otherVal)
if maybeErr == nil {
maybeErr = MaybeNoSuchOverloadErr(otherVal)
}
continue
}
valEq := thisVal.Equal(otherVal)
if valEq != True {
return valEq
if valEq == False {
return False
}
if maybeErr == nil && IsUnknownOrError(valEq) {
maybeErr = valEq
}
}
if maybeErr != nil {
return maybeErr
}
return True
}
Expand Down Expand Up @@ -320,7 +330,7 @@ type jsonStructAccessor struct {
func (a *jsonStructAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return MaybeNoSuchOverloadErr(key), false
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
}
keyVal, found := a.st[string(strKey)]
if !found {
Expand Down Expand Up @@ -369,29 +379,32 @@ func (a *reflectMapAccessor) Find(key ref.Val) (ref.Val, bool) {
if IsUnknownOrError(key) {
return MaybeNoSuchOverloadErr(key), false
}
if a.refValue.Len() == 0 {
return nil, false
}
k, err := key.ConvertToNative(a.keyType)
if err != nil {
return &Err{err}, false
}
var refKey reflect.Value
switch k := k.(type) {
case reflect.Value:
refKey = k
default:
refKey = reflect.ValueOf(k)
return NewErr("unsupported key type: %v", key.Type()), false
}
refKey := reflect.ValueOf(k)
val := a.refValue.MapIndex(refKey)
if !val.IsValid() {
return nil, false
if val.IsValid() {
return a.NativeToValue(val.Interface()), true
}
mapIt := a.refValue.MapRange()
for mapIt.Next() {
if refKey.Kind() == mapIt.Key().Kind() {
return nil, false
}
}
return a.NativeToValue(val.Interface()), true
return NewErr("unsupported key type: %v", key.Type()), false
}

// Iterator creates a Golang reflection based traits.Iterator.
func (a *reflectMapAccessor) Iterator() traits.Iterator {
return &mapIterator{
TypeAdapter: a.TypeAdapter,
mapKeys: a.refValue.MapKeys(),
mapKeys: a.refValue.MapRange(),
len: a.refValue.Len(),
}
}
Expand All @@ -412,15 +425,26 @@ func (a *refValMapAccessor) Find(key ref.Val) (ref.Val, bool) {
if IsUnknownOrError(key) {
return key, false
}
if len(a.mapVal) == 0 {
return nil, false
}
keyVal, found := a.mapVal[key]
return keyVal, found
if found {
return keyVal, true
}
for k := range a.mapVal {
if k.Type().TypeName() == key.Type().TypeName() {
return nil, false
}
}
return NewErr("unsupported key type: %v", key.Type()), found
}

// Iterator produces a new traits.Iterator which iterates over the map keys via Golang reflection.
func (a *refValMapAccessor) Iterator() traits.Iterator {
return &mapIterator{
TypeAdapter: DefaultTypeAdapter,
mapKeys: reflect.ValueOf(a.mapVal).MapKeys(),
mapKeys: reflect.ValueOf(a.mapVal).MapRange(),
len: len(a.mapVal),
}
}
Expand All @@ -441,7 +465,7 @@ type stringMapAccessor struct {
func (a *stringMapAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return MaybeNoSuchOverloadErr(key), false
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
}
keyVal, found := a.mapVal[string(strKey)]
if !found {
Expand Down Expand Up @@ -485,7 +509,7 @@ type stringIfaceMapAccessor struct {
func (a *stringIfaceMapAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return MaybeNoSuchOverloadErr(key), false
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
}
keyVal, found := a.mapVal[string(strKey)]
if !found {
Expand Down Expand Up @@ -616,10 +640,10 @@ func (m *protoMap) ConvertToType(typeVal ref.Type) ref.Val {

// Equal implements the ref.Val interface method.
func (m *protoMap) Equal(other ref.Val) ref.Val {
if MapType != other.Type() {
otherMap, ok := other.(traits.Mapper)
if !ok {
return MaybeNoSuchOverloadErr(other)
}
otherMap := other.(traits.Mapper)
if m.value.Map.Len() != int(otherMap.Size().(Int)) {
return False
}
Expand Down Expand Up @@ -659,7 +683,7 @@ func (m *protoMap) Find(key ref.Val) (ref.Val, bool) {
// Convert the input key to the expected protobuf key type.
ntvKey, err := key.ConvertToNative(m.value.KeyType.ReflectType())
if err != nil {
return &Err{err}, false
return NewErr("unsupported key type: %v", key.Type()), false
}
// Use protoreflection to get the key value.
val := m.value.Get(protoreflect.ValueOf(ntvKey).MapKey())
Expand Down Expand Up @@ -718,7 +742,7 @@ func (m *protoMap) Value() interface{} {
type mapIterator struct {
*baseIterator
ref.TypeAdapter
mapKeys []reflect.Value
mapKeys *reflect.MapIter
cursor int
len int
}
Expand All @@ -730,10 +754,9 @@ func (it *mapIterator) HasNext() ref.Val {

// Next implements the traits.Iterator interface method.
func (it *mapIterator) Next() ref.Val {
if it.HasNext() == True {
index := it.cursor
if it.HasNext() == True && it.mapKeys.Next() {
it.cursor++
refKey := it.mapKeys[index]
refKey := it.mapKeys.Key()
return it.NativeToValue(refKey.Interface())
}
return nil
Expand Down
Loading

0 comments on commit e7c178e

Please sign in to comment.