Skip to content

Commit

Permalink
Bug fixes for type-santization and abstract types. (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Oct 19, 2021
1 parent 7f2b87a commit 5cd4381
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 276 deletions.
11 changes: 8 additions & 3 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ func (c *checker) check(e *exprpb.Expr) {
case *exprpb.Expr_ComprehensionExpr:
c.checkComprehension(e)
default:
panic(fmt.Sprintf("Unrecognized ast type: %v", reflect.TypeOf(e)))
c.errors.ReportError(
c.location(e), "Unrecognized ast type: %v", reflect.TypeOf(e))
}
}

Expand Down Expand Up @@ -572,7 +573,9 @@ func (c *checker) lookupFieldType(l common.Location, messageType string, fieldNa

func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) {
if old, found := c.types[e.Id]; found && !proto.Equal(old, t) {
panic(fmt.Sprintf("(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.Id, old, t))
c.errors.ReportError(c.location(e),
"(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, t)
return
}
c.types[e.Id] = t
}
Expand All @@ -583,7 +586,9 @@ func (c *checker) getType(e *exprpb.Expr) *exprpb.Type {

func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) {
if old, found := c.references[e.Id]; found && !proto.Equal(old, r) {
panic(fmt.Sprintf("Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.Id, old, r))
c.errors.ReportError(c.location(e),
"Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.Id, old, r)
return
}
c.references[e.Id] = r
}
Expand Down
108 changes: 108 additions & 0 deletions checker/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,114 @@ _&&_(_==_(list~type(list(dyn))^list,
)~string^base64_encode_string`,
Type: decls.String,
},
{
I: `{}`,
R: `{}~map(dyn, dyn)`,
Type: decls.NewMapType(decls.Dyn, decls.Dyn),
},
{
I: `set([1, 2, 3])`,
R: `
set(
[
1~int,
2~int,
3~int
]~list(int)
)~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^set_list`,
Env: env{
functions: []*exprpb.Decl{
decls.NewFunction("set",
decls.NewParameterizedOverload(
"set_list", []*exprpb.Type{
decls.NewListType(decls.NewTypeParamType("T")),
}, decls.NewAbstractType("set", decls.NewTypeParamType("T")),
[]string{"T"})),
},
},
Type: decls.NewAbstractType("set", decls.Int),
},
{
I: `set([1, 2]) == set([2, 1])`,
R: `
_==_(
set([1~int, 2~int]~list(int))~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^set_list,
set([2~int, 1~int]~list(int))~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^set_list
)~bool^equals`,
Env: env{
functions: []*exprpb.Decl{
decls.NewFunction("set",
decls.NewParameterizedOverload(
"set_list", []*exprpb.Type{
decls.NewListType(decls.NewTypeParamType("T")),
}, decls.NewAbstractType("set", decls.NewTypeParamType("T")),
[]string{"T"})),
},
},
Type: decls.Bool,
},
{
I: `set([1, 2]) == x`,
R: `
_==_(
set([1~int, 2~int]~list(int))~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^set_list,
x~abstract_type:{name:"set" parameter_types:{primitive:INT64}}^x
)~bool^equals`,
Env: env{
idents: []*exprpb.Decl{
decls.NewVar("x", decls.NewAbstractType("set", decls.NewTypeParamType("T"))),
},
functions: []*exprpb.Decl{
decls.NewFunction("set",
decls.NewParameterizedOverload(
"set_list", []*exprpb.Type{
decls.NewListType(decls.NewTypeParamType("T")),
}, decls.NewAbstractType("set", decls.NewTypeParamType("T")),
[]string{"T"})),
},
},
Type: decls.Bool,
},
{
I: `int{}`,
Error: `
ERROR: <input>:1:4: 'int' is not a message type
| int{}
| ...^
`,
},
{
I: `Msg{}`,
Error: `
ERROR: <input>:1:4: undeclared reference to 'Msg' (in container '')
| Msg{}
| ...^
`,
},
{
I: `fun()`,
Error: `
ERROR: <input>:1:4: undeclared reference to 'fun' (in container '')
| fun()
| ...^
`,
},
{
I: `'string'.fun()`,
Error: `
ERROR: <input>:1:13: undeclared reference to 'fun' (in container '')
| 'string'.fun()
| ............^
`,
},
{
I: `[].length`,
Error: `
ERROR: <input>:1:3: type 'list_type:{elem_type:{type_param:"_var0"}}' does not support field selection
| [].length
| ..^
`,
},
}

var testEnvs = map[string]env{
Expand Down
22 changes: 7 additions & 15 deletions checker/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,34 +231,26 @@ func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl {
}

// Sanitize all of the overloads if any overload requires an update to its type references.
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, 0, len(fn.GetOverloads()))
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(fn.GetOverloads()))
for i, o := range fn.GetOverloads() {
var sanitized bool
rt := o.GetResultType()
if isObjectWellKnownType(rt) {
rt = getObjectWellKnownType(rt)
sanitized = true
}
params := make([]*exprpb.Type, 0, len(o.GetParams()))
params := make([]*exprpb.Type, len(o.GetParams()))
copy(params, o.GetParams())
for j, p := range params {
if isObjectWellKnownType(p) {
params[j] = getObjectWellKnownType(p)
sanitized = true
}
}
// If sanitized, replace the overload definition.
if sanitized {
if o.IsInstanceFunction {
overloads[i] =
decls.NewInstanceOverload(o.GetOverloadId(), params, rt)
} else {
overloads[i] =
decls.NewOverload(o.GetOverloadId(), params, rt)
}
if o.IsInstanceFunction {
overloads[i] =
decls.NewInstanceOverload(o.GetOverloadId(), params, rt)
} else {
// Otherwise, preserve the original overload.
overloads[i] = o
overloads[i] =
decls.NewOverload(o.GetOverloadId(), params, rt)
}
}
return decls.NewFunction(decl.GetName(), overloads...)
Expand Down
25 changes: 25 additions & 0 deletions checker/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
Expand Down Expand Up @@ -64,6 +65,30 @@ func TestOverlappingOverload(t *testing.T) {
}
}

func TestSanitizedOverload(t *testing.T) {
env := NewStandardEnv(containers.DefaultContainer, newTestRegistry(t))
err := env.Add(decls.NewFunction(operators.Add,
decls.NewOverload("timestamp_add_int",
[]*exprpb.Type{decls.NewObjectType("google.protobuf.Timestamp"), decls.Int},
decls.NewObjectType("google.protobuf.Timestamp"))))
if err != nil {
t.Errorf("env.Add('timestamp_add_int') failed: %v", err)
}
}

func TestSanitizedInstanceOverload(t *testing.T) {
env := NewStandardEnv(containers.DefaultContainer, newTestRegistry(t))
err := env.Add(decls.NewFunction("orDefault",
decls.NewInstanceOverload("int_ordefault_int",
[]*exprpb.Type{
decls.NewObjectType("google.protobuf.Int32Value"),
decls.NewObjectType("google.protobuf.Int32Value")},
decls.Int)))
if err != nil {
t.Errorf("env.Add('int_ordefault_int') failed: %v", err)
}
}

func newTestRegistry(t *testing.T) ref.TypeRegistry {
t.Helper()
reg, err := types.NewRegistry()
Expand Down
23 changes: 0 additions & 23 deletions checker/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,34 +41,11 @@ func (e *typeErrors) undefinedField(l common.Location, field string) {
e.ReportError(l, "undefined field '%s'", field)
}

func (e *typeErrors) fieldDoesNotSupportPresenceCheck(l common.Location, field string) {
e.ReportError(l, "field '%s' does not support presence check", field)
}

func (e *typeErrors) overlappingOverload(l common.Location, name string, overloadID1 string, f1 *exprpb.Type,
overloadID2 string, f2 *exprpb.Type) {
e.ReportError(l, "overlapping overload for name '%s' (type '%s' with overloadId: '%s' cannot be distinguished from '%s' with "+
"overloadId: '%s')", name, FormatCheckedType(f1), overloadID1, FormatCheckedType(f2), overloadID2)
}

func (e *typeErrors) overlappingMacro(l common.Location, name string, args int) {
e.ReportError(l, "overload for name '%s' with %d argument(s) overlaps with predefined macro",
name, args)
}

func (e *typeErrors) noMatchingOverload(l common.Location, name string, args []*exprpb.Type, isInstance bool) {
signature := formatFunction(nil, args, isInstance)
e.ReportError(l, "found no matching overload for '%s' applied to '%s'", name, signature)
}

func (e *typeErrors) aggregateTypeMismatch(l common.Location, aggregate *exprpb.Type, member *exprpb.Type) {
e.ReportError(
l,
"type '%s' does not match previous type '%s' in aggregate. Use 'dyn(x)' to make the aggregate dynamic.",
FormatCheckedType(member),
FormatCheckedType(aggregate))
}

func (e *typeErrors) notAType(l common.Location, t *exprpb.Type) {
e.ReportError(l, "'%s(%v)' is not a type", FormatCheckedType(t), t)
}
Expand Down
13 changes: 0 additions & 13 deletions checker/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package checker

import (
"fmt"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

Expand Down Expand Up @@ -49,14 +47,3 @@ func (m *mapping) copy() *mapping {
}
return c
}

func (m *mapping) String() string {
result := "{"

for k, v := range m.mapping {
result += fmt.Sprintf("%v => %v ", k, v)
}

result += "}"
return result
}
30 changes: 14 additions & 16 deletions checker/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const (
kindWellKnown
kindWrapper
kindNull
kindAbstract // TODO: Update the checker protos to include abstract
kindAbstract
kindType
kindList
kindMap
Expand Down Expand Up @@ -190,11 +190,11 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
if kind2 == kindTypeParam {
if t2Sub, found := m.find(t2); found {
// If the types are compatible, pick the more general type and return true
if internalIsAssignable(m, t1, t2Sub) {
m.add(t2, mostGeneral(t1, t2Sub))
return true
if !internalIsAssignable(m, t1, t2Sub) {
return false
}
return false
m.add(t2, mostGeneral(t1, t2Sub))
return true
}
if notReferencedIn(m, t2, t1) {
m.add(t2, t1)
Expand All @@ -207,11 +207,11 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// become if we generalize type unification.
if t1Sub, found := m.find(t1); found {
// If the types are compatible, pick the more general type and return true
if internalIsAssignable(m, t1Sub, t2) {
m.add(t1, mostGeneral(t1Sub, t2))
return true
if !internalIsAssignable(m, t1Sub, t2) {
return false
}
return false
m.add(t1, mostGeneral(t1Sub, t2))
return true
}
if notReferencedIn(m, t1, t2) {
m.add(t1, t2)
Expand Down Expand Up @@ -242,15 +242,11 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
switch kind1 {
// ERROR, TYPE_PARAM, and DYN handled above.
case kindAbstract:
return internalIsAssignableAbstractType(m,
t1.GetAbstractType(), t2.GetAbstractType())
return internalIsAssignableAbstractType(m, t1.GetAbstractType(), t2.GetAbstractType())
case kindFunction:
return internalIsAssignableFunction(m,
t1.GetFunction(), t2.GetFunction())
return internalIsAssignableFunction(m, t1.GetFunction(), t2.GetFunction())
case kindList:
return internalIsAssignable(m,
t1.GetListType().GetElemType(),
t2.GetListType().GetElemType())
return internalIsAssignable(m, t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
return internalIsAssignableMap(m, t1.GetMapType(), t2.GetMapType())
case kindObject:
Expand Down Expand Up @@ -389,6 +385,8 @@ func kindOf(t *exprpb.Type) int {
return kindObject
case *exprpb.Type_TypeParam:
return kindTypeParam
case *exprpb.Type_AbstractType_:
return kindAbstract
}
return kindUnknown
}
Expand Down
12 changes: 0 additions & 12 deletions parser/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,3 @@ type parseErrors struct {
func (e *parseErrors) syntaxError(l common.Location, message string) {
e.ReportError(l, fmt.Sprintf("Syntax error: %s", message))
}

func (e *parseErrors) invalidHasArgument(l common.Location) {
e.ReportError(l, "Argument to the function 'has' must be a field selection")
}

func (e *parseErrors) argumentIsNotIdent(l common.Location) {
e.ReportError(l, "Argument must be a simple name")
}

func (e *parseErrors) notAQualifiedName(l common.Location) {
e.ReportError(l, "Expected a qualified name")
}
Loading

0 comments on commit 5cd4381

Please sign in to comment.