From 3f3fe1ea31c0d1dd3d8a1c4ca966d5dd2e2f0e2d Mon Sep 17 00:00:00 2001 From: Robert Pollak Date: Tue, 23 Mar 2021 17:11:50 +0100 Subject: [PATCH] Fix points in fit_transform Brute-force pinning as in "embedding constraints" (https://github.com/lmcinnes/umap/issues/432#issuecomment-633145366), while "Possibility to fix points in the low embedding" (https://github.com/lmcinnes/umap/issues/606) is not done. --- umap/layouts.py | 9 +++++++++ umap/umap_.py | 12 ++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/umap/layouts.py b/umap/layouts.py index c3d79fc7..b7dee44b 100644 --- a/umap/layouts.py +++ b/umap/layouts.py @@ -225,6 +225,7 @@ def optimize_layout_euclidean( verbose=False, densmap=False, densmap_kwds={}, + pinned_data=None, ): """Improve an embedding using stochastic gradient descent to minimize the fuzzy set cross entropy between the 1-skeletons of the high dimensional @@ -341,6 +342,14 @@ def optimize_layout_euclidean( dens_re_mean = 0 dens_re_cov = 0 + # Brute-force pinning as in + # "embedding constraints" (https://github.com/lmcinnes/umap/issues/432#issuecomment-633145366 ), + # while "Possibility to fix points in the low embedding" (https://github.com/lmcinnes/umap/issues/606 ) + # is not done. + if pinned_data is not None: + for p in pinned_data: + head_embedding[p[0]] = p[1] + optimize_fn( head_embedding, tail_embedding, diff --git a/umap/umap_.py b/umap/umap_.py index 52c3a280..7b20663b 100644 --- a/umap/umap_.py +++ b/umap/umap_.py @@ -941,6 +941,7 @@ def simplicial_set_embedding( euclidean_output=True, parallel=False, verbose=False, + pinned_data=None, ): """Perform a fuzzy simplicial set embedding, using a specified initialisation method and then minimizing the fuzzy set cross entropy @@ -1160,6 +1161,7 @@ def simplicial_set_embedding( verbose=verbose, densmap=densmap, densmap_kwds=densmap_kwds, + pinned_data=pinned_data, ) else: embedding = optimize_layout_generic( @@ -2151,7 +2153,7 @@ def __sub__(self, other): return result - def fit(self, X, y=None): + def fit(self, X, y=None, pinned_data=None): """Fit X into an embedded space. Optionally use y for supervised dimension reduction. @@ -2584,6 +2586,7 @@ def fit(self, X, y=None): n_epochs, init, random_state, # JH why raw data? + pinned_data=pinned_data, ) # Assign any points that are fully disconnected from our manifold(s) to have embedding # coordinates of np.nan. These will be filtered by our plotting functions automatically. @@ -2608,7 +2611,7 @@ def fit(self, X, y=None): return self - def _fit_embed_data(self, X, n_epochs, init, random_state): + def _fit_embed_data(self, X, n_epochs, init, random_state, pinned_data=None): """A method wrapper for simplicial_set_embedding that can be replaced by subclasses. """ @@ -2634,9 +2637,10 @@ def _fit_embed_data(self, X, n_epochs, init, random_state): self.output_metric in ("euclidean", "l2"), self.random_state is None, self.verbose, + pinned_data=pinned_data, ) - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, pinned_data=None): """Fit X into an embedded space and return that transformed output. @@ -2666,7 +2670,7 @@ def fit_transform(self, X, y=None): r_emb: array, shape (n_samples) Local radii of data points in the embedding (log-transformed). """ - self.fit(X, y) + self.fit(X, y, pinned_data=pinned_data) if self.transform_mode == "embedding": if self.output_dens: return self.embedding_, self.rad_orig_, self.rad_emb_