Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AST: Allow children of commutative nodes to be reordered cost-wise. #800

Merged
merged 4 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions models/ast/ast_aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ var FuncAggregatorAttributes = FuncAttributes{
DebugName: "FUNC_AGGREGATOR",
AstName: "Aggregator",
NamedArguments: []string{"tableName", "fieldName", "aggregator", "filters", "label"},
Cost: 50,
}
1 change: 1 addition & 0 deletions models/ast/ast_custom_list_attr.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ var AttributeFuncCustomListAccess = struct {
NamedArguments: []string{
"customListId",
},
Cost: 30,
},
ArgumentCustomListId: "customListId",
}
Expand Down
17 changes: 13 additions & 4 deletions models/ast/ast_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ type FuncAttributes struct {
// considering for every one of them if evaluation should continue or not. For the result value of one child, the
// function returns whether evaluation of subsequent children should continue (true) or not (false).
LazyChildEvaluation func(NodeEvaluation) bool
// Commutative indicates this function can treat its children as a commutative list of arguments, and that
// they can be reordered without changing the outcome of the function.
Commutative bool
// Cost modelizes the computation cost of a given node, the default being zero.
Cost int
}

// If number of arguments -1 the function can take any number of arguments
Expand Down Expand Up @@ -143,14 +148,16 @@ var FuncAttributesMap = map[Function]FuncAttributes{
AstName: "Not",
},
FUNC_AND: {
DebugName: "FUNC_AND",
AstName: "And",
DebugName: "FUNC_AND",
AstName: "And",
Commutative: true,
// Boolean AND returns false if any child node evaluates to false
LazyChildEvaluation: shortCircuitIfFalse,
},
FUNC_OR: {
DebugName: "FUNC_OR",
AstName: "Or",
DebugName: "FUNC_OR",
AstName: "Or",
Commutative: true,
// Boolean OR returns true if any child nodes evluates to true
LazyChildEvaluation: shortCircuitIfTrue,
},
Expand All @@ -170,6 +177,7 @@ var FuncAttributesMap = map[Function]FuncAttributes{
FUNC_PAYLOAD: {
DebugName: "FUNC_PAYLOAD",
AstName: "Payload",
Cost: 30,
},
FUNC_DB_ACCESS: AttributeFuncDbAccess.FuncAttributes,
FUNC_CUSTOM_LIST_ACCESS: AttributeFuncCustomListAccess.FuncAttributes,
Expand Down Expand Up @@ -282,6 +290,7 @@ var AttributeFuncDbAccess = struct {
NamedArguments: []string{
"tableName", "fieldName", "path",
},
Cost: 30,
},
ArgumentTableName: "tableName",
ArgumentFieldName: "fieldName",
Expand Down
22 changes: 22 additions & 0 deletions models/ast/ast_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
)

type Node struct {
Index int

// A node is a constant xOR a function
Function Function
Constant any
Expand Down Expand Up @@ -49,3 +51,23 @@ func (node Node) ReadConstantNamedChildString(name string) (string, error) {
}
return value, nil
}

// Cost calculates the weights of an AST subtree to reorder, when the parent is commutative,
// nodes to prioritize faster ones.
func (node Node) Cost() int {
selfCost := 0
childCost := 0

if attrs, err := node.Function.Attributes(); err == nil {
selfCost = attrs.Cost
}

for _, ch := range node.Children {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also gotta do the same on name children (for keeping generality at least, not sure if the case can be seen currently)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't named children already order-agnostic? They're both stored in a map, which does not have a reproducible order, and I assume are used out-of-order?

Will look into it tomorrow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they are, however for cost computation the named child nodes should also have their cost taken into account

childCost += ch.Cost()
}
for _, ch := range node.NamedChildren {
childCost += ch.Cost()
}

return selfCost + childCost
}
8 changes: 8 additions & 0 deletions models/ast/ast_node_evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ import (
)

type NodeEvaluation struct {
// Index of the initial node winhin its level of the AST tree, used to
// reorder the results as they were. This should become obsolete when each
// node has a unique ID.
Index int
// Skipped indicates whether this node was evaluated at all or not. A `true` values means the
// engine determined the result of this node would not impact the overall decision's outcome.
Skipped bool
Pascal-Delange marked this conversation as resolved.
Show resolved Hide resolved

Function Function
ReturnValue any
Errors []error
Expand Down
26 changes: 26 additions & 0 deletions models/ast/ast_node_weight_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package ast

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestNodeWeights(t *testing.T) {
tts := []struct {
n Node
c int
}{
{Node{Function: FUNC_AND, Children: []Node{{Function: FUNC_DB_ACCESS}, {Function: FUNC_PAYLOAD}}}, 60},
{Node{Function: FUNC_AND, Children: []Node{{Function: FUNC_DB_ACCESS}, {
Function: FUNC_ADD, Children: []Node{{
Function: FUNC_AGGREGATOR,
Children: []Node{{Function: FUNC_CUSTOM_LIST_ACCESS}, {Function: FUNC_PAYLOAD}},
}},
}}}, 140},
}

for _, tt := range tts {
assert.Equal(t, tt.c, tt.n.Cost())
}
}
2 changes: 2 additions & 0 deletions models/ast/node_evaluation_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type NodeEvaluationDto struct {
Errors []EvaluationErrorDto `json:"errors"`
Children []NodeEvaluationDto `json:"children,omitempty"`
NamedChildren map[string]NodeEvaluationDto `json:"named_children,omitempty"`
Skipped bool `json:"skipped"`
}

func AdaptNodeEvaluationDto(evaluation NodeEvaluation) NodeEvaluationDto {
Expand All @@ -29,6 +30,7 @@ func AdaptNodeEvaluationDto(evaluation NodeEvaluation) NodeEvaluationDto {
Errors: pure_utils.Map(evaluation.Errors, AdaptEvaluationErrorDto),
Children: pure_utils.Map(evaluation.Children, AdaptNodeEvaluationDto),
NamedChildren: pure_utils.MapValues(evaluation.NamedChildren, AdaptNodeEvaluationDto),
Skipped: evaluation.Skipped,
}
}

Expand Down
6 changes: 5 additions & 1 deletion usecases/ast_eval/evaluate_ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node
// Early exit for constant, because it should have no children.
if node.Function == ast.FUNC_CONSTANT {
return ast.NodeEvaluation{
Index: node.Index,
Function: node.Function,
ReturnValue: node.Constant,
Errors: []error{},
Expand All @@ -37,10 +38,13 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node
return
}

weightedNodes := NewWeightedNodes(environment, node, node.Children)

// eval each child
evaluation := ast.NodeEvaluation{
Index: node.Index,
Function: node.Function,
Children: pure_utils.MapWhile(node.Children, evalChild),
Children: weightedNodes.Reorder(pure_utils.MapWhile(weightedNodes.Sorted(), evalChild)),
NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild),
}

Expand Down
38 changes: 35 additions & 3 deletions usecases/ast_eval/evaluate_ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/checkmarble/marble-backend/models/ast"
"github.com/checkmarble/marble-backend/usecases/ast_eval/evaluate"
"github.com/checkmarble/marble-backend/utils"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -94,7 +95,7 @@ func NewAstOrFalse() ast.Node {
}

func TestLazyAnd(t *testing.T) {
environment := NewAstEvaluationEnvironment()
environment := NewAstEvaluationEnvironment().WithoutCostOptimizations()

for _, value := range []bool{true, false} {
root := ast.Node{Function: ast.FUNC_AND}.
Expand All @@ -117,7 +118,7 @@ func TestLazyAnd(t *testing.T) {
}

func TestLazyOr(t *testing.T) {
environment := NewAstEvaluationEnvironment()
environment := NewAstEvaluationEnvironment().WithoutCostOptimizations()

for _, value := range []bool{true, false} {
root := ast.Node{Function: ast.FUNC_OR}.
Expand Down Expand Up @@ -154,7 +155,7 @@ func TestLazyBooleanNulls(t *testing.T) {
{ast.FUNC_AND, utils.Ptr(false), nil, utils.Ptr(false)},
}

environment := NewAstEvaluationEnvironment()
environment := NewAstEvaluationEnvironment().WithoutCostOptimizations()

for _, tt := range tts {
root := ast.Node{Function: tt.fn}
Expand All @@ -178,3 +179,34 @@ func TestLazyBooleanNulls(t *testing.T) {
}
}
}

const TEST_FUNC_COSTLY = -10

type costlyNode struct{}

func (costlyNode) Evaluate(ctx context.Context, arguments ast.Arguments) (any, []error) {
return evaluate.MakeEvaluateResult(false)
}

func TestAggregatesOrderedLast(t *testing.T) {
ast.FuncAttributesMap[TEST_FUNC_COSTLY] = ast.FuncAttributes{
Cost: 1000,
}

defer delete(ast.FuncAttributesMap, TEST_FUNC_COSTLY)

environment := NewAstEvaluationEnvironment()
environment.AddEvaluator(TEST_FUNC_COSTLY, costlyNode{})

root := ast.Node{Function: ast.FUNC_OR}.
AddChild(ast.Node{Function: TEST_FUNC_COSTLY}).
AddChild(ast.Node{Constant: true})

evaluation, ok := EvaluateAst(context.TODO(), environment, root)

assert.True(t, ok)
assert.Equal(t, ast.NodeEvaluation{Index: 0, Skipped: true, ReturnValue: nil}, evaluation.Children[0])
assert.Equal(t, false, evaluation.Children[1].Skipped)
assert.Equal(t, true, evaluation.Children[1].ReturnValue)
assert.Equal(t, true, evaluation.ReturnValue)
}
18 changes: 16 additions & 2 deletions usecases/ast_eval/evaluate_environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (
)

type AstEvaluationEnvironment struct {
availableFunctions map[ast.Function]evaluate.Evaluator
disableCircuitBreaking bool
availableFunctions map[ast.Function]evaluate.Evaluator
disableCostOptimizations bool
disableCircuitBreaking bool
}

func (environment *AstEvaluationEnvironment) AddEvaluator(function ast.Function, evaluator evaluate.Evaluator) {
Expand All @@ -28,12 +29,25 @@ func (environment *AstEvaluationEnvironment) GetEvaluator(function ast.Function)
return nil, errors.New(fmt.Sprintf("function '%s' is not available", function.DebugString()))
}

func (environment AstEvaluationEnvironment) WithoutOptimizations() AstEvaluationEnvironment {
environment.disableCostOptimizations = true
environment.disableCircuitBreaking = true

return environment
}

func (environment AstEvaluationEnvironment) WithoutCircuitBreaking() AstEvaluationEnvironment {
environment.disableCircuitBreaking = true

return environment
}

func (environment AstEvaluationEnvironment) WithoutCostOptimizations() AstEvaluationEnvironment {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking my understanding here: this is necessary so we don't have to reorder error results that we return to the frontend, correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Without* methods are mostly for testing (when you're testing one of the two optimizations, so the other one does not bother the test), and for validation of scenario (so optimizations do not change the given scenario).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, was your comment on the right line?

If it was about the whole re-ordering / de-re-ordering instead, yes, it is done to ensure the tree results are sent in the order the frontend expects it to be, since it relies on that specific structure to match results with their node.

That would not be necessary if every node had a unique ID, but since that is not the case, we need this.

environment.disableCostOptimizations = true

return environment
}

func NewAstEvaluationEnvironment() AstEvaluationEnvironment {
environment := AstEvaluationEnvironment{
availableFunctions: make(map[ast.Function]evaluate.Evaluator),
Expand Down
73 changes: 73 additions & 0 deletions usecases/ast_eval/weighted_nodes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package ast_eval

import (
"cmp"
"slices"

"github.com/checkmarble/marble-backend/models/ast"
)

// Weighted nodes manages a flat list of nodes and offers an interface to process
// them sorted by node cost. A lower-cost node will be executed earlier when the
// parent is commutative.
//
// The parent is passed to the constructor, so that if it is not commutative, this
// is basically a no-op.
type WeightedNodes struct {
enabled bool
original []ast.Node
}

func NewWeightedNodes(env AstEvaluationEnvironment, parent ast.Node, nodes []ast.Node) WeightedNodes {
enabled := false

if !env.disableCostOptimizations {
if fattrs, err := parent.Function.Attributes(); err == nil {
enabled = fattrs.Commutative
}
}

if enabled {
for idx := range nodes {
nodes[idx].Index = idx
}
}

return WeightedNodes{
enabled: enabled,
original: nodes,
}
}

func (wn WeightedNodes) Sorted() []ast.Node {
if !wn.enabled {
return wn.original
}

return slices.SortedFunc(slices.Values(wn.original), func(lhs, rhs ast.Node) int {
return cmp.Compare(lhs.Cost(), rhs.Cost())
})
}

func (wn WeightedNodes) Reorder(results []ast.NodeEvaluation) []ast.NodeEvaluation {
if !wn.enabled {
return results
}

output := make([]ast.NodeEvaluation, len(wn.original))

for idx := range wn.original {
output[idx] = ast.NodeEvaluation{
Index: idx,
Skipped: true,
ReturnValue: nil,
}
}

for _, result := range results {
output[result.Index] = result
output[result.Index].Skipped = false
}

return output
}
3 changes: 1 addition & 2 deletions usecases/scenarios/scenario_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ func (validator *AstValidatorImpl) MakeDryRunEnvironment(ctx context.Context,
ClientObject: clientObject,
DataModel: dataModel,
DatabaseAccessReturnFakeValue: true,
}).
WithoutCircuitBreaking()
}).WithoutOptimizations()

return env, nil
}
6 changes: 3 additions & 3 deletions usecases/scenarios/scenario_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestValidateScenarioIterationImpl_Validate(t *testing.T) {
validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
return ast_eval.NewAstEvaluationEnvironment().WithoutOptimizations()
},
ExecutorFactory: executorFactory,
}
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestValidateScenarioIterationImpl_Validate_notBool(t *testing.T) {
validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
return ast_eval.NewAstEvaluationEnvironment().WithoutOptimizations()
},
ExecutorFactory: executorFactory,
}
Expand Down Expand Up @@ -283,7 +283,7 @@ func TestValidationShouldBypassCircuitBreaking(t *testing.T) {
validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
return ast_eval.NewAstEvaluationEnvironment().WithoutOptimizations()
},
ExecutorFactory: executorFactory,
}
Expand Down
Loading