Skip to content

Commit

Permalink
AST: Allow children of commutative nodes to be reordered cost-wise.
Browse files Browse the repository at this point in the history
This adds several things:

 * Commutative functions are marked as such.
 * AST functions now have an associated cost depending on their computational complexity.
 * Evaluation will now reorder the children of a commutative node to execute the cheapest first.

This would have the benefit, thanks to the lazy evaluation of AST nodes (circuit breaking), to
improve overall performance for those scenarii.
  • Loading branch information
Antoine Popineau committed Jan 16, 2025
1 parent d0d3c04 commit 47f3847
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 15 deletions.
1 change: 1 addition & 0 deletions models/ast/ast_aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ var FuncAggregatorAttributes = FuncAttributes{
DebugName: "FUNC_AGGREGATOR",
AstName: "Aggregator",
NamedArguments: []string{"tableName", "fieldName", "aggregator", "filters", "label"},
Cost: 50,
}
15 changes: 11 additions & 4 deletions models/ast/ast_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ type FuncAttributes struct {
// considering for every one of them if evaluation should continue or not. For the result value of one child, the
// function returns whether evaluation of subsequent children should continue (true) or not (false).
LazyChildEvaluation func(NodeEvaluation) bool
// Commutative indicates this function can treat its children as a commutative list of arguments, and that
// they can be reordered without changing the outcome of the function.
Commutative bool
// Cost modelizes the computation cost of a given node, the default being zero.
Cost int
}

// If number of arguments -1 the function can take any number of arguments
Expand Down Expand Up @@ -143,14 +148,16 @@ var FuncAttributesMap = map[Function]FuncAttributes{
AstName: "Not",
},
FUNC_AND: {
DebugName: "FUNC_AND",
AstName: "And",
DebugName: "FUNC_AND",
AstName: "And",
Commutative: true,
// Boolean AND returns false if any child node evaluates to false
LazyChildEvaluation: shortCircuitIfFalse,
},
FUNC_OR: {
DebugName: "FUNC_OR",
AstName: "Or",
DebugName: "FUNC_OR",
AstName: "Or",
Commutative: true,
// Boolean OR returns true if any child nodes evluates to true
LazyChildEvaluation: shortCircuitIfTrue,
},
Expand Down
19 changes: 19 additions & 0 deletions models/ast/ast_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
)

type Node struct {
Index int

// A node is a constant xOR a function
Function Function
Constant any
Expand Down Expand Up @@ -49,3 +51,20 @@ func (node Node) ReadConstantNamedChildString(name string) (string, error) {
}
return value, nil
}

// Cost calculates the weights of an AST subtree to reorder, when the parent is commutative,
// nodes to prioritize faster ones.
func (node Node) Cost() int {
selfCost := 0
childCost := 0

if attrs, err := node.Function.Attributes(); err == nil {
selfCost = attrs.Cost
}

for _, ch := range node.Children {
childCost += ch.Cost()
}

return selfCost + childCost
}
7 changes: 7 additions & 0 deletions models/ast/ast_node_evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ import (
)

type NodeEvaluation struct {
// Index of the initial node in the AST tree, used to reorder the results as they were.
// This should become obsolete when each node has a unique ID.
Index int
// Skipped indicates whether this node was evaluated at all or not. A `true` values means the
// engine determined the result of this node would not impact the overall decision's outcome.
Skipped bool

Function Function
ReturnValue any
Errors []error
Expand Down
2 changes: 2 additions & 0 deletions models/ast/node_evaluation_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type NodeEvaluationDto struct {
Errors []EvaluationErrorDto `json:"errors"`
Children []NodeEvaluationDto `json:"children,omitempty"`
NamedChildren map[string]NodeEvaluationDto `json:"named_children,omitempty"`
Skipped bool `json:"skipped"`
}

func AdaptNodeEvaluationDto(evaluation NodeEvaluation) NodeEvaluationDto {
Expand All @@ -29,6 +30,7 @@ func AdaptNodeEvaluationDto(evaluation NodeEvaluation) NodeEvaluationDto {
Errors: pure_utils.Map(evaluation.Errors, AdaptEvaluationErrorDto),
Children: pure_utils.Map(evaluation.Children, AdaptNodeEvaluationDto),
NamedChildren: pure_utils.MapValues(evaluation.NamedChildren, AdaptNodeEvaluationDto),
Skipped: evaluation.Skipped,
}
}

Expand Down
6 changes: 5 additions & 1 deletion usecases/ast_eval/evaluate_ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node
// Early exit for constant, because it should have no children.
if node.Function == ast.FUNC_CONSTANT {
return ast.NodeEvaluation{
Index: node.Index,
Function: node.Function,
ReturnValue: node.Constant,
Errors: []error{},
Expand All @@ -37,10 +38,13 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node
return
}

weightedNodes := NewWeightedNodes(environment, node, node.Children)

// eval each child
evaluation := ast.NodeEvaluation{
Index: node.Index,
Function: node.Function,
Children: pure_utils.MapWhile(node.Children, evalChild),
Children: weightedNodes.Reorder(pure_utils.MapWhile(weightedNodes.Sorted(), evalChild)),
NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild),
}

Expand Down
38 changes: 35 additions & 3 deletions usecases/ast_eval/evaluate_ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/checkmarble/marble-backend/models/ast"
"github.com/checkmarble/marble-backend/usecases/ast_eval/evaluate"
"github.com/checkmarble/marble-backend/utils"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -94,7 +95,7 @@ func NewAstOrFalse() ast.Node {
}

func TestLazyAnd(t *testing.T) {
environment := NewAstEvaluationEnvironment()
environment := NewAstEvaluationEnvironment().WithoutCostOptimizations()

for _, value := range []bool{true, false} {
root := ast.Node{Function: ast.FUNC_AND}.
Expand All @@ -117,7 +118,7 @@ func TestLazyAnd(t *testing.T) {
}

func TestLazyOr(t *testing.T) {
environment := NewAstEvaluationEnvironment()
environment := NewAstEvaluationEnvironment().WithoutCostOptimizations()

for _, value := range []bool{true, false} {
root := ast.Node{Function: ast.FUNC_OR}.
Expand Down Expand Up @@ -154,7 +155,7 @@ func TestLazyBooleanNulls(t *testing.T) {
{ast.FUNC_AND, utils.Ptr(false), nil, utils.Ptr(false)},
}

environment := NewAstEvaluationEnvironment()
environment := NewAstEvaluationEnvironment().WithoutCostOptimizations()

for _, tt := range tts {
root := ast.Node{Function: tt.fn}
Expand All @@ -178,3 +179,34 @@ func TestLazyBooleanNulls(t *testing.T) {
}
}
}

const TEST_FUNC_COSTLY = -10

type costlyNode struct{}

func (costlyNode) Evaluate(ctx context.Context, arguments ast.Arguments) (any, []error) {
return evaluate.MakeEvaluateResult(false)
}

func TestAggregatesOrderedLast(t *testing.T) {
ast.FuncAttributesMap[TEST_FUNC_COSTLY] = ast.FuncAttributes{
Cost: 1000,
}

defer delete(ast.FuncAttributesMap, TEST_FUNC_COSTLY)

environment := NewAstEvaluationEnvironment()
environment.AddEvaluator(TEST_FUNC_COSTLY, costlyNode{})

root := ast.Node{Function: ast.FUNC_OR}.
AddChild(ast.Node{Function: TEST_FUNC_COSTLY}).
AddChild(ast.Node{Constant: true})

evaluation, ok := EvaluateAst(context.TODO(), environment, root)

assert.True(t, ok)
assert.Equal(t, ast.NodeEvaluation{Index: 0, Skipped: true, ReturnValue: nil}, evaluation.Children[0])
assert.Equal(t, false, evaluation.Children[1].Skipped)
assert.Equal(t, true, evaluation.Children[1].ReturnValue)
assert.Equal(t, true, evaluation.ReturnValue)
}
18 changes: 16 additions & 2 deletions usecases/ast_eval/evaluate_environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (
)

type AstEvaluationEnvironment struct {
availableFunctions map[ast.Function]evaluate.Evaluator
disableCircuitBreaking bool
availableFunctions map[ast.Function]evaluate.Evaluator
disableCostOptimizations bool
disableCircuitBreaking bool
}

func (environment *AstEvaluationEnvironment) AddEvaluator(function ast.Function, evaluator evaluate.Evaluator) {
Expand All @@ -28,12 +29,25 @@ func (environment *AstEvaluationEnvironment) GetEvaluator(function ast.Function)
return nil, errors.New(fmt.Sprintf("function '%s' is not available", function.DebugString()))
}

func (environment AstEvaluationEnvironment) WithoutOptimizations() AstEvaluationEnvironment {
environment.disableCostOptimizations = true
environment.disableCircuitBreaking = true

return environment
}

func (environment AstEvaluationEnvironment) WithoutCircuitBreaking() AstEvaluationEnvironment {
environment.disableCircuitBreaking = true

return environment
}

func (environment AstEvaluationEnvironment) WithoutCostOptimizations() AstEvaluationEnvironment {
environment.disableCostOptimizations = true

return environment
}

func NewAstEvaluationEnvironment() AstEvaluationEnvironment {
environment := AstEvaluationEnvironment{
availableFunctions: make(map[ast.Function]evaluate.Evaluator),
Expand Down
73 changes: 73 additions & 0 deletions usecases/ast_eval/weighted_nodes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package ast_eval

import (
"cmp"
"slices"

"github.com/checkmarble/marble-backend/models/ast"
)

// Weighted nodes manages a flat list of nodes and offers an interface to process
// them sorted by node cost. A lower-cost node will be executed earlier when the
// parent is commutative.
//
// The parent is passed to the constructor, so that if it is not commutative, this
// is basically a no-op.
type WeightedNodes struct {
enabled bool
original []ast.Node
}

func NewWeightedNodes(env AstEvaluationEnvironment, parent ast.Node, nodes []ast.Node) WeightedNodes {
enabled := false

if !env.disableCostOptimizations {
if fattrs, err := parent.Function.Attributes(); err == nil {
enabled = fattrs.Commutative
}
}

if enabled {
for idx := range nodes {
nodes[idx].Index = idx
}
}

return WeightedNodes{
enabled: enabled,
original: nodes,
}
}

func (wn WeightedNodes) Sorted() []ast.Node {
if !wn.enabled {
return wn.original
}

return slices.SortedFunc(slices.Values(wn.original), func(lhs, rhs ast.Node) int {
return cmp.Compare(lhs.Cost(), rhs.Cost())
})
}

func (wn WeightedNodes) Reorder(results []ast.NodeEvaluation) []ast.NodeEvaluation {
if !wn.enabled {
return results
}

output := make([]ast.NodeEvaluation, len(wn.original))

for idx := range wn.original {
output[idx] = ast.NodeEvaluation{
Index: idx,
Skipped: true,
ReturnValue: nil,
}
}

for _, result := range results {
output[result.Index] = result
output[result.Index].Skipped = false
}

return output
}
3 changes: 1 addition & 2 deletions usecases/scenarios/scenario_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ func (validator *AstValidatorImpl) MakeDryRunEnvironment(ctx context.Context,
ClientObject: clientObject,
DataModel: dataModel,
DatabaseAccessReturnFakeValue: true,
}).
WithoutCircuitBreaking()
}).WithoutOptimizations()

return env, nil
}
6 changes: 3 additions & 3 deletions usecases/scenarios/scenario_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestValidateScenarioIterationImpl_Validate(t *testing.T) {
validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
return ast_eval.NewAstEvaluationEnvironment().WithoutOptimizations()
},
ExecutorFactory: executorFactory,
}
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestValidateScenarioIterationImpl_Validate_notBool(t *testing.T) {
validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
return ast_eval.NewAstEvaluationEnvironment().WithoutOptimizations()
},
ExecutorFactory: executorFactory,
}
Expand Down Expand Up @@ -283,7 +283,7 @@ func TestValidationShouldBypassCircuitBreaking(t *testing.T) {
validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
return ast_eval.NewAstEvaluationEnvironment().WithoutOptimizations()
},
ExecutorFactory: executorFactory,
}
Expand Down

0 comments on commit 47f3847

Please sign in to comment.