Skip to content

Commit a648e4d

Browse files
committed
FastRP stream proc returns doubles
1 parent 724c44d commit a648e4d

File tree

3 files changed

+23
-30
lines changed

3 files changed

+23
-30
lines changed

doc/asciidoc/algorithms/fastrp/fastrp.adoc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,13 @@ YIELD nodeId, embedding
392392
.Results
393393
|===
394394
| nodeId | embedding
395-
| 0 | [0.47740015, -0.66024077, -0.36686954, -1.7089111]
396-
| 1 | [0.798936, -0.49187183, -0.41281945, -1.6314402]
397-
| 2 | [0.47275326, -0.4958715, -0.33404678, -1.7141895]
398-
| 3 | [0.8290714, -0.32604763, -0.3317275, -1.437053]
399-
| 4 | [0.7749264, -0.47732472, 0.067513466, -1.5248264]
400-
| 5 | [0.8408374, -0.37151477, 0.12121138, -1.5309601]
401-
| 6 | [1.0, -0.11054421, -0.36979336, -0.92251444]
395+
| 0 | [0.47740015387535095, -0.6602407693862915, -0.3668695390224457, -1.7089110612869263]
396+
| 1 | [0.7989360094070435, -0.49187183380126953, -0.41281944513320923, -1.6314401626586914]
397+
| 2 | [0.47275325655937195, -0.49587151408195496, -0.33404678106307983, -1.7141895294189453]
398+
| 3 | [0.8290714025497437, -0.3260476291179657, -0.3317275047302246, -1.4370529651641846]
399+
| 4 | [0.7749264240264893, -0.4773247241973877, 0.06751346588134766, -1.5248264074325562]
400+
| 5 | [0.8408374190330505, -0.37151476740837097, 0.12121137976646423, -1.5309600830078125]
401+
| 6 | [1.0, -0.11054421216249466, -0.3697933554649353, -0.9225144386291504]
402402
|===
403403
--
404404

@@ -532,13 +532,13 @@ YIELD nodeId, embedding
532532
.Results
533533
|===
534534
| nodeId | embedding
535-
| 0 | [0.10945529, -0.5032674, 0.46467367, -1.7539861]
536-
| 1 | [0.3639601, -0.39210302, 0.4627158, -1.8294234]
537-
| 2 | [0.123140976, -0.3213111, 0.40100977, -1.4710553]
538-
| 3 | [0.30704635, -0.24944797, 0.39478925, -1.34637]
539-
| 4 | [0.23112302, -0.30148715, 0.5848317, -1.2822187]
540-
| 5 | [0.14497177, -0.23121375, 0.5552002, -1.2605634]
541-
| 6 | [0.51391846, -0.079543315, 0.3690345, -0.91763735]
535+
| 0 | [0.10945528745651245, -0.5032674074172974, 0.4646736681461334, -1.753986120223999]
536+
| 1 | [0.3639600872993469, -0.39210301637649536, 0.4627158045768738, -1.829423427581787]
537+
| 2 | [0.12314097583293915, -0.3213110864162445, 0.40100976824760437, -1.471055269241333]
538+
| 3 | [0.3070463538169861, -0.24944797158241272, 0.3947892487049103, -1.346369981765747]
539+
| 4 | [0.23112301528453827, -0.30148714780807495, 0.584831714630127, -1.2822186946868896]
540+
| 5 | [0.14497177302837372, -0.2312137484550476, 0.5552002191543579, -1.2605633735656738]
541+
| 6 | [0.5139184594154358, -0.07954331487417221, 0.3690344989299774, -0.9176373481750488]
542542
|===
543543
--
544544

proc/embeddings/src/main/java/org/neo4j/gds/embeddings/fastrp/FastRPStreamProc.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,14 @@ protected AlgorithmFactory<FastRP, FastRPStreamConfig> algorithmFactory() {
9292
@SuppressWarnings("unused")
9393
public static final class StreamResult {
9494
public final long nodeId;
95-
public final List<Number> embedding;
95+
public final List<Double> embedding;
9696

9797
StreamResult(long nodeId, float[] embedding) {
9898
this.nodeId = nodeId;
99-
this.embedding = arrayToList(embedding);
100-
}
101-
102-
static List<Number> arrayToList(float[] values) {
103-
var floats = new ArrayList<Number>(values.length);
104-
for (float value : values) {
105-
floats.add(value);
99+
this.embedding = new ArrayList<>(embedding.length);
100+
for (var f : embedding) {
101+
this.embedding.add((double) f);
106102
}
107-
return floats;
108103
}
109104
}
110105
}

proc/embeddings/src/test/java/org/neo4j/gds/embeddings/fastrp/FastRPStreamProcTest.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void shouldComputeNonZeroEmbeddings(List<Float> weights) {
6767
String query = queryBuilder.yields();
6868

6969
runQueryWithRowConsumer(query, row -> {
70-
List<Float> embeddings = (List<Float>) row.get("embedding");
70+
List<Double> embeddings = (List<Double>) row.get("embedding");
7171
assertEquals(embeddingDimension, embeddings.size());
7272
assertFalse(embeddings.stream().allMatch(value -> value == 0.0));
7373
});
@@ -76,7 +76,7 @@ void shouldComputeNonZeroEmbeddings(List<Float> weights) {
7676
@Test
7777
void shouldComputeNonZeroEmbeddingsWhenFirstWeightIsZero() {
7878
int embeddingDimension = 128;
79-
List<Float> weights = List.of(0.0f, 1.0f, 2.0f, 4.0f);
79+
var weights = List.of(0.0D, 1.0D, 2.0D, 4.0D);
8080
GdsCypher.ParametersBuildStage queryBuilder = GdsCypher.call()
8181
.explicitCreation(FASTRP_GRAPH)
8282
.algo("fastRP")
@@ -88,7 +88,7 @@ void shouldComputeNonZeroEmbeddingsWhenFirstWeightIsZero() {
8888
String query = queryBuilder.yields();
8989

9090
runQueryWithRowConsumer(query, row -> {
91-
List<Float> embeddings = (List<Float>) row.get("embedding");
91+
var embeddings = (List<Double>) row.get("embedding");
9292
assertFalse(embeddings.stream().allMatch(value -> value == 0.0));
9393
});
9494
}
@@ -107,10 +107,8 @@ void shouldComputeWithWeight() {
107107
.addParameter("relationshipWeightProperty", "weight")
108108
.yields();
109109

110-
List<List<Float>> embeddings = new ArrayList<>(3);
111-
runQueryWithRowConsumer(query, row -> {
112-
embeddings.add((List<Float>) row.get("embedding"));
113-
});
110+
List<List<Double>> embeddings = new ArrayList<>(3);
111+
runQueryWithRowConsumer(query, row -> embeddings.add((List<Double>) row.get("embedding")));
114112

115113
for (int i = 0; i < 128; i++) {
116114
assertEquals(embeddings.get(1).get(i), embeddings.get(2).get(i) * 2);

0 commit comments

Comments
 (0)