Skip to content

Commit

Permalink
Add tests for quick shift and remove unused functions
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Jun 15, 2024
1 parent ca3cec8 commit 7ffa8ef
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 55 deletions.
10 changes: 5 additions & 5 deletions src/skmatter/clustering/_quick_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def __init__(
):
if (dist_cutoff_sq is None) and (gabriel_shell is None):
raise ValueError("Either dist_cutoff or gabriel_depth must be set.")

Check warning on line 109 in src/skmatter/clustering/_quick_shift.py

View check run for this annotation

Codecov / codecov/patch

src/skmatter/clustering/_quick_shift.py#L109

Added line #L109 was not covered by tests
self.dist_cutoff2 = dist_cutoff_sq
self.dist_cutoff_sq = dist_cutoff_sq
self.gabriel_shell = gabriel_shell
self.scale = scale
if self.dist_cutoff2 is not None:
self.dist_cutoff2 *= self.scale**2
if self.dist_cutoff_sq is not None:
self.dist_cutoff_sq *= self.scale**2
self.metric_params = (
metric_params if metric_params is not None else {"cell_length": None}
)
Expand Down Expand Up @@ -142,7 +142,7 @@ def fit(self, X, y=None, samples_weight=None):
)
dist_matrix = self.metric(X, X)
np.fill_diagonal(dist_matrix, np.inf)
if self.dist_cutoff2 is None:
if self.dist_cutoff_sq is None:
gabrial = _get_gabriel_graph(dist_matrix)
idmindist = np.argmin(dist_matrix, axis=1)
idxroot = np.full(dist_matrix.shape[0], -1, dtype=int)
Expand All @@ -163,7 +163,7 @@ def fit(self, X, y=None, samples_weight=None):
idmindist[current],
samples_weight,
dist_matrix,
self.dist_cutoff2[current],
self.dist_cutoff_sq[current],
)
if idxroot[idxroot[current]] != -1:
# Found a path to a root
Expand Down
46 changes: 2 additions & 44 deletions src/skmatter/metrics/_pairwise.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
from typing import Union

import numpy as np
from sklearn.metrics.pairwise import (
_euclidean_distances,
check_array,
check_pairwise_arrays,
)
from sklearn.metrics.pairwise import _euclidean_distances, check_pairwise_arrays


def periodic_pairwise_euclidean_distances(
X,
Y=None,
*,
Y_norm_squared=None,
squared=False,
X_norm_squared=None,
cell_length=None,
):
r"""
Expand Down Expand Up @@ -48,16 +42,6 @@ def periodic_pairwise_euclidean_distances(
default=None
An array where each row is a sample and each column is a component.
If `None`, method uses `Y=X`.
Y_norm_squared : array-like of shape (n_samples_Y,) or (n_samples_Y, 1) \
or (1, n_samples_Y), default=None
Pre-computed dot-products of vectors in Y (e.g., `(Y**2).sum(axis=1)`)
May be ignored in some cases, see the note below.
squared : bool, default=False
Return squared Euclidean distances.
X_norm_squared : array-like of shape (n_samples_X,) or (n_samples_X, 1) \
or (1, n_samples_X), default=None
Pre-computed dot-products of vectors in X (e.g., `(X**2).sum(axis=1)`)
May be ignored in some cases, see the note below.
cell_length : array-like of shape (n_components,), default=None
The side length of rectangular cell used for periodic boundary conditions.
`None` for non-periodic boundary conditions.
Expand Down Expand Up @@ -90,34 +74,8 @@ def periodic_pairwise_euclidean_distances(
_check_dimension(X, cell_length)
X, Y = check_pairwise_arrays(X, Y)

if X_norm_squared is not None:
X_norm_squared = check_array(X_norm_squared, ensure_2d=False)
original_shape = X_norm_squared.shape
if X_norm_squared.shape == (X.shape[0],):
X_norm_squared = X_norm_squared.reshape(-1, 1)
if X_norm_squared.shape == (1, X.shape[0]):
X_norm_squared = X_norm_squared.T
if X_norm_squared.shape != (X.shape[0], 1):
raise ValueError(
f"Incompatible dimensions for X of shape {X.shape} and "
f"X_norm_squared of shape {original_shape}."
)

if Y_norm_squared is not None:
Y_norm_squared = check_array(Y_norm_squared, ensure_2d=False)
original_shape = Y_norm_squared.shape
if Y_norm_squared.shape == (Y.shape[0],):
Y_norm_squared = Y_norm_squared.reshape(1, -1)
if Y_norm_squared.shape == (Y.shape[0], 1):
Y_norm_squared = Y_norm_squared.T
if Y_norm_squared.shape != (1, Y.shape[0]):
raise ValueError(
f"Incompatible dimensions for Y of shape {Y.shape} and "
f"Y_norm_squared of shape {original_shape}."
)

if cell_length is None:
return _euclidean_distances(X, Y, X_norm_squared, Y_norm_squared, squared)
return _euclidean_distances(X, Y, squared=squared)
else:
return _periodic_euclidean_distances(X, Y, squared=squared, cell=cell_length)

Expand Down
25 changes: 19 additions & 6 deletions tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,28 @@ def setUpClass(cls) -> None:
-2.61132267,
]
)
cls.labels_ = np.array([0, 0, 0, 5, 5, 5])
cls.cluster_centers_idx_ = np.array([0, 5])
cls.qs_labels_ = np.array([0, 0, 0, 5, 5, 5])
cls.qs_cluster_centers_idx_ = np.array([0, 5])
cls.gabriel_labels_ = np.array([5, 5, 5, 5, 5, 5])
cls.gabriel_cluster_centers_idx_ = np.array([5])
cls.cell = [3, 3]
cls.gabriel_shell = 1

def test_fit(self):
model = QuickShift(self.cuts)
def test_fit_qs(self):
model = QuickShift(dist_cutoff_sq=self.cuts)
model.fit(self.points, samples_weight=self.weights)
self.assertTrue(np.all(model.labels_ == self.labels_))
self.assertTrue(np.all(model.cluster_centers_idx_ == self.cluster_centers_idx_))
self.assertTrue(np.all(model.labels_ == self.qs_labels_))
self.assertTrue(
np.all(model.cluster_centers_idx_ == self.qs_cluster_centers_idx_)
)

def test_fit_garbriel(self):
model = QuickShift(gabriel_shell=self.gabriel_shell)
model.fit(self.points, samples_weight=self.weights)
self.assertTrue(np.all(model.labels_ == self.gabriel_labels_))
self.assertTrue(
np.all(model.cluster_centers_idx_ == self.gabriel_cluster_centers_idx_)
)

def test_dimension_check(self):
model = QuickShift(self.cuts, metric_params={"cell_length": self.cell})
Expand Down

0 comments on commit 7ffa8ef

Please sign in to comment.