Skip to content

Commit

Permalink
Fix points in fit_transform
Browse files Browse the repository at this point in the history
Brute-force pinning as in "embedding constraints"
(lmcinnes#432 (comment)),
while "Possibility to fix points in the low embedding"
(lmcinnes#606) is not done.
  • Loading branch information
jondo committed Mar 23, 2021
1 parent f86c922 commit 3f3fe1e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
9 changes: 9 additions & 0 deletions umap/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def optimize_layout_euclidean(
verbose=False,
densmap=False,
densmap_kwds={},
pinned_data=None,
):
"""Improve an embedding using stochastic gradient descent to minimize the
fuzzy set cross entropy between the 1-skeletons of the high dimensional
Expand Down Expand Up @@ -341,6 +342,14 @@ def optimize_layout_euclidean(
dens_re_mean = 0
dens_re_cov = 0

# Brute-force pinning as in
# "embedding constraints" (https://github.com/lmcinnes/umap/issues/432#issuecomment-633145366 ),
# while "Possibility to fix points in the low embedding" (https://github.com/lmcinnes/umap/issues/606 )
# is not done.
if pinned_data is not None:
for p in pinned_data:
head_embedding[p[0]] = p[1]

optimize_fn(
head_embedding,
tail_embedding,
Expand Down
12 changes: 8 additions & 4 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ def simplicial_set_embedding(
euclidean_output=True,
parallel=False,
verbose=False,
pinned_data=None,
):
"""Perform a fuzzy simplicial set embedding, using a specified
initialisation method and then minimizing the fuzzy set cross entropy
Expand Down Expand Up @@ -1160,6 +1161,7 @@ def simplicial_set_embedding(
verbose=verbose,
densmap=densmap,
densmap_kwds=densmap_kwds,
pinned_data=pinned_data,
)
else:
embedding = optimize_layout_generic(
Expand Down Expand Up @@ -2151,7 +2153,7 @@ def __sub__(self, other):

return result

def fit(self, X, y=None):
def fit(self, X, y=None, pinned_data=None):
"""Fit X into an embedded space.
Optionally use y for supervised dimension reduction.
Expand Down Expand Up @@ -2584,6 +2586,7 @@ def fit(self, X, y=None):
n_epochs,
init,
random_state, # JH why raw data?
pinned_data=pinned_data,
)
# Assign any points that are fully disconnected from our manifold(s) to have embedding
# coordinates of np.nan. These will be filtered by our plotting functions automatically.
Expand All @@ -2608,7 +2611,7 @@ def fit(self, X, y=None):

return self

def _fit_embed_data(self, X, n_epochs, init, random_state):
def _fit_embed_data(self, X, n_epochs, init, random_state, pinned_data=None):
"""A method wrapper for simplicial_set_embedding that can be
replaced by subclasses.
"""
Expand All @@ -2634,9 +2637,10 @@ def _fit_embed_data(self, X, n_epochs, init, random_state):
self.output_metric in ("euclidean", "l2"),
self.random_state is None,
self.verbose,
pinned_data=pinned_data,
)

def fit_transform(self, X, y=None):
def fit_transform(self, X, y=None, pinned_data=None):
"""Fit X into an embedded space and return that transformed
output.
Expand Down Expand Up @@ -2666,7 +2670,7 @@ def fit_transform(self, X, y=None):
r_emb: array, shape (n_samples)
Local radii of data points in the embedding (log-transformed).
"""
self.fit(X, y)
self.fit(X, y, pinned_data=pinned_data)
if self.transform_mode == "embedding":
if self.output_dens:
return self.embedding_, self.rad_orig_, self.rad_emb_
Expand Down

0 comments on commit 3f3fe1e

Please sign in to comment.