From e517cf50ea3388089d22394f1ff46f2d51d096ce Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 19 Jul 2023 14:38:52 -0700 Subject: [PATCH] Helpers for accessing unknown state information (#781) --- common/types/unknown.go | 36 +++++++++++++++++++ common/types/unknown_test.go | 67 ++++++++++++++++++++++++++++++++++-- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/common/types/unknown.go b/common/types/unknown.go index f281a87f..9dd2b257 100644 --- a/common/types/unknown.go +++ b/common/types/unknown.go @@ -18,6 +18,7 @@ import ( "fmt" "math" "reflect" + "sort" "strings" "unicode" @@ -160,6 +161,26 @@ func NewUnknown(id int64, attr *AttributeTrail) *Unknown { } } +// IDs returns the set of unknown expression ids contained by this value. +// +// Numeric identifiers are guaranteed to be in sorted order. +func (u *Unknown) IDs() []int64 { + ids := make(int64Slice, len(u.attributeTrails)) + i := 0 + for id := range u.attributeTrails { + ids[i] = id + i++ + } + ids.Sort() + return ids +} + +// GetAttributeTrails returns the attribute trails, if present, missing for a given expression id. +func (u *Unknown) GetAttributeTrails(id int64) ([]*AttributeTrail, bool) { + trails, found := u.attributeTrails[id] + return trails, found +} + // Contains returns true if the input unknown is a subset of the current unknown. func (u *Unknown) Contains(other *Unknown) bool { for id, otherTrails := range other.attributeTrails { @@ -288,3 +309,18 @@ func MergeUnknowns(unk1, unk2 *Unknown) *Unknown { } return out } + +// int64Slice is an implementation of the sort.Interface +type int64Slice []int64 + +// Len returns the number of elements in the slice. +func (x int64Slice) Len() int { return len(x) } + +// Less indicates whether the value at index i is less than the value at index j. +func (x int64Slice) Less(i, j int) bool { return x[i] < x[j] } + +// Swap swaps the values at indices i and j in place. +func (x int64Slice) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +// Sort is a convenience method: x.Sort() calls Sort(x). +func (x int64Slice) Sort() { sort.Sort(x) } diff --git a/common/types/unknown_test.go b/common/types/unknown_test.go index d0d83d6a..59370895 100644 --- a/common/types/unknown_test.go +++ b/common/types/unknown_test.go @@ -17,6 +17,7 @@ package types import ( "fmt" "math" + "reflect" "strings" "testing" @@ -228,6 +229,68 @@ func TestUnknownContains(t *testing.T) { } } +func TestUnknownIDs(t *testing.T) { + tests := []struct { + unk *Unknown + ids []int64 + attrs []string + }{ + { + unk: NewUnknown(1, nil), + ids: []int64{1}, + attrs: []string{""}, + }, + { + unk: NewUnknown(2, QualifyAttribute[bool](NewAttributeTrail("a"), true)), + ids: []int64{2}, + attrs: []string{"a[true]"}, + }, + { + unk: NewUnknown(3, QualifyAttribute[string](NewAttributeTrail("a"), "b")), + ids: []int64{3}, + attrs: []string{"a.b"}, + }, + { + unk: NewUnknown(4, QualifyAttribute[string](NewAttributeTrail("a"), "c")), + ids: []int64{4}, + attrs: []string{"a.c"}, + }, + { + unk: MergeUnknowns( + NewUnknown(4, QualifyAttribute[string](NewAttributeTrail("a"), "b")), + NewUnknown(3, QualifyAttribute[bool](NewAttributeTrail("a"), true)), + ), + ids: []int64{3, 4}, + attrs: []string{"a[true]", "a.b"}, + }, + } + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ids := tc.unk.IDs() + if !reflect.DeepEqual(ids, tc.ids) { + t.Errorf("%v.IDs() got %v, wanted %v", tc.unk, ids, tc.ids) + } + attrs := make([]string, len(ids)) + idx := 0 + for _, id := range ids { + trails, found := tc.unk.GetAttributeTrails(id) + if !found { + t.Fatalf("GetAttributeTrails(%d) not found", id) + } + if len(trails) != 1 { + t.Fatalf("GetAttributeTrails(%d) got %d trails, wanted 1", id, len(trails)) + } + attrs[idx] = trails[0].String() + idx++ + } + if !reflect.DeepEqual(attrs, tc.attrs) { + t.Errorf("%v.GetAttributeTrails() got %v, wanted %v", tc.unk, attrs, tc.attrs) + } + }) + } +} + func TestUnknownString(t *testing.T) { tests := []struct { unk *Unknown @@ -340,7 +403,5 @@ func TestMaybeMergeUnknowns(t *testing.T) { func newUnk(t *testing.T, id int64, varName string) *Unknown { t.Helper() - attr := NewAttributeTrail(varName) - unk := NewUnknown(id, attr) - return unk + return NewUnknown(id, NewAttributeTrail(varName)) }