diff --git a/test/test_mnn.py b/test/test_mnn.py index 6d667ab..f1aa8c3 100644 --- a/test/test_mnn.py +++ b/test/test_mnn.py @@ -233,7 +233,7 @@ def test_single_sample_idx(): ) G = build_graph(data, sample_idx=np.repeat(1, len(data))) G2 = build_graph(data) - assert (G.K - G2.K).nnz == 0 + np.testing.assert_array_equal(G.K, G2.K) def test_mnn_with_non_zero_indexed_sample_idx():