Skip to content

Commit

Permalink
Fix expression inlining when working with macros (#853)
Browse files Browse the repository at this point in the history
* Allow constant macro references to remain in the macro call set
* Fix stability issues in the macro rewrites
* Introduce `NativeRep` call on `cel.Ast`
* Additional tests and refactors to simplify testing / extension
  • Loading branch information
TristonianJones authored Nov 10, 2023
1 parent 39c2810 commit 1460938
Show file tree
Hide file tree
Showing 9 changed files with 628 additions and 178 deletions.
3 changes: 3 additions & 0 deletions cel/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_library(
"env.go",
"folding.go",
"io.go",
"inlining.go",
"library.go",
"macro.go",
"optimizer.go",
Expand Down Expand Up @@ -60,6 +61,8 @@ go_test(
"env_test.go",
"folding_test.go",
"io_test.go",
"inlining_test.go",
"optimizer_test.go",
"validator_test.go",
],
data = [
Expand Down
5 changes: 5 additions & 0 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ type Ast struct {
impl *celast.AST
}

// NativeRep converts the AST to a Go-native representation.
func (ast *Ast) NativeRep() *celast.AST {
return ast.impl
}

// Expr returns the proto serializable instance of the parsed/checked expression.
//
// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr()
Expand Down
36 changes: 17 additions & 19 deletions cel/folding.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error())
return
}
e.SetKindCase(adapted)
ctx.UpdateExpr(e, adapted)
}))

return a
Expand All @@ -134,10 +134,8 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
if err != nil {
return err
}
// Clear any macro metadata associated with the fold.
a.SourceInfo().ClearMacroCall(expr.ID())
// Update the fold expression to be a literal.
expr.SetKindCase(ctx.NewLiteral(out))
ctx.UpdateExpr(expr, ctx.NewLiteral(out))
return nil
}

Expand All @@ -159,15 +157,15 @@ func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
return false
}
if cond.AsLiteral() == types.True {
expr.SetKindCase(truthy)
ctx.UpdateExpr(expr, truthy)
} else {
expr.SetKindCase(falsy)
ctx.UpdateExpr(expr, falsy)
}
return true
case operators.In:
haystack := args[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
expr.SetKindCase(ctx.NewLiteral(types.False))
ctx.UpdateExpr(expr, ctx.NewLiteral(types.False))
return true
}
needle := args[0]
Expand All @@ -176,7 +174,7 @@ func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
expr.SetKindCase(ctx.NewLiteral(types.True))
ctx.UpdateExpr(expr, ctx.NewLiteral(types.True))
return true
}
}
Expand All @@ -202,20 +200,20 @@ func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.E
continue
}
if arg.AsLiteral() == shortcircuit {
expr.SetKindCase(arg)
ctx.UpdateExpr(expr, arg)
return true
}
}
if len(newArgs) == 0 {
newArgs = append(newArgs, args[0])
expr.SetKindCase(newArgs[0])
ctx.UpdateExpr(expr, newArgs[0])
return true
}
if len(newArgs) == 1 {
expr.SetKindCase(newArgs[0])
ctx.UpdateExpr(expr, newArgs[0])
return true
}
expr.SetKindCase(ctx.NewCall(function, newArgs...))
ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...))
return true
}

Expand Down Expand Up @@ -270,10 +268,10 @@ func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) {
newOptIndex-- // Skipping causes the list to get smaller.
continue
}
e.SetKindCase(ctx.NewLiteral(optElemVal.GetValue()))
ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue()))
updatedElems = append(updatedElems, e)
}
e.SetKindCase(ctx.NewList(updatedElems, updatedIndices))
ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices))
}

func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
Expand Down Expand Up @@ -303,20 +301,20 @@ func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
if err != nil {
ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err)
}
val.SetKindCase(undoOptVal)
ctx.UpdateExpr(val, undoOptVal)
updatedEntries = append(updatedEntries, e)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue()))
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedEntry := ctx.NewMapEntry(key, val, false)
updatedEntries = append(updatedEntries, updatedEntry)
}
if modified {
e.SetKindCase(ctx.NewMap(updatedEntries))
ctx.UpdateExpr(e, ctx.NewMap(updatedEntries))
}
}

Expand All @@ -341,12 +339,12 @@ func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) {
if !optElemVal.HasValue() {
continue
}
val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue()))
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedField := ctx.NewStructField(field.Name(), val, false)
updatedFields = append(updatedFields, updatedField)
}
if modified {
e.SetKindCase(ctx.NewStruct(s.TypeName(), updatedFields))
ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields))
}
}

Expand Down
62 changes: 62 additions & 0 deletions cel/folding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,68 @@ func TestConstantFoldingOptimizer(t *testing.T) {
}
}

func TestConstantFoldingOptimizerMacroElimination(t *testing.T) {
tests := []struct {
expr string
folded string
macroCount int
}{
{
expr: `has({}.key)`,
folded: `false`,
},
{
expr: `[1, 2, 3].filter(i, i < 1)`,
folded: `[]`,
},
{
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b))`,
folded: `true`,
},
{
expr: `has(x.b) && [{}, {"a": 1}, {"b": 2}].exists(i, has(i.b))`,
folded: `has(x.b)`,
macroCount: 1,
},
}
e, err := NewEnv(
OptionalTypes(),
EnableMacroCallTracking(),
Types(&proto3pb.TestAllTypes{}),
Variable("x", DynType))
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
for _, tst := range tests {
tc := tst
t.Run(tc.expr, func(t *testing.T) {
checked, iss := e.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
folder, err := NewConstantFoldingOptimizer()
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt := NewStaticOptimizer(folder)
optimized, iss := opt.Optimize(e, checked)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
}
folded, err := AstToString(optimized)
if err != nil {
t.Fatalf("AstToString() failed: %v", err)
}
if folded != tc.folded {
t.Errorf("folding got %q, wanted %q", folded, tc.folded)
}
if len(optimized.SourceInfo().GetMacroCalls()) != tc.macroCount {
t.Errorf("folding got %d macros, wanted %d macros", len(optimized.SourceInfo().GetMacroCalls()), tc.macroCount)
}
})
}
}

func TestConstantFoldingOptimizerWithLimit(t *testing.T) {
tests := []struct {
expr string
Expand Down
68 changes: 44 additions & 24 deletions cel/inlining.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,57 +85,76 @@ func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.A
}

// For a single match, do a direct replacement of the expression sub-graph.
if len(matches) == 1 {
opt.inlineExpr(ctx, matches[0], ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type())
continue
}

if !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
if len(matches) == 1 || !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
for _, match := range matches {
opt.inlineExpr(ctx, match, ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type())
// Copy the inlined AST expr and source info.
copyExpr := copyASTAndMetadata(ctx, inlineVar.def)
opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type())
}
continue
}

// For multiple matches, find the least common ancestor (lca) and insert the
// variable as a cel.bind() macro.
var lca ast.NavigableExpr = nil
ancestors := map[int64]bool{}
var lca ast.NavigableExpr = root
lcaAncestorCount := 0
ancestors := map[int64]int{}
for _, match := range matches {
// Update the identifier matches with the provided alias.
aliasExpr := ctx.NewIdent(inlineVar.Alias())
opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type())
parent, found := match, true
for found {
_, hasAncestor := ancestors[parent.ID()]
if hasAncestor && (lca == nil || lca.Depth() < parent.Depth()) {
ancestorCount, hasAncestor := ancestors[parent.ID()]
if !hasAncestor {
ancestors[parent.ID()] = 1
parent, found = parent.Parent()
continue
}
if lcaAncestorCount < ancestorCount || (lcaAncestorCount == ancestorCount && lca.Depth() < parent.Depth()) {
lca = parent
lcaAncestorCount = ancestorCount
}
ancestors[parent.ID()] = true
ancestors[parent.ID()] = ancestorCount + 1
parent, found = parent.Parent()
}
aliasExpr := ctx.NewIdent(inlineVar.Alias())
opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type())
}

// Copy the inlined AST expr and source info.
copyExpr := copyASTAndMetadata(ctx, inlineVar.def)
// Update the least common ancestor by inserting a cel.bind() call to the alias.
inlined := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), inlineVar.Expr(), lca)
inlined, bindMacro := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), copyExpr, lca)
opt.inlineExpr(ctx, lca, inlined, inlineVar.Type())
ctx.sourceInfo.SetMacroCall(lca.ID(), bindMacro)
}
return a
}

// copyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being
// optimized.
func copyASTAndMetadata(ctx *OptimizerContext, a *ast.AST) ast.Expr {
copyExpr, copyInfo := ctx.CopyAST(a)
// Add in the macro calls from the inlined AST
for id, call := range copyInfo.MacroCalls() {
ctx.sourceInfo.SetMacroCall(id, call)
}
return copyExpr
}

// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining
// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is
// made to determine whether the inlined value can be presence or existence tested.
func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) {
switch prev.Kind() {
case ast.SelectKind:
sel := prev.AsSelect()
if !sel.IsTestOnly() {
prev.SetKindCase(inlined)
ctx.UpdateExpr(prev, inlined)
return
}
opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType)
default:
prev.SetKindCase(inlined)
ctx.UpdateExpr(prev, inlined)
}
}

Expand All @@ -146,23 +165,24 @@ func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev, inlined as
func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
// If the input inlined expression is not a select expression it won't work with the has()
// macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error.
ctx.sourceInfo.ClearMacroCall(prev.ID())
if inlined.Kind() == ast.SelectKind {
inlinedSel := inlined.AsSelect()
prev.SetKindCase(
ctx.NewPresenceTest(prev.ID(), inlinedSel.Operand(), inlinedSel.FieldName()))
presenceTest, hasMacro := ctx.NewHasMacro(prev.ID(), inlined)
ctx.UpdateExpr(prev, presenceTest)
ctx.sourceInfo.SetMacroCall(prev.ID(), hasMacro)
return
}

ctx.sourceInfo.ClearMacroCall(prev.ID())
if inlinedType.IsAssignableType(NullType) {
prev.SetKindCase(
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
inlined,
ctx.NewLiteral(types.NullValue),
))
return
}
if inlinedType.HasTrait(traits.SizerType) {
prev.SetKindCase(
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
ctx.NewMemberCall(overloads.Size, inlined),
ctx.NewLiteral(types.IntZero),
Expand Down
Loading

0 comments on commit 1460938

Please sign in to comment.