Skip to content

Commit

Permalink
Add unparsing support for macro calls (#458)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Oct 19, 2021
1 parent 92119b8 commit 7f2b87a
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 177 deletions.
12 changes: 0 additions & 12 deletions common/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ type Source interface {
// and second line, or EOF if there is only one line of source.
LineOffsets() []int32

// Macro calls returns the macro calls map containing the original
// expression from a macro replacement, keyed by Id.
MacroCalls() map[int64]*exprpb.Expr

// LocationOffset translates a Location to an offset.
// Given the line and column of the Location returns the
// Location's character offset in the Source, and a bool
Expand All @@ -69,7 +65,6 @@ type sourceImpl struct {
description string
lineOffsets []int32
idOffsets map[int64]int32
macroCalls map[int64]*exprpb.Expr
}

var _ runes.Buffer = &sourceImpl{}
Expand Down Expand Up @@ -98,7 +93,6 @@ func NewStringSource(contents string, description string) Source {
description: description,
lineOffsets: offsets,
idOffsets: map[int64]int32{},
macroCalls: map[int64]*exprpb.Expr{},
}
}

Expand All @@ -109,7 +103,6 @@ func NewInfoSource(info *exprpb.SourceInfo) Source {
description: info.GetLocation(),
lineOffsets: info.GetLineOffsets(),
idOffsets: info.GetPositions(),
macroCalls: info.GetMacroCalls(),
}
}

Expand All @@ -128,11 +121,6 @@ func (s *sourceImpl) LineOffsets() []int32 {
return s.lineOffsets
}

// MacroCalls implements the Source interface method.
func (s *sourceImpl) MacroCalls() map[int64]*exprpb.Expr {
return s.macroCalls
}

// LocationOffset implements the Source interface method.
func (s *sourceImpl) LocationOffset(location Location) (int32, bool) {
if lineOffset, found := s.findLineOffset(location.Line()); found {
Expand Down
76 changes: 42 additions & 34 deletions parser/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ import (
)

type parserHelper struct {
source common.Source
nextID int64
positions map[int64]int32
source common.Source
nextID int64
positions map[int64]int32
macroCalls map[int64]*exprpb.Expr
}

func newParserHelper(source common.Source) *parserHelper {
return &parserHelper{
source: source,
nextID: 1,
positions: make(map[int64]int32),
source: source,
nextID: 1,
positions: make(map[int64]int32),
macroCalls: make(map[int64]*exprpb.Expr),
}
}

Expand All @@ -42,7 +44,7 @@ func (p *parserHelper) getSourceInfo() *exprpb.SourceInfo {
Location: p.source.Description(),
Positions: p.positions,
LineOffsets: p.source.LineOffsets(),
MacroCalls: p.source.MacroCalls()}
MacroCalls: p.macroCalls}
}

func (p *parserHelper) newLiteral(ctx interface{}, value *exprpb.Constant) *exprpb.Expr {
Expand Down Expand Up @@ -211,27 +213,34 @@ func (p *parserHelper) getLocation(id int64) common.Location {
// buildMacroCallArg iterates the expression and returns a new expression
// where all macros have been replaced by their IDs in MacroCalls
func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr {
resultExpr := &exprpb.Expr{Id: expr.GetId()}
if _, found := p.source.MacroCalls()[expr.GetId()]; found {
return resultExpr
if _, found := p.macroCalls[expr.GetId()]; found {
return &exprpb.Expr{Id: expr.GetId()}
}

switch expr.ExprKind.(type) {
case *exprpb.Expr_CallExpr:
resultExpr.ExprKind = &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: expr.GetCallExpr().GetFunction(),
},
}
resultExpr.GetCallExpr().Args = make([]*exprpb.Expr, len(expr.GetCallExpr().GetArgs()))
// Iterate the AST from `expr` recursively looking for macros. Because we are at most
// starting from the top level macro, this recursion is bounded by the size of the AST. This
// means that the depth check on the AST during parsing will catch recursion overflows
// before we get to here.
macroTarget := expr.GetCallExpr().GetTarget()
if macroTarget != nil {
macroTarget = p.buildMacroCallArg(macroTarget)
}
macroArgs := make([]*exprpb.Expr, len(expr.GetCallExpr().GetArgs()))
for index, arg := range expr.GetCallExpr().GetArgs() {
resultExpr.GetCallExpr().GetArgs()[index] = p.buildMacroCallArg(arg)
macroArgs[index] = p.buildMacroCallArg(arg)
}
return &exprpb.Expr{
Id: expr.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Target: macroTarget,
Function: expr.GetCallExpr().GetFunction(),
Args: macroArgs,
},
},
}
return resultExpr
}

return expr
Expand All @@ -240,28 +249,27 @@ func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr {
// addMacroCall adds the macro the the MacroCalls map in source info. If a macro has args/subargs/target
// that are macros, their ID will be stored instead for later self-lookups.
func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) {
expr := &exprpb.Expr{
Id: exprID,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: function,
},
},
}

macroTarget := target
if target != nil {
if _, found := p.source.MacroCalls()[target.GetId()]; found {
expr.GetCallExpr().Target = &exprpb.Expr{Id: target.GetId()}
} else {
expr.GetCallExpr().Target = target
if _, found := p.macroCalls[target.GetId()]; found {
macroTarget = &exprpb.Expr{Id: target.GetId()}
}
}

expr.GetCallExpr().Args = make([]*exprpb.Expr, len(args))
macroArgs := make([]*exprpb.Expr, len(args))
for index, arg := range args {
expr.GetCallExpr().GetArgs()[index] = p.buildMacroCallArg(arg)
macroArgs[index] = p.buildMacroCallArg(arg)
}

p.macroCalls[exprID] = &exprpb.Expr{
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Target: macroTarget,
Function: function,
Args: macroArgs,
},
},
}
p.source.MacroCalls()[exprID] = expr
}

// balancer performs tree balancing on operators whose arguments are of equal precedence.
Expand Down
44 changes: 33 additions & 11 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,19 @@ var testCases = []testInfo{
z^#8:*expr.Expr_IdentExpr#.a^#9:*expr.Expr_SelectExpr#
)^#10:has#`,
},
{
I: `(has(a.b) || has(c.d)).string()`,
P: `_||_(
a^#2:*expr.Expr_IdentExpr#.b~test-only~^#4:*expr.Expr_SelectExpr#,
c^#6:*expr.Expr_IdentExpr#.d~test-only~^#8:*expr.Expr_SelectExpr#
)^#9:*expr.Expr_CallExpr#.string()^#10:*expr.Expr_CallExpr#`,
M: `has(
c^#6:*expr.Expr_IdentExpr#.d^#7:*expr.Expr_SelectExpr#
)^#8:has#,
has(
a^#2:*expr.Expr_IdentExpr#.b^#3:*expr.Expr_SelectExpr#
)^#4:has#`,
},
}

type testInfo struct {
Expand Down Expand Up @@ -1496,9 +1509,10 @@ func (k *kindAndIDAdorner) GetMetadata(elem interface{}) string {
switch elem.(type) {
case *exprpb.Expr:
e := elem.(*exprpb.Expr)
if k.sourceInfo != nil {
if val, found := k.sourceInfo.MacroCalls[e.GetId()]; found {
return fmt.Sprintf("^#%d:%s#", e.Id, val.GetCallExpr().GetFunction())
macroCalls := k.sourceInfo.GetMacroCalls()
if macroCalls != nil {
if val, found := macroCalls[e.GetId()]; found {
return fmt.Sprintf("^#%d:%s#", e.GetId(), val.GetCallExpr().GetFunction())
}
}
var valType interface{} = e.ExprKind
Expand Down Expand Up @@ -1552,18 +1566,26 @@ func (l *locationAdorner) GetMetadata(elem interface{}) string {
}

func convertMacroCallsToString(source *exprpb.SourceInfo) string {
keys := make([]int64, len(source.GetMacroCalls()))
adornedStrings := make([]string, len(source.GetMacroCalls()))
macroCalls := source.GetMacroCalls()
keys := make([]int64, len(macroCalls))
adornedStrings := make([]string, len(macroCalls))
i := 0
for k := range source.GetMacroCalls() {
for k := range macroCalls {
keys[i] = k
i++
}
// Sort the keys in descending order to create a stable ordering for tests and improve readability.
sort.Slice(keys, func(i, j int) bool { return keys[i] > keys[j] })
i = 0
for _, key := range keys {
adornedStrings[i] = debug.ToAdornedDebugString(source.GetMacroCalls()[int64(key)], &kindAndIDAdorner{sourceInfo: source})
call := macroCalls[int64(key)]
callWithID := &exprpb.Expr{
Id: int64(key),
ExprKind: call.GetExprKind(),
}
adornedStrings[i] = debug.ToAdornedDebugString(
callWithID,
&kindAndIDAdorner{sourceInfo: source})
i++
}
return strings.Join(adornedStrings, ",\n")
Expand Down Expand Up @@ -1591,7 +1613,7 @@ func TestParse(t *testing.T) {
tt.Parallel()

src := common.NewTextSource(tc.I)
expression, errors := p.Parse(src)
parsedExpr, errors := p.Parse(src)
if len(errors.GetErrors()) > 0 {
actualErr := errors.ToDisplayString()
if tc.E == "" {
Expand All @@ -1604,20 +1626,20 @@ func TestParse(t *testing.T) {
tt.Fatalf("Expected error not thrown: '%s'", tc.E)
}
failureDisplayMethod := fmt.Sprintf("Parse(\"%s\")", tc.I)
actualWithKind := debug.ToAdornedDebugString(expression.Expr, &kindAndIDAdorner{})
actualWithKind := debug.ToAdornedDebugString(parsedExpr.GetExpr(), &kindAndIDAdorner{})
if !test.Compare(actualWithKind, tc.P) {
tt.Fatal(test.DiffMessage(fmt.Sprintf("Structure - %s", failureDisplayMethod), actualWithKind, tc.P))
}

if tc.L != "" {
actualWithLocation := debug.ToAdornedDebugString(expression.Expr, &locationAdorner{expression.GetSourceInfo()})
actualWithLocation := debug.ToAdornedDebugString(parsedExpr.GetExpr(), &locationAdorner{parsedExpr.GetSourceInfo()})
if !test.Compare(actualWithLocation, tc.L) {
tt.Fatal(test.DiffMessage(fmt.Sprintf("Location - %s", failureDisplayMethod), actualWithLocation, tc.L))
}
}

if tc.M != "" {
actualAdornedMacroCalls := convertMacroCallsToString(expression.GetSourceInfo())
actualAdornedMacroCalls := convertMacroCallsToString(parsedExpr.GetSourceInfo())
if !test.Compare(actualAdornedMacroCalls, tc.M) {
tt.Fatal(test.DiffMessage(fmt.Sprintf("Macro Calls - %s", failureDisplayMethod), actualAdornedMacroCalls, tc.M))
}
Expand Down
38 changes: 21 additions & 17 deletions parser/unparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package parser

import (
"errors"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -46,19 +47,21 @@ func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo) (string, error) {

// unparser visits an expression to reconstruct a human-readable string from an AST.
type unparser struct {
str strings.Builder
offset int32
// TODO: use the source info to rescontruct macros into function calls.
str strings.Builder
info *exprpb.SourceInfo
}

func (un *unparser) visit(expr *exprpb.Expr) error {
if expr == nil {
return errors.New("unsupported expression")
}
visited, err := un.visitMaybeMacroCall(expr)
if visited || err != nil {
return err
}
switch expr.ExprKind.(type) {
case *exprpb.Expr_CallExpr:
return un.visitCall(expr)
// TODO: Comprehensions are currently not supported.
case *exprpb.Expr_ComprehensionExpr:
return un.visitComprehension(expr)
case *exprpb.Expr_ConstExpr:
return un.visitConst(expr)
case *exprpb.Expr_IdentExpr:
Expand All @@ -69,8 +72,9 @@ func (un *unparser) visit(expr *exprpb.Expr) error {
return un.visitSelect(expr)
case *exprpb.Expr_StructExpr:
return un.visitStruct(expr)
default:
return fmt.Errorf("unsupported expression: %v", expr)
}
return fmt.Errorf("unsupported expr: %v", expr)
}

func (un *unparser) visitCall(expr *exprpb.Expr) error {
Expand Down Expand Up @@ -220,12 +224,6 @@ func (un *unparser) visitCallUnary(expr *exprpb.Expr) error {
return un.visitMaybeNested(args[0], nested)
}

func (un *unparser) visitComprehension(expr *exprpb.Expr) error {
// TODO: introduce a macro expansion map between the top-level comprehension id and the
// function call that the macro replaces.
return fmt.Errorf("unimplemented : %v", expr)
}

func (un *unparser) visitConst(expr *exprpb.Expr) error {
c := expr.GetConstExpr()
switch c.ConstantKind.(type) {
Expand Down Expand Up @@ -255,7 +253,7 @@ func (un *unparser) visitConst(expr *exprpb.Expr) error {
un.str.WriteString(ui)
un.str.WriteString("u")
default:
return fmt.Errorf("unimplemented : %v", expr)
return fmt.Errorf("unsupported constant: %v", expr)
}
return nil
}
Expand Down Expand Up @@ -357,6 +355,15 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error {
return nil
}

func (un *unparser) visitMaybeMacroCall(expr *exprpb.Expr) (bool, error) {
macroCalls := un.info.GetMacroCalls()
call, found := macroCalls[expr.GetId()]
if !found {
return false, nil
}
return true, un.visit(call)
}

func (un *unparser) visitMaybeNested(expr *exprpb.Expr, nested bool) error {
if nested {
un.str.WriteString("(")
Expand Down Expand Up @@ -395,9 +402,6 @@ func isSamePrecedence(op string, expr *exprpb.Expr) bool {
//
// If the expr is not a Call, the result is false.
func isLowerPrecedence(op string, expr *exprpb.Expr) bool {
if expr.GetCallExpr() == nil {
return false
}
c := expr.GetCallExpr()
other := c.GetFunction()
return operators.Precedence(op) < operators.Precedence(other)
Expand Down
Loading

0 comments on commit 7f2b87a

Please sign in to comment.