diff --git a/.github/workflows/test_ci.yml b/.github/workflows/test_ci.yml index 3422b63..fa44fe3 100644 --- a/.github/workflows/test_ci.yml +++ b/.github/workflows/test_ci.yml @@ -4,7 +4,8 @@ on: branches: - main pull_request: - +env: + go_version: 1.21.6 jobs: golangci-lint: name: golangci-lint @@ -12,6 +13,10 @@ jobs: steps: - name: Checkout uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4 + - uses: actions/setup-go@v4 + with: + go-version: ${{ env.go_version }} + cache: false - name: Run golangci-lint uses: golangci/golangci-lint-action@3a919529898de77ec3da873e3063ca4b10e7f5cc # v3 with: @@ -25,7 +30,7 @@ jobs: - name: Set up Go uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5 with: - go-version: 1.18 + go-version: ${{ env.go_version }} - name: Set up gotestfmt uses: GoTestTools/gotestfmt-action@v2 - uses: actions/cache@704facf57e6136b1bc63b828d79edcd491f0ee84 # v3 diff --git a/expression_data_test.go b/expression_data_test.go index 3edf85b..2e3accf 100644 --- a/expression_data_test.go +++ b/expression_data_test.go @@ -67,6 +67,16 @@ var testScope = schema.NewScopeSchema( nil, nil, ), + "simple_int_2": schema.NewPropertySchema( + schema.NewIntSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), "simple_any": schema.NewPropertySchema( schema.NewAnySchema(), nil, @@ -77,6 +87,16 @@ var testScope = schema.NewScopeSchema( nil, nil, ), + "simple_bool": schema.NewPropertySchema( + schema.NewBoolSchema(), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), "int_list": schema.NewPropertySchema( schema.NewListSchema( schema.NewIntSchema(nil, nil, nil), diff --git a/expression_dependencies.go b/expression_dependencies.go index cebfdbf..36f15ed 100644 --- a/expression_dependencies.go +++ b/expression_dependencies.go @@ -4,6 +4,7 @@ import ( "fmt" "go.flow.arcalot.io/expressions/internal/ast" "go.flow.arcalot.io/pluginsdk/schema" + "slices" ) // dependencyContext holds the root data for a dependency evaluation in an expression. This is useful so that we @@ -74,6 +75,14 @@ func (c *dependencyContext) dependencies( return &dependencyResult{resolvedType: schema.NewStringSchema(nil, nil, nil)}, nil case *ast.IntLiteral: return &dependencyResult{resolvedType: schema.NewIntSchema(nil, nil, nil)}, nil + case *ast.FloatLiteral: + return &dependencyResult{resolvedType: schema.NewFloatSchema(nil, nil, nil)}, nil + case *ast.BooleanLiteral: + return &dependencyResult{resolvedType: schema.NewBoolSchema()}, nil + case *ast.BinaryOperation: + return c.binaryOperationDependencies(n) + case *ast.UnaryOperation: + return c.unaryOperationDependencies(n) case *ast.FunctionCall: return c.functionDependencies(n) default: @@ -81,6 +90,138 @@ func (c *dependencyContext) dependencies( } } +func (c *dependencyContext) unaryOperationDependencies( + node *ast.UnaryOperation, +) (*dependencyResult, error) { + // Unary operations don't change dependencies, but the type must be validated. + innerResult, err := c.rootDependencies(node.RightNode) + if err != nil { + return nil, err + } + switch node.LeftOperation { + case ast.Subtract: + // Negation expects a numerical type + if innerResult.resolvedType.TypeID() != schema.TypeIDInt && + innerResult.resolvedType.TypeID() != schema.TypeIDFloat { + return nil, + fmt.Errorf("attempted negation operation on non-numeric type %q", + innerResult.resolvedType.TypeID()) + } + case ast.Not: + // 'not' expects a boolean input + if innerResult.resolvedType.TypeID() != schema.TypeIDBool { + return nil, + fmt.Errorf("attempted 'not' operation on non-boolean type %q", + innerResult.resolvedType.TypeID()) + } + default: + return nil, fmt.Errorf("unsupported unary operation: %q", node.LeftOperation) + } + // Negation and 'not' do not change the type or dependencies. + return innerResult, nil +} + +func (c *dependencyContext) binaryOperationDependencies( + node *ast.BinaryOperation, +) (*dependencyResult, error) { + leftResult, err := c.rootDependencies(node.LeftNode) + if err != nil { + return nil, err + } + // Right dependencies, using left type. + rightResult, err := c.rootDependencies(node.RightNode) + if err != nil { + return nil, err + } + var resultType schema.Type + // Validate operations with the resolved type, and compute the return type for the combination. + switch node.Operation { + case ast.Add: + // Add or concatenate + err = validateValidBinaryOpTypes( + node, + leftResult.resolvedType.TypeID(), + rightResult.resolvedType.TypeID(), + []schema.TypeID{schema.TypeIDInt, schema.TypeIDFloat, schema.TypeIDString}, + ) + resultType = leftResult.resolvedType + case ast.Subtract, ast.Multiply, ast.Divide, ast.Modulus, ast.Power: + // Math. Same as type going in. Plus validate that it's numeric. + err = validateValidBinaryOpTypes( + node, + leftResult.resolvedType.TypeID(), + rightResult.resolvedType.TypeID(), + []schema.TypeID{schema.TypeIDInt, schema.TypeIDFloat}, + ) + resultType = leftResult.resolvedType + case ast.And, ast.Or: + // Boolean operations. Bool in and out. + err = validateValidBinaryOpTypes( + node, + leftResult.resolvedType.TypeID(), + rightResult.resolvedType.TypeID(), + []schema.TypeID{schema.TypeIDBool}, + ) + resultType = schema.NewBoolSchema() + case ast.GreaterThan, ast.LessThan, ast.GreaterThanEqualTo, ast.LessThanEqualTo: + // Inequality. Int, float, or string in; bool out. + err = validateValidBinaryOpTypes( + node, + leftResult.resolvedType.TypeID(), + rightResult.resolvedType.TypeID(), + []schema.TypeID{schema.TypeIDInt, schema.TypeIDString, schema.TypeIDFloat}, + ) + resultType = schema.NewBoolSchema() + case ast.EqualTo, ast.NotEqualTo: + // Equality comparison. Any supported type in. Bool out. + err = validateValidBinaryOpTypes( + node, + leftResult.resolvedType.TypeID(), + rightResult.resolvedType.TypeID(), + []schema.TypeID{schema.TypeIDInt, schema.TypeIDString, schema.TypeIDFloat, schema.TypeIDBool}, + ) + resultType = schema.NewBoolSchema() + case ast.Invalid: + panic(fmt.Errorf("attempted to perform invalid operation (binary operation type invalid)")) + default: + panic(fmt.Errorf("bug: binary operation %s missing from dependency evaluation code", node.Operation)) + } + if err != nil { + return nil, err + } + // Combine the left and right dependencies. + finalDependencies := append(leftResult.completedPaths, rightResult.completedPaths...) + return &dependencyResult{ + resolvedType: resultType, + rootPathResult: nil, // Cannot be chained. It's a primitive. + completedPaths: finalDependencies, + }, nil +} + +func validateValidBinaryOpTypes( + node *ast.BinaryOperation, + leftType schema.TypeID, + rightType schema.TypeID, + expectedTypes []schema.TypeID, +) error { + // First validate left and right types are within the expected types. + leftIsValid := slices.Contains(expectedTypes, leftType) + if !leftIsValid { + return fmt.Errorf("invalid type %q from left expression %q for binary operation %q; expected one of %q", + leftType, node.LeftNode.String(), node.Operation, expectedTypes) + } + rightIsValid := slices.Contains(expectedTypes, rightType) + if !rightIsValid { + return fmt.Errorf("invalid type %q from right expression %q for binary operation %q; expected one of %q", + rightType, node.RightNode.String(), node.Operation, expectedTypes) + } + // Next, validate that left and right types match + if leftType != rightType { + return fmt.Errorf("left (%s) and right (%s) types do not match for binary expression %q", leftType, rightType, node.String()) + } + return nil +} + func (c *dependencyContext) functionDependencies(node *ast.FunctionCall) (*dependencyResult, error) { // Get the types and dependencies of all parameters. functionSchema, found := c.functions[node.FuncIdentifier.IdentifierName] diff --git a/expression_dependencies_test.go b/expression_dependencies_test.go index 605e235..6968c2e 100644 --- a/expression_dependencies_test.go +++ b/expression_dependencies_test.go @@ -419,3 +419,153 @@ func TestFunctionDependencyResolution_dynamicTyping(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "unsupported data type") } + +func TestDependencyResolution_MathHomogeneousLiterals(t *testing.T) { + // Test simple literal integer math, same type + expr, err := expressions.New("5 + 5") + assert.NoError(t, err) + pathWithExtra, err := expr.Dependencies(nil, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(pathWithExtra), 0) +} + +func TestDependencyResolution_MathHomogeneousReference(t *testing.T) { + // Test simple reference integer math, same type + expr, err := expressions.New("5 + $.simple_int") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 1) + assert.Equals(t, paths[0].String(), "$.simple_int") + // Swapped + expr, err = expressions.New("$.simple_int + 5") + assert.NoError(t, err) + paths, err = expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 1) + assert.Equals(t, paths[0].String(), "$.simple_int") +} + +func TestDependencyResolution_MathSameDependency(t *testing.T) { + // Test simple reference integer math, with both items having the same dependency + expr, err := expressions.New("$.simple_int + $.simple_int") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 1) + assert.Equals(t, paths[0].String(), "$.simple_int") +} + +func TestDependencyResolution_MathDifferentDependencies(t *testing.T) { + // Test math with two references, with different references. + // Tests the duplicate removal logic, and the left and right paths. + expr, err := expressions.New("$.simple_int + $.simple_int_2") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 2) + assert.SliceContainsExtractor(t, pathStrExtractor, "$.simple_int", paths) + assert.SliceContainsExtractor(t, pathStrExtractor, "$.simple_int_2", paths) +} + +func TestDependencyResolution_UnaryLiteral(t *testing.T) { + // Test unary operation with literals. + expr, err := expressions.New("-5") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 0) +} + +func TestDependencyResolution_UnaryReference(t *testing.T) { + // Test unary operation with references. + expr, err := expressions.New("-$.simple_int") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 1) + assert.Equals(t, paths[0].String(), "$.simple_int") +} + +func TestDependencyResolution_TestBinaryComparisonLiterals(t *testing.T) { + // Test comparison with literals + expr, err := expressions.New("4 == 5") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 0) +} + +func TestDependencyResolution_TestBinaryComparisonReferences(t *testing.T) { + // Test comparison with references + expr, err := expressions.New("$.simple_int == 5") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 1) + assert.Equals(t, paths[0].String(), "$.simple_int") +} + +func TestDependencyResolution_TestBooleanOperationLiterals(t *testing.T) { + // Test boolean operations with literals + expr, err := expressions.New("true && false") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 0) +} + +func TestDependencyResolution_TestBooleanOperationReferences(t *testing.T) { + // Test boolean operations with references + expr, err := expressions.New("true && $.simple_bool") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, nil, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 1) + assert.Equals(t, paths[0].String(), "$.simple_bool") +} + +func TestDependencyResolution_TestMixedMathAndFunc(t *testing.T) { + // Test dependencies properly propagated from a function through an operation. + intInFunc, err := schema.NewCallableFunction( + "intToFloat", + []schema.Type{schema.NewIntSchema(nil, nil, nil)}, + schema.NewFloatSchema(nil, nil, nil), + false, + nil, + func(a int64) float64 { + return float64(a) + }, + ) + assert.NoError(t, err) + funcMap := map[string]schema.Function{"intToFloat": intInFunc} + + expr, err := expressions.New("5.0 + intToFloat($.simple_int)") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, funcMap, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 1) + assert.Equals(t, paths[0].String(), "$.simple_int") +} + +func TestDependencyResolution_TestMixedOperations(t *testing.T) { + // Test different types of operations that take different types. + intInFunc, err := schema.NewCallableFunction( + "giveFloat", + []schema.Type{}, + schema.NewFloatSchema(nil, nil, nil), + false, + nil, + func() float64 { + return 5.5 + }, + ) + assert.NoError(t, err) + funcMap := map[string]schema.Function{"giveFloat": intInFunc} + + expr, err := expressions.New("1.0 == (5.0 / giveFloat()) && !true") + assert.NoError(t, err) + paths, err := expr.Dependencies(testScope, funcMap, nil, fullDataRequirements) + assert.NoError(t, err) + assert.Equals(t, len(paths), 0) +} diff --git a/expression_evaluate.go b/expression_evaluate.go index 8abd2e6..18666a5 100644 --- a/expression_evaluate.go +++ b/expression_evaluate.go @@ -3,6 +3,7 @@ package expressions import ( "fmt" "go.flow.arcalot.io/pluginsdk/schema" + "math" "reflect" "go.flow.arcalot.io/expressions/internal/ast" @@ -33,13 +34,163 @@ func (c evaluateContext) evaluate(node ast.Node, data any) (any, error) { case *ast.Identifier: return c.evaluateIdentifier(n, data) case *ast.FunctionCall: - return c.evaluateFuncCall(n, data) + return c.evaluateFuncCall(n) + case *ast.BinaryOperation: + return c.evaluateBinaryOperation(n) + case *ast.UnaryOperation: + return c.evaluateUnaryOperation(n) default: return nil, fmt.Errorf("unsupported node type: %T", n) } } -func (c evaluateContext) evaluateFuncCall(node *ast.FunctionCall, data any) (any, error) { //nolint:unparam +type SupportedNumber interface { + int64 | float64 +} + +func evalNumericalOperation[T SupportedNumber](a, b T, op ast.MathOperationType) (any, error) { + switch op { + case ast.Add: + return a + b, nil + case ast.Subtract: + return a - b, nil + case ast.Multiply: + return a * b, nil + case ast.Divide: + return a / b, nil + case ast.Modulus: + switch any(a).(type) { + case int64: + return int64(a) % int64(b), nil + case float64: + return math.Mod(float64(a), float64(b)), nil + } + return nil, fmt.Errorf("unsupported type for modulus: %T", a) + case ast.Power: + return T(math.Pow(float64(a), float64(b))), nil + case ast.EqualTo: + return a == b, nil + case ast.NotEqualTo: + return a != b, nil + case ast.GreaterThan: + return a > b, nil + case ast.LessThan: + return a < b, nil + case ast.GreaterThanEqualTo: + return a >= b, nil + case ast.LessThanEqualTo: + return a <= b, nil + case ast.And, ast.Or: + return nil, fmt.Errorf("attempted logical operation %s on numeric input %T", op, a) + case ast.Invalid: + panic(fmt.Errorf("invalid operation encountered evaluating numerical operation; this is likely due to a bug in the parser")) + default: + panic(fmt.Errorf("numeric eval missing case for logical operation %s", op)) + } +} + +func evalBooleanOperation(a, b bool, op ast.MathOperationType) (any, error) { + switch op { + case ast.EqualTo: + return a == b, nil + case ast.NotEqualTo: + return a != b, nil + case ast.And: + return a && b, nil + case ast.Or: + return a || b, nil + case ast.Power, ast.Modulus, ast.Divide, ast.Multiply, ast.Subtract, ast.Add, + ast.GreaterThan, ast.LessThan, ast.GreaterThanEqualTo, ast.LessThanEqualTo: + return nil, fmt.Errorf("attempted to perform invalid operation '%s' on boolean", op) + case ast.Invalid: + panic(fmt.Errorf("invalid operation encountered evaluating boolean operation; this is likely due to a bug in the parser")) + default: + panic(fmt.Errorf("boolean eval missing case for logical operation %s", op)) + } +} + +func evalStringOperation(a, b string, op ast.MathOperationType) (any, error) { + switch op { + case ast.Add: + // Concatenate + return a + b, nil + case ast.EqualTo: + return a == b, nil + case ast.NotEqualTo: + return a != b, nil + case ast.GreaterThan: + return a > b, nil + case ast.LessThan: + return a < b, nil + case ast.GreaterThanEqualTo: + return a >= b, nil + case ast.LessThanEqualTo: + return a <= b, nil + case ast.Subtract, ast.Multiply, ast.Divide, ast.Modulus, ast.Power, ast.And, ast.Or: + return nil, fmt.Errorf("string operations do not support operator '%s'", op) + case ast.Invalid: + panic(fmt.Errorf("invalid operation encountered evaluating string operation; this is likely due to a bug in the parser")) + default: + panic(fmt.Errorf("string eval missing case for logical operation %s", op)) + } +} + +func (c evaluateContext) evaluateBinaryOperation(node *ast.BinaryOperation) (any, error) { + leftEval, err := c.evaluate(node.Left(), c.rootData) + if err != nil { + return nil, err + } + rightEval, err := c.evaluate(node.Right(), c.rootData) + if err != nil { + return nil, err + } + rightType := reflect.TypeOf(rightEval) + leftType := reflect.TypeOf(leftEval) + if rightType != leftType { + return nil, fmt.Errorf("left type '%s' and right type '%s' of binary operation '%s' do not match", + leftType, rightType, node.Operation) + } + + switch left := leftEval.(type) { + case int64: + return evalNumericalOperation(left, rightEval.(int64), node.Operation) + case float64: + return evalNumericalOperation(left, rightEval.(float64), node.Operation) + case string: + return evalStringOperation(left, rightEval.(string), node.Operation) + case bool: + return evalBooleanOperation(left, rightEval.(bool), node.Operation) + default: + return nil, fmt.Errorf("unsupported type to perform binary operation on: %T", left) + } +} + +func (c evaluateContext) evaluateUnaryOperation(node *ast.UnaryOperation) (any, error) { + rightEval, err := c.evaluate(node.RightNode, c.rootData) + if err != nil { + return nil, err + } + if node.LeftOperation == ast.Subtract { + switch right := rightEval.(type) { + case int64: + return -right, nil + case float64: + return -right, nil + default: + return nil, fmt.Errorf("unsupported type for arithmetic negation: %T; expected 64-bit int or float", right) + } + } else if node.LeftOperation == ast.Not { + booleanResult, isBool := rightEval.(bool) + if !isBool { + return nil, fmt.Errorf("unsupported type for boolean complement: %T; expected boolean", rightEval) + } + return !booleanResult, nil + } else { + return nil, fmt.Errorf("only unary operators '-' and '!' are currently supported; got '%s'", node.LeftOperation) + } +} + +func (c evaluateContext) evaluateFuncCall(node *ast.FunctionCall) (any, error) { funcID := node.FuncIdentifier functionSchema, found := c.functions[funcID.String()] if !found { @@ -54,7 +205,7 @@ func (c evaluateContext) evaluateFuncCall(node *ast.FunctionCall, data any) (any gotArgs := len(evaluatedArgs) if gotArgs != expectedArgs { return nil, fmt.Errorf( - "function '%s' called with incorrect number of arguments. Expected %d, got %d", + "function '%s' called with incorrect number of arguments; expected %d, got %d", funcID, expectedArgs, gotArgs) } return functionSchema.Call(evaluatedArgs) diff --git a/expression_evaluate_test.go b/expression_evaluate_test.go index ce749fa..6a0325b 100644 --- a/expression_evaluate_test.go +++ b/expression_evaluate_test.go @@ -39,6 +39,16 @@ var strToStrFunc, strToStrFuncErr = schema.NewCallableFunction( return a, nil }, ) +var intToFloatFunc, intToFloatFuncErr = schema.NewCallableFunction( + "intToFloat", + []schema.Type{schema.NewIntSchema(nil, nil, nil)}, + schema.NewFloatSchema(nil, nil, nil), + true, + nil, + func(a int64) (float64, error) { + return float64(a), nil + }, +) var twoIntToIntFunc, twoIntToIntFuncErr = schema.NewCallableFunction( "multiply", @@ -222,6 +232,685 @@ var testData = map[string]struct { false, []string{"test", "test"}, }, + "error-wrong-function-id": { + nil, + map[string]schema.CallableFunction{}, + `wrong()`, + false, + true, + nil, + }, + "error-incorrect-param-count": { + nil, + map[string]schema.CallableFunction{ + "test": voidFunc, + }, + `test("wrong")`, + false, + true, + nil, + }, + "simple-int-addition": { + nil, + nil, + `5 + 5`, + false, + false, + int64(10), + }, + "referenced-int-addition": { + map[string]int64{ + "a": 1, + "b": 2, + }, + nil, + `$.a + $.b`, + false, + false, + int64(3), + }, + "simple-int-subtraction": { + nil, + nil, + `5 - 1`, + false, + false, + int64(4), + }, + "simple-int-multiplication": { + nil, + nil, + `2 * 2`, + false, + false, + int64(4), + }, + "simple-int-division": { + nil, + nil, + `2 / 2`, + false, + false, + int64(1), + }, + "simple-int-mod": { + nil, + nil, + `3 % 2`, + false, + false, + int64(1), + }, + "simple-int-power": { + nil, + nil, + `2 ^ 3`, + false, + false, + int64(8), + }, + "simple-int-equals-same": { + nil, + nil, + `1 == 1`, + false, + false, + true, + }, + "simple-int-equals-different": { + nil, + nil, + `1 == 2`, + false, + false, + false, + }, + "simple-int-not-equals-same": { + nil, + nil, + `1 != 1`, + false, + false, + false, + }, + "simple-int-not-equals-different": { + nil, + nil, + `1 != 2`, + false, + false, + true, + }, + "simple-int-greater-than-false": { + nil, + nil, + `1 > 1`, + false, + false, + false, + }, + "simple-int-greater-than-true": { + nil, + nil, + `2 > 1`, + false, + false, + true, + }, + "simple-int-less-than-false": { + nil, + nil, + `1 < 1`, + false, + false, + false, + }, + "simple-int-less-than-true": { + nil, + nil, + `0 < 1`, + false, + false, + true, + }, + "simple-int-greater-than-equals-true": { + nil, + nil, + `1 >= 1`, + false, + false, + true, + }, + "simple-int-greater-than-equals-false": { + nil, + nil, + `0 >= 1`, + false, + false, + false, + }, + "simple-int-less-than-equals-true": { + nil, + nil, + `1 <= 1`, + false, + false, + true, + }, + "simple-int-less-than-equals-false": { + nil, + nil, + `2 <= 1`, + false, + false, + false, + }, + "simple-float-addition": { + nil, + nil, + `5.0 + 5.0`, + false, + false, + 10.0, + }, + "exponential-form-float-addition": { + nil, + nil, + `5.0E-5 + 4.0e2`, + false, + false, + 400.00005, + }, + "simple-float-subtraction": { + nil, + nil, + `5.0 - 1.0`, + false, + false, + 4.0, + }, + "simple-float-multiplication": { + nil, + nil, + `2.0 * 2.0`, + false, + false, + 4.0, + }, + "simple-float-division": { + nil, + nil, + `2.0 / 2.0`, + false, + false, + 1.0, + }, + "simple-float-mod": { + nil, + nil, + `3.0 % 2.0`, + false, + false, + 1.0, + }, + "simple-float-power": { + nil, + nil, + `2.0 ^ 3.0`, + false, + false, + 8.0, + }, + "simple-float-equals-same": { + nil, + nil, + `1.0 == 1.0`, + false, + false, + true, + }, + "simple-float-equals-different": { + nil, + nil, + `1.0 == 2.0`, + false, + false, + false, + }, + "simple-float-not-equals-same": { + nil, + nil, + `1.0 != 1.0`, + false, + false, + false, + }, + "simple-float-not-equals-different": { + nil, + nil, + `1.0 != 2.0`, + false, + false, + true, + }, + "simple-float-greater-than-false": { + nil, + nil, + `1.0 > 1.0`, + false, + false, + false, + }, + "simple-float-greater-than-true": { + nil, + nil, + `1.01 > 1.0`, + false, + false, + true, + }, + "simple-float-less-than-false": { + nil, + nil, + `1.0 < 1.0`, + false, + false, + false, + }, + "simple-float-less-than-true": { + nil, + nil, + `1.0 < 1.01`, + false, + false, + true, + }, + "simple-float-greater-than-equals-true": { + nil, + nil, + `1.0 >= 1.0`, + false, + false, + true, + }, + "simple-float-greater-than-equals-false": { + nil, + nil, + `0.1 >= 1.0`, + false, + false, + false, + }, + "simple-float-less-than-equals-true": { + nil, + nil, + `1.0 <= 1.0`, + false, + false, + true, + }, + "simple-float-less-than-equals-false": { + nil, + nil, + `1.1 <= 1.0`, + false, + false, + false, + }, + "simple-bool-equals-same": { + nil, + nil, + `false == false`, + false, + false, + true, + }, + "simple-bool-equals-different": { + nil, + nil, + `false == true`, + false, + false, + false, + }, + "simple-bool-not-equals-different": { + nil, + nil, + `false != true`, + false, + false, + true, + }, + "simple-bool-not-equals-same": { + nil, + nil, + `false != false`, + false, + false, + false, + }, + "simple-bool-and-1": { + nil, + nil, + `true && true`, + false, + false, + true, + }, + "simple-bool-and-2": { + nil, + nil, + `true && false`, + false, + false, + false, + }, + "simple-bool-and-3": { + nil, + nil, + `false && true`, + false, + false, + false, + }, + "simple-bool-and-4": { + nil, + nil, + `false && false`, + false, + false, + false, + }, + "simple-bool-or-1": { + nil, + nil, + `true || false`, + false, + false, + true, + }, + "simple-bool-or-2": { + nil, + nil, + `false || false`, + false, + false, + false, + }, + "simple-bool-or-3": { + nil, + nil, + `true || true`, + false, + false, + true, + }, + "simple-bool-or-4": { + nil, + nil, + `false || true`, + false, + false, + true, + }, + "simple-string-concatenation": { + nil, + nil, + `"a" + "b"`, + false, + false, + "ab", + }, + "simple-string-equals-1": { + nil, + nil, + `"a" == "a"`, + false, + false, + true, + }, + "simple-string-equals-2": { + nil, + nil, + `"a" == "A"`, + false, + false, + false, + }, + "simple-string-not-equals-1": { + nil, + nil, + `"a" != "a"`, + false, + false, + false, + }, + "simple-string-not-equals-2": { + nil, + nil, + `"a" != "A"`, + false, + false, + true, + }, + "simple-string-greater-false": { + nil, + nil, + `"a" > "b"`, + false, + false, + false, + }, + "simple-string-greater-true": { + nil, + nil, + `"b" > "a"`, + false, + false, + true, + }, + "simple-string-less-true": { + nil, + nil, + `"a" < "b"`, + false, + false, + true, + }, + "simple-string-less-false": { + nil, + nil, + `"c" < "b"`, + false, + false, + false, + }, + "simple-string-greater-than-equals-true": { + nil, + nil, + `"a" >= "a"`, + false, + false, + true, + }, + "simple-string-greater-than-equals-false": { + nil, + nil, + `"a" >= "b"`, + false, + false, + false, + }, + "simple-string-less-than-equals-true": { + nil, + nil, + `"a" <= "b"`, + false, + false, + true, + }, + "simple-string-less-than-equals-false": { + nil, + nil, + `"c" <= "b"`, + false, + false, + false, + }, + "error-number-and": { + nil, + nil, + `1 && 1`, + false, + true, + nil, + }, + "error-number-or": { + nil, + nil, + `1 || 1`, + false, + true, + nil, + }, + "error-bool-math": { + nil, + nil, + `true + false`, + false, + true, + nil, + }, + "error-bool-comparison": { + nil, + nil, + `true > false`, + false, + true, + nil, + }, + "error-string-math": { + nil, + nil, + `"5" - "6"`, + false, + true, + nil, + }, + "error-string-logic": { + nil, + nil, + `"5" && "6"`, + false, + true, + nil, + }, + "error-mismatched-types": { + nil, + nil, + `5 + 5.0`, + false, + true, + nil, + }, + "function-float-addition": { // An example of how you would convert the type. + nil, + map[string]schema.CallableFunction{ + "intToFloat": intToFloatFunc, + }, + `intToFloat(5) + 5.0`, + false, + false, + 10.0, + }, + "int-negation": { + nil, + nil, + `-5`, + false, + false, + int64(-5), + }, + "double-int-negation": { + nil, + nil, + `--5`, + false, + false, + int64(5), + }, + "double-parenthesized-int-negation": { + nil, + nil, + `-(-5)`, + false, + false, + int64(5), + }, + "triple-parenthesized-int-negation": { + nil, + nil, + `--(-5)`, + false, + false, + int64(-5), + }, + "float-negation": { + nil, + nil, + `-5.0`, + false, + false, + -5.0, + }, + "invalid-negation": { + nil, + nil, + `-true`, + false, + true, + nil, + }, + "negation-and-subtraction": { + nil, + nil, + `5 - -5`, + false, + false, + int64(10), + }, + "simple-not-true": { + nil, + nil, + `!true`, + false, + false, + false, + }, + "simple-not-false": { + nil, + nil, + `!false`, + false, + false, + true, + }, + "invalid-type-not": { + nil, + nil, + `!5`, + false, + true, + nil, + }, + "mixed-not": { + nil, + nil, + `!(5 != 5) && !false`, + false, + false, + true, + }, } func TestEvaluate(t *testing.T) { @@ -230,6 +919,7 @@ func TestEvaluate(t *testing.T) { assert.NoError(t, strToStrFuncErr) assert.NoError(t, twoIntToIntFuncErr) assert.NoError(t, dynamicToListFuncErr) + assert.NoError(t, intToFloatFuncErr) for name, tc := range testData { testCase := tc diff --git a/expression_type_test.go b/expression_type_test.go index cc82f48..55b9a83 100644 --- a/expression_type_test.go +++ b/expression_type_test.go @@ -212,3 +212,225 @@ func TestFunctionTypeResolution_advancedDynamicTyping(t *testing.T) { schema.NewStringSchema(nil, nil, nil), ) } + +func TestTypeResolution_BinaryMathHomogeneousIntLiterals(t *testing.T) { + // Two ints added should give an int + expr, err := expressions.New("5 + 5") + assert.NoError(t, err) + typeResult, err := expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewIntSchema(nil, nil, nil)) + expr, err = expressions.New("5 * 5") + assert.NoError(t, err) + typeResult, err = expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewIntSchema(nil, nil, nil)) +} + +func TestTypeResolution_BinaryConcatenateStrings(t *testing.T) { + expr, err := expressions.New(`"5" + "5"`) + assert.NoError(t, err) + typeResult, err := expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewStringSchema(nil, nil, nil)) +} + +func TestTypeResolution_BinaryMathHomogeneousIntReference(t *testing.T) { + // Two ints added should give an int. One int is a reference. + expr, err := expressions.New("5 + $.simple_int") + assert.NoError(t, err) + typeResult, err := expr.Type(testScope, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewIntSchema(nil, nil, nil)) + expr, err = expressions.New("$.simple_int + 5") + assert.NoError(t, err) + typeResult, err = expr.Type(testScope, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewIntSchema(nil, nil, nil)) +} + +func TestTypeResolution_BinaryMathHomogeneousFloatLiterals(t *testing.T) { + // Two floats added, subtracted, multiplied, and divided should give floats + expr, err := expressions.New("5.0 / 5.0") + assert.NoError(t, err) + typeResult, err := expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewFloatSchema(nil, nil, nil)) + expr, err = expressions.New("5.0 + 5.0") + assert.NoError(t, err) + typeResult, err = expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewFloatSchema(nil, nil, nil)) + expr, err = expressions.New("5.0 - 5.0") + assert.NoError(t, err) + typeResult, err = expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewFloatSchema(nil, nil, nil)) + expr, err = expressions.New("5.0 * 5.0") + assert.NoError(t, err) + typeResult, err = expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewFloatSchema(nil, nil, nil)) + expr, err = expressions.New("5.0 % 5.0") + assert.NoError(t, err) + typeResult, err = expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewFloatSchema(nil, nil, nil)) + expr, err = expressions.New("5.0 ^ 5.0") + assert.NoError(t, err) + typeResult, err = expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewFloatSchema(nil, nil, nil)) +} + +func TestTypeResolution_Error_BinaryHeterogeneousLiterals(t *testing.T) { + // This is designed to hit the type checker code, with an error from a mismatch in type. + expr, err := expressions.New("5 + 5.0") + assert.NoError(t, err) + _, err = expr.Type(nil, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "types do not match") +} + +func TestTypeResolution_UnaryOperation(t *testing.T) { + // Tests that the unary operator properly passes the type upwards. + expr, err := expressions.New("-5") + assert.NoError(t, err) + typeResult, err := expr.Type(nil, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewIntSchema(nil, nil, nil)) +} + +func TestTypeResolution_TestMixedMathAndFunc(t *testing.T) { + // Testing the combination of a float and a function. + intInFunc, err := schema.NewCallableFunction( + "intToFloat", + []schema.Type{schema.NewIntSchema(nil, nil, nil)}, + schema.NewFloatSchema(nil, nil, nil), + false, + nil, + func(a int64) float64 { + return float64(a) + }, + ) + assert.NoError(t, err) + funcMap := map[string]schema.Function{"intToFloat": intInFunc} + + expr, err := expressions.New("5.0 + intToFloat($.simple_int)") + assert.NoError(t, err) + typeResult, err := expr.Type(testScope, funcMap, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewFloatSchema(nil, nil, nil)) +} + +func TestTypeResolution_Error_NonBoolType(t *testing.T) { + // Non-bool type for operation that requires boolean types + expr, err := expressions.New(`0 && 1`) + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid type") +} + +func TestTypeResolution_TestMixedOperations(t *testing.T) { + // Mixture of operations producing the expected type. + intInFunc, err := schema.NewCallableFunction( + "giveFloat", + []schema.Type{}, + schema.NewFloatSchema(nil, nil, nil), + false, + nil, + func() float64 { + return 5.5 + }, + ) + assert.NoError(t, err) + funcMap := map[string]schema.Function{"giveFloat": intInFunc} + + expr, err := expressions.New("1.0 == (5.0 / giveFloat()) && !true") + assert.NoError(t, err) + typeResult, err := expr.Type(testScope, funcMap, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewBoolSchema()) +} + +func TestDependencyResolution_Error_TestSecondTypeIncorrect(t *testing.T) { + // The binary operation type-checker checks both the left and the right + // Validate that the right type is correctly validated. + expr, err := expressions.New(`5 + true`) + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid type") + assert.Contains(t, err.Error(), "right expression") +} + +func TestDependencyResolution_Error_TestInvalidTypeOnBoolean(t *testing.T) { + // Tests invalid type for relational operator + expr, err := expressions.New("true > false") + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid type") +} + +func TestDependencyResolution_TestSizeComparison(t *testing.T) { + expr, err := expressions.New("5 > 6") + assert.NoError(t, err) + typeResult, err := expr.Type(testScope, nil, nil) + assert.NoError(t, err) + assert.Equals[schema.Type](t, typeResult, schema.NewBoolSchema()) +} + +func TestDependencyResolution_Error_TestInvalidNot(t *testing.T) { + // 'not' expects boolean + expr, err := expressions.New("!5") + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-boolean type") +} + +func TestDependencyResolution_Error_TestInvalidNegation(t *testing.T) { + // 'not' expects boolean + expr, err := expressions.New("-true") + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-numeric type") +} + +func TestDependencyResolution_Error_TestComparingScopes(t *testing.T) { + // scopes cannot be compared + expr, err := expressions.New("$ > $") + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid type") +} + +func TestDependencyResolution_Error_TestAddingScopes(t *testing.T) { + // scopes cannot be added + expr, err := expressions.New("$ + $") + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid type") +} + +func TestDependencyResolution_Error_TestAddingMaps(t *testing.T) { + // maps cannot be added + expr, err := expressions.New("$.faz + $.faz") + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid type") +} +func TestDependencyResolution_Error_TestAddingLists(t *testing.T) { + // lists cannot be added + expr, err := expressions.New("$.int_list + $.int_list") + assert.NoError(t, err) + _, err = expr.Type(testScope, nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid type") +} diff --git a/go.mod b/go.mod index b4b84df..a18d046 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module go.flow.arcalot.io/expressions -go 1.18 +go 1.21 -require go.arcalot.io/assert v1.6.0 +require go.arcalot.io/assert v1.7.0 -require go.flow.arcalot.io/pluginsdk v0.7.0 +require go.flow.arcalot.io/pluginsdk v0.8.0 diff --git a/go.sum b/go.sum index 410497e..6eb874a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ -go.arcalot.io/assert v1.6.0 h1:iKA8SZZ1MRblMX5QAwwY5RbpR+VNyp//4IU7vo08Xu0= -go.arcalot.io/assert v1.6.0/go.mod h1:Xy3ScX0p9IMY89gdsgexOKxnmDr0nGHG9dV7p8Uxg7w= -go.flow.arcalot.io/pluginsdk v0.7.0 h1:5oZ9mH5KJwvhUKxPJ0hj1aAq8CZoY9CdJIaLgbRquNo= -go.flow.arcalot.io/pluginsdk v0.7.0/go.mod h1:2s2f//7uOkBjr1QaiWJD/bqDIeLlINJtD1BhiY4aGPM= +go.arcalot.io/assert v1.7.0 h1:PTLyeisNMUKpM9wXRDxResanBhuGOYO1xFK3v5b3FSw= +go.arcalot.io/assert v1.7.0/go.mod h1:nNmWPoNUHFyrPkNrD2aASm5yPuAfiWdB/4X7Lw3ykHk= +go.flow.arcalot.io/pluginsdk v0.8.0 h1:cShsshrR17ZFLcbgi3aZvqexLttcp3JISFNqPUPuDvA= +go.flow.arcalot.io/pluginsdk v0.8.0/go.mod h1:sk7ssInR/T+Gy+RSRr+QhKqZcECFFxMyn1hPQCTZSyU= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/ast/ast.go b/internal/ast/ast.go index a68c3ae..bc6dd8d 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -67,6 +67,41 @@ func (l *IntLiteral) Value() interface{} { return l.IntValue } +// FloatLiteral represents a floating point literal value in the abstract syntax +// tree. +type FloatLiteral struct { + FloatValue float64 +} + +// String returns a string representation of the float contained. +func (l *FloatLiteral) String() string { + // 'f' for full float, instead of exponential format. + // The third arg, prec, is -1 to give an exact output. + // The fourth arg specifies that we're using 64-bit floats. + return strconv.FormatFloat(l.FloatValue, 'f', -1, 64) +} + +// Value returns the float contained. +func (l *FloatLiteral) Value() interface{} { + return l.FloatValue +} + +// BooleanLiteral represents a boolean literal value in the abstract syntax +// tree. true or false +type BooleanLiteral struct { + BooleanValue bool +} + +// String returns a string representation of the boolean contained. +func (l *BooleanLiteral) String() string { + return strconv.FormatBool(l.BooleanValue) +} + +// Value returns the float contained. +func (l *BooleanLiteral) Value() interface{} { + return l.BooleanValue +} + // BracketAccessor represents a part of the abstract syntax tree that is accessing // the value at a key in a map/object, or index of a list. // The format is the value to the left, followed by an open/right square bracket, followed @@ -178,7 +213,7 @@ func (l *ArgumentList) GetChild(index int) (Node, error) { return l.Arguments[index], nil } -// String returns the identifier name. +// String gives a comma-separated list of the arguments func (l *ArgumentList) String() string { if len(l.Arguments) == 0 { return "" @@ -189,3 +224,96 @@ func (l *ArgumentList) String() string { } return result } + +type MathOperationType int + +const ( + Invalid MathOperationType = iota + Add + Subtract + Multiply + Divide + Modulus + Power + EqualTo + NotEqualTo + GreaterThan + LessThan + GreaterThanEqualTo + LessThanEqualTo + And + Or + Not +) + +func (e MathOperationType) String() string { + switch e { + case Invalid: + return "INVALID" + case Add: + return "+" + case Subtract: + return "-" + case Multiply: + return "*" + case Divide: + return "÷" + case Modulus: + return "%" + case Power: + return "^" + case EqualTo: + return "==" + case NotEqualTo: + return "!=" + case GreaterThan: + return ">" + case LessThan: + return "<" + case GreaterThanEqualTo: + return ">=" + case LessThanEqualTo: + return "<=" + case And: + return "&&" + case Or: + return "||" + case Not: + return "!" + default: + return "ENTRY MISSING" + } +} + +type BinaryOperation struct { + LeftNode Node + RightNode Node + Operation MathOperationType +} + +func (b *BinaryOperation) Right() Node { + return b.RightNode +} + +func (b *BinaryOperation) Left() Node { + return b.LeftNode +} + +// String returns the left node, followed by the operator, followed by the right node. +// The left and right nodes are clarified with (), because context that determined order of +// operations, like parentheses, are not explicitly retained in the tree. But the structure +// of the tree represents the evaluation order present in the original expression. +func (b *BinaryOperation) String() string { + return "(" + b.LeftNode.String() + ") " + b.Operation.String() + " (" + b.RightNode.String() + ")" +} + +type UnaryOperation struct { + LeftOperation MathOperationType + RightNode Node +} + +// String returns the operation, followed by the string representation of the right node. +// The wrapped node is surrounded by parentheses to remove ambiguity. +func (b *UnaryOperation) String() string { + return b.LeftOperation.String() + "(" + b.RightNode.String() + ") " +} diff --git a/internal/ast/errors.go b/internal/ast/errors.go index 46b639b..c2eb38d 100644 --- a/internal/ast/errors.go +++ b/internal/ast/errors.go @@ -22,11 +22,13 @@ type InvalidGrammarError struct { } func (e *InvalidGrammarError) Error() string { - errorMsg := fmt.Sprintf("Token \"%s\" of ID %s placed in invalid configuration in %s at line %d:%d.", + errorMsg := fmt.Sprintf("Token %q of ID %q placed in invalid configuration in %q at line %d:%d.", e.FoundToken.Value, e.FoundToken.TokenID, e.FoundToken.Filename, e.FoundToken.Line, e.FoundToken.Column) switch { - case e.ExpectedTokens == nil || len(e.ExpectedTokens) == 0: + case e.ExpectedTokens == nil: errorMsg += " Expected end of expression." + case len(e.ExpectedTokens) == 0: + errorMsg += " Expected any token." case len(e.ExpectedTokens) == 1: errorMsg += fmt.Sprintf(" Expected token \"%v\"", e.ExpectedTokens[0]) default: diff --git a/internal/ast/grammar_test.go b/internal/ast/grammar_test.go index 7bf45b4..ea51572 100644 --- a/internal/ast/grammar_test.go +++ b/internal/ast/grammar_test.go @@ -1,6 +1,7 @@ package ast import ( + "errors" "strings" "testing" @@ -395,6 +396,43 @@ func TestSubExpression(t *testing.T) { assert.Equals(t, parsedRoot, root) } +func TestParseExpression_Error_BracketAfterLiteral(t *testing.T) { + // Test error message for bracket access after literal + expression := "0[0]" + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + _, err = p.ParseExpression() + assert.Error(t, err) + assert.Equals(t, err.Error(), `bracket access cannot follow a literal; got "[" after "0"`) +} + +func TestParseExpression_Error_DotAfterLiteral(t *testing.T) { + // Test error message for dot notation after literal. + expression := "0 .a" // The space is needed for the tokens to be separated for the behavior we're testing. + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + _, err = p.ParseExpression() + assert.Error(t, err) + assert.Equals(t, err.Error(), `dot notation cannot follow a literal; got "." after "0"`) +} + +func TestParseExpression_Error_ParenthesesAfterLiteral(t *testing.T) { + // Test the error message for function call after literal. + expression := "0(0 + 0)" + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + parsedResult, err := p.ParseExpression() + assert.Error(t, err) + assert.Nil(t, parsedResult) + assert.Contains(t, err.Error(), "an opening parentheses cannot follow a literal") +} + func TestEmptyFunctionExpression(t *testing.T) { expression := "funcName()" @@ -561,11 +599,674 @@ func TestExpressionInvalidIdentifier(t *testing.T) { assert.Error(t, err) } +func TestExpression_SimpleAdd(t *testing.T) { + expression := "2 + 2" + + // 2 + 2 as tree + root := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 2}, + RightNode: &IntLiteral{IntValue: 2}, + Operation: Add, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals(t, parsedResult.(*BinaryOperation), root) +} + +func TestExpression_ThreeSub(t *testing.T) { + expression := "1.0 - 2.0 - 3.0" + + // 1.0 - 2.0 - 3.0 as tree + // - + // / \ + // - 3.0 + // / \ + // 1.0 2.0 + level2 := &BinaryOperation{ + LeftNode: &FloatLiteral{FloatValue: 1.0}, + RightNode: &FloatLiteral{FloatValue: 2.0}, + Operation: Subtract, + } + root := &BinaryOperation{ + LeftNode: level2, + RightNode: &FloatLiteral{FloatValue: 3.0}, + Operation: Subtract, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_MixedAddMultiplicationDivision(t *testing.T) { + expression := "7 + 50 * 6 / 10" + + // 7 + 50 * 6 / 10 as tree + // + + // / \ + // ÷ 7 + // / \ + // * 10 + // / \ + // 50 6 + level3 := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 50}, + RightNode: &IntLiteral{IntValue: 6}, + Operation: Multiply, + } + level2 := &BinaryOperation{ + LeftNode: level3, + RightNode: &IntLiteral{IntValue: 10}, + Operation: Divide, + } + root := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 7}, + RightNode: level2, + Operation: Add, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_Power(t *testing.T) { + expression := "1 ^ 4 * 3" + + // 1 ^ 4 * 3 as tree + // * + // / \ + // ^ 3 + // / \ + // 1 4 + level2 := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 1}, + RightNode: &IntLiteral{IntValue: 4}, + Operation: Power, + } + root := &BinaryOperation{ + LeftNode: level2, + RightNode: &IntLiteral{IntValue: 3}, + Operation: Multiply, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_PowerParentheses(t *testing.T) { + expression := "2 ^ (4 * 3)" + + // 2 ^ 4 * 3 as tree + // ^ + // / \ + // 2 * + // / \ + // 4 3 + level2 := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 4}, + RightNode: &IntLiteral{IntValue: 3}, + Operation: Multiply, + } + root := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 2}, + RightNode: level2, + Operation: Power, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_Parentheses(t *testing.T) { + expression := "(4 + 3) * 2" + + // (4 + 3) * 2 as tree + // * + // / \ + // + 2 + // / \ + // 4 3 + level2 := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 4}, + RightNode: &IntLiteral{IntValue: 3}, + Operation: Add, + } + root := &BinaryOperation{ + LeftNode: level2, + RightNode: &IntLiteral{IntValue: 2}, + Operation: Multiply, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_UnaryNegative(t *testing.T) { + expression := "5 + -5" + + // 5 + -5 as tree + // + + // / \ + // 5 - + // | + // 5 + level2 := &UnaryOperation{ + LeftOperation: Subtract, + RightNode: &IntLiteral{IntValue: 5}, + } + root := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 5}, + RightNode: level2, + Operation: Add, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_MultiNegationUnary(t *testing.T) { + expression := `---5` + level3 := &UnaryOperation{ + LeftOperation: Subtract, + RightNode: &IntLiteral{IntValue: 5}, + } + level2 := &UnaryOperation{ + LeftOperation: Subtract, + RightNode: level3, + } + root := &UnaryOperation{ + LeftOperation: Subtract, + RightNode: level2, + } + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + parsedResult, err := p.ParseExpression() + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*UnaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_MultiNegationUnaryParentheses(t *testing.T) { + expression := `-(-5)` + level2 := &UnaryOperation{ + LeftOperation: Subtract, + RightNode: &IntLiteral{IntValue: 5}, + } + root := &UnaryOperation{ + LeftOperation: Subtract, + RightNode: level2, + } + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + parsedResult, err := p.ParseExpression() + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*UnaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_MultiNotUnary(t *testing.T) { + expression := `!!true` + level2 := &UnaryOperation{ + LeftOperation: Not, + RightNode: &BooleanLiteral{BooleanValue: true}, + } + root := &UnaryOperation{ + LeftOperation: Not, + RightNode: level2, + } + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + parsedResult, err := p.ParseExpression() + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*UnaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +// In the binary operator grammar tests, not all operators are tested +// in every scenario because not every operator has its own code path. +// The per-operator tests are done in expression_evaluate_test.go + +func TestExpression_SimpleComparison(t *testing.T) { + expression := "2 > 2" + + // 2 > 2 as tree + // > + // / \ + // 2 2 + root := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 2}, + RightNode: &IntLiteral{IntValue: 2}, + Operation: GreaterThan, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_SimpleComparisonTwoToken(t *testing.T) { + expression := "2 >= 2" + + // 2 >= 2 as tree + // >= + // / \ + // 2 2 + root := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 2}, + RightNode: &IntLiteral{IntValue: 2}, + Operation: GreaterThanEqualTo, + } + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_ErrIncorrectEquals(t *testing.T) { + // In this test, we ensure that it properly rejects a single equals. A double equals is required. + expression := "2 = 2" + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + _, err = p.ParseExpression() + + assert.Error(t, err) + var grammarErr *InvalidGrammarError + ok := errors.As(err, &grammarErr) + if !ok { + t.Fatalf("Returned error is not InvalidGrammarError") + } + assert.Equals(t, grammarErr.ExpectedTokens, []TokenID{EqualsToken}) +} + +func TestExpression_MixedComparisons(t *testing.T) { + expression := "0 < 1 + 2" + + // 0 < 1 + 2 as tree + // < + // / \ + // 0 + + // / \ + // 1 2 + level2 := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 1}, + RightNode: &IntLiteral{IntValue: 2}, + Operation: Add, + } + root := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 0}, + RightNode: level2, + Operation: LessThan, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} +func TestExpression_AndLogic(t *testing.T) { + expression := "true && false" + + // true && false as tree + // && + // / \ + // true false + root := &BinaryOperation{ + LeftNode: &BooleanLiteral{BooleanValue: true}, + RightNode: &BooleanLiteral{BooleanValue: false}, + Operation: And, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + assert.Equals[Node](t, parsedResult, root) +} + +func TestExpression_AllTypes(t *testing.T) { + expression := "2 * 3 + 4 > 2 % 5 || $.test && !true" + + // 2 * 3 + 4 > 2 % 5 || $.test && !true as tree + // || + // / \ + // > && + // / \ / \ + // + % $.test ! + // / \ / \ | + // * 4 2 5 true + // / \ + // 2 3 + multiplicationNode := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 2}, + RightNode: &IntLiteral{IntValue: 3}, + Operation: Multiply, + } + addNode := &BinaryOperation{ + LeftNode: multiplicationNode, + RightNode: &IntLiteral{IntValue: 4}, + Operation: Add, + } + modNode := &BinaryOperation{ + LeftNode: &IntLiteral{IntValue: 2}, + RightNode: &IntLiteral{IntValue: 5}, + Operation: Modulus, + } + greaterThanNode := &BinaryOperation{ + LeftNode: addNode, + RightNode: modNode, + Operation: GreaterThan, + } + notNode := &UnaryOperation{ + LeftOperation: Not, + RightNode: &BooleanLiteral{BooleanValue: true}, + } + andNode := &BinaryOperation{ + LeftNode: &Identifier{IdentifierName: "$.test"}, + RightNode: notNode, + Operation: And, + } + root := &BinaryOperation{ + LeftNode: greaterThanNode, + RightNode: andNode, + Operation: Or, + } + + // Create parser + p, err := InitParser(expression, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + parsedResult, err := p.ParseExpression() + + assert.NoError(t, err) + assert.NotNil(t, parsedResult) + + assert.InstanceOf[*BinaryOperation](t, parsedResult) + // For some reason, comparing the raw results was failing falsely. + assert.Equals(t, parsedResult.String(), root.String()) +} + +// Test unexpected tokens +// This is specifically targeted for places where a specific token is always expected, +// which is where .eat is called. +func TestExpression_MismatchedPair(t *testing.T) { + bracketAccessExpr := "$.test[5)" + // Create parser + p, err := InitParser(bracketAccessExpr, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + _, err = p.ParseExpression() + + assert.Error(t, err) + var grammarErr *InvalidGrammarError + ok := errors.As(err, &grammarErr) + if !ok { + t.Fatalf("Returned error is not InvalidGrammarError") + } + assert.Equals(t, grammarErr.ExpectedTokens, []TokenID{BracketAccessDelimiterEndToken}) + + funcExpr := "5 * (5 * 5]" + // Create parser + p, err = InitParser(funcExpr, t.Name()) + + assert.NoError(t, err) + + // Parse and validate + _, err = p.ParseExpression() + + assert.Error(t, err) + ok = errors.As(err, &grammarErr) + if !ok { + t.Fatalf("Returned error is not InvalidGrammarError") + } + assert.Equals(t, grammarErr.ExpectedTokens, []TokenID{ParenthesesEndToken}) +} + func TestExpressionErrorChainLiteral(t *testing.T) { expression := `"a".a` p, err := InitParser(expression, t.Name()) assert.NoError(t, err) _, err = p.ParseExpression() assert.Error(t, err) - assert.Contains(t, err.Error(), "Expected end of expression") + assert.Contains(t, err.Error(), "dot notation cannot follow a literal") +} + +func TestParseArgs_badStart(t *testing.T) { + expression := `))` + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + err = p.advanceToken() + assert.NoError(t, err) + _, err = p.parseArgs() + assert.Error(t, err) + var grammarErr *InvalidGrammarError + ok := errors.As(err, &grammarErr) + if !ok { + t.Fatalf("Returned error is not InvalidGrammarError") + } + assert.Equals(t, grammarErr.ExpectedTokens, []TokenID{ParenthesesStartToken}) +} + +func TestParseArgs_badEnd2(t *testing.T) { + // An incomplete argument list with a missing close parentheses. + expression := `(""` + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + err = p.advanceToken() + assert.NoError(t, err) + _, err = p.parseArgs() + assert.Error(t, err) + var grammarErr *InvalidGrammarError + ok := errors.As(err, &grammarErr) + if !ok { + t.Fatalf("Returned error is not InvalidGrammarError") + } + assert.Equals(t, grammarErr.ExpectedTokens, []TokenID{ParenthesesEndToken, ListSeparatorToken}) +} + +func TestParseArgs_badSeparator(t *testing.T) { + // Testing a bad argument list + expression := `(""1` + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + err = p.advanceToken() + assert.NoError(t, err) + _, err = p.parseArgs() + assert.Error(t, err) + var grammarErr *InvalidGrammarError + ok := errors.As(err, &grammarErr) + if !ok { + t.Fatalf("Returned error is not InvalidGrammarError") + } + assert.Equals(t, grammarErr.ExpectedTokens, []TokenID{ListSeparatorToken, ParenthesesEndToken}) +} + +func TestParseArgs_endedAfterSeparator(t *testing.T) { + // Test separator in place of close parentheses + expression := `("",` + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + err = p.advanceToken() + assert.NoError(t, err) + _, err = p.parseArgs() + assert.Error(t, err) + var grammarErr *InvalidGrammarError + ok := errors.As(err, &grammarErr) + if !ok { + t.Fatalf("Returned error is not InvalidGrammarError") + } + assert.Equals(t, grammarErr.ExpectedTokens, []TokenID{ParenthesesEndToken}) +} + +func TestParseArgs_badFirstToken(t *testing.T) { + // Test a missing open parentheses. + // This is testing an edge-case that should not be hit if designed correctly. + expression := `1` + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + err = p.advanceToken() + assert.NoError(t, err) + _, err = p.parseArgs() + assert.Error(t, err) + var grammarErr *InvalidGrammarError + ok := errors.As(err, &grammarErr) + if !ok { + t.Fatalf("Returned error is not InvalidGrammarError") + } + assert.Equals(t, grammarErr.ExpectedTokens, []TokenID{ParenthesesStartToken}) +} + +func TestParseString_EscapedStrings(t *testing.T) { + expression := `"a\"b" "a\tb" "a\\b" "a\bb" "a\nb" "a\\nb" '\''` + p, err := InitParser(expression, t.Name()) + assert.NoError(t, err) + err = p.advanceToken() + assert.NoError(t, err) + result, err := p.parseStringLiteral() + assert.NoError(t, err) + assert.Equals(t, result.StrValue, `a"b`) + result, err = p.parseStringLiteral() + assert.NoError(t, err) + assert.Equals(t, result.StrValue, "a\tb") + result, err = p.parseStringLiteral() + assert.NoError(t, err) + assert.Equals(t, result.StrValue, `a\b`) + result, err = p.parseStringLiteral() + assert.NoError(t, err) + assert.Equals(t, result.StrValue, "a\bb") + result, err = p.parseStringLiteral() + assert.NoError(t, err) + assert.Equals(t, result.StrValue, "a\nb") + result, err = p.parseStringLiteral() + assert.NoError(t, err) + assert.Equals(t, result.StrValue, "a\\nb") + result, err = p.parseStringLiteral() + assert.NoError(t, err) + assert.Equals(t, result.StrValue, "'") } diff --git a/internal/ast/recursive_descent_parser.go b/internal/ast/recursive_descent_parser.go index 68fbda3..38d07ea 100644 --- a/internal/ast/recursive_descent_parser.go +++ b/internal/ast/recursive_descent_parser.go @@ -4,19 +4,33 @@ import ( "errors" "fmt" "strconv" + "strings" ) /* -Current grammar: -root_expression ::= root_identifier [expression_access] | literal | function_call -chained_expression := identifier [expression_access] -expression_access ::= map_access | dot_notation -map_access ::= "[" key "]" [chained_expression] -dot_notation ::= "." identifier [chained_expression] -root_identifier ::= identifier | "$" -literal := IntLiteralToken | StringLiteralToken -function_call := identifier "(" [argument_list] ")" -argument_list := argument_list "," root_expression | root_expression +Current grammar in Backus–Naur form: + ::= + ::= [ "|" "|" ] + ::= [ "&" "&" ] + ::= [ "!" ] + ::= [ ] + ::= ">" | "<" | ">" "=" | "<" "=" | "=" "=" | "!" "=" + ::= [ ] + ::= "+" | "-" + ::= [ ] + ::= "*" | "/" | "%" + ::= [ "^" ] + ::= | "(" ")" + ::= ["-"] + ::= | [ ] + := IdentifierToken | + := IdentifierToken "(" [ ] ")" + := [ ] + := | + := "." IdentifierToken + := "[" "]" + := IntLiteralToken | StringLiteralToken | FloatLiteralToken | BooleanLiteralToken + := [ "," ] filtering/querying will be added later if needed. */ @@ -77,11 +91,7 @@ func (p *Parser) parseBracketAccess(expressionToAccess Node) (*BracketAccessor, } // Verify and read in the ] - if p.currentToken == nil || - p.currentToken.TokenID != BracketAccessDelimiterEndToken { - return nil, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{IdentifierToken}} - } - err = p.advanceToken() + err = p.eat([]TokenID{BracketAccessDelimiterEndToken}) if err != nil { return nil, err } @@ -105,9 +115,58 @@ func (p *Parser) parseIntLiteral() (*IntLiteral, error) { return literal, nil } +func (p *Parser) parseFloatLiteral() (*FloatLiteral, error) { + if p.currentToken.TokenID != FloatLiteralToken { + return nil, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{FloatLiteralToken}} + } + parsedFloat, err := strconv.ParseFloat(p.currentToken.Value, 64) + if err != nil { + // If this happens, make sure ParseFloat's requirements match the tokenizer's requirements. + return nil, fmt.Errorf("bug: could not parse float %s (%w)", p.currentToken.Value, err) + } + literal := &FloatLiteral{FloatValue: parsedFloat} + err = p.advanceToken() + if err != nil { + return nil, err + } + return literal, nil +} + +func (p *Parser) parseBooleanLiteral() (*BooleanLiteral, error) { + if p.currentToken.TokenID != BooleanLiteralToken { + return nil, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{BooleanLiteralToken}} + } + parsedBoolean, err := strconv.ParseBool(p.currentToken.Value) + if err != nil { + return nil, err // Should not fail if the parser is set up correctly + } + literal := &BooleanLiteral{BooleanValue: parsedBoolean} + err = p.advanceToken() + if err != nil { + return nil, err + } + return literal, nil +} + +// Common escape characters +var escapeReplacer = strings.NewReplacer( + `\\`, `\`, + `\t`, "\t", + `\n`, "\n", + `\r`, "\r", + `\b`, "\b", + `\"`, `"`, + `\'`, `'`, + `\0`, "\000", +) + func (p *Parser) parseStringLiteral() (*StringLiteral, error) { // The literal token includes the "", so trim the ends off. - literal := &StringLiteral{StrValue: p.currentToken.Value[1 : len(p.currentToken.Value)-1]} + parsedString := p.currentToken.Value[1 : len(p.currentToken.Value)-1] + // Replace escaped characters + parsedString = escapeReplacer.Replace(parsedString) + // Now create the literal itself and advance the token. + literal := &StringLiteral{StrValue: parsedString} err := p.advanceToken() if err != nil { return nil, err @@ -118,10 +177,17 @@ func (p *Parser) parseStringLiteral() (*StringLiteral, error) { func (p *Parser) parseArgs() (*ArgumentList, error) { // Keep parsing expressions until you hit a comma. argNodes := make([]Node, 0) - expectedToken := ArgListStartToken + expectedToken := ParenthesesStartToken for i := 0; ; i++ { + // Check for incomplete scenario. + if p.currentToken == nil && i != 0 { // Reached end too early. + return nil, &InvalidGrammarError{ + FoundToken: p.currentToken, + ExpectedTokens: []TokenID{ParenthesesEndToken, ListSeparatorToken}, + } + } // Validate and go past the first ( on the first iteration, and commas on later iterations. - if i != 0 && p.currentToken.TokenID == ArgListEndToken { + if i != 0 && p.currentToken.TokenID == ParenthesesEndToken { // Advances past the ) err := p.advanceToken() if err != nil { @@ -130,9 +196,14 @@ func (p *Parser) parseArgs() (*ArgumentList, error) { return &ArgumentList{Arguments: argNodes}, nil } else if p.currentToken.TokenID != expectedToken { // The first is preceded by a (, the others are preceded by , + expectedTokens := []TokenID{expectedToken} + if i != 0 { + // Example: after `func(0` the next token should be `)` or `,` + expectedTokens = append(expectedTokens, ParenthesesEndToken) + } return nil, &InvalidGrammarError{ FoundToken: p.currentToken, - ExpectedTokens: []TokenID{expectedToken}, + ExpectedTokens: expectedTokens, } } @@ -141,8 +212,15 @@ func (p *Parser) parseArgs() (*ArgumentList, error) { if err != nil { return nil, err } + // Check for incomplete scenario. + if p.currentToken == nil { // Reached end too early. + return nil, &InvalidGrammarError{ + FoundToken: p.currentToken, + ExpectedTokens: []TokenID{ParenthesesEndToken}, + } + } // Check end condition - if i == 0 && p.currentToken.TokenID == ArgListEndToken { + if i == 0 && p.currentToken.TokenID == ParenthesesEndToken { // Advances past the ) err := p.advanceToken() if err != nil { @@ -198,31 +276,319 @@ func (p *Parser) ParseExpression() (Node, error) { return node, err } -var expStartIdentifierTokens = []TokenID{RootAccessToken, CurrentObjectAccessToken, IdentifierToken} -var literalTokens = []TokenID{StringLiteralToken, IntLiteralToken} -var validStartTokens = append(expStartIdentifierTokens, literalTokens...) - -// parseSubExpression parses all the dot notations, map accesses, and function calls. -func (p *Parser) parseAfterIdentifier(identifier *Identifier) (Node, error) { - var currentNode Node = identifier - // Handle types that cannot be chained first. - if p.currentToken.TokenID == ArgListStartToken { - // Function call - argList, err := p.parseArgs() +func (p *Parser) parseMathOperator() (MathOperationType, error) { + firstToken := p.currentToken.TokenID + err := p.advanceToken() + if err != nil { + return Invalid, err + } + switch firstToken { + case PlusToken: + return Add, nil + case NegationToken: + return Subtract, nil + case AsteriskToken: + return Multiply, nil + case DivideToken: + return Divide, nil + case PowerToken: + return Power, nil + case ModulusToken: + return Modulus, nil + case NotToken, GreaterThanToken, LessThanToken, EqualsToken: + // Need to validate and advance past the following = + if p.currentToken != nil && p.currentToken.TokenID == EqualsToken { + // Equals is next, so return based on the token preceding the = token. + err := p.advanceToken() + if err != nil { + return Invalid, err + } + switch firstToken { + case NotToken: + return NotEqualTo, nil + case GreaterThanToken: + return GreaterThanEqualTo, nil + case LessThanToken: + return LessThanEqualTo, nil + case EqualsToken: + return EqualTo, nil + default: + // If you get here, there is a case missing here that is in the outer switch + panic(fmt.Errorf("illegal code state hit after token %s", firstToken)) + } + } else { + // No token, or non-equals token next, so validate as a single token. + switch firstToken { + case GreaterThanToken: + return GreaterThan, nil + case LessThanToken: + return LessThan, nil + case NotToken: + return Not, nil + case EqualsToken: + // Expected double equals, but got single equals + return Invalid, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{EqualsToken}} + default: + // If you get here, there is a case missing here that is in the outer switch + panic(fmt.Errorf("illegal code state hit after token %s", firstToken)) + } + } + case AndToken: + if p.currentToken == nil || p.currentToken.TokenID != AndToken { + return Invalid, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{AndToken}} + } + err := p.advanceToken() + if err != nil { + return Invalid, err + } + return And, nil + case OrToken: + if p.currentToken == nil || p.currentToken.TokenID != OrToken { + return Invalid, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{OrToken}} + } + err := p.advanceToken() + if err != nil { + return Invalid, err + } + return Or, nil + default: + return Invalid, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{ + PlusToken, + NegationToken, + AsteriskToken, + DivideToken, + PowerToken, + NotToken, + GreaterThanToken, + LessThanToken, + EqualsToken, + AndToken, + OrToken, + ModulusToken, + }} + } +} + +// parseBinaryExpression parses a binary expression that has one of the supported operators, +// and uses childNodeParser for the left and right of the node. +// If, after parsing the first operand, the operator is not present, then the function returns +// successfully. +func (p *Parser) parseBinaryExpression(supportedOperators []TokenID, childNodeParser func() (Node, error)) (Node, error) { + root, err := childNodeParser() + if err != nil { + return nil, err + } + // Loop to allow non-recursively evaluated repeating compatible operations. + // Necessary for proper order of operations as currently designed. + for p.currentToken != nil && sliceContains(supportedOperators, p.currentToken.TokenID) { + operatorToken, err := p.parseMathOperator() if err != nil { return nil, err } - currentNode = &FunctionCall{ - FuncIdentifier: identifier, - ArgumentInputs: argList, + right, err := childNodeParser() + if err != nil { + return nil, err + } + root = &BinaryOperation{ + LeftNode: root, + RightNode: right, + Operation: operatorToken, } } - for { - switch { - case p.currentToken == nil: - // Reached end - return currentNode, nil - case p.currentToken.TokenID == DotObjectAccessToken: + return root, nil +} + +// parseLeftUnaryExpression parses an expression with the operator on the left, and the rest of the expression +// on the right. If the expected token is not there, it continues recursively with childNodeParser. +func (p *Parser) parseLeftUnaryExpression(supportedOperators []TokenID, childNodeParser func() (Node, error)) (Node, error) { + if p.currentToken == nil { + return nil, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{}} + } + if sliceContains(supportedOperators, p.currentToken.TokenID) { + operation, err := p.parseMathOperator() + if err != nil { + return nil, err + } + subNode, err := p.parseRootExpression() + if err != nil { + return nil, err + } + return &UnaryOperation{ + LeftOperation: operation, + RightNode: subNode, + }, nil + } + return childNodeParser() +} + +// ORDER OF OPERATIONS +// - negation +// - Parentheses +// - Exponent +// - Multiplication and Division +// - Addition Subtraction +// - Comparisons +// - not +// - and +// - or +// The higher-precedence ones should be deepest in the call tree. So logical `or` should be called first. +// For more details, see the grammar at the top of this file. + +func (p *Parser) parseRootExpression() (Node, error) { + // Currently `or` is the first one to call based on the order of operations specified above, + // and based on the grammar specified at the top of the file. + return p.parseConditionalOr() +} + +func (p *Parser) parseConditionalOr() (Node, error) { + return p.parseBinaryExpression([]TokenID{OrToken}, p.parseConditionalAnd) +} + +func (p *Parser) parseConditionalAnd() (Node, error) { + return p.parseBinaryExpression([]TokenID{AndToken}, p.parseConditionalNot) +} + +func (p *Parser) parseConditionalNot() (Node, error) { + return p.parseLeftUnaryExpression([]TokenID{NotToken}, p.parseComparisonExpression) +} + +func (p *Parser) parseComparisonExpression() (Node, error) { + // The allowed tokens are the FIRST ones associated with a binary comparison. The parseMathOperator func called by + // parseBinaryExpression will handle the second token, if present. + return p.parseBinaryExpression([]TokenID{GreaterThanToken, LessThanToken, NotToken, EqualsToken}, p.parseAdditionSubtraction) +} + +func (p *Parser) parseAdditionSubtraction() (Node, error) { + return p.parseBinaryExpression([]TokenID{PlusToken, NegationToken}, p.parseMultiplicationDivision) +} + +func (p *Parser) parseMultiplicationDivision() (Node, error) { + return p.parseBinaryExpression([]TokenID{AsteriskToken, DivideToken, ModulusToken}, p.parseExponents) +} + +func (p *Parser) parseExponents() (Node, error) { + return p.parseBinaryExpression([]TokenID{PowerToken}, p.parseParentheses) +} + +func (p *Parser) parseParentheses() (Node, error) { + // If parentheses, continue recursing back from the root. + // If not parentheses, recurse down into negation. + if p.currentToken.TokenID != ParenthesesStartToken { + return p.parseNegationOperation() + } + err := p.advanceToken() // Go past the parentheses + if err != nil { + return nil, err + } + node, err := p.parseRootExpression() + if err != nil { + return nil, err + } + err = p.eat([]TokenID{ParenthesesEndToken}) + if err != nil { + return nil, err + } + return node, nil +} + +func (p *Parser) parseNegationOperation() (Node, error) { + return p.parseLeftUnaryExpression([]TokenID{NegationToken}, p.parseValueOrAccessExpression) +} + +var literalTokens = []TokenID{StringLiteralToken, IntLiteralToken, BooleanLiteralToken, FloatLiteralToken} +var identifierTokens = []TokenID{IdentifierToken, RootAccessToken} +var validRootValueOrAccessStartTokens = append(literalTokens, identifierTokens...) +var validValueOrAccessStartTokens = append(validRootValueOrAccessStartTokens, CurrentObjectAccessToken) + +// parseValueOrAccessExpression parses a root expression +func (p *Parser) parseValueOrAccessExpression() (Node, error) { + if p.currentToken == nil || !sliceContains(validValueOrAccessStartTokens, p.currentToken.TokenID) { + return nil, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: validValueOrAccessStartTokens} + } else if p.atRoot && p.currentToken.TokenID == CurrentObjectAccessToken { + // Can't support @/CurrentObjectAccessToken at root + return nil, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: validRootValueOrAccessStartTokens} + } + p.atRoot = false // Know when you can reference the current object. + + var literalNode Node + var err error + // A value or access expression can start with a literal, or an identifier. + // If an identifier, it can lead to a chain or a function. + switch p.currentToken.TokenID { + case StringLiteralToken: + literalNode, err = p.parseStringLiteral() + case IntLiteralToken: + literalNode, err = p.parseIntLiteral() + case FloatLiteralToken: + literalNode, err = p.parseFloatLiteral() + case BooleanLiteralToken: + literalNode, err = p.parseBooleanLiteral() + default: + // We have a valid token that isn't a literal. + return p.parseIdentifierOrFunction() + } + // Literal case + if err != nil { + return nil, err + } + // Lookahead validation for nothing incorrect following the literal for better error messages. + if p.currentToken != nil { // Nothing after, so likely valid. + switch p.currentToken.TokenID { + // These are all access start tokens which cannot follow a literal. + case ParenthesesStartToken: + return nil, fmt.Errorf("an opening parentheses cannot follow a literal; got %q after %q", p.currentToken.Value, literalNode.String()) + case DotObjectAccessToken: + return nil, fmt.Errorf("dot notation cannot follow a literal; got %q after %q", p.currentToken.Value, literalNode.String()) + case BracketAccessDelimiterStartToken: + return nil, fmt.Errorf("bracket access cannot follow a literal; got %q after %q", p.currentToken.Value, literalNode.String()) + } + } + return literalNode, nil +} + +// Parses the current identifier, parses the arg list if available, then checks for chainable accesses. +// Expects to be called when the current node is an identifier. +func (p *Parser) parseIdentifierOrFunction() (Node, error) { + firstNode := &Identifier{IdentifierName: p.currentToken.Value} + err := p.advanceToken() + if err != nil { + return nil, err + } + chainableNode, err := p.parseFunctionArgs(firstNode) + if err != nil { + return nil, err + } + if p.currentToken == nil { + // Nothing follows, so stop here + return chainableNode, nil + } + return p.parseChainedAccess(chainableNode) +} + +// parseFunctionArgs parses all parts of a function call that follow the identifier, including the parentheses. +// If a parameter list is not found, it returns the identifier. +func (p *Parser) parseFunctionArgs(precedingNode *Identifier) (Node, error) { + if p.currentToken == nil || p.currentToken.TokenID != ParenthesesStartToken { + // No function call. Return the original input for chaining. + return precedingNode, nil + } + argList, err := p.parseArgs() + if err != nil { + return nil, err + } + return &FunctionCall{ + FuncIdentifier: precedingNode, + ArgumentInputs: argList, + }, nil +} + +// parseChainedAccess parses all the dot notations, map accesses, binary operations, and function calls. +// Must be called after Identifier, FunctionCall, or another chained access node. +func (p *Parser) parseChainedAccess(rootNode Node) (Node, error) { + var currentNode = rootNode + for p.currentToken != nil { + switch p.currentToken.TokenID { + case DotObjectAccessToken: // Dot notation err := p.advanceToken() // Move past the . if err != nil { @@ -233,7 +599,7 @@ func (p *Parser) parseAfterIdentifier(identifier *Identifier) (Node, error) { return nil, err } currentNode = &DotNotation{LeftAccessibleNode: currentNode, RightAccessIdentifier: accessingIdentifier} - case p.currentToken.TokenID == BracketAccessDelimiterStartToken: + case BracketAccessDelimiterStartToken: // Bracket notation parsedMapAccess, err := p.parseBracketAccess(currentNode) if err != nil { @@ -245,46 +611,19 @@ func (p *Parser) parseAfterIdentifier(identifier *Identifier) (Node, error) { return currentNode, nil } } + return currentNode, nil } -// parseRootExpression parses a root expression -func (p *Parser) parseRootExpression() (Node, error) { - if p.currentToken == nil || !sliceContains(validStartTokens, p.currentToken.TokenID) { - return nil, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: validStartTokens} - } else if p.atRoot && p.currentToken.TokenID == CurrentObjectAccessToken { - // Can't support @ at root - return nil, &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: []TokenID{RootAccessToken, IdentifierToken}} - } - if p.atRoot { - p.atRoot = false // Know when you can reference the current object. - } - - // An expression can start with a literal, or an identifier. If an identifier, it can lead to a chain or a function. - if sliceContains(literalTokens, p.currentToken.TokenID) { - switch literalToken := p.currentToken.TokenID; literalToken { - case StringLiteralToken: - return p.parseStringLiteral() - case IntLiteralToken: - return p.parseIntLiteral() - default: - return nil, fmt.Errorf( - "bug: Literal token type %s is missing from switch in parseUnchainedRootExpression", - literalTokens) - } - } - // The literal case is accounted for, so if it gets here it's an identifier. That can lead to a chain or function call. - var firstIdentifier = &Identifier{IdentifierName: p.currentToken.Value} - err := p.advanceToken() - if err != nil { - return nil, err - } - if p.currentToken != nil { - return p.parseAfterIdentifier(firstIdentifier) - } else { - return firstIdentifier, nil +// eat validates then goes past the given token. +// For use when you know which tokens are required. +func (p *Parser) eat(validTokens []TokenID) error { + if p.currentToken == nil || !sliceContains(validTokens, p.currentToken.TokenID) { + return &InvalidGrammarError{FoundToken: p.currentToken, ExpectedTokens: validTokens} } + return p.advanceToken() } +// sliceContains is here to support versions of go before slices.Contains was added in Go 1.21 func sliceContains(slice []TokenID, value TokenID) bool { for _, val := range slice { if val == value { diff --git a/internal/ast/tokenizer.go b/internal/ast/tokenizer.go index 5a3ccc3..9c89099 100644 --- a/internal/ast/tokenizer.go +++ b/internal/ast/tokenizer.go @@ -19,6 +19,10 @@ const ( StringLiteralToken TokenID = "string" // IntLiteralToken represents an integer token. Must not start with 0. IntLiteralToken TokenID = "int" + // FloatLiteralToken represents a float token. + FloatLiteralToken TokenID = "float" + // BooleanLiteralToken represents true or false. + BooleanLiteralToken TokenID = "boolean" // BracketAccessDelimiterStartToken represents the token before an object // access. The '[' in 'obj["key"]'. //nolint:gosec @@ -27,10 +31,10 @@ const ( // access. The '[' in 'obj["key"]'. //nolint:gosec BracketAccessDelimiterEndToken TokenID = "map-delimiter-end" - // ArgListStartToken represents the start token of a argument list. '(' - ArgListStartToken TokenID = "args-start" - // ArgListEndToken represents the closing of the argument list. ')' - ArgListEndToken TokenID = "args-end" + // ParenthesesStartToken represents the start token of an argument list or a parenthesized expression. '(' + ParenthesesStartToken TokenID = "parentheses-start" + // ParenthesesEndToken represents the closing of the argument list. ')' + ParenthesesEndToken TokenID = "parentheses-end" // DotObjectAccessToken represents the '.' token in 'a.b' (dot notation). DotObjectAccessToken TokenID = "object-access" // RootAccessToken represents the token that identifies accessing the @@ -49,10 +53,28 @@ const ( // NegationToken represents a negation sign '-'. //nolint:gosec NegationToken TokenID = "negation-sign" - // WildcardToken represents a wildcard token '*'. - WildcardToken TokenID = "wildcard" + // AsteriskToken represents a wildcard/multiplication token '*'. + AsteriskToken TokenID = "asterisk" // ListSeparatorToken represents a comma in a parameter list ListSeparatorToken TokenID = "list-separator" //nolint:gosec // not a security credential + // DivideToken represents the forward slash used to specify division. + DivideToken TokenID = "divide" + // GreaterThanToken represents a > symbol. + GreaterThanToken TokenID = "greater-than" + // LessThanToken represents a < symbol. + LessThanToken TokenID = "less-than" + // PlusToken represents a + symbol. + PlusToken TokenID = "plus" + // NotToken represents an ! symbol. + NotToken TokenID = "not" + // PowerToken represents a caret symbol for exponentiation. + PowerToken TokenID = "power" + // ModulusToken represents a percent symbol for remainder. + ModulusToken TokenID = "mod" + // AndToken represents logical-and && + AndToken TokenID = "and" + // OrToken represents logical-or || + OrToken TokenID = "or" // UnknownToken is a placeholder for when there was an error in the token. UnknownToken TokenID = "error" ) @@ -82,22 +104,33 @@ type tokenPattern struct { } var tokenPatterns = []tokenPattern{ - {IntLiteralToken, regexp.MustCompile(`^0$|^[1-9]\d*$`)}, // Note: numbers that start with 0 are identifiers. - {IdentifierToken, regexp.MustCompile(`^\w+$`)}, // Any valid object name - {StringLiteralToken, regexp.MustCompile(`^".*"$|^'.*'$`)}, // "string example" - {BracketAccessDelimiterStartToken, regexp.MustCompile(`^\[$`)}, // the [ in map["key"] - {BracketAccessDelimiterEndToken, regexp.MustCompile(`^]$`)}, // the ] in map["key"] - {ArgListStartToken, regexp.MustCompile(`^\($`)}, // ( - {ArgListEndToken, regexp.MustCompile(`^\)$`)}, // ) - {DotObjectAccessToken, regexp.MustCompile(`^\.$`)}, // . - {RootAccessToken, regexp.MustCompile(`^\$$`)}, // $ - {CurrentObjectAccessToken, regexp.MustCompile(`^@$`)}, // @ - {EqualsToken, regexp.MustCompile(`^=$`)}, // = - {SelectorToken, regexp.MustCompile(`^:$`)}, // : - {FilterToken, regexp.MustCompile(`^\?$`)}, // ? - {NegationToken, regexp.MustCompile(`^-$`)}, // - - {WildcardToken, regexp.MustCompile(`^\*$`)}, // * - {ListSeparatorToken, regexp.MustCompile(`^,$`)}, // , + {BooleanLiteralToken, regexp.MustCompile(`^true|false$`)}, // true or false. Note: This needs to be above IdentifierToken + {FloatLiteralToken, regexp.MustCompile(`^\d+\.\d*(?:[eE][+-]?\d+)?$`)}, // Like an integer, but with a period and digits after. + {IntLiteralToken, regexp.MustCompile(`^(?:0|[1-9]\d*)$`)}, // Note: numbers that start with 0 are identifiers. + {IdentifierToken, regexp.MustCompile(`^\w+$`)}, // Any valid object name + {StringLiteralToken, regexp.MustCompile(`^(?:".*"|'.*')$`)}, // "string example" + {BracketAccessDelimiterStartToken, regexp.MustCompile(`^\[$`)}, // the [ in map["key"] + {BracketAccessDelimiterEndToken, regexp.MustCompile(`^]$`)}, // the ] in map["key"] + {ParenthesesStartToken, regexp.MustCompile(`^\($`)}, // ( + {ParenthesesEndToken, regexp.MustCompile(`^\)$`)}, // ) + {DotObjectAccessToken, regexp.MustCompile(`^\.$`)}, // . + {RootAccessToken, regexp.MustCompile(`^\$$`)}, // $ + {CurrentObjectAccessToken, regexp.MustCompile(`^@$`)}, // @ + {EqualsToken, regexp.MustCompile(`^=$`)}, // = + {SelectorToken, regexp.MustCompile(`^:$`)}, // : + {FilterToken, regexp.MustCompile(`^\?$`)}, // ? + {NegationToken, regexp.MustCompile(`^-$`)}, // - + {AsteriskToken, regexp.MustCompile(`^\*$`)}, // * + {ListSeparatorToken, regexp.MustCompile(`^,$`)}, // , + {DivideToken, regexp.MustCompile(`^/$`)}, // / + {GreaterThanToken, regexp.MustCompile(`^>$`)}, // > + {LessThanToken, regexp.MustCompile(`^<$`)}, // < + {PlusToken, regexp.MustCompile(`^\+$`)}, // + + {NotToken, regexp.MustCompile(`^!$`)}, // ! + {PowerToken, regexp.MustCompile(`^\^$`)}, // ^ + {ModulusToken, regexp.MustCompile(`^%$`)}, // % + {AndToken, regexp.MustCompile(`^&$`)}, // && + {OrToken, regexp.MustCompile(`^\|$`)}, // || } // initTokenizer initializes the tokenizer struct with the given expression. diff --git a/internal/ast/tokenizer_test.go b/internal/ast/tokenizer_test.go index c715c55..22c127b 100644 --- a/internal/ast/tokenizer_test.go +++ b/internal/ast/tokenizer_test.go @@ -27,11 +27,11 @@ func TestTokenizer(t *testing.T) { {"credentials", IdentifierToken, filename, 1, 43}, {"[", BracketAccessDelimiterStartToken, filename, 1, 54}, {"f", IdentifierToken, filename, 1, 55}, - {"(", ArgListStartToken, filename, 1, 56}, + {"(", ParenthesesStartToken, filename, 1, 56}, {"1", IntLiteralToken, filename, 1, 57}, {",", ListSeparatorToken, filename, 1, 58}, {"2", IntLiteralToken, filename, 1, 59}, - {")", ArgListEndToken, filename, 1, 60}, + {")", ParenthesesEndToken, filename, 1, 60}, {"]", BracketAccessDelimiterEndToken, filename, 1, 61}, } for _, expected := range expectedValue { @@ -46,7 +46,7 @@ func TestTokenizer(t *testing.T) { } } -func TestTokenizerWithEscapedStr(t *testing.T) { +func TestTokenizer_TokenizerWithEscapedStr(t *testing.T) { input := `$.output["ab\"|cd"]` tokenizer := initTokenizer(input, filename) expectedValue := []string{"$", ".", "output", "[", `"ab\"|cd"`, "]"} @@ -58,7 +58,35 @@ func TestTokenizerWithEscapedStr(t *testing.T) { } } -func TestWithFilterType(t *testing.T) { +func TestTokenizer_BinaryOperations(t *testing.T) { + input := `5 + 5 / 1 >= 5^5` + tokenizer := initTokenizer(input, filename) + expectedValue := []TokenValue{ + {"5", IntLiteralToken, filename, 1, 1}, + {"+", PlusToken, filename, 1, 3}, + {"5", IntLiteralToken, filename, 1, 5}, + {"/", DivideToken, filename, 1, 7}, + {"1", IntLiteralToken, filename, 1, 9}, + {">", GreaterThanToken, filename, 1, 11}, + {"=", EqualsToken, filename, 1, 12}, + {"5", IntLiteralToken, filename, 1, 14}, + {"^", PowerToken, filename, 1, 15}, + {"5", IntLiteralToken, filename, 1, 16}, + } + for _, expected := range expectedValue { + assert.Equals(t, tokenizer.hasNextToken(), true) + nextToken, err := tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, nextToken.Value, expected.Value) + assert.Equals(t, nextToken.TokenID, expected.TokenID) + assert.Equals(t, nextToken.Filename, expected.Filename) + assert.Equals(t, nextToken.Line, expected.Line) + assert.Equals(t, nextToken.Column, expected.Column) + } + assert.Equals(t, tokenizer.hasNextToken(), false) +} + +func TestTokenizer_WithFilterType(t *testing.T) { input := "$.steps.foo.outputs[\"bar\"][?(@._type=='x')].a" tokenizer := initTokenizer(input, filename) expectedValue := []string{"$", ".", "steps", ".", "foo", ".", "outputs", @@ -71,19 +99,19 @@ func TestWithFilterType(t *testing.T) { } } -func TestInvalidToken(t *testing.T) { - input := "[&" +func TestTokenizer_InvalidToken(t *testing.T) { + input := "[€" tokenizer := initTokenizer(input, filename) assert.Equals(t, tokenizer.hasNextToken(), true) tokenVal, err := tokenizer.getNext() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equals(t, tokenVal.TokenID, BracketAccessDelimiterStartToken) assert.Equals(t, tokenVal.Value, "[") assert.Equals(t, tokenizer.hasNextToken(), true) tokenVal, err = tokenizer.getNext() - assert.NotNil(t, err) + assert.Error(t, err) assert.Equals(t, tokenVal.TokenID, UnknownToken) - assert.Equals(t, tokenVal.Value, "&") + assert.Equals(t, tokenVal.Value, "€") expectedError := &InvalidTokenError{} isCorrectErrType := errors.As(err, &expectedError) if !isCorrectErrType { @@ -92,27 +120,123 @@ func TestInvalidToken(t *testing.T) { assert.Equals(t, expectedError.InvalidToken.Column, 2) assert.Equals(t, expectedError.InvalidToken.Line, 1) assert.Equals(t, expectedError.InvalidToken.Filename, filename) - assert.Equals(t, expectedError.InvalidToken.Value, "&") + assert.Equals(t, expectedError.InvalidToken.Value, "€") } -func TestIntLiteral(t *testing.T) { +func TestTokenizer_IntLiteral(t *testing.T) { input := "70 07" tokenizer := initTokenizer(input, filename) assert.Equals(t, tokenizer.hasNextToken(), true) tokenVal, err := tokenizer.getNext() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equals(t, tokenVal.TokenID, IntLiteralToken) assert.Equals(t, tokenVal.Value, "70") assert.Equals(t, tokenizer.hasNextToken(), true) // Numbers that start with 0 are interpreted as octal by the string tokenizer, // resulting in an error printed to stderr. It doesn't change the behavior. tokenVal, err = tokenizer.getNext() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equals(t, tokenVal.TokenID, IdentifierToken) assert.Equals(t, tokenVal.Value, "07") } -func TestWildcard(t *testing.T) { +func TestTokenizer_FloatLiteral(t *testing.T) { + input := "0.0 40.099 5.0e5 5.0E-5 05.00 5." + tokenizer := initTokenizer(input, filename) + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err := tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, FloatLiteralToken) + assert.Equals(t, tokenVal.Value, "0.0") + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, FloatLiteralToken) + assert.Equals(t, tokenVal.Value, "40.099") + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, FloatLiteralToken) + assert.Equals(t, tokenVal.Value, "5.0e5") + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, FloatLiteralToken) + assert.Equals(t, tokenVal.Value, "5.0E-5") + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, FloatLiteralToken) + assert.Equals(t, tokenVal.Value, "05.00") + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, FloatLiteralToken) + assert.Equals(t, tokenVal.Value, "5.") + assert.Equals(t, tokenizer.hasNextToken(), false) + +} + +func TestTokenizer_BooleanLiterals(t *testing.T) { + input := "true && false || false" + tokenizer := initTokenizer(input, filename) + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err := tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, BooleanLiteralToken) + assert.Equals(t, tokenVal.Value, "true") + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, AndToken) + assert.Equals(t, tokenVal.Value, "&") + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, AndToken) + assert.Equals(t, tokenVal.Value, "&") + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, BooleanLiteralToken) + assert.Equals(t, tokenVal.Value, "false") + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, OrToken) + assert.Equals(t, tokenVal.Value, "|") + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, OrToken) + assert.Equals(t, tokenVal.Value, "|") + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, BooleanLiteralToken) + assert.Equals(t, tokenVal.Value, "false") + assert.Equals(t, tokenizer.hasNextToken(), false) +} + +func TestTokenizer_StringLiteral(t *testing.T) { + input := `"" "a" "a\"b"` + tokenizer := initTokenizer(input, filename) + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err := tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, StringLiteralToken) + assert.Equals(t, tokenVal.Value, `""`) + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, StringLiteralToken) + assert.Equals(t, tokenVal.Value, `"a"`) + assert.Equals(t, tokenizer.hasNextToken(), true) + tokenVal, err = tokenizer.getNext() + assert.NoError(t, err) + assert.Equals(t, tokenVal.TokenID, StringLiteralToken) + assert.Equals(t, tokenVal.Value, `"a\"b"`) + assert.Equals(t, tokenizer.hasNextToken(), false) +} + +func TestTokenizer_Wildcard(t *testing.T) { input := `$.*` tokenizer := initTokenizer(input, filename) expectedValue := []string{"$", ".", "*"}