Skip to content

Commit

Permalink
Merge pull request lmcinnes#621 from timsainb/fix_protobuf_limit
Browse files Browse the repository at this point in the history
when dataset is >2GB, switch to numpy function to sample from dataset…
  • Loading branch information
lmcinnes authored Mar 16, 2021
2 parents 246c4d0 + 1cfe90a commit f86c922
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,6 @@ def _fit_embed_data(self, X, n_epochs, init, random_state):
self.head = tf.constant(tf.expand_dims(head.astype(np.int64), 0))
self.tail = tf.constant(tf.expand_dims(tail.astype(np.int64), 0))

a, b = next(iter(edge_dataset))
# breakme

if self.parametric_embedding:
init_embedding = None
else:
Expand Down Expand Up @@ -877,23 +874,38 @@ def construct_edge_dataset(
Whether the decoder is parametric or non-parametric
"""

def gather_index(index):
return X[index]

# if X is > 2Gb in size, we need to use a different, slower method for
# batching data.
gather_indices_in_python = True if X.nbytes * 1e-9 > 2 else False

def gather_X(edge_to, edge_from):
edge_to_batch = tf.gather(X, edge_to)
edge_from_batch = tf.gather(X, edge_from)
outputs = {"umap": 0}
# gather data from indexes (edges) in either numpy of tf, depending on array size
if gather_indices_in_python:
edge_to_batch = tf.py_function(gather_index, [edge_to], [tf.float32])[0]
edge_from_batch = tf.py_function(gather_index, [edge_from], [tf.float32])[0]
else:
edge_to_batch = tf.gather(X, edge_to)
edge_from_batch = tf.gather(X, edge_from)
return edge_to_batch, edge_from_batch

def get_outputs(edge_to_batch, edge_from_batch):
outputs = {"umap": tf.repeat(0, batch_size)}
if global_correlation_loss_weight > 0:
outputs["global_correlation"] = edge_to_batch

if parametric_reconstruction:
# add reconstruction to iterator output
# edge_out = tf.concat([edge_to_batch, edge_from_batch], axis=0)
outputs["reconstruction"] = edge_to_batch

return (edge_to_batch, edge_from_batch), outputs

def make_sham_generator():
"""
The sham generator is used to
The sham generator is a placeholder when all data is already intrinsic to
the model, but keras wants some input data. Used for non-parametric
embedding.
"""

def sham_generator():
Expand Down Expand Up @@ -932,10 +944,13 @@ def sham_generator():
)
edge_dataset = edge_dataset.repeat()
edge_dataset = edge_dataset.shuffle(10000)
edge_dataset = edge_dataset.batch(batch_size, drop_remainder=True)
edge_dataset = edge_dataset.map(
gather_X, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
edge_dataset = edge_dataset.batch(batch_size, drop_remainder=True)
edge_dataset = edge_dataset.map(
get_outputs, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
edge_dataset = edge_dataset.prefetch(10)
else:
# nonparametric embedding uses a sham dataset
Expand Down

0 comments on commit f86c922

Please sign in to comment.