diff --git a/compliance_test.go b/compliance_test.go index 54cc235..419a9d0 100644 --- a/compliance_test.go +++ b/compliance_test.go @@ -98,6 +98,11 @@ func runSyntaxTestCase(assert *assert.Assertions, given interface{}, testcase Te // an error when we try to evaluate the expression. _, err := Search(testcase.Expression, given) assert.NotNil(err, fmt.Sprintf("Expression: %s", testcase.Expression)) + if er, ok := err.(SyntaxError); !ok { + assert.Fail("unexpected error: %T, %v: %s", err, err, fmt.Sprintf("Expression: %s", testcase.Expression)) + } else { + assert.Equal(testcase.Error, er.Type(), fmt.Sprintf("Expression: %s", testcase.Expression)) + } } func runTestCase(assert *assert.Assertions, given interface{}, testcase TestCase, filename string) { diff --git a/functions.go b/functions.go index e9770e8..eb5284a 100644 --- a/functions.go +++ b/functions.go @@ -328,7 +328,10 @@ func (e *functionEntry) resolveArgs(arguments []interface{}) ([]interface{}, err } if !e.arguments[len(e.arguments)-1].variadic { if len(e.arguments) != len(arguments) { - return nil, errors.New("incorrect number of args") + return nil, SyntaxError{ + typ: ErrInvalidArity, + msg: "incorrect number of args", + } } for i, spec := range e.arguments { userArg := arguments[i] @@ -340,7 +343,10 @@ func (e *functionEntry) resolveArgs(arguments []interface{}) ([]interface{}, err return arguments, nil } if len(arguments) < len(e.arguments) { - return nil, errors.New("invalid arity") + return nil, SyntaxError{ + typ: ErrInvalidArity, + msg: fmt.Sprintf("not enough arguments for function %s", e.name), + } } return arguments, nil } @@ -380,13 +386,19 @@ func (a *argSpec) typeCheck(arg interface{}) error { } } } - return fmt.Errorf("Invalid type for: %v, expected: %#v", arg, a.types) + return SyntaxError{ + typ: ErrInvalidType, + msg: fmt.Sprintf("Invalid type for: %v, expected: %#v", arg, a.types), + } } func (f *functionCaller) CallFunction(name string, arguments []interface{}, intr *treeInterpreter) (interface{}, error) { entry, ok := f.functionTable[name] if !ok { - return nil, errors.New("unknown function: " + name) + return nil, SyntaxError{ + typ: ErrUnknownFunction, + msg: "unknown function: " + name, + } } resolvedArgs, err := entry.resolveArgs(arguments) if err != nil { @@ -548,7 +560,10 @@ func jpfMaxBy(arguments []interface{}) (interface{}, error) { } current, ok := result.(float64) if !ok { - return nil, errors.New("invalid type, must be number") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "invalid type, must be number", + } } if current > bestVal { bestVal = current @@ -566,7 +581,10 @@ func jpfMaxBy(arguments []interface{}) (interface{}, error) { } current, ok := result.(string) if !ok { - return nil, errors.New("invalid type, must be string") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "invalid type, must be string", + } } if current > bestVal { bestVal = current @@ -575,7 +593,10 @@ func jpfMaxBy(arguments []interface{}) (interface{}, error) { } return bestItem, nil default: - return nil, errors.New("invalid type, must be number of string") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "invalid type, must be number of string", + } } } func jpfSum(arguments []interface{}) (interface{}, error) { @@ -643,7 +664,10 @@ func jpfMinBy(arguments []interface{}) (interface{}, error) { } current, ok := result.(float64) if !ok { - return nil, errors.New("invalid type, must be number") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "invalid type, must be number", + } } if current < bestVal { bestVal = current @@ -661,7 +685,10 @@ func jpfMinBy(arguments []interface{}) (interface{}, error) { } current, ok := result.(string) if !ok { - return nil, errors.New("invalid type, must be string") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "invalid type, must be string", + } } if current < bestVal { bestVal = current @@ -670,7 +697,10 @@ func jpfMinBy(arguments []interface{}) (interface{}, error) { } return bestItem, nil } else { - return nil, errors.New("invalid type, must be number of string") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "invalid type, must be number of string", + } } } func jpfType(arguments []interface{}) (interface{}, error) { @@ -749,18 +779,27 @@ func jpfSortBy(arguments []interface{}) (interface{}, error) { sortable := &byExprFloat{intr, node, arr, false} sort.Stable(sortable) if sortable.hasError { - return nil, errors.New("error in sort_by comparison") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "error in sort_by comparison", + } } return arr, nil } else if _, ok := start.(string); ok { sortable := &byExprString{intr, node, arr, false} sort.Stable(sortable) if sortable.hasError { - return nil, errors.New("error in sort_by comparison") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "error in sort_by comparison", + } } return arr, nil } else { - return nil, errors.New("invalid type, must be number of string") + return nil, SyntaxError{ + typ: ErrInvalidType, + msg: "invalid type, must be number of string", + } } } func jpfJoin(arguments []interface{}) (interface{}, error) { diff --git a/lexer.go b/lexer.go index 817900c..f369d69 100644 --- a/lexer.go +++ b/lexer.go @@ -28,13 +28,27 @@ type Lexer struct { buf bytes.Buffer // Internal buffer used for building up values. } +const ( + ErrSyntax string = "syntax" + ErrInvalidArity string = "invalid-arity" + ErrInvalidType string = "invalid-type" + ErrInvalidValue string = "invalid-value" + ErrUnknownFunction string = "unknown-function" +) + // SyntaxError is the main error used whenever a lexing or parsing error occurs. type SyntaxError struct { + typ string // Error type as defined in the JMESpath specification msg string // Error message displayed to user Expression string // Expression that generated a SyntaxError Offset int // The location in the string where the error occurred } +// Type returns the type of error, as defined in the JMESpath specification. +func (e SyntaxError) Type() string { + return e.typ +} + func (e SyntaxError) Error() string { // In the future, it would be good to underline the specific // location where the error occurred. @@ -208,7 +222,7 @@ loop: } else if _, ok := whiteSpace[r]; ok { // Ignore whitespace } else { - return tokens, lexer.syntaxError(fmt.Sprintf("Unknown char: %s", strconv.QuoteRuneToASCII(r))) + return tokens, lexer.syntaxError(ErrSyntax, fmt.Sprintf("Unknown char: %s", strconv.QuoteRuneToASCII(r))) } } tokens = append(tokens, token{tEOF, "", len(lexer.expression), 0}) @@ -233,6 +247,7 @@ func (lexer *Lexer) consumeUntil(end rune) (string, error) { // Then we hit an EOF so we never reached the closing // delimiter. return "", SyntaxError{ + typ: ErrSyntax, msg: "Unclosed delimiter: " + string(end), Expression: lexer.expression, Offset: len(lexer.expression), @@ -274,6 +289,7 @@ func (lexer *Lexer) consumeRawStringLiteral() (token, error) { // Then we hit an EOF so we never reached the closing // delimiter. return token{}, SyntaxError{ + typ: ErrSyntax, msg: "Unclosed delimiter: '", Expression: lexer.expression, Offset: len(lexer.expression), @@ -293,8 +309,9 @@ func (lexer *Lexer) consumeRawStringLiteral() (token, error) { }, nil } -func (lexer *Lexer) syntaxError(msg string) SyntaxError { +func (lexer *Lexer) syntaxError(typ string, msg string) SyntaxError { return SyntaxError{ + typ: typ, msg: msg, Expression: lexer.expression, Offset: lexer.currentPos - 1, diff --git a/parser.go b/parser.go index 4abc303..3a1c265 100644 --- a/parser.go +++ b/parser.go @@ -136,7 +136,7 @@ func (p *Parser) Parse(expression string) (ASTNode, error) { return ASTNode{}, err } if p.current() != tEOF { - return ASTNode{}, p.syntaxError(fmt.Sprintf( + return ASTNode{}, p.syntaxError(ErrSyntax, fmt.Sprintf( "Unexpected token at the end of the expression: %s", p.current())) } return parsed, nil @@ -195,8 +195,8 @@ func (p *Parser) parseSliceExpression() (ASTNode, error) { parts[index] = &parsedInt p.advance() } else { - return ASTNode{}, p.syntaxError( - "Expected tColon or tNumber" + ", received: " + p.current().String()) + return ASTNode{}, p.syntaxError(ErrSyntax, + "Expected tColon or tNumber"+", received: "+p.current().String()) } current = p.current() } @@ -214,7 +214,7 @@ func (p *Parser) match(tokenType tokType) error { p.advance() return nil } - return p.syntaxError("Expected " + tokenType.String() + ", received: " + p.current().String()) + return p.syntaxError(ErrSyntax, "Expected "+tokenType.String()+", received: "+p.current().String()) } func (p *Parser) led(tokenType tokType, node ASTNode) (ASTNode, error) { @@ -311,7 +311,7 @@ func (p *Parser) led(tokenType tokType, node ASTNode) (ASTNode, error) { children: []ASTNode{node, right}, }, nil } - return ASTNode{}, p.syntaxError("Unexpected token: " + tokenType.String()) + return ASTNode{}, p.syntaxError(ErrSyntax, "Unexpected token: "+tokenType.String()) } func (p *Parser) nud(token token) (ASTNode, error) { @@ -333,7 +333,7 @@ func (p *Parser) nud(token token) (ASTNode, error) { case tQuotedIdentifier: node := ASTNode{nodeType: ASTField, value: token.value} if p.current() == tLparen { - return ASTNode{}, p.syntaxErrorToken("Can't have quoted identifier as function name.", token) + return ASTNode{}, p.syntaxErrorToken(ErrSyntax, "Can't have quoted identifier as function name.", token) } return node, nil case tStar: @@ -407,10 +407,10 @@ func (p *Parser) nud(token token) (ASTNode, error) { } return expression, nil case tEOF: - return ASTNode{}, p.syntaxErrorToken("Incomplete expression", token) + return ASTNode{}, p.syntaxErrorToken(ErrSyntax, "Incomplete expression", token) } - return ASTNode{}, p.syntaxErrorToken("Invalid token: "+token.tokenType.String(), token) + return ASTNode{}, p.syntaxErrorToken(ErrSyntax, "Invalid token: "+token.tokenType.String(), token) } func (p *Parser) parseMultiSelectList() (ASTNode, error) { @@ -445,7 +445,7 @@ func (p *Parser) parseMultiSelectHash() (ASTNode, error) { keyToken := p.lookaheadToken(0) if err := p.match(tUnquotedIdentifier); err != nil { if err := p.match(tQuotedIdentifier); err != nil { - return ASTNode{}, p.syntaxError("Expected tQuotedIdentifier or tUnquotedIdentifier") + return ASTNode{}, p.syntaxError(ErrSyntax, "Expected tQuotedIdentifier or tUnquotedIdentifier") } } keyName := keyToken.value @@ -536,7 +536,7 @@ func (p *Parser) parseDotRHS(bindingPower int) (ASTNode, error) { } return p.parseMultiSelectHash() } - return ASTNode{}, p.syntaxError("Expected identifier, lbracket, or lbrace") + return ASTNode{}, p.syntaxError(ErrSyntax, "Expected identifier, lbracket, or lbrace") } func (p *Parser) parseProjectionRHS(bindingPower int) (ASTNode, error) { @@ -554,7 +554,7 @@ func (p *Parser) parseProjectionRHS(bindingPower int) (ASTNode, error) { } return p.parseDotRHS(bindingPower) } else { - return ASTNode{}, p.syntaxError("Error") + return ASTNode{}, p.syntaxError(ErrSyntax, "Error") } } @@ -583,8 +583,9 @@ func tokensOneOf(elements []tokType, token tokType) bool { return false } -func (p *Parser) syntaxError(msg string) SyntaxError { +func (p *Parser) syntaxError(typ string, msg string) SyntaxError { return SyntaxError{ + typ: typ, msg: msg, Expression: p.expression, Offset: p.lookaheadToken(0).position, @@ -594,8 +595,9 @@ func (p *Parser) syntaxError(msg string) SyntaxError { // Create a SyntaxError based on the provided token. // This differs from syntaxError() which creates a SyntaxError // based on the current lookahead token. -func (p *Parser) syntaxErrorToken(msg string, t token) SyntaxError { +func (p *Parser) syntaxErrorToken(typ string, msg string, t token) SyntaxError { return SyntaxError{ + typ: typ, msg: msg, Expression: p.expression, Offset: t.position, diff --git a/parser_test.go b/parser_test.go index 4c920fe..e8bb676 100644 --- a/parser_test.go +++ b/parser_test.go @@ -11,8 +11,8 @@ var parsingErrorTests = []struct { expression string msg string }{ - {"foo.", "Incopmlete expression"}, - {"[foo", "Incopmlete expression"}, + {"foo.", "Incomplete expression"}, + {"[foo", "Incomplete expression"}, {"]", "Invalid"}, {")", "Invalid"}, {"}", "Invalid"}, diff --git a/util.go b/util.go index ddc1b7d..0ccd167 100644 --- a/util.go +++ b/util.go @@ -1,7 +1,6 @@ package jmespath import ( - "errors" "reflect" ) @@ -84,7 +83,10 @@ func computeSliceParams(length int, parts []sliceParam) ([]int, error) { if !parts[2].Specified { step = 1 } else if parts[2].N == 0 { - return nil, errors.New("Invalid slice, step cannot be 0") + return nil, SyntaxError{ + typ: ErrInvalidValue, + msg: "Invalid slice, step cannot be 0", + } } else { step = parts[2].N }