Skip to content

Commit 47e7545

Browse files
jjaderbergDarthMaxknutwalker
committed
Extract constants and handle zero union
Co-Authored-By: Max Kießling <[email protected]> Co-Authored-By: Paul Horn <[email protected]>
1 parent bc05d0c commit 47e7545

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

alpha/alpha-proc/src/main/java/org/neo4j/graphalgo/similarity/SimilaritiesFunc.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,21 @@
3131
import org.neo4j.procedure.UserFunction;
3232
import org.neo4j.values.storable.Values;
3333

34-
import java.util.Collections;
3534
import java.util.Comparator;
3635
import java.util.HashSet;
3736
import java.util.List;
3837
import java.util.Map;
38+
import java.util.function.Predicate;
3939

4040
import static org.neo4j.graphalgo.impl.similarity.SimilarityVectorAggregator.CATEGORY_KEY;
4141
import static org.neo4j.graphalgo.impl.similarity.SimilarityVectorAggregator.WEIGHT_KEY;
4242
import static org.neo4j.graphalgo.impl.utils.NumberUtils.getDoubleValue;
4343

4444
public class SimilaritiesFunc {
4545

46+
public static final Predicate<Number> IS_NULL = Predicate.isEqual(null);
47+
public static final Comparator<Number> NUMBER_COMPARATOR = new NumberComparator();
48+
4649
@UserFunction("gds.alpha.similarity.jaccard")
4750
@Description("RETURN gds.alpha.similarity.jaccard(vector1, vector2) - Given two collection vectors, calculate Jaccard similarity")
4851
public double jaccardSimilarity(@Name("vector1") List<Number> vector1, @Name("vector2") List<Number> vector2) {
@@ -190,12 +193,10 @@ public double overlapSimilarity(@Name("vector1") List<Number> vector1, @Name("ve
190193
* @return The jaccard score, the intersection divided by the union of the input lists.
191194
*/
192195
private double jaccard(List<Number> vector1, List<Number> vector2) {
193-
Comparator<Number> numberComparator = new NumberComparator();
194-
List<Number> nullList = Collections.singletonList(null);
195-
vector1.removeAll(nullList);
196-
vector2.removeAll(nullList);
197-
vector1.sort(numberComparator);
198-
vector2.sort(numberComparator);
196+
vector1.removeIf(IS_NULL);
197+
vector2.removeIf(IS_NULL);
198+
vector1.sort(NUMBER_COMPARATOR);
199+
vector2.sort(NUMBER_COMPARATOR);
199200

200201
int index1 = 0;
201202
int index2 = 0;
@@ -206,7 +207,7 @@ private double jaccard(List<Number> vector1, List<Number> vector2) {
206207
while (index1 < vector1.size() && index2 < vector2.size()) {
207208
Number val1 = vector1.get(index1);
208209
Number val2 = vector2.get(index2);
209-
int compare = numberComparator.compare(val1, val2);
210+
int compare = NUMBER_COMPARATOR.compare(val1, val2);
210211

211212
if (compare == 0) {
212213
intersection++;
@@ -225,7 +226,7 @@ private double jaccard(List<Number> vector1, List<Number> vector2) {
225226
// the remainder, if any, is never shared so add to the union
226227
union += (vector1.size() - index1) + (vector2.size() - index2);
227228

228-
return intersection / union;
229+
return union == 0 ? 1 : intersection / union;
229230
}
230231

231232
static class NumberComparator implements Comparator<Number> {

alpha/alpha-proc/src/test/java/org/neo4j/graphalgo/similarity/SimilaritiesFuncTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,14 @@ static Stream<Arguments> listsWithDuplicates() {
283283
new ArrayList<Number>(Arrays.asList(1, 2, 2)),
284284
new ArrayList<Number>(Arrays.asList(2, 2, 3)),
285285
2/4D
286+
), Arguments.of(
287+
new ArrayList<Number>(Arrays.asList(null, 2, 2)),
288+
new ArrayList<Number>(Arrays.asList(2, 2, null, null)),
289+
1D
290+
), Arguments.of(
291+
new ArrayList<Number>(),
292+
new ArrayList<Number>(),
293+
1D
286294
)
287295
);
288296
}

0 commit comments

Comments
 (0)