Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement masking to control how embedded points are updated #620

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 307 additions & 0 deletions umap/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,135 @@ def _optimize_layout_euclidean_single_epoch(
)


def _optimize_layout_euclidean_masked_single_epoch(
head_embedding,
tail_embedding,
head,
tail,
mask,
n_vertices,
epochs_per_sample,
a,
b,
rng_state,
gamma,
dim,
move_other,
alpha,
epochs_per_negative_sample,
epoch_of_next_negative_sample,
epoch_of_next_sample,
n,
densmap_flag,
dens_phi_sum,
dens_re_sum,
dens_re_cov,
dens_re_std,
dens_re_mean,
dens_lambda,
dens_R,
dens_mu,
dens_mu_tot,
):
for i in numba.prange(epochs_per_sample.shape[0]):
if epoch_of_next_sample[i] <= n:
j = head[i]
k = tail[i]

current = head_embedding[j]
other = tail_embedding[k]

current_mask = mask[j]
other_mask = mask[k]

dist_squared = rdist(current, other)

if densmap_flag:
phi = 1.0 / (1.0 + a * pow(dist_squared, b))
dphi_term = (
a * b * pow(dist_squared, b - 1) / (1.0 + a * pow(dist_squared, b))
)

q_jk = phi / dens_phi_sum[k]
q_kj = phi / dens_phi_sum[j]

drk = q_jk * (
(1.0 - b * (1 - phi)) / np.exp(dens_re_sum[k]) + dphi_term
)
drj = q_kj * (
(1.0 - b * (1 - phi)) / np.exp(dens_re_sum[j]) + dphi_term
)

re_std_sq = dens_re_std * dens_re_std
weight_k = (
dens_R[k]
- dens_re_cov * (dens_re_sum[k] - dens_re_mean) / re_std_sq
)
weight_j = (
dens_R[j]
- dens_re_cov * (dens_re_sum[j] - dens_re_mean) / re_std_sq
)

grad_cor_coeff = (
dens_lambda
* dens_mu_tot
* (weight_k * drk + weight_j * drj)
/ (dens_mu[i] * dens_re_std)
/ n_vertices
)

if dist_squared > 0.0:
grad_coeff = -2.0 * a * b * pow(dist_squared, b - 1.0)
grad_coeff /= a * pow(dist_squared, b) + 1.0
else:
grad_coeff = 0.0

for d in range(dim):
grad_d = clip(grad_coeff * (current[d] - other[d]))

if densmap_flag:
grad_d += clip(2 * grad_cor_coeff * (current[d] - other[d]))

current[d] += current_mask * grad_d * alpha
if move_other:
other[d] += - other_mask * grad_d * alpha

epoch_of_next_sample[i] += epochs_per_sample[i]

n_neg_samples = int(
(n - epoch_of_next_negative_sample[i]) / epochs_per_negative_sample[i]
)

for p in range(n_neg_samples):
k = tau_rand_int(rng_state) % n_vertices

other = tail_embedding[k]

dist_squared = rdist(current, other)

if dist_squared > 0.0:
grad_coeff = 2.0 * gamma * b
grad_coeff /= (0.001 + dist_squared) * (
a * pow(dist_squared, b) + 1
)
elif j == k:
continue
else:
grad_coeff = 0.0

for d in range(dim):
if grad_coeff > 0.0:
grad_d = clip(grad_coeff * (current[d] - other[d]))
else:
grad_d = 4.0
current[d] += current_mask * grad_d * alpha

epoch_of_next_negative_sample[i] += (
n_neg_samples * epochs_per_negative_sample[i]
)



def _optimize_layout_euclidean_densmap_epoch_init(
head_embedding, tail_embedding, head, tail, a, b, re_sum, phi_sum,
):
Expand Down Expand Up @@ -379,6 +508,184 @@ def optimize_layout_euclidean(
return head_embedding


def optimize_layout_euclidean_masked(
head_embedding,
tail_embedding,
head,
tail,
mask,
n_epochs,
n_vertices,
epochs_per_sample,
a,
b,
rng_state,
gamma=1.0,
initial_alpha=1.0,
negative_sample_rate=5.0,
parallel=False,
verbose=False,
densmap=False,
densmap_kwds={},
):
"""Improve an embedding using stochastic gradient descent to minimize the
fuzzy set cross entropy between the 1-skeletons of the high dimensional
and low dimensional fuzzy simplicial sets. In practice this is done by
sampling edges based on their membership strength (with the (1-p) terms
coming from negative sampling similar to word2vec).
Parameters
----------
head_embedding: array of shape (n_samples, n_components)
The initial embedding to be improved by SGD.
tail_embedding: array of shape (source_samples, n_components)
The reference embedding of embedded points. If not embedding new
previously unseen points with respect to an existing embedding this
is simply the head_embedding (again); otherwise it provides the
existing embedding to embed with respect to.
head: array of shape (n_1_simplices)
The indices of the heads of 1-simplices with non-zero membership.
tail: array of shape (n_1_simplices)
The indices of the tails of 1-simplices with non-zero membership.
mask: array of shape (n_samples)
The weights (in [0,1]) assigned to each sample, defining how much they
should be updated. 0 means the point will not move at all, 1 means
they are updated normally. In-between values allow for fine-tuning.
n_epochs: int
The number of training epochs to use in optimization.
n_vertices: int
The number of vertices (0-simplices) in the dataset.
epochs_per_samples: array of shape (n_1_simplices)
A float value of the number of epochs per 1-simplex. 1-simplices with
weaker membership strength will have more epochs between being sampled.
a: float
Parameter of differentiable approximation of right adjoint functor
b: float
Parameter of differentiable approximation of right adjoint functor
rng_state: array of int64, shape (3,)
The internal state of the rng
gamma: float (optional, default 1.0)
Weight to apply to negative samples.
initial_alpha: float (optional, default 1.0)
Initial learning rate for the SGD.
negative_sample_rate: int (optional, default 5)
Number of negative samples to use per positive sample.
parallel: bool (optional, default False)
Whether to run the computation using numba parallel.
Running in parallel is non-deterministic, and is not used
if a random seed has been set, to ensure reproducibility.
verbose: bool (optional, default False)
Whether to report information on the current progress of the algorithm.
densmap: bool (optional, default False)
Whether to use the density-augmented densMAP objective
densmap_kwds: dict (optional, default {})
Auxiliary data for densMAP
Returns
-------
embedding: array of shape (n_samples, n_components)
The optimized embedding.
"""

dim = head_embedding.shape[1]
move_other = head_embedding.shape[0] == tail_embedding.shape[0]
alpha = initial_alpha

epochs_per_negative_sample = epochs_per_sample / negative_sample_rate
epoch_of_next_negative_sample = epochs_per_negative_sample.copy()
epoch_of_next_sample = epochs_per_sample.copy()

optimize_fn = numba.njit(
_optimize_layout_euclidean_masked_single_epoch, fastmath=True, parallel=parallel
)

if densmap:
dens_init_fn = numba.njit(
_optimize_layout_euclidean_densmap_epoch_init,
fastmath=True,
parallel=parallel,
)

dens_mu_tot = np.sum(densmap_kwds["mu_sum"]) / 2
dens_lambda = densmap_kwds["lambda"]
dens_R = densmap_kwds["R"]
dens_mu = densmap_kwds["mu"]
dens_phi_sum = np.zeros(n_vertices, dtype=np.float32)
dens_re_sum = np.zeros(n_vertices, dtype=np.float32)
dens_var_shift = densmap_kwds["var_shift"]
else:
dens_mu_tot = 0
dens_lambda = 0
dens_R = np.zeros(1, dtype=np.float32)
dens_mu = np.zeros(1, dtype=np.float32)
dens_phi_sum = np.zeros(1, dtype=np.float32)
dens_re_sum = np.zeros(1, dtype=np.float32)

for n in range(n_epochs):

densmap_flag = (
densmap
and (densmap_kwds["lambda"] > 0)
and (((n + 1) / float(n_epochs)) > (1 - densmap_kwds["frac"]))
)

if densmap_flag:
dens_init_fn(
head_embedding,
tail_embedding,
head,
tail,
a,
b,
dens_re_sum,
dens_phi_sum,
)

dens_re_std = np.sqrt(np.var(dens_re_sum) + dens_var_shift)
dens_re_mean = np.mean(dens_re_sum)
dens_re_cov = np.dot(dens_re_sum, dens_R) / (n_vertices - 1)
else:
dens_re_std = 0
dens_re_mean = 0
dens_re_cov = 0

optimize_fn(
head_embedding,
tail_embedding,
head,
tail,
mask,
n_vertices,
epochs_per_sample,
a,
b,
rng_state,
gamma,
dim,
move_other,
alpha,
epochs_per_negative_sample,
epoch_of_next_negative_sample,
epoch_of_next_sample,
n,
densmap_flag,
dens_phi_sum,
dens_re_sum,
dens_re_cov,
dens_re_std,
dens_re_mean,
dens_lambda,
dens_R,
dens_mu,
dens_mu_tot,
)

alpha = initial_alpha * (1.0 - (float(n) / float(n_epochs)))

if verbose and n % int(n_epochs / 10) == 0:
print("\tcompleted ", n, " / ", n_epochs, "epochs")

return head_embedding


@numba.njit(fastmath=True)
def optimize_layout_generic(
head_embedding,
Expand Down
8 changes: 7 additions & 1 deletion umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _compile_model(self):
run_eagerly=self.run_eagerly,
)

def _fit_embed_data(self, X, n_epochs, init, random_state):
def _fit_embed_data(self, X, n_epochs, init, random_state, pin_mask):

if self.metric == "precomputed":
X = self._X
Expand All @@ -371,6 +371,12 @@ def _fit_embed_data(self, X, n_epochs, init, random_state):
if len(self.dims) > 1:
X = np.reshape(X, [len(X)] + list(self.dims))

if pin_mask is not None:
warn(
"Pinning is not yet supported by Parametric UMAP.\
Ignoring the pin_mask."
)

if self.parametric_reconstruction and (np.max(X) > 1.0 or np.min(X) < 0.0):
warn(
"Data should be scaled to the range 0-1 for cross-entropy reconstruction loss."
Expand Down
Loading