From 590865dcfa925ae45cb4dce2856cf05bd9849734 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 27 Sep 2024 15:42:39 -0700 Subject: [PATCH] fix test_nearest_neighbors_rbc test for haversine distance --- python/cuml/cuml/tests/test_nearest_neighbors.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/tests/test_nearest_neighbors.py b/python/cuml/cuml/tests/test_nearest_neighbors.py index 9f5764a7e9..aa612b7763 100644 --- a/python/cuml/cuml/tests/test_nearest_neighbors.py +++ b/python/cuml/cuml/tests/test_nearest_neighbors.py @@ -573,12 +573,14 @@ def test_nearest_neighbors_rbc(distance_dims, n_neighbors, nrows): X[:query_rows, :], n_neighbors=n_neighbors ) - assert len(brute_d[brute_d != rbc_d]) == 0 + cp.testing.assert_allclose(brute_d, rbc_d, atol=1e-3, rtol=1e-3) # All the distances match so allow a couple mismatched indices # through from potential non-determinism in exact matching # distances - assert len(brute_i[brute_i != rbc_i]) <= 3 + assert ( + len(brute_i[brute_i != rbc_i]) <= 3 if distance != "haversine" else 10 + ) @pytest.mark.parametrize("metric", valid_metrics_sparse())