31
31
import org .neo4j .procedure .UserFunction ;
32
32
import org .neo4j .values .storable .Values ;
33
33
34
- import java .util .Collections ;
35
34
import java .util .Comparator ;
36
35
import java .util .HashSet ;
37
36
import java .util .List ;
38
37
import java .util .Map ;
38
+ import java .util .function .Predicate ;
39
39
40
40
import static org .neo4j .graphalgo .impl .similarity .SimilarityVectorAggregator .CATEGORY_KEY ;
41
41
import static org .neo4j .graphalgo .impl .similarity .SimilarityVectorAggregator .WEIGHT_KEY ;
42
42
import static org .neo4j .graphalgo .impl .utils .NumberUtils .getDoubleValue ;
43
43
44
44
public class SimilaritiesFunc {
45
45
46
+ public static final Predicate <Number > IS_NULL = Predicate .isEqual (null );
47
+ public static final Comparator <Number > NUMBER_COMPARATOR = new NumberComparator ();
48
+
46
49
@ UserFunction ("gds.alpha.similarity.jaccard" )
47
50
@ Description ("RETURN gds.alpha.similarity.jaccard(vector1, vector2) - Given two collection vectors, calculate Jaccard similarity" )
48
51
public double jaccardSimilarity (@ Name ("vector1" ) List <Number > vector1 , @ Name ("vector2" ) List <Number > vector2 ) {
@@ -190,12 +193,10 @@ public double overlapSimilarity(@Name("vector1") List<Number> vector1, @Name("ve
190
193
* @return The jaccard score, the intersection divided by the union of the input lists.
191
194
*/
192
195
private double jaccard (List <Number > vector1 , List <Number > vector2 ) {
193
- Comparator <Number > numberComparator = new NumberComparator ();
194
- List <Number > nullList = Collections .singletonList (null );
195
- vector1 .removeAll (nullList );
196
- vector2 .removeAll (nullList );
197
- vector1 .sort (numberComparator );
198
- vector2 .sort (numberComparator );
196
+ vector1 .removeIf (IS_NULL );
197
+ vector2 .removeIf (IS_NULL );
198
+ vector1 .sort (NUMBER_COMPARATOR );
199
+ vector2 .sort (NUMBER_COMPARATOR );
199
200
200
201
int index1 = 0 ;
201
202
int index2 = 0 ;
@@ -206,7 +207,7 @@ private double jaccard(List<Number> vector1, List<Number> vector2) {
206
207
while (index1 < vector1 .size () && index2 < vector2 .size ()) {
207
208
Number val1 = vector1 .get (index1 );
208
209
Number val2 = vector2 .get (index2 );
209
- int compare = numberComparator .compare (val1 , val2 );
210
+ int compare = NUMBER_COMPARATOR .compare (val1 , val2 );
210
211
211
212
if (compare == 0 ) {
212
213
intersection ++;
@@ -225,7 +226,7 @@ private double jaccard(List<Number> vector1, List<Number> vector2) {
225
226
// the remainder, if any, is never shared so add to the union
226
227
union += (vector1 .size () - index1 ) + (vector2 .size () - index2 );
227
228
228
- return intersection / union ;
229
+ return union == 0 ? 1 : intersection / union ;
229
230
}
230
231
231
232
static class NumberComparator implements Comparator <Number > {
0 commit comments