From 7f2b87a579708543d0a9632d6294461399abf168 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 19 Oct 2021 14:37:51 -0700 Subject: [PATCH] Add unparsing support for macro calls (#458) --- common/source.go | 12 -- parser/helper.go | 76 +++++----- parser/parser_test.go | 44 ++++-- parser/unparser.go | 38 ++--- parser/unparser_test.go | 317 +++++++++++++++++++++++++++------------- 5 files changed, 310 insertions(+), 177 deletions(-) diff --git a/common/source.go b/common/source.go index 090667af..52377d93 100644 --- a/common/source.go +++ b/common/source.go @@ -39,10 +39,6 @@ type Source interface { // and second line, or EOF if there is only one line of source. LineOffsets() []int32 - // Macro calls returns the macro calls map containing the original - // expression from a macro replacement, keyed by Id. - MacroCalls() map[int64]*exprpb.Expr - // LocationOffset translates a Location to an offset. // Given the line and column of the Location returns the // Location's character offset in the Source, and a bool @@ -69,7 +65,6 @@ type sourceImpl struct { description string lineOffsets []int32 idOffsets map[int64]int32 - macroCalls map[int64]*exprpb.Expr } var _ runes.Buffer = &sourceImpl{} @@ -98,7 +93,6 @@ func NewStringSource(contents string, description string) Source { description: description, lineOffsets: offsets, idOffsets: map[int64]int32{}, - macroCalls: map[int64]*exprpb.Expr{}, } } @@ -109,7 +103,6 @@ func NewInfoSource(info *exprpb.SourceInfo) Source { description: info.GetLocation(), lineOffsets: info.GetLineOffsets(), idOffsets: info.GetPositions(), - macroCalls: info.GetMacroCalls(), } } @@ -128,11 +121,6 @@ func (s *sourceImpl) LineOffsets() []int32 { return s.lineOffsets } -// MacroCalls implements the Source interface method. -func (s *sourceImpl) MacroCalls() map[int64]*exprpb.Expr { - return s.macroCalls -} - // LocationOffset implements the Source interface method. func (s *sourceImpl) LocationOffset(location Location) (int32, bool) { if lineOffset, found := s.findLineOffset(location.Line()); found { diff --git a/parser/helper.go b/parser/helper.go index 7b10ff37..6abb4694 100644 --- a/parser/helper.go +++ b/parser/helper.go @@ -24,16 +24,18 @@ import ( ) type parserHelper struct { - source common.Source - nextID int64 - positions map[int64]int32 + source common.Source + nextID int64 + positions map[int64]int32 + macroCalls map[int64]*exprpb.Expr } func newParserHelper(source common.Source) *parserHelper { return &parserHelper{ - source: source, - nextID: 1, - positions: make(map[int64]int32), + source: source, + nextID: 1, + positions: make(map[int64]int32), + macroCalls: make(map[int64]*exprpb.Expr), } } @@ -42,7 +44,7 @@ func (p *parserHelper) getSourceInfo() *exprpb.SourceInfo { Location: p.source.Description(), Positions: p.positions, LineOffsets: p.source.LineOffsets(), - MacroCalls: p.source.MacroCalls()} + MacroCalls: p.macroCalls} } func (p *parserHelper) newLiteral(ctx interface{}, value *exprpb.Constant) *exprpb.Expr { @@ -211,27 +213,34 @@ func (p *parserHelper) getLocation(id int64) common.Location { // buildMacroCallArg iterates the expression and returns a new expression // where all macros have been replaced by their IDs in MacroCalls func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr { - resultExpr := &exprpb.Expr{Id: expr.GetId()} - if _, found := p.source.MacroCalls()[expr.GetId()]; found { - return resultExpr + if _, found := p.macroCalls[expr.GetId()]; found { + return &exprpb.Expr{Id: expr.GetId()} } switch expr.ExprKind.(type) { case *exprpb.Expr_CallExpr: - resultExpr.ExprKind = &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Function: expr.GetCallExpr().GetFunction(), - }, - } - resultExpr.GetCallExpr().Args = make([]*exprpb.Expr, len(expr.GetCallExpr().GetArgs())) // Iterate the AST from `expr` recursively looking for macros. Because we are at most // starting from the top level macro, this recursion is bounded by the size of the AST. This // means that the depth check on the AST during parsing will catch recursion overflows // before we get to here. + macroTarget := expr.GetCallExpr().GetTarget() + if macroTarget != nil { + macroTarget = p.buildMacroCallArg(macroTarget) + } + macroArgs := make([]*exprpb.Expr, len(expr.GetCallExpr().GetArgs())) for index, arg := range expr.GetCallExpr().GetArgs() { - resultExpr.GetCallExpr().GetArgs()[index] = p.buildMacroCallArg(arg) + macroArgs[index] = p.buildMacroCallArg(arg) + } + return &exprpb.Expr{ + Id: expr.GetId(), + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Target: macroTarget, + Function: expr.GetCallExpr().GetFunction(), + Args: macroArgs, + }, + }, } - return resultExpr } return expr @@ -240,28 +249,27 @@ func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr { // addMacroCall adds the macro the the MacroCalls map in source info. If a macro has args/subargs/target // that are macros, their ID will be stored instead for later self-lookups. func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) { - expr := &exprpb.Expr{ - Id: exprID, - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Function: function, - }, - }, - } - + macroTarget := target if target != nil { - if _, found := p.source.MacroCalls()[target.GetId()]; found { - expr.GetCallExpr().Target = &exprpb.Expr{Id: target.GetId()} - } else { - expr.GetCallExpr().Target = target + if _, found := p.macroCalls[target.GetId()]; found { + macroTarget = &exprpb.Expr{Id: target.GetId()} } } - expr.GetCallExpr().Args = make([]*exprpb.Expr, len(args)) + macroArgs := make([]*exprpb.Expr, len(args)) for index, arg := range args { - expr.GetCallExpr().GetArgs()[index] = p.buildMacroCallArg(arg) + macroArgs[index] = p.buildMacroCallArg(arg) + } + + p.macroCalls[exprID] = &exprpb.Expr{ + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Target: macroTarget, + Function: function, + Args: macroArgs, + }, + }, } - p.source.MacroCalls()[exprID] = expr } // balancer performs tree balancing on operators whose arguments are of equal precedence. diff --git a/parser/parser_test.go b/parser/parser_test.go index aa376e73..690c453c 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1465,6 +1465,19 @@ var testCases = []testInfo{ z^#8:*expr.Expr_IdentExpr#.a^#9:*expr.Expr_SelectExpr# )^#10:has#`, }, + { + I: `(has(a.b) || has(c.d)).string()`, + P: `_||_( + a^#2:*expr.Expr_IdentExpr#.b~test-only~^#4:*expr.Expr_SelectExpr#, + c^#6:*expr.Expr_IdentExpr#.d~test-only~^#8:*expr.Expr_SelectExpr# + )^#9:*expr.Expr_CallExpr#.string()^#10:*expr.Expr_CallExpr#`, + M: `has( + c^#6:*expr.Expr_IdentExpr#.d^#7:*expr.Expr_SelectExpr# + )^#8:has#, + has( + a^#2:*expr.Expr_IdentExpr#.b^#3:*expr.Expr_SelectExpr# + )^#4:has#`, + }, } type testInfo struct { @@ -1496,9 +1509,10 @@ func (k *kindAndIDAdorner) GetMetadata(elem interface{}) string { switch elem.(type) { case *exprpb.Expr: e := elem.(*exprpb.Expr) - if k.sourceInfo != nil { - if val, found := k.sourceInfo.MacroCalls[e.GetId()]; found { - return fmt.Sprintf("^#%d:%s#", e.Id, val.GetCallExpr().GetFunction()) + macroCalls := k.sourceInfo.GetMacroCalls() + if macroCalls != nil { + if val, found := macroCalls[e.GetId()]; found { + return fmt.Sprintf("^#%d:%s#", e.GetId(), val.GetCallExpr().GetFunction()) } } var valType interface{} = e.ExprKind @@ -1552,10 +1566,11 @@ func (l *locationAdorner) GetMetadata(elem interface{}) string { } func convertMacroCallsToString(source *exprpb.SourceInfo) string { - keys := make([]int64, len(source.GetMacroCalls())) - adornedStrings := make([]string, len(source.GetMacroCalls())) + macroCalls := source.GetMacroCalls() + keys := make([]int64, len(macroCalls)) + adornedStrings := make([]string, len(macroCalls)) i := 0 - for k := range source.GetMacroCalls() { + for k := range macroCalls { keys[i] = k i++ } @@ -1563,7 +1578,14 @@ func convertMacroCallsToString(source *exprpb.SourceInfo) string { sort.Slice(keys, func(i, j int) bool { return keys[i] > keys[j] }) i = 0 for _, key := range keys { - adornedStrings[i] = debug.ToAdornedDebugString(source.GetMacroCalls()[int64(key)], &kindAndIDAdorner{sourceInfo: source}) + call := macroCalls[int64(key)] + callWithID := &exprpb.Expr{ + Id: int64(key), + ExprKind: call.GetExprKind(), + } + adornedStrings[i] = debug.ToAdornedDebugString( + callWithID, + &kindAndIDAdorner{sourceInfo: source}) i++ } return strings.Join(adornedStrings, ",\n") @@ -1591,7 +1613,7 @@ func TestParse(t *testing.T) { tt.Parallel() src := common.NewTextSource(tc.I) - expression, errors := p.Parse(src) + parsedExpr, errors := p.Parse(src) if len(errors.GetErrors()) > 0 { actualErr := errors.ToDisplayString() if tc.E == "" { @@ -1604,20 +1626,20 @@ func TestParse(t *testing.T) { tt.Fatalf("Expected error not thrown: '%s'", tc.E) } failureDisplayMethod := fmt.Sprintf("Parse(\"%s\")", tc.I) - actualWithKind := debug.ToAdornedDebugString(expression.Expr, &kindAndIDAdorner{}) + actualWithKind := debug.ToAdornedDebugString(parsedExpr.GetExpr(), &kindAndIDAdorner{}) if !test.Compare(actualWithKind, tc.P) { tt.Fatal(test.DiffMessage(fmt.Sprintf("Structure - %s", failureDisplayMethod), actualWithKind, tc.P)) } if tc.L != "" { - actualWithLocation := debug.ToAdornedDebugString(expression.Expr, &locationAdorner{expression.GetSourceInfo()}) + actualWithLocation := debug.ToAdornedDebugString(parsedExpr.GetExpr(), &locationAdorner{parsedExpr.GetSourceInfo()}) if !test.Compare(actualWithLocation, tc.L) { tt.Fatal(test.DiffMessage(fmt.Sprintf("Location - %s", failureDisplayMethod), actualWithLocation, tc.L)) } } if tc.M != "" { - actualAdornedMacroCalls := convertMacroCallsToString(expression.GetSourceInfo()) + actualAdornedMacroCalls := convertMacroCallsToString(parsedExpr.GetSourceInfo()) if !test.Compare(actualAdornedMacroCalls, tc.M) { tt.Fatal(test.DiffMessage(fmt.Sprintf("Macro Calls - %s", failureDisplayMethod), actualAdornedMacroCalls, tc.M)) } diff --git a/parser/unparser.go b/parser/unparser.go index 654f4b95..6a610ff7 100644 --- a/parser/unparser.go +++ b/parser/unparser.go @@ -15,6 +15,7 @@ package parser import ( + "errors" "fmt" "strconv" "strings" @@ -46,19 +47,21 @@ func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo) (string, error) { // unparser visits an expression to reconstruct a human-readable string from an AST. type unparser struct { - str strings.Builder - offset int32 - // TODO: use the source info to rescontruct macros into function calls. + str strings.Builder info *exprpb.SourceInfo } func (un *unparser) visit(expr *exprpb.Expr) error { + if expr == nil { + return errors.New("unsupported expression") + } + visited, err := un.visitMaybeMacroCall(expr) + if visited || err != nil { + return err + } switch expr.ExprKind.(type) { case *exprpb.Expr_CallExpr: return un.visitCall(expr) - // TODO: Comprehensions are currently not supported. - case *exprpb.Expr_ComprehensionExpr: - return un.visitComprehension(expr) case *exprpb.Expr_ConstExpr: return un.visitConst(expr) case *exprpb.Expr_IdentExpr: @@ -69,8 +72,9 @@ func (un *unparser) visit(expr *exprpb.Expr) error { return un.visitSelect(expr) case *exprpb.Expr_StructExpr: return un.visitStruct(expr) + default: + return fmt.Errorf("unsupported expression: %v", expr) } - return fmt.Errorf("unsupported expr: %v", expr) } func (un *unparser) visitCall(expr *exprpb.Expr) error { @@ -220,12 +224,6 @@ func (un *unparser) visitCallUnary(expr *exprpb.Expr) error { return un.visitMaybeNested(args[0], nested) } -func (un *unparser) visitComprehension(expr *exprpb.Expr) error { - // TODO: introduce a macro expansion map between the top-level comprehension id and the - // function call that the macro replaces. - return fmt.Errorf("unimplemented : %v", expr) -} - func (un *unparser) visitConst(expr *exprpb.Expr) error { c := expr.GetConstExpr() switch c.ConstantKind.(type) { @@ -255,7 +253,7 @@ func (un *unparser) visitConst(expr *exprpb.Expr) error { un.str.WriteString(ui) un.str.WriteString("u") default: - return fmt.Errorf("unimplemented : %v", expr) + return fmt.Errorf("unsupported constant: %v", expr) } return nil } @@ -357,6 +355,15 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error { return nil } +func (un *unparser) visitMaybeMacroCall(expr *exprpb.Expr) (bool, error) { + macroCalls := un.info.GetMacroCalls() + call, found := macroCalls[expr.GetId()] + if !found { + return false, nil + } + return true, un.visit(call) +} + func (un *unparser) visitMaybeNested(expr *exprpb.Expr, nested bool) error { if nested { un.str.WriteString("(") @@ -395,9 +402,6 @@ func isSamePrecedence(op string, expr *exprpb.Expr) bool { // // If the expr is not a Call, the result is false. func isLowerPrecedence(op string, expr *exprpb.Expr) bool { - if expr.GetCallExpr() == nil { - return false - } c := expr.GetCallExpr() other := c.GetFunction() return operators.Precedence(op) < operators.Precedence(other) diff --git a/parser/unparser_test.go b/parser/unparser_test.go index 734477c1..cd2c6668 100644 --- a/parser/unparser_test.go +++ b/parser/unparser_test.go @@ -15,134 +15,245 @@ package parser import ( + "errors" + "strings" "testing" "github.com/google/cel-go/common" "google.golang.org/protobuf/proto" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) -func TestUnparse_Identical(t *testing.T) { - tests := map[string]string{ - "call_add": `a + b - c`, - "call_and": `a && b && c && d && e`, - "call_and_or": `a || b && (c || d) && e`, - "call_cond": `a ? b : c`, - "call_index": `a[1]["b"]`, - "call_index_eq": `x["a"].single_int32 == 23`, - "call_mul": `a * (b / c) % 0`, - "call_mul_add": `a + b * c`, - "call_mul_add_nested": `(a + b) * c / (d - e)`, - "call_mul_nested": `a * b / c % 0`, - "call_not": `!true`, - "call_neg": `-num`, - "call_or": `a || b || c || d || e`, - "call_neg_mult": `-(1 * 2)`, - "call_neg_add": `-(1 + 2)`, - "calc_distr_paren": `(1 + 2) * 3`, - "calc_distr_noparen": `1 + 2 * 3`, - "cond_tern_simple": `(x > 5) ? (x - 5) : 0`, - "cond_tern_neg_expr": `-((x > 5) ? (x - 5) : 0)`, - "cond_tern_neg_term": `-x ? (x - 5) : 0`, - "func_global": `size(a ? (b ? c : d) : e)`, - "func_member": `a.hello("world")`, - "func_no_arg": `zero()`, - "func_one_arg": `one("a")`, - "func_two_args": `and(d, 32u)`, - "func_var_args": `max(a, b, 100)`, - "func_neq": `x != "a"`, - "func_in": `a in b`, - "list_empty": `[]`, - "list_one": `[1]`, - "list_many": `["hello, world", "goodbye, world", "sure, why not?"]`, - "lit_bytes": `b"\303\203\302\277"`, - "lit_double": `-42.101`, - "lit_false": `false`, - "lit_int": `-405069`, - "lit_null": `null`, - "lit_string": `"hello:\t'world'"`, - "lit_true": `true`, - "lit_uint": `42u`, - "ident": `my_ident`, - "macro_has": `has(hello.world)`, - "map_empty": `{}`, - "map_lit_key": `{"a": a.b.c, b"\142": bytes(a.b.c)}`, - "map_expr_key": `{a: a, b: a.b, c: a.b.c, a ? b : c: false, a || b: true}`, - "msg_empty": `v1alpha1.Expr{}`, - "msg_fields": `v1alpha1.Expr{id: 1, call_expr: v1alpha1.Call_Expr{function: "name"}}`, - "select": `a.b.c`, - "idx_idx_sel": `a[b][c].name`, - "sel_expr_target": `(a + b).name`, - "sel_cond_target": `(a ? b : c).name`, - "idx_cond_target": `(a ? b : c)[0]`, - "cond_conj": `(a1 && a2) ? b : c`, - "cond_disj_conj": `a ? (b1 || b2) : (c1 && c2)`, - "call_cond_target": `(a ? b : c).method(d)`, - "cond_flat": `false && !true || false`, - "cond_paren": `false && (!true || false)`, - "cond_cond": `(false && !true || false) ? 2 : 3`, - "cond_binop": `(x < 5) ? x : 5`, - "cond_binop_binop": `(x > 5) ? (x - 5) : 0`, - "cond_cond_binop": `(x > 5) ? ((x > 10) ? (x - 10) : 5) : 0`, - //"comp_all": `[1, 2, 3].all(x, x > 0)`, - //"comp_exists": `[1, 2, 3].exists(x, x > 0)`, - //"comp_map": `[1, 2, 3].map(x, x >= 2, x * 4)`, - //"comp_exists_one": `[1, 2, 3].exists_one(x, x >= 2)`, +func TestUnparse(t *testing.T) { + tests := []struct { + name string + in string + out interface{} + requiresMacroCalls bool + }{ + {name: "call_add", in: `a + b - c`}, + {name: "call_and", in: `a && b && c && d && e`}, + {name: "call_and_or", in: `a || b && (c || d) && e`}, + {name: "call_cond", in: `a ? b : c`}, + {name: "call_index", in: `a[1]["b"]`}, + {name: "call_index_eq", in: `x["a"].single_int32 == 23`}, + {name: "call_mul", in: `a * (b / c) % 0`}, + {name: "call_mul_add", in: `a + b * c`}, + {name: "call_mul_add_nested", in: `(a + b) * c / (d - e)`}, + {name: "call_mul_nested", in: `a * b / c % 0`}, + {name: "call_not", in: `!true`}, + {name: "call_neg", in: `-num`}, + {name: "call_or", in: `a || b || c || d || e`}, + {name: "call_neg_mult", in: `-(1 * 2)`}, + {name: "call_neg_add", in: `-(1 + 2)`}, + {name: "call_operator_precedence", in: `1 - (2 == -1)`}, + {name: "calc_distr_paren", in: `(1 + 2) * 3`}, + {name: "calc_distr_noparen", in: `1 + 2 * 3`}, + {name: "cond_tern_simple", in: `(x > 5) ? (x - 5) : 0`}, + {name: "cond_tern_neg_expr", in: `-((x > 5) ? (x - 5) : 0)`}, + {name: "cond_tern_neg_term", in: `-x ? (x - 5) : 0`}, + {name: "func_global", in: `size(a ? (b ? c : d) : e)`}, + {name: "func_member", in: `a.hello("world")`}, + {name: "func_no_arg", in: `zero()`}, + {name: "func_one_arg", in: `one("a")`}, + {name: "func_two_args", in: `and(d, 32u)`}, + {name: "func_var_args", in: `max(a, b, 100)`}, + {name: "func_neq", in: `x != "a"`}, + {name: "func_in", in: `a in b`}, + {name: "list_empty", in: `[]`}, + {name: "list_one", in: `[1]`}, + {name: "list_many", in: `["hello, world", "goodbye, world", "sure, why not?"]`}, + {name: "lit_bytes", in: `b"\303\203\302\277"`}, + {name: "lit_double", in: `-42.101`}, + {name: "lit_false", in: `false`}, + {name: "lit_int", in: `-405069`}, + {name: "lit_null", in: `null`}, + {name: "lit_string", in: `"hello:\t'world'"`}, + {name: "lit_string_quote", in: `"hello:\"world\""`}, + {name: "lit_true", in: `true`}, + {name: "lit_uint", in: `42u`}, + {name: "ident", in: `my_ident`}, + {name: "macro_has", in: `has(hello.world)`}, + {name: "map_empty", in: `{}`}, + {name: "map_lit_key", in: `{"a": a.b.c, b"\142": bytes(a.b.c)}`}, + {name: "map_expr_key", in: `{a: a, b: a.b, c: a.b.c, a ? b : c: false, a || b: true}`}, + {name: "msg_empty", in: `v1alpha1.Expr{}`}, + {name: "msg_fields", in: `v1alpha1.Expr{id: 1, call_expr: v1alpha1.Call_Expr{function: "name"}}`}, + {name: "select", in: `a.b.c`}, + {name: "idx_idx_sel", in: `a[b][c].name`}, + {name: "sel_expr_target", in: `(a + b).name`}, + {name: "sel_cond_target", in: `(a ? b : c).name`}, + {name: "idx_cond_target", in: `(a ? b : c)[0]`}, + {name: "cond_conj", in: `(a1 && a2) ? b : c`}, + {name: "cond_disj_conj", in: `a ? (b1 || b2) : (c1 && c2)`}, + {name: "call_cond_target", in: `(a ? b : c).method(d)`}, + {name: "cond_flat", in: `false && !true || false`}, + {name: "cond_paren", in: `false && (!true || false)`}, + {name: "cond_cond", in: `(false && !true || false) ? 2 : 3`}, + {name: "cond_binop", in: `(x < 5) ? x : 5`}, + {name: "cond_binop_binop", in: `(x > 5) ? (x - 5) : 0`}, + {name: "cond_cond_binop", in: `(x > 5) ? ((x > 10) ? (x - 10) : 5) : 0`}, + + // Equivalent expressions form unparse which do not match the originals. + {name: "call_add_equiv", in: `a+b-c`, out: `a + b - c`}, + {name: "call_cond_equiv", in: `a ? b : c`, out: `a ? b : c`}, + {name: "call_index_equiv", in: `a[ 1 ]["b"]`, out: `a[1]["b"]`}, + {name: "call_or_and_equiv", in: `(false && !true) || false`, out: `false && !true || false`}, + {name: "call_not_not_equiv", in: `!!true`, out: `true`}, + {name: "call_cond_equiv", in: `(a || b ? c : d).e`, out: `((a || b) ? c : d).e`}, + {name: "lit_quote_bytes_equiv", in: `b'aaa"bbb'`, out: `b"\141\141\141\042\142\142\142"`}, + {name: "select_equiv", in: `a . b . c`, out: `a.b.c`}, + + // These expressions require macro call tracking to be enabled. + { + name: "comp_all", + in: `[1, 2, 3].all(x, x > 0)`, + requiresMacroCalls: true, + }, + { + name: "comp_exists", + in: `[1, 2, 3].exists(x, x > 0)`, + requiresMacroCalls: true, + }, + { + name: "comp_map", + in: `[1, 2, 3].map(x, x >= 2, x * 4)`, + requiresMacroCalls: true, + }, + { + name: "comp_exists_one", + in: `[1, 2, 3].exists_one(x, x >= 2)`, + requiresMacroCalls: true, + }, + { + name: "comp_nested", + in: `[[1], [2], [3]].map(x, x.filter(y, y > 1))`, + requiresMacroCalls: true, + }, + { + name: "comp_chained", + in: `[1, 2, 3].map(x, x >= 2, x * 4).filter(x, x <= 10)`, + requiresMacroCalls: true, + }, } - for name, in := range tests { - t.Run(name, func(tt *testing.T) { - p, iss := Parse(common.NewTextSource(in)) + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + prsr, err := NewParser( + Macros(AllMacros...), + PopulateMacroCalls(tc.requiresMacroCalls), + ) + if err != nil { + t.Fatalf("NewParser() failed: %v", err) + } + p, iss := prsr.Parse(common.NewTextSource(tc.in)) if len(iss.GetErrors()) > 0 { - tt.Fatal(iss.ToDisplayString()) + t.Fatalf("parser.Parse(%s) failed: %v", tc.in, iss.ToDisplayString()) } out, err := Unparse(p.GetExpr(), p.GetSourceInfo()) if err != nil { - tt.Error(err) + t.Fatalf("Unparse(%s) failed: %v", tc.in, err) + } + var want interface{} = tc.in + if tc.out != nil { + want = tc.out } - if out != in { - tt.Errorf("Got '%s', wanted '%s'", out, in) + if out != want { + t.Errorf("Unparse() got '%s', wanted '%s'", out, want) + } + p2, iss := prsr.Parse(common.NewTextSource(out)) + if len(iss.GetErrors()) > 0 { + t.Fatalf("parser.Parse(%s) roundtrip failed: %v", tc.in, iss.ToDisplayString()) } - p2, _ := Parse(common.NewTextSource(out)) before := p.GetExpr() after := p2.GetExpr() if !proto.Equal(before, after) { - tt.Errorf("Second parse differs from the first. Got '%v', wanted '%v'", - before, after) + t.Errorf("Roundtrip Parse() differs from original. Got '%v', wanted '%v'", before, after) } }) } } -func TestUnparse_Equivalent(t *testing.T) { - tests := map[string][]string{ - "call_add": {`a+b-c`, `a + b - c`}, - "call_cond": {`a ? b : c`, `a ? b : c`}, - "call_index": {`a[ 1 ]["b"]`, `a[1]["b"]`}, - "call_or_and": {`(false && !true) || false`, `false && !true || false`}, - "call_not_not": {`!!true`, `true`}, - "lit_quote_bytes": {`b'aaa"bbb'`, `b"\141\141\141\042\142\142\142"`}, - "select": {`a . b . c`, `a.b.c`}, +func TestUnparseErrors(t *testing.T) { + tests := []struct { + name string + in *exprpb.Expr + err error + }{ + {name: "empty_expr", in: &exprpb.Expr{}, err: errors.New("unsupported expression")}, + { + name: "bad_constant", + in: &exprpb.Expr{ + ExprKind: &exprpb.Expr_ConstExpr{ + ConstExpr: &exprpb.Constant{}, + }, + }, + err: errors.New("unsupported constant"), + }, + { + name: "bad_args", + in: &exprpb.Expr{ + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: "_&&_", + Args: []*exprpb.Expr{{}, {}}, + }, + }, + }, + err: errors.New("unsupported expression"), + }, + { + name: "bad_struct", + in: &exprpb.Expr{ + ExprKind: &exprpb.Expr_StructExpr{ + StructExpr: &exprpb.Expr_CreateStruct{ + MessageName: "Msg", + Entries: []*exprpb.Expr_CreateStruct_Entry{ + {KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{FieldKey: "field"}}, + }, + }, + }, + }, + err: errors.New("unsupported expression"), + }, + { + name: "bad_map", + in: &exprpb.Expr{ + ExprKind: &exprpb.Expr_StructExpr{ + StructExpr: &exprpb.Expr_CreateStruct{ + Entries: []*exprpb.Expr_CreateStruct_Entry{ + {KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{FieldKey: "field"}}, + }, + }, + }, + }, + err: errors.New("unsupported expression"), + }, + { + name: "bad_index", + in: &exprpb.Expr{ + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: "_[_]", + Args: []*exprpb.Expr{{}, {}}, + }, + }, + }, + err: errors.New("unsupported expression"), + }, } - for name, in := range tests { - t.Run(name, func(tt *testing.T) { - p, iss := Parse(common.NewTextSource(in[0])) - if len(iss.GetErrors()) > 0 { - tt.Fatal(iss.ToDisplayString()) - } - out, err := Unparse(p.GetExpr(), p.GetSourceInfo()) - if err != nil { - tt.Error(err) + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + out, err := Unparse(tc.in, &exprpb.SourceInfo{}) + if err == nil { + t.Fatalf("Unparse(%v) got %v, wanted error %v", tc.in, out, tc.err) } - if out != in[1] { - tt.Errorf("Got '%s', wanted '%s'", out, in[1]) - } - p2, _ := Parse(common.NewTextSource(out)) - before := p.GetExpr() - after := p2.GetExpr() - if !proto.Equal(before, after) { - tt.Errorf("Second parse differs from the first. Got '%v', wanted '%v'", - before, after) + if !strings.Contains(err.Error(), tc.err.Error()) { + t.Errorf("Unparse(%v) got unexpected error: %v, wanted %v", tc.in, err, tc.err) } }) }