diff --git a/cel/cel_test.go b/cel/cel_test.go index e80a3288..b27bd6d1 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1195,6 +1195,14 @@ func TestResidualAstMacros(t *testing.T) { } } +func TestResidualAstNil(t *testing.T) { + env := testEnv(t) + ast, err := env.ResidualAst(nil, nil) + if err == nil || !strings.Contains(err.Error(), "unsupported expr") { + t.Errorf("env.ResidualAst() got (%v, %v) wanted unsupported expr error", ast, err) + } +} + func BenchmarkEvalOptions(b *testing.B) { env := testEnv(b, Variable("ai", IntType), @@ -1323,7 +1331,7 @@ func TestEnvExtensionIsolation(t *testing.T) { func TestVariadicLogicalOperators(t *testing.T) { env := testEnv(t, variadicLogicalOperatorASTs()) ast, iss := env.Compile( - `(false || false || false || false || true) && + `(false || false || false || false || true) && (true && true && true && true && false)`) if iss.Err() != nil { t.Fatalf("Compile() failed: %v", iss.Err()) @@ -2293,7 +2301,7 @@ func TestOptionalValuesCompile(t *testing.T) { if iss.Err() != nil { t.Fatalf("%v failed: %v", tc.expr, iss.Err()) } - for id, reference := range ast.impl.ReferenceMap() { + for id, reference := range ast.NativeRep().ReferenceMap() { other, found := tc.references[id] if !found { t.Errorf("Compile(%v) expected reference %d: %v", tc.expr, id, reference) @@ -2955,6 +2963,15 @@ func BenchmarkDynamicDispatch(b *testing.B) { }) } +func TestAstProgramNilValue(t *testing.T) { + var ast *Ast = nil + env := testEnv(t) + prg, err := env.Program(ast) + if err == nil || !strings.Contains(err.Error(), "unsupported expr") { + t.Errorf("env.Program() got (%v,%v) wanted unsupported expr error", prg, err) + } +} + // TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package type testCostEstimator struct { hints map[string]uint64 diff --git a/cel/env.go b/cel/env.go index ab736b77..caee8e8c 100644 --- a/cel/env.go +++ b/cel/env.go @@ -556,7 +556,8 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) { // TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an // Ast format and then Program again. func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { - pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State()) + ast := a.NativeRep() + pruned := interpreter.PruneAst(ast.Expr(), ast.SourceInfo().MacroCalls(), details.State()) newAST := &Ast{source: a.Source(), impl: pruned} expr, err := AstToString(newAST) if err != nil { @@ -582,7 +583,7 @@ func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...ch extendedOpts := make([]checker.CostOption, 0, len(e.costOptions)) extendedOpts = append(extendedOpts, opts...) extendedOpts = append(extendedOpts, e.costOptions...) - return checker.Cost(ast.impl, estimator, extendedOpts...) + return checker.Cost(ast.NativeRep(), estimator, extendedOpts...) } // configure applies a series of EnvOptions to the current environment. diff --git a/cel/folding_test.go b/cel/folding_test.go index 52357cc3..3f24f50e 100644 --- a/cel/folding_test.go +++ b/cel/folding_test.go @@ -650,11 +650,11 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) { t.Fatalf("Compile() failed: %v", iss.Err()) } preOpt := newIDCollector() - ast.PostOrderVisit(checked.impl.Expr(), preOpt) + ast.PostOrderVisit(checked.NativeRep().Expr(), preOpt) if !reflect.DeepEqual(preOpt.IDs(), tc.ids) { t.Errorf("Compile() got ids %v, expected %v", preOpt.IDs(), tc.ids) } - for id, call := range checked.impl.SourceInfo().MacroCalls() { + for id, call := range checked.NativeRep().SourceInfo().MacroCalls() { macroText, found := tc.macros[id] if !found { t.Fatalf("Compile() did not find macro %d", id) @@ -682,11 +682,11 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) } postOpt := newIDCollector() - ast.PostOrderVisit(optimized.impl.Expr(), postOpt) + ast.PostOrderVisit(optimized.NativeRep().Expr(), postOpt) if !reflect.DeepEqual(postOpt.IDs(), tc.normalizedIDs) { t.Errorf("Optimize() got ids %v, expected %v", postOpt.IDs(), tc.normalizedIDs) } - for id, call := range optimized.impl.SourceInfo().MacroCalls() { + for id, call := range optimized.NativeRep().SourceInfo().MacroCalls() { macroText, found := tc.normalizedMacros[id] if !found { t.Fatalf("Optimize() did not find macro %d", id) diff --git a/cel/inlining.go b/cel/inlining.go index 78d5bea6..a4530e19 100644 --- a/cel/inlining.go +++ b/cel/inlining.go @@ -60,7 +60,7 @@ func NewInlineVariable(name string, definition *Ast) *InlineVariable { // If the variable occurs more than once, the provided alias will be used to replace the expressions // where the variable name occurs. func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable { - return &InlineVariable{name: name, alias: alias, def: definition.impl} + return &InlineVariable{name: name, alias: alias, def: definition.NativeRep()} } // NewInliningOptimizer creates and optimizer which replaces variables with expression definitions. diff --git a/cel/io.go b/cel/io.go index dd010a99..a327c967 100644 --- a/cel/io.go +++ b/cel/io.go @@ -62,7 +62,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) { if !a.IsChecked() { return nil, fmt.Errorf("cannot convert unchecked ast") } - return ast.ToProto(a.impl) + return ast.ToProto(a.NativeRep()) } // ParsedExprToAst converts a parsed expression proto message to an Ast. @@ -99,7 +99,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) { // Note, the conversion may not be an exact replica of the original expression, but will produce // a string that is semantically equivalent and whose textual representation is stable. func AstToString(a *Ast) (string, error) { - return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo()) + return parser.Unparse(a.NativeRep().Expr(), a.NativeRep().SourceInfo()) } // RefValueToValue converts between ref.Val and google.api.expr.v1alpha1.Value. diff --git a/cel/io_test.go b/cel/io_test.go index d1710f4d..7bc34ee9 100644 --- a/cel/io_test.go +++ b/cel/io_test.go @@ -16,6 +16,7 @@ package cel import ( "fmt" + "strings" "testing" "time" @@ -144,6 +145,27 @@ func TestAstToString(t *testing.T) { } } +func TestAstToStringNil(t *testing.T) { + expr, err := AstToString(nil) + if err == nil || !strings.Contains(err.Error(), "unsupported expr") { + t.Errorf("env.AstToString() got (%v, %v) wanted unsupported expr error", expr, err) + } +} + +func TestAstToCheckedExprNil(t *testing.T) { + expr, err := AstToCheckedExpr(nil) + if err == nil || !strings.Contains(err.Error(), "cannot convert unchecked ast") { + t.Errorf("env.AstToCheckedExpr() got (%v, %v) wanted conversion error", expr, err) + } +} + +func TestAstToParsedExprNil(t *testing.T) { + expr, err := AstToParsedExpr(nil) + if err != nil { + t.Errorf("env.AstToParsedExpr() got (%v, %v) wanted conversion error", expr, err) + } +} + func TestCheckedExprToAstConstantExpr(t *testing.T) { stdEnv, err := NewEnv() if err != nil { diff --git a/cel/optimizer.go b/cel/optimizer.go index c149abb7..9a2a97a6 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -48,8 +48,8 @@ func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { // If issues are encountered, the Issues.Err() return value will be non-nil. func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { // Make a copy of the AST to be optimized. - optimized := ast.Copy(a.impl) - ids := newIDGenerator(ast.MaxID(a.impl)) + optimized := ast.Copy(a.NativeRep()) + ids := newIDGenerator(ast.MaxID(a.NativeRep())) // Create the optimizer context, could be pooled in the future. issues := NewIssues(common.NewErrors(a.Source())) @@ -86,7 +86,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { if iss.Err() != nil { return nil, iss } - optimized = checked.impl + optimized = checked.NativeRep() } // Return the optimized result. diff --git a/cel/optimizer_test.go b/cel/optimizer_test.go index 2faac17e..f406e7c7 100644 --- a/cel/optimizer_test.go +++ b/cel/optimizer_test.go @@ -16,6 +16,7 @@ package cel_test import ( "sort" + "strings" "testing" "github.com/google/cel-go/cel" @@ -201,6 +202,15 @@ func TestStaticOptimizerNewAST(t *testing.T) { } } +func TestStaticOptimizerNilAST(t *testing.T) { + env := optimizerEnv(t) + opt := cel.NewStaticOptimizer(&identityOptimizer{t: t}) + optAST, iss := opt.Optimize(env, nil) + if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "unexpected unspecified type") { + t.Errorf("opt.Optimize(env, nil) got (%v, %v), wanted unexpected unspecified type", optAST, iss) + } +} + type identityOptimizer struct { t *testing.T } diff --git a/cel/program.go b/cel/program.go index 6f477afc..49bd5378 100644 --- a/cel/program.go +++ b/cel/program.go @@ -100,6 +100,9 @@ type EvalDetails struct { // State of the evaluation, non-nil if the OptTrackState or OptExhaustiveEval is specified // within EvalOptions. func (ed *EvalDetails) State() interpreter.EvalState { + if ed == nil { + return interpreter.NewEvalState() + } return ed.state } diff --git a/common/errors.go b/common/errors.go index 25adc73d..89570683 100644 --- a/common/errors.go +++ b/common/errors.go @@ -30,9 +30,13 @@ type Errors struct { // NewErrors creates a new instance of the Errors type. func NewErrors(source Source) *Errors { + src := source + if src == nil { + src = NewTextSource("") + } return &Errors{ errors: []*Error{}, - source: source, + source: src, maxErrorsToReport: 100, } }