Skip to content

Commit

Permalink
Merge pull request #15 from edahelsinki/faster_kmeans
Browse files Browse the repository at this point in the history
Faster kmeans
  • Loading branch information
Aggrathon authored Feb 20, 2024
2 parents e26e2aa + 0ac39b3 commit 82b4012
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "slisemap"
version = "1.6.0"
version = "1.6.1"
authors = [
{ name = "Anton Björklund", email = "[email protected]" },
{ name = "Jarmo Mäkelä" },
Expand Down
28 changes: 20 additions & 8 deletions slisemap/slipmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import lzma
import warnings
from copy import copy
from os import PathLike
from timeit import default_timer as timer
Expand All @@ -26,7 +27,7 @@
import seaborn as sns
import torch
from matplotlib.figure import Figure
from sklearn.cluster import KMeans
from sklearn.cluster import KMeans, MiniBatchKMeans

from slisemap.local_models import (
ALocalModel,
Expand Down Expand Up @@ -1062,7 +1063,7 @@ def optimize(

def predict(
self,
Xnew: ToTensor,
X: ToTensor,
weighted: bool = True,
numpy: bool = True,
) -> Union[np.ndarray, torch.Tensor]:
Expand All @@ -1072,23 +1073,23 @@ def predict(
Then the prediction is made with the local model (of the closest prototype).
Args:
Xnew: Data matrix.
X: Data matrix.
weighted: Use a weighted model instead of just the nearest. Defaults to True
numpy: Return the predictions as a `numpy.ndarray` instead of `torch.Tensor`. Defaults to True.
Returns:
Predicted Y:s.
"""
Xnew = self._as_new_X(Xnew)
xnn = torch.cdist(Xnew, self._X).argmin(1)
X = self._as_new_X(X)
xnn = torch.cdist(X, self._X).argmin(1)
if weighted:
Y = self.local_model(Xnew, self._Bp)
Y = self.local_model(X, self._Bp)
D = self.get_D(True, False, numpy=False)[:, xnn]
W = softmax_column_kernel(D)
Y = torch.sum(W[..., None] * Y, 0)
else:
B = self.get_B(False)[xnn, :]
Y = local_predict(Xnew, B, self.local_model)
Y = local_predict(X, B, self.local_model)
return tonp(Y) if numpy else Y

def get_model_clusters(
Expand All @@ -1097,6 +1098,7 @@ def get_model_clusters(
B: Optional[np.ndarray] = None,
Z: Optional[np.ndarray] = None,
random_state: int = 42,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray]:
"""Cluster the local model coefficients using k-means (from scikit-learn).
Expand All @@ -1108,13 +1110,23 @@ def get_model_clusters(
Z: Z matrix. Defaults to `self.get_Z()`.
random_state: random_state for the KMeans clustering. Defaults to 42.
Keyword Args:
**kwargs: Additional arguments to `sklearn.cluster.KMeans` or `sklearn.cluster.MiniBatchKMeans` if `self.n >= 1024`.
Returns:
labels: Vector of cluster labels.
centres: Matrix of cluster centres.
"""
B = B if B is not None else self.get_B()
Z = Z if Z is not None else self.get_Z()
km = KMeans(clusters, random_state=random_state).fit(B)
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
# Some sklearn versions warn about changing defaults for KMeans
kwargs.setdefault("random_state", random_state)
if self.n >= 1024:
km = MiniBatchKMeans(clusters, **kwargs).fit(B)
else:
km = KMeans(clusters, **kwargs).fit(B)
ord = np.argsort([Z[km.labels_ == k, 0].mean() for k in range(clusters)])
return np.argsort(ord)[km.labels_], km.cluster_centers_[ord]

Expand Down
14 changes: 11 additions & 3 deletions slisemap/slisemap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module that contains the `Slisemap` class."""

import lzma
import warnings
from copy import copy
from os import PathLike
from timeit import default_timer as timer
Expand All @@ -22,7 +23,7 @@
import seaborn as sns
import torch
from matplotlib.figure import Figure
from sklearn.cluster import KMeans
from sklearn.cluster import KMeans, MiniBatchKMeans

from slisemap.escape import escape_neighbourhood
from slisemap.local_models import (
Expand Down Expand Up @@ -1284,15 +1285,22 @@ def get_model_clusters(
random_state: random_state for the KMeans clustering. Defaults to 42.
Keyword Args:
**kwargs: Additional arguments to `sklearn.KMeans`.
**kwargs: Additional arguments to `sklearn.cluster.KMeans` or `sklearn.cluster.MiniBatchKMeans` if `self.n >= 1024`.
Returns:
labels: Vector of cluster labels.
centres: Matrix of cluster centres.
"""
B = B if B is not None else self.get_B()
Z = Z if Z is not None else self.get_Z(rotate=True)
km = KMeans(clusters, random_state=random_state, **kwargs).fit(B)
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
# Some sklearn versions warn about changing defaults for KMeans
kwargs.setdefault("random_state", random_state)
if self.n >= 1024:
km = MiniBatchKMeans(clusters, **kwargs).fit(B)
else:
km = KMeans(clusters, **kwargs).fit(B)
ord = np.argsort([Z[km.labels_ == k, 0].mean() for k in range(clusters)])
return np.argsort(ord)[km.labels_], km.cluster_centers_[ord]

Expand Down

0 comments on commit 82b4012

Please sign in to comment.