Skip to content

Commit acf8462

Browse files
authored
Merge pull request #4242 from IoannisPanagiotas/bughuntingknn
Examining the fastrp+knn bug on filtered graphs.
2 parents 56be31e + e4b80c9 commit acf8462

File tree

9 files changed

+161
-5
lines changed

9 files changed

+161
-5
lines changed

algo/src/main/java/org/neo4j/gds/similarity/SimilarityGraphBuilder.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,19 @@ public static MemoryEstimation memoryEstimation(int topK, int topN) {
7171
}
7272

7373
private final NodeMapping nodeMapping;
74+
private final NodeMapping rootNodeMapping;
7475
private final int concurrency;
7576
private final ExecutorService executorService;
7677
private final AllocationTracker allocationTracker;
7778

7879
public SimilarityGraphBuilder(
7980
NodeMapping nodeMapping,
81+
NodeMapping rootNodeMapping,
8082
int concurrency,
8183
ExecutorService executorService,
8284
AllocationTracker allocationTracker
8385
) {
86+
this.rootNodeMapping = rootNodeMapping;
8487
this.concurrency = concurrency;
8588
this.executorService = executorService;
8689
this.allocationTracker = allocationTracker;
@@ -100,7 +103,7 @@ public Graph build(Stream<SimilarityResult> stream) {
100103
ParallelUtil.parallelStreamConsume(stream, concurrency, relationshipsBuilder::addFromInternal);
101104

102105
return GraphFactory.create(
103-
nodeMapping,
106+
rootNodeMapping,
104107
relationshipsBuilder.build(),
105108
allocationTracker
106109
);

algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ public SimilarityGraphResult computeToGraph() {
142142
} else {
143143
Stream<SimilarityResult> similarities = computeToStream();
144144
similarityGraph = new SimilarityGraphBuilder(
145+
graph,
145146
graph,
146147
config.concurrency(),
147148
executorService,

algo/src/test/java/org/neo4j/gds/similarity/nodesim/SimilarityGraphBuilderTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
package org.neo4j.gds.similarity.nodesim;
2121

2222
import org.junit.jupiter.api.Test;
23-
import org.neo4j.gds.similarity.SimilarityGraphBuilder;
24-
import org.neo4j.gds.similarity.SimilarityResult;
2523
import org.neo4j.gds.api.Graph;
2624
import org.neo4j.gds.core.concurrency.Pools;
2725
import org.neo4j.gds.core.huge.HugeGraph;
@@ -31,6 +29,8 @@
3129
import org.neo4j.gds.extension.GdlGraph;
3230
import org.neo4j.gds.extension.Inject;
3331
import org.neo4j.gds.extension.TestGraph;
32+
import org.neo4j.gds.similarity.SimilarityGraphBuilder;
33+
import org.neo4j.gds.similarity.SimilarityResult;
3434

3535
import java.util.stream.Stream;
3636

@@ -68,6 +68,7 @@ void testConstructionFromHugeGraph() {
6868
assertEquals(HugeGraph.class, unlabelledGraph.innerGraph().getClass());
6969

7070
SimilarityGraphBuilder similarityGraphBuilder = new SimilarityGraphBuilder(
71+
unlabelledGraph,
7172
unlabelledGraph,
7273
1,
7374
Pools.DEFAULT,
@@ -91,6 +92,7 @@ void testConstructionFromUnionGraph() {
9192
assertEquals(UnionGraph.class, graph.innerGraph().getClass());
9293

9394
SimilarityGraphBuilder similarityGraphBuilder = new SimilarityGraphBuilder(
95+
graph,
9496
graph,
9597
1,
9698
Pools.DEFAULT,

proc/similarity/src/main/java/org/neo4j/gds/similarity/knn/KnnMutateProc.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ protected Stream<SimilarityMutateResult> mutate(ComputationResult<Knn, Knn.Resul
118118
try (ProgressTimer ignored = ProgressTimer.start(mutateMillis::addAndGet)) {
119119
similarityGraphResult = computeToGraph(
120120
computationResult.graph(),
121+
computationResult.graphStore(),
121122
algorithm.nodeCount(),
122123
config.concurrency(),
123124
Objects.requireNonNull(computationResult.result()),

proc/similarity/src/main/java/org/neo4j/gds/similarity/knn/KnnStatsProc.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ public Stream<SimilarityStatsResult> stats(AlgoBaseProc.ComputationResult<Knn, K
115115
try (ProgressTimer ignored = resultBuilder.timePostProcessing()) {
116116
SimilarityGraphResult similarityGraphResult = computeToGraph(
117117
computationResult.graph(),
118+
computationResult.graphStore(),
118119
algorithm.nodeCount(),
119120
config.concurrency(),
120121
result,

proc/similarity/src/main/java/org/neo4j/gds/similarity/knn/KnnWriteProc.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.neo4j.gds.AlgorithmFactory;
2323
import org.neo4j.gds.api.Graph;
24+
import org.neo4j.gds.api.GraphStore;
2425
import org.neo4j.gds.config.GraphCreateConfig;
2526
import org.neo4j.gds.core.CypherMapWrapper;
2627
import org.neo4j.gds.results.MemoryEstimateResult;
@@ -88,6 +89,7 @@ protected SimilarityGraphResult similarityGraphResult(ComputationResult<Knn, Knn
8889
KnnWriteConfig config = computationResult.config();
8990
return computeToGraph(
9091
computationResult.graph(),
92+
computationResult.graphStore(),
9193
algorithm.nodeCount(),
9294
config.concurrency(),
9395
Objects.requireNonNull(computationResult.result()),
@@ -97,13 +99,15 @@ protected SimilarityGraphResult similarityGraphResult(ComputationResult<Knn, Knn
9799

98100
static SimilarityGraphResult computeToGraph(
99101
Graph graph,
102+
GraphStore graphStore,
100103
long nodeCount,
101104
int concurrency,
102105
Knn.Result result,
103106
KnnContext context
104107
) {
105108
Graph similarityGraph = new SimilarityGraphBuilder(
106109
graph,
110+
graphStore.nodes(),
107111
concurrency,
108112
context.executor(),
109113
context.allocationTracker()

proc/similarity/src/test/java/org/neo4j/gds/similarity/knn/KnnMutateProcTest.java

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,21 @@
2828
import org.neo4j.gds.Orientation;
2929
import org.neo4j.gds.StoreLoaderBuilder;
3030
import org.neo4j.gds.api.DefaultValue;
31+
import org.neo4j.gds.api.Graph;
3132
import org.neo4j.gds.api.nodeproperties.ValueType;
3233
import org.neo4j.gds.core.CypherMapWrapper;
34+
import org.neo4j.gds.core.loading.GraphStoreCatalog;
3335

36+
import java.util.List;
3437
import java.util.Map;
3538
import java.util.Optional;
3639

3740
import static org.hamcrest.CoreMatchers.equalTo;
3841
import static org.hamcrest.MatcherAssert.assertThat;
3942
import static org.hamcrest.Matchers.lessThan;
4043
import static org.junit.jupiter.api.Assertions.assertEquals;
44+
import static org.neo4j.gds.TestSupport.assertGraphEquals;
45+
import static org.neo4j.gds.TestSupport.fromGdl;
4146

4247
class KnnMutateProcTest extends KnnProcTest<KnnMutateConfig>
4348
implements MutateRelationshipWithPropertyTest<Knn, KnnMutateConfig, Knn.Result> {
@@ -149,7 +154,7 @@ void shouldMutateResults() {
149154
@Override
150155
@Test
151156
@Disabled("This test does not work for KNN")
152-
public void testGraphMutationOnFilteredGraph() { }
157+
public void testGraphMutationOnFilteredGraph() {}
153158

154159
@Test
155160
void shouldMutateUniqueRelationships() {
@@ -200,4 +205,50 @@ public void setupStoreLoader(StoreLoaderBuilder storeLoaderBuilder, Map<String,
200205
);
201206
}
202207
}
208+
209+
@Test
210+
void shouldMutateWithFilteredNodes() {
211+
String nodeCreateQuery =
212+
"CREATE " +
213+
" (alice:Person {age: 24})" +
214+
" ,(carol:Person {age: 24})" +
215+
" ,(eve:Person {age: 67})" +
216+
" ,(dave:Foo {age: 48})" +
217+
" ,(bob:Foo {age: 48})";
218+
219+
runQuery(nodeCreateQuery);
220+
221+
String createQuery = GdsCypher.call()
222+
.withNodeLabel("Person")
223+
.withNodeLabel("Foo")
224+
.withNodeProperty("age")
225+
.withAnyRelationshipType()
226+
.graphCreate("graph")
227+
.yields();
228+
runQuery(createQuery);
229+
230+
String relationshipType = "SIMILAR";
231+
String relationshipProperty = "score";
232+
233+
String algoQuery = GdsCypher.call()
234+
.explicitCreation("graph")
235+
.algo("gds.beta.knn")
236+
.mutateMode()
237+
.addParameter("nodeLabels", List.of("Foo"))
238+
.addParameter("nodeWeightProperty", "age")
239+
.addParameter("mutateRelationshipType", relationshipType)
240+
.addParameter("mutateProperty", relationshipProperty).yields();
241+
runQuery(algoQuery);
242+
243+
Graph mutatedGraph = GraphStoreCatalog.get(getUsername(), db.databaseId(), "graph").graphStore().getUnion();
244+
245+
assertGraphEquals(
246+
fromGdl(
247+
nodeCreateQuery +
248+
"(dave)-[{score: 1.0}]->(bob)" +
249+
"(bob)-[{score: 1.0}]->(dave)"
250+
),
251+
mutatedGraph
252+
);
253+
}
203254
}

proc/similarity/src/test/java/org/neo4j/gds/similarity/knn/KnnStreamProcTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.junit.jupiter.api.Test;
2323
import org.neo4j.gds.AlgoBaseProc;
24+
import org.neo4j.gds.GdsCypher;
2425
import org.neo4j.gds.core.CypherMapWrapper;
2526
import org.neo4j.gds.similarity.SimilarityResult;
2627

@@ -62,4 +63,38 @@ void shouldStreamResults() {
6263
Map.of("node1", 2L, "node2", 1L, "similarity", 0.25)
6364
));
6465
}
66+
67+
@Test
68+
void shouldStreamWithFilteredNodes() {
69+
String nodeCreateQuery =
70+
"CREATE " +
71+
" (alice:Person {age: 24})" +
72+
" ,(carol:Person {age: 24})" +
73+
" ,(eve:Person {age: 67})" +
74+
" ,(dave:Foo {age: 48})" +
75+
" ,(bob:Foo {age: 48})";
76+
77+
runQuery(nodeCreateQuery);
78+
79+
String createQuery = GdsCypher.call()
80+
.withNodeLabel("Person")
81+
.withNodeLabel("Foo")
82+
.withNodeProperty("age")
83+
.withAnyRelationshipType()
84+
.graphCreate("graph")
85+
.yields();
86+
runQuery(createQuery);
87+
88+
String algoQuery = GdsCypher.call()
89+
.explicitCreation("graph")
90+
.algo("gds.beta.knn")
91+
.streamMode()
92+
.addParameter("nodeLabels", List.of("Foo"))
93+
.addParameter("nodeWeightProperty", "age")
94+
.yields("node1", "node2", "similarity");
95+
assertCypherResult(algoQuery, List.of(
96+
Map.of("node1", 6L, "node2", 7L, "similarity", 1.0),
97+
Map.of("node1", 7L, "node2", 6L, "similarity", 1.0)
98+
));
99+
}
65100
}

proc/similarity/src/test/java/org/neo4j/gds/similarity/knn/KnnWriteProcTest.java

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import org.neo4j.gds.StoreLoaderBuilder;
3131
import org.neo4j.gds.WriteRelationshipWithPropertyTest;
3232
import org.neo4j.gds.api.DefaultValue;
33+
import org.neo4j.gds.api.Graph;
34+
import org.neo4j.gds.core.Aggregation;
3335
import org.neo4j.gds.core.CypherMapWrapper;
3436
import org.neo4j.gds.core.loading.GraphStoreCatalog;
3537
import org.neo4j.gds.core.utils.progress.GlobalTaskStore;
@@ -38,6 +40,7 @@
3840
import org.neo4j.gds.test.config.ConcurrencyConfigProcTest;
3941

4042
import java.util.Collection;
43+
import java.util.List;
4144
import java.util.Map;
4245
import java.util.Optional;
4346
import java.util.stream.Stream;
@@ -202,7 +205,10 @@ void testProgressTracking() {
202205

203206
var taskStore = new GlobalTaskStore();
204207

205-
pathProc.taskRegistryFactory = () -> new NonReleasingTaskRegistry(new TaskRegistry(getUsername(), taskStore));
208+
pathProc.taskRegistryFactory = () -> new NonReleasingTaskRegistry(new TaskRegistry(
209+
getUsername(),
210+
taskStore
211+
));
206212

207213
pathProc.write("undirectedGraph", createMinimalConfig(CypherMapWrapper.empty()).toMap());
208214

@@ -213,6 +219,58 @@ void testProgressTracking() {
213219
});
214220
}
215221

222+
@Test
223+
void shouldWriteWithFilteredNodes() {
224+
runQuery("CREATE (alice:Person {name: 'Alice', age: 24})" +
225+
"CREATE (carol:Person {name: 'Carol', age: 24})" +
226+
"CREATE (eve:Person {name: 'Eve', age: 67})" +
227+
"CREATE (dave:Foo {name: 'Dave', age: 48})" +
228+
"CREATE (bob:Foo {name: 'Bob', age: 48})");
229+
230+
String createQuery = GdsCypher.call()
231+
.withNodeLabel("Person")
232+
.withNodeLabel("Foo")
233+
.withNodeProperty("age")
234+
.withAnyRelationshipType()
235+
.graphCreate("graph")
236+
.yields();
237+
runQuery(createQuery);
238+
239+
String relationshipType = "SIMILAR";
240+
String relationshipProperty = "score";
241+
242+
String algoQuery = GdsCypher.call()
243+
.explicitCreation("graph")
244+
.algo("gds.beta.knn")
245+
.writeMode()
246+
.addParameter("nodeLabels", List.of("Foo"))
247+
.addParameter("nodeWeightProperty", "age")
248+
.addParameter("writeRelationshipType", relationshipType)
249+
.addParameter("writeProperty", relationshipProperty).yields();
250+
runQuery(algoQuery);
251+
252+
Graph knnGraph = new StoreLoaderBuilder()
253+
.api(db)
254+
.addNodeLabel("Person")
255+
.addNodeLabel("Foo")
256+
.addRelationshipType(relationshipType)
257+
.addRelationshipProperty(relationshipProperty, relationshipProperty, DefaultValue.DEFAULT, Aggregation.NONE)
258+
.build()
259+
.graph();
260+
261+
assertGraphEquals(
262+
fromGdl("(alice:Person)" +
263+
"(carol:Person)" +
264+
"(eve:Person)" +
265+
"(dave:Foo)" +
266+
"(bob:Foo)" +
267+
"(dave)-[{score: 1.0}]->(bob)" +
268+
"(bob)-[{score: 1.0}]->(dave)"
269+
),
270+
knnGraph
271+
);
272+
}
273+
216274
@Override
217275
public String writeRelationshipType() {
218276
return "KNN_REL";

0 commit comments

Comments
 (0)