Skip to content

Commit 94ba4d1

Browse files
vnickolovs1ckFlorentinDadamnsch
committed
Fix AIOOB edge case in SparseLongArray
Cherry-pick #4037 Co-authored-by: Martin Junghanns <[email protected]> Co-authored-by: Florentin Dörre <[email protected]> Co-authored-by: Adam Schill Collberg <[email protected]>
1 parent 1f4016a commit 94ba4d1

File tree

9 files changed

+38
-19
lines changed

9 files changed

+38
-19
lines changed

alpha/alpha-proc/src/main/java/org/neo4j/gds/traverse/TraverseProc.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ protected Traverse build(
9898
// target node given; terminate if target is reached
9999
if (!configuration.targetNodes().isEmpty()) {
100100
List<Long> mappedTargets = configuration.targetNodes().stream()
101-
.map(graph::toMappedNodeId)
101+
.map(graph::safeToMappedNodeId)
102102
.collect(Collectors.toList());
103103
exitFunction = (s, t, w) -> mappedTargets.contains(t) ? Traverse.ExitPredicate.Result.BREAK : Traverse.ExitPredicate.Result.FOLLOW;
104104
aggregatorFunction = (s, t, w) -> .0;
@@ -113,7 +113,7 @@ protected Traverse build(
113113
}
114114

115115
validateStartNode(configuration.startNode(), graph);
116-
configuration.targetNodes().stream().forEach(neoId -> validateEndNode(neoId, graph));
116+
configuration.targetNodes().forEach(neoId -> validateEndNode(neoId, graph));
117117

118118
var mappedStartNodeId = graph.toMappedNodeId(configuration.startNode());
119119

alpha/alpha-proc/src/main/java/org/neo4j/gds/utils/InputNodeValidator.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,22 @@
2020
package org.neo4j.gds.utils;
2121

2222
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.api.IdMapping;
2324

2425
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
2526

2627
public final class InputNodeValidator {
2728

2829
public static void validateStartNode(long nodeId, Graph graph) throws IllegalArgumentException {
29-
validateNodeIsLoaded(nodeId, graph.toMappedNodeId(nodeId), "startNode");
30+
validateNodeIsLoaded(nodeId, graph.safeToMappedNodeId(nodeId), "startNode");
3031
}
3132

3233
public static void validateEndNode(long nodeId, Graph graph) throws IllegalArgumentException {
33-
validateNodeIsLoaded(nodeId, graph.toMappedNodeId(nodeId), "endNode");
34+
validateNodeIsLoaded(nodeId, graph.safeToMappedNodeId(nodeId), "endNode");
3435
}
3536

3637
private static void validateNodeIsLoaded(long nodeId, long mappedId, String nodeDescription) throws IllegalArgumentException {
37-
if (mappedId == -1) {
38+
if (mappedId == IdMapping.NOT_FOUND) {
3839
throw new IllegalArgumentException(formatWithLocale(
3940
"%s with id %d was not loaded",
4041
nodeDescription,

alpha/alpha-proc/src/main/java/org/neo4j/gds/walking/RandomWalkProc.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,15 @@ private IntStream idStream(Object start, Graph graph, int limit) {
154154
return cursor.nodeReference();
155155
});
156156
}
157-
return ids.map(graph::toMappedNodeId).mapToInt(Math::toIntExact).onClose(cursor::close);
157+
return ids.map(graph::safeToMappedNodeId).mapToInt(Math::toIntExact).onClose(cursor::close);
158158
} else if (start instanceof Collection) {
159159
return ((Collection<?>) start)
160160
.stream()
161161
.mapToLong(e -> ((Number) e).longValue())
162-
.map(graph::toMappedNodeId)
162+
.map(graph::safeToMappedNodeId)
163163
.mapToInt(Math::toIntExact);
164164
} else if (start instanceof Number) {
165-
return LongStream.of(((Number) start).longValue()).map(graph::toMappedNodeId).mapToInt(Math::toIntExact);
165+
return LongStream.of(((Number) start).longValue()).map(graph::safeToMappedNodeId).mapToInt(Math::toIntExact);
166166
} else {
167167
if (nodeCount < limit) {
168168
return IntStream.range(0, nodeCount).limit(limit);

core/src/main/java/org/neo4j/gds/api/IdMapping.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,27 @@ public interface IdMapping {
3030
*/
3131
long START_NODE_ID = 0;
3232

33+
/**
34+
* Defines the value for unmapped ids
35+
*/
36+
long NOT_FOUND = -1;
37+
3338
/**
3439
* Map original nodeId to inner nodeId
40+
*
41+
* @param nodeId must be smaller or equal to the id returned by {@link IdMapping#highestNeoId}
3542
*/
3643
long toMappedNodeId(long nodeId);
3744

45+
/**
46+
* Map original nodeId to inner nodeId
47+
*
48+
* Returns org.neo4j.gds.api.IdMapping#NOT_FOUND if the nodeId is not mapped.
49+
*/
50+
default long safeToMappedNodeId(long nodeId) {
51+
return highestNeoId() < nodeId ? NOT_FOUND : toMappedNodeId(nodeId);
52+
}
53+
3854
/**
3955
* Map inner nodeId back to original nodeId
4056
*/

core/src/main/java/org/neo4j/gds/api/NodeMapping.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626

2727
public interface NodeMapping extends IdMapping, NodeIterator, BatchNodeIterable {
2828

29-
long NOT_FOUND = -1;
30-
3129
Set<NodeLabel> nodeLabels(long nodeId);
3230

3331
void forEachNodeLabel(long nodeId, NodeLabelConsumer consumer);

core/src/main/java/org/neo4j/gds/core/utils/paged/HugeSparseLongArray.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
*/
2020
package org.neo4j.gds.core.utils.paged;
2121

22-
import org.neo4j.gds.api.NodeMapping;
22+
import org.neo4j.gds.api.IdMapping;
2323
import org.neo4j.gds.core.utils.mem.AllocationTracker;
2424
import org.neo4j.gds.core.utils.mem.MemoryRange;
2525
import org.neo4j.gds.core.utils.mem.MemoryUsage;
@@ -35,7 +35,7 @@
3535

3636
public final class HugeSparseLongArray {
3737

38-
private static final long NOT_FOUND = NodeMapping.NOT_FOUND;
38+
private static final long NOT_FOUND = IdMapping.NOT_FOUND;
3939

4040
private static final int PAGE_SHIFT = 12;
4141
private static final int PAGE_SIZE = 1 << PAGE_SHIFT;

core/src/main/java/org/neo4j/gds/core/utils/paged/SparseLongArray.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import com.carrotsearch.hppc.sorting.IndirectSort;
2323
import org.jetbrains.annotations.TestOnly;
24-
import org.neo4j.gds.api.NodeMapping;
24+
import org.neo4j.gds.api.IdMapping;
2525
import org.neo4j.gds.core.utils.ArrayLayout;
2626
import org.neo4j.gds.core.utils.AscendingLongComparator;
2727
import org.neo4j.gds.core.utils.BitUtil;
@@ -37,7 +37,7 @@
3737

3838
public final class SparseLongArray {
3939

40-
public static final long NOT_FOUND = NodeMapping.NOT_FOUND;
40+
public static final long NOT_FOUND = IdMapping.NOT_FOUND;
4141

4242
public static final int BLOCK_SIZE = 64;
4343
public static final int SUPER_BLOCK_SIZE = BLOCK_SIZE * Long.SIZE;
@@ -125,6 +125,10 @@ public long highestNeoId() {
125125
return highestNeoId;
126126
}
127127

128+
/**
129+
*
130+
* @param originalId must be smaller or equal to highestNeoId
131+
*/
128132
public long toMappedNodeId(long originalId) {
129133
var page = pageId(originalId);
130134
var indexInPage = indexInPage(originalId);

proc/common/src/main/java/org/neo4j/gds/GraphStoreValidation.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.neo4j.gds;
2121

2222
import org.neo4j.gds.api.GraphStore;
23-
import org.neo4j.gds.api.NodeMapping;
23+
import org.neo4j.gds.api.IdMapping;
2424
import org.neo4j.gds.config.AlgoBaseConfig;
2525
import org.neo4j.gds.config.ConfigurableSeedConfig;
2626
import org.neo4j.gds.config.FeaturePropertiesConfig;
@@ -185,7 +185,7 @@ private static void validateFeaturesProperties(GraphStore graphStore, FeaturePro
185185
private static void validateSourceNode(GraphStore graphStore, SourceNodeConfig config) {
186186
var sourceNodeId = config.sourceNode();
187187

188-
if (graphStore.nodes().toMappedNodeId(sourceNodeId) == NodeMapping.NOT_FOUND) {
188+
if (graphStore.nodes().safeToMappedNodeId(sourceNodeId) == IdMapping.NOT_FOUND) {
189189
throw new IllegalArgumentException(formatWithLocale(
190190
"Source node does not exist in the in-memory graph: `%d`",
191191
sourceNodeId
@@ -196,7 +196,7 @@ private static void validateSourceNode(GraphStore graphStore, SourceNodeConfig c
196196
private static void validateSourceNodes(GraphStore graphStore, SourceNodesConfig config) {
197197
var nodeMapping = graphStore.nodes();
198198
var missingNodes = config.sourceNodes().stream()
199-
.filter(nodeId -> nodeMapping.toMappedNodeId(nodeId) == NodeMapping.NOT_FOUND)
199+
.filter(nodeId -> nodeMapping.safeToMappedNodeId(nodeId) == IdMapping.NOT_FOUND)
200200
.map(Object::toString)
201201
.collect(Collectors.toList());
202202

@@ -211,7 +211,7 @@ private static void validateSourceNodes(GraphStore graphStore, SourceNodesConfig
211211
private static void validateTargetNode(GraphStore graphStore, TargetNodeConfig config) {
212212
var targetNodeId = config.targetNode();
213213

214-
if (graphStore.nodes().toMappedNodeId(targetNodeId) == NodeMapping.NOT_FOUND) {
214+
if (graphStore.nodes().safeToMappedNodeId(targetNodeId) == IdMapping.NOT_FOUND) {
215215
throw new IllegalArgumentException(formatWithLocale(
216216
"Target node does not exist in the in-memory graph: `%d`",
217217
targetNodeId

proc/common/src/main/java/org/neo4j/gds/functions/NodePropertyFunc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public Object nodeProperty(
8686
}
8787
}
8888

89-
long internalId = graphStore.nodes().toMappedNodeId(nodeId.longValue());
89+
long internalId = graphStore.nodes().safeToMappedNodeId(nodeId.longValue());
9090

9191
if (internalId == -1) {
9292
throw new IllegalArgumentException(formatWithLocale("Node id %d does not exist.", nodeId.longValue()));

0 commit comments

Comments
 (0)