From 5cd4381596ffcf1b04fff1ded42c0a988e59280e Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 19 Oct 2021 14:38:05 -0700 Subject: [PATCH] Bug fixes for type-santization and abstract types. (#460) --- checker/checker.go | 11 +- checker/checker_test.go | 108 +++++++++++++++++ checker/env.go | 22 ++-- checker/env_test.go | 25 ++++ checker/errors.go | 23 ---- checker/mapping.go | 13 -- checker/types.go | 30 +++-- parser/errors.go | 12 -- parser/unescape_test.go | 255 ++++++++++------------------------------ 9 files changed, 223 insertions(+), 276 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 2394d2dc..99f7c5bd 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -112,7 +112,8 @@ func (c *checker) check(e *exprpb.Expr) { case *exprpb.Expr_ComprehensionExpr: c.checkComprehension(e) default: - panic(fmt.Sprintf("Unrecognized ast type: %v", reflect.TypeOf(e))) + c.errors.ReportError( + c.location(e), "Unrecognized ast type: %v", reflect.TypeOf(e)) } } @@ -572,7 +573,9 @@ func (c *checker) lookupFieldType(l common.Location, messageType string, fieldNa func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) { if old, found := c.types[e.Id]; found && !proto.Equal(old, t) { - panic(fmt.Sprintf("(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.Id, old, t)) + c.errors.ReportError(c.location(e), + "(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, t) + return } c.types[e.Id] = t } @@ -583,7 +586,9 @@ func (c *checker) getType(e *exprpb.Expr) *exprpb.Type { func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) { if old, found := c.references[e.Id]; found && !proto.Equal(old, r) { - panic(fmt.Sprintf("Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.Id, old, r)) + c.errors.ReportError(c.location(e), + "Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.Id, old, r) + return } c.references[e.Id] = r } diff --git a/checker/checker_test.go b/checker/checker_test.go index 9913e98a..03a4a315 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -1564,6 +1564,114 @@ _&&_(_==_(list~type(list(dyn))^list, )~string^base64_encode_string`, Type: decls.String, }, + { + I: `{}`, + R: `{}~map(dyn, dyn)`, + Type: decls.NewMapType(decls.Dyn, decls.Dyn), + }, + { + I: `set([1, 2, 3])`, + R: ` + set( + [ + 1~int, + 2~int, + 3~int + ]~list(int) + )~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^set_list`, + Env: env{ + functions: []*exprpb.Decl{ + decls.NewFunction("set", + decls.NewParameterizedOverload( + "set_list", []*exprpb.Type{ + decls.NewListType(decls.NewTypeParamType("T")), + }, decls.NewAbstractType("set", decls.NewTypeParamType("T")), + []string{"T"})), + }, + }, + Type: decls.NewAbstractType("set", decls.Int), + }, + { + I: `set([1, 2]) == set([2, 1])`, + R: ` + _==_( + set([1~int, 2~int]~list(int))~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^set_list, + set([2~int, 1~int]~list(int))~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^set_list + )~bool^equals`, + Env: env{ + functions: []*exprpb.Decl{ + decls.NewFunction("set", + decls.NewParameterizedOverload( + "set_list", []*exprpb.Type{ + decls.NewListType(decls.NewTypeParamType("T")), + }, decls.NewAbstractType("set", decls.NewTypeParamType("T")), + []string{"T"})), + }, + }, + Type: decls.Bool, + }, + { + I: `set([1, 2]) == x`, + R: ` + _==_( + set([1~int, 2~int]~list(int))~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^set_list, + x~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^x + )~bool^equals`, + Env: env{ + idents: []*exprpb.Decl{ + decls.NewVar("x", decls.NewAbstractType("set", decls.NewTypeParamType("T"))), + }, + functions: []*exprpb.Decl{ + decls.NewFunction("set", + decls.NewParameterizedOverload( + "set_list", []*exprpb.Type{ + decls.NewListType(decls.NewTypeParamType("T")), + }, decls.NewAbstractType("set", decls.NewTypeParamType("T")), + []string{"T"})), + }, + }, + Type: decls.Bool, + }, + { + I: `int{}`, + Error: ` + ERROR: :1:4: 'int' is not a message type + | int{} + | ...^ + `, + }, + { + I: `Msg{}`, + Error: ` + ERROR: :1:4: undeclared reference to 'Msg' (in container '') + | Msg{} + | ...^ + `, + }, + { + I: `fun()`, + Error: ` + ERROR: :1:4: undeclared reference to 'fun' (in container '') + | fun() + | ...^ + `, + }, + { + I: `'string'.fun()`, + Error: ` + ERROR: :1:13: undeclared reference to 'fun' (in container '') + | 'string'.fun() + | ............^ + `, + }, + { + I: `[].length`, + Error: ` + ERROR: :1:3: type 'list_type:{elem_type:{type_param:"_var0"}}' does not support field selection + | [].length + | ..^ + `, + }, } var testEnvs = map[string]env{ diff --git a/checker/env.go b/checker/env.go index 93cb1061..e8bce7b0 100644 --- a/checker/env.go +++ b/checker/env.go @@ -231,34 +231,26 @@ func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl { } // Sanitize all of the overloads if any overload requires an update to its type references. - overloads := make([]*exprpb.Decl_FunctionDecl_Overload, 0, len(fn.GetOverloads())) + overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(fn.GetOverloads())) for i, o := range fn.GetOverloads() { - var sanitized bool rt := o.GetResultType() if isObjectWellKnownType(rt) { rt = getObjectWellKnownType(rt) - sanitized = true } - params := make([]*exprpb.Type, 0, len(o.GetParams())) + params := make([]*exprpb.Type, len(o.GetParams())) copy(params, o.GetParams()) for j, p := range params { if isObjectWellKnownType(p) { params[j] = getObjectWellKnownType(p) - sanitized = true } } // If sanitized, replace the overload definition. - if sanitized { - if o.IsInstanceFunction { - overloads[i] = - decls.NewInstanceOverload(o.GetOverloadId(), params, rt) - } else { - overloads[i] = - decls.NewOverload(o.GetOverloadId(), params, rt) - } + if o.IsInstanceFunction { + overloads[i] = + decls.NewInstanceOverload(o.GetOverloadId(), params, rt) } else { - // Otherwise, preserve the original overload. - overloads[i] = o + overloads[i] = + decls.NewOverload(o.GetOverloadId(), params, rt) } } return decls.NewFunction(decl.GetName(), overloads...) diff --git a/checker/env_test.go b/checker/env_test.go index 930cd9bb..011b3d72 100644 --- a/checker/env_test.go +++ b/checker/env_test.go @@ -20,6 +20,7 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -64,6 +65,30 @@ func TestOverlappingOverload(t *testing.T) { } } +func TestSanitizedOverload(t *testing.T) { + env := NewStandardEnv(containers.DefaultContainer, newTestRegistry(t)) + err := env.Add(decls.NewFunction(operators.Add, + decls.NewOverload("timestamp_add_int", + []*exprpb.Type{decls.NewObjectType("google.protobuf.Timestamp"), decls.Int}, + decls.NewObjectType("google.protobuf.Timestamp")))) + if err != nil { + t.Errorf("env.Add('timestamp_add_int') failed: %v", err) + } +} + +func TestSanitizedInstanceOverload(t *testing.T) { + env := NewStandardEnv(containers.DefaultContainer, newTestRegistry(t)) + err := env.Add(decls.NewFunction("orDefault", + decls.NewInstanceOverload("int_ordefault_int", + []*exprpb.Type{ + decls.NewObjectType("google.protobuf.Int32Value"), + decls.NewObjectType("google.protobuf.Int32Value")}, + decls.Int))) + if err != nil { + t.Errorf("env.Add('int_ordefault_int') failed: %v", err) + } +} + func newTestRegistry(t *testing.T) ref.TypeRegistry { t.Helper() reg, err := types.NewRegistry() diff --git a/checker/errors.go b/checker/errors.go index 33af57cd..06566aec 100644 --- a/checker/errors.go +++ b/checker/errors.go @@ -41,34 +41,11 @@ func (e *typeErrors) undefinedField(l common.Location, field string) { e.ReportError(l, "undefined field '%s'", field) } -func (e *typeErrors) fieldDoesNotSupportPresenceCheck(l common.Location, field string) { - e.ReportError(l, "field '%s' does not support presence check", field) -} - -func (e *typeErrors) overlappingOverload(l common.Location, name string, overloadID1 string, f1 *exprpb.Type, - overloadID2 string, f2 *exprpb.Type) { - e.ReportError(l, "overlapping overload for name '%s' (type '%s' with overloadId: '%s' cannot be distinguished from '%s' with "+ - "overloadId: '%s')", name, FormatCheckedType(f1), overloadID1, FormatCheckedType(f2), overloadID2) -} - -func (e *typeErrors) overlappingMacro(l common.Location, name string, args int) { - e.ReportError(l, "overload for name '%s' with %d argument(s) overlaps with predefined macro", - name, args) -} - func (e *typeErrors) noMatchingOverload(l common.Location, name string, args []*exprpb.Type, isInstance bool) { signature := formatFunction(nil, args, isInstance) e.ReportError(l, "found no matching overload for '%s' applied to '%s'", name, signature) } -func (e *typeErrors) aggregateTypeMismatch(l common.Location, aggregate *exprpb.Type, member *exprpb.Type) { - e.ReportError( - l, - "type '%s' does not match previous type '%s' in aggregate. Use 'dyn(x)' to make the aggregate dynamic.", - FormatCheckedType(member), - FormatCheckedType(aggregate)) -} - func (e *typeErrors) notAType(l common.Location, t *exprpb.Type) { e.ReportError(l, "'%s(%v)' is not a type", FormatCheckedType(t), t) } diff --git a/checker/mapping.go b/checker/mapping.go index bd5e412d..fbc55a28 100644 --- a/checker/mapping.go +++ b/checker/mapping.go @@ -15,8 +15,6 @@ package checker import ( - "fmt" - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -49,14 +47,3 @@ func (m *mapping) copy() *mapping { } return c } - -func (m *mapping) String() string { - result := "{" - - for k, v := range m.mapping { - result += fmt.Sprintf("%v => %v ", k, v) - } - - result += "}" - return result -} diff --git a/checker/types.go b/checker/types.go index 05d30aa1..38c11d41 100644 --- a/checker/types.go +++ b/checker/types.go @@ -34,7 +34,7 @@ const ( kindWellKnown kindWrapper kindNull - kindAbstract // TODO: Update the checker protos to include abstract + kindAbstract kindType kindList kindMap @@ -190,11 +190,11 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool { if kind2 == kindTypeParam { if t2Sub, found := m.find(t2); found { // If the types are compatible, pick the more general type and return true - if internalIsAssignable(m, t1, t2Sub) { - m.add(t2, mostGeneral(t1, t2Sub)) - return true + if !internalIsAssignable(m, t1, t2Sub) { + return false } - return false + m.add(t2, mostGeneral(t1, t2Sub)) + return true } if notReferencedIn(m, t2, t1) { m.add(t2, t1) @@ -207,11 +207,11 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool { // become if we generalize type unification. if t1Sub, found := m.find(t1); found { // If the types are compatible, pick the more general type and return true - if internalIsAssignable(m, t1Sub, t2) { - m.add(t1, mostGeneral(t1Sub, t2)) - return true + if !internalIsAssignable(m, t1Sub, t2) { + return false } - return false + m.add(t1, mostGeneral(t1Sub, t2)) + return true } if notReferencedIn(m, t1, t2) { m.add(t1, t2) @@ -242,15 +242,11 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool { switch kind1 { // ERROR, TYPE_PARAM, and DYN handled above. case kindAbstract: - return internalIsAssignableAbstractType(m, - t1.GetAbstractType(), t2.GetAbstractType()) + return internalIsAssignableAbstractType(m, t1.GetAbstractType(), t2.GetAbstractType()) case kindFunction: - return internalIsAssignableFunction(m, - t1.GetFunction(), t2.GetFunction()) + return internalIsAssignableFunction(m, t1.GetFunction(), t2.GetFunction()) case kindList: - return internalIsAssignable(m, - t1.GetListType().GetElemType(), - t2.GetListType().GetElemType()) + return internalIsAssignable(m, t1.GetListType().GetElemType(), t2.GetListType().GetElemType()) case kindMap: return internalIsAssignableMap(m, t1.GetMapType(), t2.GetMapType()) case kindObject: @@ -389,6 +385,8 @@ func kindOf(t *exprpb.Type) int { return kindObject case *exprpb.Type_TypeParam: return kindTypeParam + case *exprpb.Type_AbstractType_: + return kindAbstract } return kindUnknown } diff --git a/parser/errors.go b/parser/errors.go index 140beb95..ce49bb87 100644 --- a/parser/errors.go +++ b/parser/errors.go @@ -28,15 +28,3 @@ type parseErrors struct { func (e *parseErrors) syntaxError(l common.Location, message string) { e.ReportError(l, fmt.Sprintf("Syntax error: %s", message)) } - -func (e *parseErrors) invalidHasArgument(l common.Location) { - e.ReportError(l, "Argument to the function 'has' must be a field selection") -} - -func (e *parseErrors) argumentIsNotIdent(l common.Location) { - e.ReportError(l, "Argument must be a simple name") -} - -func (e *parseErrors) notAQualifiedName(l common.Location) { - e.ReportError(l, "Expected a qualified name") -} diff --git a/parser/unescape_test.go b/parser/unescape_test.go index 9c0384ea..24fd6e3a 100644 --- a/parser/unescape_test.go +++ b/parser/unescape_test.go @@ -15,202 +15,69 @@ package parser import ( + "errors" + "strings" "testing" ) -func TestUnescapeSingleQuote(t *testing.T) { - text, err := unescape(`'hello'`, false) - if err != nil { - t.Fatal(err) - } - if text != "hello" { - t.Errorf("Got '%v', wanted '%v", text, "hello") - } -} - -func TestUnescapeDoubleQuote(t *testing.T) { - text, err := unescape(`""`, false) - if err != nil { - t.Fatal(err) - } - if text != `` { - t.Errorf("Got '%v', wanted '%v'", text, ``) - } -} - -func TestUnescapeEscapedQuote(t *testing.T) { - // The argument to unescape is dquote-backslash-dquote-dquote where both - // the backslash and inner double-quote are escaped. - text, err := unescape(`"\\\""`, false) - if err != nil { - t.Fatal(err) - } - if text != `\"` { - t.Errorf("Got '%v', wanted '%v'", text, `\"`) - } -} - -func TestUnescapeEscapedEscape(t *testing.T) { - text, err := unescape(`"\\"`, false) - if err != nil { - t.Fatal(err) - } - if text != `\` { - t.Errorf("Got '%v', wanted '%v'", text, `\`) - } -} - -func TestUnescapeTripleSingleQuote(t *testing.T) { - text, err := unescape(`'''x''x'''`, false) - if err != nil { - t.Fatal(err) - } - if text != `x''x` { - t.Errorf("Got '%v', wanted '%v'", text, `x''x`) - } -} - -func TestUnescapeTripleDoubleQuote(t *testing.T) { - text, err := unescape(`"""x""x"""`, false) - if err != nil { - t.Fatal(err) - } - if text != `x""x` { - t.Errorf("Got '%v', wanted '%v'", text, `x""x`) - } -} - -func TestUnescapeMultiOctalSequence(t *testing.T) { - // Octal 303 -> Code point 195 (Ã) - // Octal 277 -> Code point 191 (¿) - text, err := unescape(`"\303\277"`, false) - if err != nil { - t.Fatal(err) - } - if text != `ÿ` { - t.Errorf("Got '%v', wanted '%v'", text, `ÿ`) - } -} - -func TestUnescapeOctalSequence(t *testing.T) { - // Octal 377 -> Code point 255 (ÿ) - text, err := unescape(`"\377"`, false) - if err != nil { - t.Fatal(err) - } - if text != `ÿ` { - t.Errorf("Got '%v', wanted '%v'", text, `ÿ`) - } -} - -func TestUnescapeUnicodeSequence(t *testing.T) { - text, err := unescape(`"\u263A\u263A"`, false) - if err != nil { - t.Fatal(err) - } - if text != `☺☺` { - t.Errorf("Got '%v', wanted '%v'", text, `☺☺`) - } -} - -func TestUnescapeLegalEscapes(t *testing.T) { - text, err := unescape(`"\a\b\f\n\r\t\v\'\"\\\? Legal escapes"`, false) - if err != nil { - t.Fatal(err) - } - if text != "\a\b\f\n\r\t\v'\"\\? Legal escapes" { - t.Errorf("Got '%v', wanted '%v'", text, "\a\b\f\n\r\t\v'\"\\? Legal escapes") - } -} - -func TestUnescapeIllegalEscapes(t *testing.T) { - // The first escape sequences are legal, but the '\>' is not. - text, err := unescape(`"\a\b\f\n\r\t\v\'\"\\\? Illegal escape \>"`, false) - if err == nil { - t.Errorf("Got '%v', expected error", text) - } -} - -func TestUnescapeBytesAscii(t *testing.T) { - bs, err := unescape(`"abc"`, true) - if err != nil { - t.Fatal(err) - } - want := "\x61\x62\x63" - if bs != want { - t.Errorf("Got '%v', wanted '%v'", bs, want) - } -} - -func TestUnescapeBytesUnicode(t *testing.T) { - bs, err := unescape(`"ÿ"`, true) - if err != nil { - t.Fatal(err) - } - want := "\xc3\xbf" - if bs != want { - t.Errorf("Got '%v', wanted '%v'", bs, want) - } -} - -func TestUnescapeBytesOctal(t *testing.T) { - bs, err := unescape(`"\303\277"`, true) - if err != nil { - t.Fatal(err) - } - want := "\xc3\xbf" - if bs != want { - t.Errorf("Got '%v', wanted '%v'", bs, want) - } -} - -func TestUnescapeBytesOctalMax(t *testing.T) { - bs, err := unescape(`"\377"`, true) - if err != nil { - t.Fatal(err) - } - want := "\xff" - if bs != want { - t.Errorf("Got '%v', wanted '%v'", bs, want) - } -} - -func TestUnescapeBytesQuoting(t *testing.T) { - bs, err := unescape(`'''"Kim\t"'''`, true) - if err != nil { - t.Fatal(err) - } - want := "\x22\x4b\x69\x6d\x09\x22" - if bs != want { - t.Errorf("Got '%v', wanted '%v'", bs, want) - } -} - -func TestUnescapeBytesHex(t *testing.T) { - bs, err := unescape(`"\xc3\xbf"`, true) - if err != nil { - t.Fatal(err) - } - want := "\xc3\xbf" - if bs != want { - t.Errorf("Got '%v', wanted '%v'", bs, want) - } -} - -func TestUnescapeBytesHexMax(t *testing.T) { - bs, err := unescape(`"\xff"`, true) - if err != nil { - t.Fatal(err) - } - want := "\xff" - if bs != want { - t.Errorf("Got '%v', wanted '%v'", bs, want) - } -} - -func TestUnescapeBytesUnicodeEscape(t *testing.T) { - bs, err := unescape(`"\u00ff"`, true) - if err == nil { - t.Errorf("Got '%v', expected error", bs) +func TestUnescape(t *testing.T) { + tests := []struct { + in string + out interface{} + isBytes bool + }{ + // Simple string unescaping tests. + {in: `'hello'`, out: `hello`}, + {in: `r'hello'`, out: `hello`}, + {in: `""`, out: ``}, + {in: `"\\\""`, out: `\"`}, + {in: `"\\"`, out: `\`}, + {in: `'''x''x'''`, out: `x''x`}, + {in: `"""x""x"""`, out: `x""x`}, + {in: `"\303\277"`, out: `ÿ`}, + {in: `"\377"`, out: `ÿ`}, + {in: `"\u263A\u263A"`, out: `☺☺`}, + {in: `"\a\b\f\n\r\t\v\'\"\\\? Legal escapes"`, out: "\a\b\f\n\r\t\v'\"\\? Legal escapes"}, + // Byte unescaping tests. + {in: `"abc"`, out: "\x61\x62\x63", isBytes: true}, + {in: `"ÿ"`, out: "\xc3\xbf", isBytes: true}, + {in: `"\303\277"`, out: "\xc3\xbf", isBytes: true}, + {in: `"\377"`, out: "\xff", isBytes: true}, + {in: `"\xff"`, out: "\xff", isBytes: true}, + {in: `"\xc3\xbf"`, out: "\xc3\xbf", isBytes: true}, + {in: `'''"Kim\t"'''`, out: "\x22\x4b\x69\x6d\x09\x22", isBytes: true}, + // Escaping errors. + {in: `"\a\b\f\n\r\t\v\'\"\\\? Illegal escape \>"`, out: errors.New("unable to unescape string")}, + {in: `"\u00f"`, out: errors.New("unable to unescape string")}, + {in: `"\u00fÿ"`, out: errors.New("unable to unescape string")}, + {in: `"\u00ff"`, out: errors.New("unable to unescape string"), isBytes: true}, + {in: `"\U00ff"`, out: errors.New("unable to unescape string"), isBytes: true}, + {in: `"\26"`, out: errors.New("unable to unescape octal sequence")}, + {in: `"\268"`, out: errors.New("unable to unescape octal sequence")}, + {in: `"\267\"`, out: errors.New(`found '\' as last character`)}, + {in: `'`, out: errors.New("unable to unescape string")}, + {in: `*hello*`, out: errors.New("unable to unescape string")}, + {in: `r'''hello'`, out: errors.New("unable to unescape string")}, + {in: `r"""hello"`, out: errors.New("unable to unescape string")}, + {in: `r"""hello"`, out: errors.New("unable to unescape string")}, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.in, func(t *testing.T) { + got, err := unescape(tc.in, tc.isBytes) + if err != nil { + expect, isErr := tc.out.(error) + if isErr { + if !strings.Contains(err.Error(), expect.Error()) { + t.Errorf("unescape(%s, %v) errored with %v, wanted %v", tc.in, tc.isBytes, err, expect) + } + } else { + t.Fatalf("unescape(%s, %v) failed: %v", tc.in, tc.isBytes, err) + } + } else if got != tc.out { + t.Errorf("unescape(%s, %v) got %v, wanted %v", tc.in, tc.isBytes, got, tc.out) + } + }) } }