Skip to content

Commit f13ea02

Browse files
authored
BED-5464 - Additional CySQL Fixes (#1189)
* fix: BP-1396, BED-5464: support additional query components; rework translate with to support better group by logic * chore: PFC * chore: cleanup
1 parent 6c06125 commit f13ea02

File tree

24 files changed

+684
-442
lines changed

24 files changed

+684
-442
lines changed

packages/go/cypher/models/pgsql/test/translation_cases/multipart.sql

+21-1
Large diffs are not rendered by default.

packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql

+4-4
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,13 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from
169169
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not exists (select 1 from edge e0 where e0.start_id = (s0.n0).id or e0.end_id = (s0.n0).id);
170170

171171
-- case: match (s) where not (s)-[]->()-[]->() return s
172-
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, edge e0 join node n0 on (s0.n0).id = e0.start_id join node n1 on n1.id = e0.end_id), s2 as (select s1.e0 as e0, (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e1, s1.n0 as n0, s1.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from s1 join edge e1 on (s1.n1).id = e1.start_id join node n2 on n2.id = e1.end_id) select count(*) > 0 from s2);
172+
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0 join edge e0 on (s0.n0).id = e0.start_id join node n1 on n1.id = e0.end_id), s2 as (select s1.e0 as e0, (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e1, s1.n0 as n0, s1.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from s1 join edge e1 on (s1.n1).id = e1.start_id join node n2 on n2.id = e1.end_id) select count(*) > 0 from s2);
173173

174174
-- case: match (s) where not (s)-[{prop: 'a'}]-({name: 'n3'}) return s
175-
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on (s0.n0).id = e0.end_id or (s0.n0).id = e0.start_id join node n1 on n1.id = e0.end_id or n1.id = e0.start_id where e0.properties ->> 'prop' = 'a') select count(*) > 0 from s1);
175+
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, edge e0 join node n0 on (s0.n0).id = e0.end_id or (s0.n0).id = e0.start_id join node n1 on n1.id = e0.end_id or n1.id = e0.start_id where e0.properties ->> 'prop' = 'a') select count(*) > 0 from s1);
176176

177177
-- case: match (s) where not (s)<-[{prop: 'a'}]-({name: 'n3'}) return s
178-
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, edge e0 join node n0 on (s0.n0).id = e0.end_id join node n1 on n1.id = e0.start_id where n1.properties ->> 'name' = 'n3' and e0.properties ->> 'prop' = 'a') select count(*) > 0 from s1);
178+
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0 join edge e0 on (s0.n0).id = e0.end_id join node n1 on n1.id = e0.start_id where n1.properties ->> 'name' = 'n3' and e0.properties ->> 'prop' = 'a') select count(*) > 0 from s1);
179179

180180
-- case: match (n:NodeKind1) where n.distinguishedname = toUpper('admin') return n
181181
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.properties ->> 'distinguishedname' = upper('admin')::text and n0.kind_ids operator (pg_catalog.&&) array [1]::int2[]) select s0.n0 as n from s0;
@@ -190,7 +190,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from
190190
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.properties ->> 'distinguishedname' like '%' || upper('admin')::text and n0.kind_ids operator (pg_catalog.&&) array [1]::int2[]) select s0.n0 as n from s0;
191191

192192
-- case: match (s) where not (s)-[{prop: 'a'}]->({name: 'n3'}) return s
193-
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, edge e0 join node n0 on (s0.n0).id = e0.start_id join node n1 on n1.id = e0.end_id where n1.properties ->> 'name' = 'n3' and e0.properties ->> 'prop' = 'a') select count(*) > 0 from s1);
193+
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0 join edge e0 on (s0.n0).id = e0.start_id join node n1 on n1.id = e0.end_id where n1.properties ->> 'name' = 'n3' and e0.properties ->> 'prop' = 'a') select count(*) > 0 from s1);
194194

195195
-- case: match (s) where not (s)-[]-() return id(s)
196196
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select (s0.n0).id from s0 where not exists (select 1 from edge e0 where e0.start_id = (s0.n0).id or e0.end_id = (s0.n0).id);

packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::e
7474
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where n1.kind_ids operator (pg_catalog.&&) array [2]::int2[] and e0.kind_id = any (array [3, 4]::int2[]) and n0.kind_ids operator (pg_catalog.&&) array [1]::int2[]) select (s0.n0).properties -> 'name', (s0.n1).properties -> 'name' from s0;
7575

7676
-- case: match (s)-[r:EdgeKind1]->() where (s)-[r {prop: 'a'}]->() return s
77-
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.properties ->> 'prop' = 'a' and e0.kind_id = any (array [3]::int2[])) select s0.n0 as s from s0 where (with s1 as (select s0.e0 as e0, s0.n0 as n0, s0.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from s0, edge e0 join node n0 on (s0.n0).id = (s0.e0).start_id join node n2 on n2.id = (s0.e0).end_id) select count(*) > 0 from s1);
77+
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.properties ->> 'prop' = 'a' and e0.kind_id = any (array [3]::int2[])) select s0.n0 as s from s0 where (with s1 as (select s0.e0 as e0, s0.n0 as n0, s0.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from s0 join edge e0 on (s0.n0).id = (s0.e0).start_id join node n2 on n2.id = (s0.e0).end_id) select count(*) > 0 from s1);
7878

7979
-- case: match (s)-[r:EdgeKind1]->(e) where not (s.system_tags contains 'admin_tier_0') and id(e) = 1 return id(s), labels(s), id(r), type(r)
8080
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where n1.id = 1 and e0.kind_id = any (array [3]::int2[]) and not (coalesce(n0.properties ->> 'system_tags', '')::text like '%admin_tier_0%')) select (s0.n0).id, (s0.n0).kind_ids, (s0.e0).id, (s0.e0).kind_id from s0;

packages/go/cypher/models/pgsql/test/translation_test.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ func TestTranslate(t *testing.T) {
5555
)
5656

5757
if updateCases, varSet := os.LookupEnv("CYSQL_UPDATE_CASES"); varSet && strings.ToLower(strings.TrimSpace(updateCases)) == "true" {
58-
if err := UpdateTranslationTestCases(kindMapper); err != nil {
58+
if err := UpdateTranslationTestCases(kindMapper); err != nil {
5959
fmt.Printf("Error updating cases: %v\n", err)
6060
}
6161
}
6262

63-
6463
if testCases, err := ReadTranslationTestCases(); err != nil {
6564
t.Fatal(err)
6665
} else {

packages/go/cypher/models/pgsql/translate/building.go

+29-105
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,32 @@ package translate
1818

1919
import (
2020
"errors"
21+
2122
"github.com/specterops/bloodhound/cypher/models/pgsql"
2223
)
2324

2425
func (s *Translator) buildInlineProjection(part *QueryPart) (pgsql.Select, error) {
25-
var sqlSelect pgsql.Select
26-
27-
if part.projections.Frame != nil {
28-
sqlSelect.From = []pgsql.FromClause{{
29-
Source: part.projections.Frame.Binding.Identifier,
30-
}}
26+
sqlSelect := pgsql.Select{
27+
Where: part.projections.Constraints,
3128
}
3229

33-
if projectionConstraint, err := s.treeTranslator.ConsumeAll(); err != nil {
34-
return sqlSelect, err
35-
} else {
36-
sqlSelect.Where = projectionConstraint.Expression
30+
// If there's a projection frame set, some additional negotiation is required to identify which frame the
31+
// from-statement should be written to. Some of this would be better figured out during the translation
32+
// of the projection where query scope and other components are not yet fully translated.
33+
if part.projections.Frame != nil {
34+
// Look up to see if there are CTE expressions registered. If there are then it is likely
35+
// there was a projection between this CTE and the previous multipart query part
36+
hasCTEs := part.Model.CommonTableExpressions != nil && len(part.Model.CommonTableExpressions.Expressions) > 0
37+
38+
if part.Frame.Previous == nil || hasCTEs {
39+
sqlSelect.From = []pgsql.FromClause{{
40+
Source: part.projections.Frame.Binding.Identifier,
41+
}}
42+
} else {
43+
sqlSelect.From = []pgsql.FromClause{{
44+
Source: part.Frame.Previous.Binding.Identifier,
45+
}}
46+
}
3747
}
3848

3949
for _, projection := range part.projections.Items {
@@ -161,7 +171,7 @@ func (s *Translator) translateTraversalPatternPart(part *PatternPart, isolatedPr
161171
}
162172

163173
func (s *Translator) translateTraversalPatternPartWithoutExpansion(isFirstTraversalStep bool, traversalStep *PatternSegment) error {
164-
if constraints, err := s.patternConstraints(isFirstTraversalStep, nonRecursivePattern, traversalStep); err != nil {
174+
if constraints, err := consumePatternConstraints(isFirstTraversalStep, nonRecursivePattern, traversalStep, s.treeTranslator.IdentifierConstraints); err != nil {
165175
return err
166176
} else {
167177
if isFirstTraversalStep {
@@ -237,7 +247,7 @@ func (s *Translator) translateTraversalPatternPartWithoutExpansion(isFirstTraver
237247
} else {
238248
// Zip through all projected identifiers and update their last projected frame
239249
for _, binding := range boundProjections.Bindings {
240-
binding.LastProjection = traversalStep.Frame
250+
binding.MaterializedBy(traversalStep.Frame)
241251
}
242252

243253
traversalStep.Projection = boundProjections.Items
@@ -246,94 +256,8 @@ func (s *Translator) translateTraversalPatternPartWithoutExpansion(isFirstTraver
246256
return nil
247257
}
248258

249-
type PatternConstraints struct {
250-
LeftNode *Constraint
251-
Edge *Constraint
252-
RightNode *Constraint
253-
}
254-
255-
// OptimizePatternConstraintBalance considers the constraints that apply to a pattern segment's bound identifiers.
256-
//
257-
// If only the right side of the pattern segment is constrained, this could result in an imbalanced expansion where one side
258-
// of the traversal has an extreme disparity in search space.
259-
//
260-
// In cases that match this heuristic, it's beneficial to begin the traversal with the most tightly constrained set
261-
// of nodes. To accomplish this we flip the order of the traversal step.
262-
func (s *PatternConstraints) OptimizePatternConstraintBalance(traversalStep *PatternSegment) {
263-
var (
264-
// If the left node is previously bound (query knows a set of IDs) the left node is considered to sill be constrained
265-
leftNodeHasConstraints = traversalStep.LeftNodeBound || s.LeftNode.Expression != nil
266-
rightNodeHasConstraints = s.RightNode.Expression != nil
267-
)
268-
269-
// (a)-[*..]->(b:Constraint)
270-
// (a)<-[*..]-(b:Constraint)
271-
if !leftNodeHasConstraints && rightNodeHasConstraints {
272-
traversalStep.FlipNodes()
273-
s.FlipNodes()
274-
}
275-
}
276-
277-
func (s *PatternConstraints) FlipNodes() {
278-
oldLeftNode := s.LeftNode
279-
s.LeftNode = s.RightNode
280-
s.RightNode = oldLeftNode
281-
}
282-
283-
const (
284-
recursivePattern = true
285-
nonRecursivePattern = false
286-
)
287-
288-
func (s *Translator) patternConstraints(isFirstTraversalStep, isRecursivePattern bool, traversalStep *PatternSegment) (PatternConstraints, error) {
289-
var (
290-
constraints PatternConstraints
291-
err error
292-
)
293-
294-
// Even if this isn't the first traversal and the node may be already bound, this should result in an empty
295-
// constraint instead of a nil value for `leftNode`
296-
if constraints.LeftNode, err = consumeConstraintsFrom(pgsql.AsIdentifierSet(traversalStep.LeftNode.Identifier), s.treeTranslator.IdentifierConstraints); err != nil {
297-
return constraints, err
298-
}
299-
300-
if isFirstTraversalStep {
301-
// If this is the first traversal step then the left node is just coming into scope
302-
traversalStep.Frame.Export(traversalStep.LeftNode.Identifier)
303-
}
304-
305-
// Track the identifiers visible at this frame to correctly assign the remaining constraints
306-
knownBindings := traversalStep.Frame.Known()
307-
308-
if isRecursivePattern {
309-
// The exclusion below is done at this step in the process since the recursive descent portion of an expansion
310-
// will no longer have a reference to the root node; any dependent interaction between the root and terminal
311-
// nodes would require an additional join. By not consuming the remaining constraints for the root and terminal
312-
// nodes, they become visible up in the outer select of the recursive CTE.
313-
knownBindings.Remove(traversalStep.LeftNode.Identifier)
314-
}
315-
316-
// Export the edge identifier first
317-
traversalStep.Frame.Export(traversalStep.Edge.Identifier)
318-
knownBindings.Add(traversalStep.Edge.Identifier)
319-
320-
if constraints.Edge, err = consumeConstraintsFrom(knownBindings, s.treeTranslator.IdentifierConstraints); err != nil {
321-
return constraints, err
322-
}
323-
324-
// Export the right node identifier last
325-
traversalStep.Frame.Export(traversalStep.RightNode.Identifier)
326-
knownBindings.Add(traversalStep.RightNode.Identifier)
327-
328-
if constraints.RightNode, err = consumeConstraintsFrom(knownBindings, s.treeTranslator.IdentifierConstraints); err != nil {
329-
return constraints, err
330-
}
331-
332-
return constraints, nil
333-
}
334-
335259
func (s *Translator) translateTraversalPatternPartWithExpansion(isFirstTraversalStep bool, traversalStep *PatternSegment) error {
336-
if constraints, err := s.patternConstraints(isFirstTraversalStep, recursivePattern, traversalStep); err != nil {
260+
if constraints, err := consumePatternConstraints(isFirstTraversalStep, recursivePattern, traversalStep, s.treeTranslator.IdentifierConstraints); err != nil {
337261
return err
338262
} else {
339263
// If one side of the expansion has constraints but the other does not this may be an opportunity to reorder the traversal
@@ -392,15 +316,15 @@ func (s *Translator) translateTraversalPatternPartWithExpansion(isFirstTraversal
392316
traversalStep.Expansion.Value.RecursiveConstraints = pgsql.OptionalAnd(traversalStep.Expansion.Value.ExpansionEdgeConstraints, expansionConstraints(expansionFrame.Binding.Identifier, traversalStep.Expansion.Value.MinDepth, traversalStep.Expansion.Value.MaxDepth))
393317

394318
// Remove the previous projections of the root and terminal node to reproject them after expansion
395-
traversalStep.LeftNode.LastProjection = nil
396-
traversalStep.RightNode.LastProjection = nil
319+
traversalStep.LeftNode.Dematerialize()
320+
traversalStep.RightNode.Dematerialize()
397321

398322
if boundProjections, err := buildVisibleProjections(s.query.Scope); err != nil {
399323
return err
400324
} else {
401325
// Zip through all projected identifiers and update their last projected frame
402326
for _, binding := range boundProjections.Bindings {
403-
binding.LastProjection = expansionFrame
327+
binding.MaterializedBy(expansionFrame)
404328
}
405329

406330
traversalStep.Expansion.Value.Projection = boundProjections.Items
@@ -416,7 +340,7 @@ func (s *Translator) translateTraversalPatternPartWithExpansion(isFirstTraversal
416340
} else {
417341
// Zip through all projected identifiers and update their last projected frame
418342
for _, binding := range boundProjections.Bindings {
419-
binding.LastProjection = traversalStep.Frame
343+
binding.MaterializedBy(traversalStep.Frame)
420344
}
421345

422346
traversalStep.Projection = boundProjections.Items
@@ -433,7 +357,7 @@ func (s *Translator) translateNonTraversalPatternPart(part *PatternPart) error {
433357

434358
nextFrame.Export(part.NodeSelect.Binding.Identifier)
435359

436-
if constraint, err := consumeConstraintsFrom(nextFrame.Known(), s.treeTranslator.IdentifierConstraints); err != nil {
360+
if constraint, err := s.treeTranslator.IdentifierConstraints.ConsumeSet(nextFrame.Known()); err != nil {
437361
return err
438362
} else if err := RewriteFrameBindings(s.query.Scope, constraint.Expression); err != nil {
439363
return err
@@ -446,7 +370,7 @@ func (s *Translator) translateNonTraversalPatternPart(part *PatternPart) error {
446370
} else {
447371
// Zip through all projected identifiers and update their last projected frame
448372
for _, binding := range boundProjections.Bindings {
449-
binding.LastProjection = nextFrame
373+
binding.MaterializedBy(nextFrame)
450374
}
451375

452376
part.NodeSelect.Select.Projection = boundProjections.Items

0 commit comments

Comments
 (0)