Skip to content

Commit

Permalink
Improved support for nested rules (#991)
Browse files Browse the repository at this point in the history
* Improved support for nested rules
* Add nil safety to cel.Ast and additional rule tests
  • Loading branch information
TristonianJones authored Aug 2, 2024
1 parent 5bcdb8b commit 3545aac
Show file tree
Hide file tree
Showing 15 changed files with 537 additions and 63 deletions.
22 changes: 11 additions & 11 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ type Ast struct {

// NativeRep converts the AST to a Go-native representation.
func (ast *Ast) NativeRep() *celast.AST {
if ast == nil {
return nil
}
return ast.impl
}

Expand All @@ -55,24 +58,21 @@ func (ast *Ast) Expr() *exprpb.Expr {
if ast == nil {
return nil
}
pbExpr, _ := celast.ExprToProto(ast.impl.Expr())
pbExpr, _ := celast.ExprToProto(ast.NativeRep().Expr())
return pbExpr
}

// IsChecked returns whether the Ast value has been successfully type-checked.
func (ast *Ast) IsChecked() bool {
if ast == nil {
return false
}
return ast.impl.IsChecked()
return ast.NativeRep().IsChecked()
}

// SourceInfo returns character offset and newline position information about expression elements.
func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
if ast == nil {
return nil
}
pbInfo, _ := celast.SourceInfoToProto(ast.impl.SourceInfo())
pbInfo, _ := celast.SourceInfoToProto(ast.NativeRep().SourceInfo())
return pbInfo
}

Expand All @@ -95,7 +95,7 @@ func (ast *Ast) OutputType() *Type {
if ast == nil {
return types.ErrorType
}
return ast.impl.GetType(ast.impl.Expr().ID())
return ast.NativeRep().GetType(ast.NativeRep().Expr().ID())
}

// Source returns a view of the input used to create the Ast. This source may be complete or
Expand Down Expand Up @@ -218,12 +218,12 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.NativeRep().SourceInfo())
}

checked, errs := checker.Check(ast.impl, ast.Source(), chk)
checked, errs := checker.Check(ast.NativeRep(), ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
return nil, NewIssuesWithSourceInfo(errs, ast.NativeRep().SourceInfo())
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
Expand All @@ -244,7 +244,7 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
}
}
// Apply additional validators on the type-checked result.
iss := NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo())
iss := NewIssuesWithSourceInfo(errs, ast.NativeRep().SourceInfo())
for _, v := range e.validators {
v.Validate(e, vConfig, checked, iss)
}
Expand Down
1 change: 1 addition & 0 deletions policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ go_library(
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//ext:go_default_library",
"@in_gopkg_yaml_v3//:go_default_library",
],
Expand Down
89 changes: 75 additions & 14 deletions policy/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)

// CompiledRule represents the variables and match blocks associated with a rule block.
type CompiledRule struct {
exprID int64
id *ValueString
variables []*CompiledVariable
matches []*CompiledMatch
}

// SourceID returns the source metadata identifier associated with the compiled rule.
func (r *CompiledRule) SourceID() int64 {
return r.ID().ID
return r.exprID
}

// ID returns the expression id associated with the rule.
Expand Down Expand Up @@ -63,17 +65,33 @@ func (r *CompiledRule) OutputType() *cel.Type {
return cel.DynType
}

// HasOptionalOutput returns whether the rule returns a concrete or optional value.
// The rule may return an optional value if all match expressions under the rule are conditional.
func (r *CompiledRule) HasOptionalOutput() bool {
optionalOutput := false
for _, m := range r.Matches() {
if m.NestedRule() != nil && m.NestedRule().HasOptionalOutput() {
return true
}
if m.ConditionIsLiteral(types.True) {
return false
}
optionalOutput = true
}
return optionalOutput
}

// CompiledVariable represents the variable name, expression, and associated type-check declaration.
type CompiledVariable struct {
id int64
exprID int64
name string
expr *cel.Ast
varDecl *decls.VariableDecl
}

// SourceID returns the source metadata identifier associated with the variable.
func (v *CompiledVariable) SourceID() int64 {
return v.id
return v.exprID
}

// Name returns the variable name.
Expand All @@ -94,17 +112,29 @@ func (v *CompiledVariable) Declaration() *decls.VariableDecl {
// CompiledMatch represents a match block which has an optional condition (true, by default) as well
// as an output or a nested rule (one or the other, but not both).
type CompiledMatch struct {
exprID int64
cond *cel.Ast
output *OutputValue
nestedRule *CompiledRule
}

// SourceID returns the source identifier associated with the compiled match.
func (m *CompiledMatch) SourceID() int64 {
return m.exprID
}

// Condition returns the compiled predicate expression which must evaluate to true before the output
// or subrule is entered.
func (m *CompiledMatch) Condition() *cel.Ast {
return m.cond
}

// ConditionIsLiteral indicates whether the condition for the match is a literal with a given value.
func (m *CompiledMatch) ConditionIsLiteral(val ref.Val) bool {
c := m.cond.NativeRep().Expr()
return c.Kind() == ast.LiteralKind && c.AsLiteral().Equal(val) == types.True
}

// Output returns the compiled output expression associated with the match block, if set.
func (m *CompiledMatch) Output() *OutputValue {
return m.output
Expand All @@ -128,13 +158,13 @@ func (m *CompiledMatch) OutputType() *cel.Type {

// OutputValue represents the output expression associated with a match block.
type OutputValue struct {
id int64
expr *cel.Ast
exprID int64
expr *cel.Ast
}

// ID returns the expression id associated with the output expression.
func (o *OutputValue) ID() int64 {
return o.id
// SourceID returns the expression id associated with the output expression.
func (o *OutputValue) SourceID() int64 {
return o.exprID
}

// Expr returns the compiled expression associated with the output.
Expand Down Expand Up @@ -229,7 +259,7 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
iss.ReportErrorAtID(v.Expression().ID, "invalid variable declaration")
}
compiledVar := &CompiledVariable{
id: v.name.ID,
exprID: v.name.ID,
name: v.name.Value,
expr: varAST,
varDecl: varDecl,
Expand Down Expand Up @@ -261,10 +291,11 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
outAST, outIss := ruleEnv.CompileSource(outSrc)
iss = iss.Append(outIss)
compiledMatches = append(compiledMatches, &CompiledMatch{
cond: condAST,
exprID: m.exprID,
cond: condAST,
output: &OutputValue{
id: m.Output().ID,
expr: outAST,
exprID: m.Output().ID,
expr: outAST,
},
})
continue
Expand All @@ -273,6 +304,7 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
nestedRule, ruleIss := c.compileRule(m.Rule(), ruleEnv, iss)
iss = iss.Append(ruleIss)
compiledMatches = append(compiledMatches, &CompiledMatch{
exprID: m.exprID,
cond: condAST,
nestedRule: nestedRule,
})
Expand All @@ -285,13 +317,20 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*Com
}
}

// Validate that all branches in the rule are reachable
rule := &CompiledRule{
exprID: r.exprID,
id: r.id,
variables: compiledVars,
matches: compiledMatches,
}

// Note: Consider supporting configurable policy validators that take the policy, rule, and issues
// Validate type agreement between the different match outputs
c.checkMatchOutputTypesAgree(rule, iss)
// Validate that all branches in the policy are reachable
c.checkUnreachableCode(rule, iss)

return rule, iss
}

Expand All @@ -309,13 +348,35 @@ func (c *compiler) checkMatchOutputTypesAgree(rule *CompiledRule, iss *cel.Issue
if matchOutputType.TypeName() == "error" {
continue
}
if !outputType.IsAssignableType(matchOutputType) {
iss.ReportErrorAtID(m.Output().ID(), "incompatible output types: %s not assignable to %s", outputType, matchOutputType)
// Handle assignability as the output type is assignable to the match output or vice versa.
// During composition, this is roughly how the type-checker will handle the type agreement check.
if !(outputType.IsAssignableType(matchOutputType) || matchOutputType.IsAssignableType(outputType)) {
iss.ReportErrorAtID(m.Output().SourceID(), "incompatible output types: %s not assignable to %s", outputType, matchOutputType)
return
}
}
}

func (c *compiler) checkUnreachableCode(rule *CompiledRule, iss *cel.Issues) {
ruleHasOptional := rule.HasOptionalOutput()
compiledMatches := rule.Matches()
matchCount := len(compiledMatches)
for i := matchCount - 1; i >= 0; i-- {
m := compiledMatches[i]
triviallyTrue := m.ConditionIsLiteral(types.True)

if triviallyTrue && !ruleHasOptional && i != matchCount-1 {
if m.Output() != nil {
iss.ReportErrorAtID(m.SourceID(), "match creates unreachable outputs")
}
if m.NestedRule() != nil {
iss.ReportErrorAtID(m.NestedRule().SourceID(), "rule creates unreachable outputs")
}
break
}
}
}

func (c *compiler) relSource(pstr ValueString) *RelativeSource {
line := 0
col := 1
Expand Down
61 changes: 55 additions & 6 deletions policy/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,49 @@ func TestCompileError(t *testing.T) {
}
}

func TestCompiledRuleHasOptionalOutput(t *testing.T) {
env, err := cel.NewEnv()
if err != nil {
t.Fatalf("cel.NewEnv() failed: %v", err)
}
tests := []struct {
rule *CompiledRule
optional bool
}{
{rule: &CompiledRule{}, optional: false},
{
rule: &CompiledRule{
matches: []*CompiledMatch{{}},
},
optional: true,
},
{
rule: &CompiledRule{
matches: []*CompiledMatch{{}},
},
optional: true,
},
{
rule: &CompiledRule{
matches: []*CompiledMatch{{cond: mustCompileExpr(t, env, "true")}},
},
optional: false,
},
{
rule: &CompiledRule{
matches: []*CompiledMatch{{cond: mustCompileExpr(t, env, "1 < 0")}},
},
optional: true,
},
}
for _, tst := range tests {
got := tst.rule.HasOptionalOutput()
if got != tst.optional {
t.Errorf("rule.HasOptionalOutput() got %v, wanted, %v", got, tst.optional)
}
}
}

func BenchmarkCompile(b *testing.B) {
for _, tst := range policyTests {
r := newRunner(b, tst.name, tst.expr, tst.parseOpts, tst.envOpts...)
Expand Down Expand Up @@ -70,7 +113,17 @@ type runner struct {
prg cel.Program
}

func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast {
t.Helper()
out, iss := env.Compile(expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err())
}
return out
}

func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
t.Helper()
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name))
srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name))
parser, err := NewParser(parseOpts...)
Expand Down Expand Up @@ -158,12 +211,8 @@ func (r *runner) run(t *testing.T) {
} else if testOut.Equal(optOut.GetValue()) != types.True {
t.Errorf("policy eval got %v, wanted %v", out, testOut)
}
} else if boolOut, ok := out.(types.Bool); ok {
if testOut.Equal(boolOut) != types.True {
t.Errorf("policy eval got %v, wanted %v", boolOut, testOut)
}
} else {
t.Errorf("unexpected policy output type %v", out)
} else if testOut.Equal(out) != types.True {
t.Errorf("policy eval got %v, wanted %v", out, testOut)
}
})
}
Expand Down
Loading

0 comments on commit 3545aac

Please sign in to comment.