From 7ffa8ef86e64c08bb5a2e14dd0f0e85795b359fd Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Sat, 15 Jun 2024 14:51:48 +0200 Subject: [PATCH] Add tests for quick shift and remove unused functions --- src/skmatter/clustering/_quick_shift.py | 10 +++--- src/skmatter/metrics/_pairwise.py | 46 ++----------------------- tests/test_clustering.py | 25 ++++++++++---- 3 files changed, 26 insertions(+), 55 deletions(-) diff --git a/src/skmatter/clustering/_quick_shift.py b/src/skmatter/clustering/_quick_shift.py index 27d41a122..92fc4eb6e 100644 --- a/src/skmatter/clustering/_quick_shift.py +++ b/src/skmatter/clustering/_quick_shift.py @@ -107,11 +107,11 @@ def __init__( ): if (dist_cutoff_sq is None) and (gabriel_shell is None): raise ValueError("Either dist_cutoff or gabriel_depth must be set.") - self.dist_cutoff2 = dist_cutoff_sq + self.dist_cutoff_sq = dist_cutoff_sq self.gabriel_shell = gabriel_shell self.scale = scale - if self.dist_cutoff2 is not None: - self.dist_cutoff2 *= self.scale**2 + if self.dist_cutoff_sq is not None: + self.dist_cutoff_sq *= self.scale**2 self.metric_params = ( metric_params if metric_params is not None else {"cell_length": None} ) @@ -142,7 +142,7 @@ def fit(self, X, y=None, samples_weight=None): ) dist_matrix = self.metric(X, X) np.fill_diagonal(dist_matrix, np.inf) - if self.dist_cutoff2 is None: + if self.dist_cutoff_sq is None: gabrial = _get_gabriel_graph(dist_matrix) idmindist = np.argmin(dist_matrix, axis=1) idxroot = np.full(dist_matrix.shape[0], -1, dtype=int) @@ -163,7 +163,7 @@ def fit(self, X, y=None, samples_weight=None): idmindist[current], samples_weight, dist_matrix, - self.dist_cutoff2[current], + self.dist_cutoff_sq[current], ) if idxroot[idxroot[current]] != -1: # Found a path to a root diff --git a/src/skmatter/metrics/_pairwise.py b/src/skmatter/metrics/_pairwise.py index 3af73ab23..4455a7b3a 100644 --- a/src/skmatter/metrics/_pairwise.py +++ b/src/skmatter/metrics/_pairwise.py @@ -1,20 +1,14 @@ from typing import Union import numpy as np -from sklearn.metrics.pairwise import ( - _euclidean_distances, - check_array, - check_pairwise_arrays, -) +from sklearn.metrics.pairwise import _euclidean_distances, check_pairwise_arrays def periodic_pairwise_euclidean_distances( X, Y=None, *, - Y_norm_squared=None, squared=False, - X_norm_squared=None, cell_length=None, ): r""" @@ -48,16 +42,6 @@ def periodic_pairwise_euclidean_distances( default=None An array where each row is a sample and each column is a component. If `None`, method uses `Y=X`. - Y_norm_squared : array-like of shape (n_samples_Y,) or (n_samples_Y, 1) \ - or (1, n_samples_Y), default=None - Pre-computed dot-products of vectors in Y (e.g., `(Y**2).sum(axis=1)`) - May be ignored in some cases, see the note below. - squared : bool, default=False - Return squared Euclidean distances. - X_norm_squared : array-like of shape (n_samples_X,) or (n_samples_X, 1) \ - or (1, n_samples_X), default=None - Pre-computed dot-products of vectors in X (e.g., `(X**2).sum(axis=1)`) - May be ignored in some cases, see the note below. cell_length : array-like of shape (n_components,), default=None The side length of rectangular cell used for periodic boundary conditions. `None` for non-periodic boundary conditions. @@ -90,34 +74,8 @@ def periodic_pairwise_euclidean_distances( _check_dimension(X, cell_length) X, Y = check_pairwise_arrays(X, Y) - if X_norm_squared is not None: - X_norm_squared = check_array(X_norm_squared, ensure_2d=False) - original_shape = X_norm_squared.shape - if X_norm_squared.shape == (X.shape[0],): - X_norm_squared = X_norm_squared.reshape(-1, 1) - if X_norm_squared.shape == (1, X.shape[0]): - X_norm_squared = X_norm_squared.T - if X_norm_squared.shape != (X.shape[0], 1): - raise ValueError( - f"Incompatible dimensions for X of shape {X.shape} and " - f"X_norm_squared of shape {original_shape}." - ) - - if Y_norm_squared is not None: - Y_norm_squared = check_array(Y_norm_squared, ensure_2d=False) - original_shape = Y_norm_squared.shape - if Y_norm_squared.shape == (Y.shape[0],): - Y_norm_squared = Y_norm_squared.reshape(1, -1) - if Y_norm_squared.shape == (Y.shape[0], 1): - Y_norm_squared = Y_norm_squared.T - if Y_norm_squared.shape != (1, Y.shape[0]): - raise ValueError( - f"Incompatible dimensions for Y of shape {Y.shape} and " - f"Y_norm_squared of shape {original_shape}." - ) - if cell_length is None: - return _euclidean_distances(X, Y, X_norm_squared, Y_norm_squared, squared) + return _euclidean_distances(X, Y, squared=squared) else: return _periodic_euclidean_distances(X, Y, squared=squared, cell=cell_length) diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 3b8df19ac..ca4340c0d 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -31,15 +31,28 @@ def setUpClass(cls) -> None: -2.61132267, ] ) - cls.labels_ = np.array([0, 0, 0, 5, 5, 5]) - cls.cluster_centers_idx_ = np.array([0, 5]) + cls.qs_labels_ = np.array([0, 0, 0, 5, 5, 5]) + cls.qs_cluster_centers_idx_ = np.array([0, 5]) + cls.gabriel_labels_ = np.array([5, 5, 5, 5, 5, 5]) + cls.gabriel_cluster_centers_idx_ = np.array([5]) cls.cell = [3, 3] + cls.gabriel_shell = 1 - def test_fit(self): - model = QuickShift(self.cuts) + def test_fit_qs(self): + model = QuickShift(dist_cutoff_sq=self.cuts) model.fit(self.points, samples_weight=self.weights) - self.assertTrue(np.all(model.labels_ == self.labels_)) - self.assertTrue(np.all(model.cluster_centers_idx_ == self.cluster_centers_idx_)) + self.assertTrue(np.all(model.labels_ == self.qs_labels_)) + self.assertTrue( + np.all(model.cluster_centers_idx_ == self.qs_cluster_centers_idx_) + ) + + def test_fit_garbriel(self): + model = QuickShift(gabriel_shell=self.gabriel_shell) + model.fit(self.points, samples_weight=self.weights) + self.assertTrue(np.all(model.labels_ == self.gabriel_labels_)) + self.assertTrue( + np.all(model.cluster_centers_idx_ == self.gabriel_cluster_centers_idx_) + ) def test_dimension_check(self): model = QuickShift(self.cuts, metric_params={"cell_length": self.cell})