From 272da26041ee849c4c12fe2e335b99119dd693b9 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:39:37 +0100 Subject: [PATCH 01/17] #779 At coreax.solvers.composite, changed the return statement of the _jit_tree function to return indices as well, changed the reduce method to keep track of indices, added a line plt.show() to pounce.py --- coreax/solvers/composite.py | 58 +++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/coreax/solvers/composite.py b/coreax/solvers/composite.py index d8c8b297..b5ac0f88 100644 --- a/coreax/solvers/composite.py +++ b/coreax/solvers/composite.py @@ -120,24 +120,57 @@ def __check_init__(self): @override def reduce( - self, dataset: _Data, solver_state: Optional[_State] = None + self, dataset: _Data, solver_state: Optional[_State] = None ) -> tuple[_Coreset, _State]: # There is no obvious way to use state information here. del solver_state - def _reduce_coreset(data: _Data) -> tuple[_Coreset, _State]: + def _reduce_coreset(data: _Data, _indices=None) -> tuple[_Coreset, _State, _Data]: if len(data) <= self.leaf_size: - return self.base_solver.reduce(data) - partitioned_dataset = _jit_tree(data, self.leaf_size, self.tree_type) - coreset_ensemble, _ = jax.vmap(self.base_solver.reduce)(partitioned_dataset) + coreset, state = self.base_solver.reduce(data) + if _indices is not None: + _indices = _indices[coreset.unweighted_indices] + return coreset, state, _indices + + def wrapper(row: _Data) -> tuple[_Data, _Data]: + """ + Apply the reduce method of the base solver on a dataset and + return the data and unweighted indices of the coreset. + + It is a wrapper to process a single partition (row) of the result of _jit_tree + that works with the vmap + """ + x, _ = self.base_solver.reduce(row) + return x.coreset, x.unweighted_indices + + def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """ + Perform advanced indexing on array 'a' using indices 'b'. + Returns a new array with elements of 'a' at positions specified by 'b'. + """ + return a[b] + + # First partition the data + partitioned_dataset, partitioned_indices = _jit_tree(data, self.leaf_size, self.tree_type) + # Then apply base solver to each partition and keep track of indices with respect to partitions + coreset_ensemble, ensemble_indices = jax.vmap(wrapper)(partitioned_dataset) + # Calculate the indices with respect to the data (one passed to _reduce_coreset) + concatenated_indices = jax.vmap(get_indices)(partitioned_indices, ensemble_indices) + # flatten the indices + concatenated_indices = jnp.ravel(concatenated_indices) _coreset = jtu.tree_map(jnp.concatenate, coreset_ensemble) - return _reduce_coreset(_coreset.coreset) - coreset_wrong_pre_coreset_data, output_solver_state = _reduce_coreset(dataset) - coreset = eqx.tree_at( - lambda x: x.pre_coreset_data, coreset_wrong_pre_coreset_data, dataset - ) - return coreset, output_solver_state + if _indices is not None: + final_indices = _indices[concatenated_indices] + else: + final_indices = concatenated_indices + return _reduce_coreset(_coreset, final_indices) + + coreset, output_solver_state, _indices = _reduce_coreset(dataset) + del coreset + final_coreset = Coresubset(_indices, dataset) + + return final_coreset, output_solver_state def _jit_tree(dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]) -> _Data: @@ -183,4 +216,5 @@ def _binary_tree(_input_data: Data) -> np.ndarray: return node_indices.reshape(n_leaves, -1).astype(np.int32) indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset) - return dataset[indices] + return dataset[indices], indices # (Now it returns both data and the indices) + From d9045a4cad99f920a4df02dfd4a825fcc0c0edc1 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:42:15 +0100 Subject: [PATCH 02/17] #779 At coreax.solvers.composite, changed the return statement of the _jit_tree function to return indices as well, changed the reduce method to keep track of indices, added a line plt.show() to pounce.py --- examples/pounce.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/pounce.py b/examples/pounce.py index 0cc9195e..3ad944f7 100644 --- a/examples/pounce.py +++ b/examples/pounce.py @@ -175,6 +175,7 @@ def main( plt.xlabel("Frame") plt.ylabel("Chosen") plt.tight_layout() + plt.show() if out_path is not None: plt.savefig(out_path / "pounce_frames.png") plt.close() From d87afde772b03bbc50056d80e84fc476ea6a3b87 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:45:50 +0100 Subject: [PATCH 03/17] Added a test in unit/test_solvers.py that checks if MapReduce's reduce method returns coreset with indices bigger than the leafsize as was present in #779. This test fails when that bug is present and passes when it is not present --- tests/unit/test_solvers.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index 733ada99..302172e4 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -850,6 +850,34 @@ def test_base_solver( solver_factory.keywords["base_solver"] = base_solver solver_factory() + def test_mapreduce_diverse_selection(self): + """Check if MapReduce returns indices from multiple partitions.""" + dataset_size = 40 + data_dim = 5 + coreset_size = 6 + leaf_size = 12 + + key = jr.PRNGKey(0) + dataset = jr.normal(key, shape=(dataset_size, data_dim)) + + kernel = SquaredExponentialKernel() + base_solver = KernelHerding(coreset_size=coreset_size, kernel=kernel) + + solver = MapReduce(base_solver=base_solver, leaf_size=leaf_size) + coreset, _ = solver.reduce(Data(dataset)) + selected_indices = coreset.nodes.data + + # Check if there are indices beyond the first few + assert jnp.any( + selected_indices >= coreset_size + ), "MapReduce should select points beyond the first few" + + # Check if there are indices from different partitions + partitions_represented = jnp.unique(selected_indices // leaf_size) + assert ( + len(partitions_represented) > 1 + ), "MapReduce should select points from multiple partitions" + class TestCaratheodoryRecombination(RecombinationSolverTest): """Tests for :class:`coreax.solvers.recombination.CaratheodoryRecombination`.""" From 08f666ac0670dfdd8496ad56b7d63cb81a82cafa Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:50:45 +0100 Subject: [PATCH 04/17] Added an if statement on MapReduce.reduce method, it now only assigns calculated index if base solver returns a Coresubset object and not just a Coreset object. --- coreax/solvers/composite.py | 51 ++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/coreax/solvers/composite.py b/coreax/solvers/composite.py index b5ac0f88..9e7c07fb 100644 --- a/coreax/solvers/composite.py +++ b/coreax/solvers/composite.py @@ -120,42 +120,44 @@ def __check_init__(self): @override def reduce( - self, dataset: _Data, solver_state: Optional[_State] = None + self, dataset: _Data, solver_state: Optional[_State] = None ) -> tuple[_Coreset, _State]: # There is no obvious way to use state information here. del solver_state - def _reduce_coreset(data: _Data, _indices=None) -> tuple[_Coreset, _State, _Data]: + def _reduce_coreset( + data: _Data, _indices=None + ) -> (tuple)[_Coreset, _State, _Data]: if len(data) <= self.leaf_size: coreset, state = self.base_solver.reduce(data) if _indices is not None: - _indices = _indices[coreset.unweighted_indices] + _indices = _indices[coreset.nodes.data] return coreset, state, _indices def wrapper(row: _Data) -> tuple[_Data, _Data]: """ - Apply the reduce method of the base solver on a dataset and - return the data and unweighted indices of the coreset. + Apply the reduce method of the base solver on a row. - It is a wrapper to process a single partition (row) of the result of _jit_tree - that works with the vmap + It is a wrapper to process a single partition (row) + of the result of _jit_tree that works with the vmap """ x, _ = self.base_solver.reduce(row) - return x.coreset, x.unweighted_indices + return x.coreset, x.nodes.data def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - """ - Perform advanced indexing on array 'a' using indices 'b'. - Returns a new array with elements of 'a' at positions specified by 'b'. - """ return a[b] # First partition the data - partitioned_dataset, partitioned_indices = _jit_tree(data, self.leaf_size, self.tree_type) - # Then apply base solver to each partition and keep track of indices with respect to partitions + partitioned_dataset, partitioned_indices = _jit_tree( + data, self.leaf_size, self.tree_type + ) + # Then apply base solver to each partition and + # keep track of indices with respect to partitions coreset_ensemble, ensemble_indices = jax.vmap(wrapper)(partitioned_dataset) - # Calculate the indices with respect to the data (one passed to _reduce_coreset) - concatenated_indices = jax.vmap(get_indices)(partitioned_indices, ensemble_indices) + # Calculate the indices with respect to the data (_reduce_coreset) + concatenated_indices = jax.vmap(get_indices)( + partitioned_indices, ensemble_indices + ) # flatten the indices concatenated_indices = jnp.ravel(concatenated_indices) _coreset = jtu.tree_map(jnp.concatenate, coreset_ensemble) @@ -166,11 +168,15 @@ def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: final_indices = concatenated_indices return _reduce_coreset(_coreset, final_indices) - coreset, output_solver_state, _indices = _reduce_coreset(dataset) - del coreset - final_coreset = Coresubset(_indices, dataset) - - return final_coreset, output_solver_state + (coreset_wrong_pre_coreset_data, output_solver_state, _indices) = ( + _reduce_coreset(dataset) + ) + coreset = eqx.tree_at( + lambda x: x.pre_coreset_data, coreset_wrong_pre_coreset_data, dataset + ) + if isinstance(coreset, Coresubset): + coreset = eqx.tree_at(lambda x: x.nodes.data, coreset, _indices) + return coreset, output_solver_state def _jit_tree(dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]) -> _Data: @@ -216,5 +222,4 @@ def _binary_tree(_input_data: Data) -> np.ndarray: return node_indices.reshape(n_leaves, -1).astype(np.int32) indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset) - return dataset[indices], indices # (Now it returns both data and the indices) - + return dataset[indices], indices From 7d41a4b0f183e455a667b7fdaa91f3bc7909a65e Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:29:28 +0100 Subject: [PATCH 05/17] replaced mapreduce by map_reduce in test_map_reduce_diverse_selection for precommit --- tests/unit/test_solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index 302172e4..b0ef1894 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -850,7 +850,7 @@ def test_base_solver( solver_factory.keywords["base_solver"] = base_solver solver_factory() - def test_mapreduce_diverse_selection(self): + def test_map_reduce_diverse_selection(self): """Check if MapReduce returns indices from multiple partitions.""" dataset_size = 40 data_dim = 5 From 48a1037491385e9d3517efba6e7b94f4c6847c35 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:56:44 +0100 Subject: [PATCH 06/17] In coreax/solvers/composite.py, the reduce method updates indices only when _indices is not none. _jit_tree now returns jnp.arange(len(dataset))[indices] instead of indices so that it is within bounds. --- coreax/solvers/composite.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/coreax/solvers/composite.py b/coreax/solvers/composite.py index 9e7c07fb..90487e03 100644 --- a/coreax/solvers/composite.py +++ b/coreax/solvers/composite.py @@ -174,8 +174,9 @@ def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: coreset = eqx.tree_at( lambda x: x.pre_coreset_data, coreset_wrong_pre_coreset_data, dataset ) - if isinstance(coreset, Coresubset): - coreset = eqx.tree_at(lambda x: x.nodes.data, coreset, _indices) + if _indices is not None: + if isinstance(coreset, Coresubset): + coreset = eqx.tree_at(lambda x: x.nodes.data, coreset, _indices) return coreset, output_solver_state @@ -222,4 +223,11 @@ def _binary_tree(_input_data: Data) -> np.ndarray: return node_indices.reshape(n_leaves, -1).astype(np.int32) indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset) - return dataset[indices], indices + return dataset[indices], jnp.arange(len(dataset))[indices] + + +# If you are dividing a dataset of size 20 to three partitions of size 8, +# then your indices will exceed the len(dataset) - 1, +# at the moment the last entry in dataset is chosen for every index +# that is out of bounds, it might be worth considering +# if choosing a datapoint randomly makes more sense From 1bd8501c1dcc32171b9ce835d582f4854373f6bb Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Fri, 27 Sep 2024 16:41:50 +0100 Subject: [PATCH 07/17] Removed the line plt.show() in examples/pounce.py --- examples/pounce.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/pounce.py b/examples/pounce.py index 3ad944f7..0cc9195e 100644 --- a/examples/pounce.py +++ b/examples/pounce.py @@ -175,7 +175,6 @@ def main( plt.xlabel("Frame") plt.ylabel("Chosen") plt.tight_layout() - plt.show() if out_path is not None: plt.savefig(out_path / "pounce_frames.png") plt.close() From 082ada7bf150563a1757f71442a3aa2c428f5e83 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:12:19 +0100 Subject: [PATCH 08/17] Made requested changes in the PR #790 (fixed type hints, comments etc.) --- benchmark/mnist_benchmark.py | 385 +++++++++++++++++++++++++++++++++++ coreax/solvers/composite.py | 43 ++-- 2 files changed, 402 insertions(+), 26 deletions(-) create mode 100644 benchmark/mnist_benchmark.py diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py new file mode 100644 index 00000000..d6e892e7 --- /dev/null +++ b/benchmark/mnist_benchmark.py @@ -0,0 +1,385 @@ +import time +from typing import Any, Dict, Tuple + +import numpy as np +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import Dataset, DataLoader + +import jax +import jax.numpy as jnp +import optax +from flax import linen as nn +from flax.training import train_state + +from coreax import Data + +from coreax.kernels import ( + SquaredExponentialKernel, + PCIMQKernel, + SteinKernel, + median_heuristic +) + +from coreax.solvers import ( + KernelHerding, + RandomSample, + RPCholesky, + SteinThinning, + GreedyKernelPoints, + MapReduce, + PaddingInvariantSolver +) + +from coreax.metrics import ( + MMD, + KSD +) + +from coreax.score_matching import KernelDensityMatching +from matplotlib import pyplot as plt + + +key = jax.random.PRNGKey(0) # for reproducibility +n_components = 16 #for PCA + +transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))]) +train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) +test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) +# Convert PyTorch dataset to JAX arrays +def convert_to_jax_arrays(dataset: Dataset) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Convert a PyTorch dataset to JAX arrays. + + This function takes a PyTorch dataset, loads all the data at once using a DataLoader, + and converts it to JAX arrays. It's designed to work with datasets that can fit into memory. + + Args: + dataset (Dataset): A PyTorch dataset to be converted. + + Returns: + Tuple[jnp.ndarray, jnp.ndarray]: A tuple containing two JAX arrays: + - The first array contains the data. + - The second array contains the targets (labels). + """ + data_loader = DataLoader(dataset, batch_size=len(dataset)) + data, targets = next(iter(data_loader)) # Load all data at once + data_jax = jnp.array(data.numpy()) # Convert to NumPy first, then JAX array + targets_jax = jnp.array(targets.numpy()) + return data_jax, targets_jax + +# Usage remains the same +train_data_jax, train_targets_jax = convert_to_jax_arrays(train_dataset) +test_data_jax, test_targets_jax = convert_to_jax_arrays(test_dataset) + +def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: + """Compute cross-entropy loss.""" + return jnp.mean(optax.softmax_cross_entropy( + logits, + jax.nn.one_hot(labels, num_classes=10) + )) + + +def compute_metrics( + logits: jnp.ndarray, + labels: jnp.ndarray +) -> Dict[str, jnp.ndarray]: + """Compute loss and accuracy metrics.""" + loss = cross_entropy_loss(logits, labels) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + return {'loss': loss, 'accuracy': accuracy} + + +class MLP(nn.Module): + """Multi-layer perceptron with optional batch normalization and dropout.""" + + hidden_size: int + output_size: int = 10 + use_batchnorm: bool = True + dropout_rate: float = 0.5 + + @nn.compact + def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray: + """Forward pass of the MLP.""" + x = nn.Dense(self.hidden_size)(x) + if self.use_batchnorm: + x = nn.BatchNorm(use_running_average=not training)(x) + x = nn.relu(x) + x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) + x = nn.Dense(self.output_size)(x) + return x + + +class TrainState(train_state.TrainState): + """Custom train state with batch statistics and dropout RNG.""" + + batch_stats: Any = None + dropout_rng: jnp.ndarray = None + + +def create_train_state( + rng: jnp.ndarray, + model: nn.Module, + learning_rate: float, + weight_decay: float +) -> TrainState: + """Create and initialize the train state.""" + dropout_rng, params_rng = jax.random.split(rng) + params = model.init( + {'params': params_rng, 'dropout': dropout_rng}, + jnp.ones([1, 784]), + training=False + ) + tx = optax.adamw(learning_rate, weight_decay=weight_decay) + return TrainState.create( + apply_fn=model.apply, + params=params['params'], + tx=tx, + batch_stats=params['batch_stats'], + dropout_rng=dropout_rng + ) + + +@jax.jit +def train_step( + state: TrainState, + batch_data: jnp.ndarray, + batch_labels: jnp.ndarray +) -> Tuple[TrainState, Dict[str, jnp.ndarray]]: + """Perform a single training step.""" + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def loss_fn(params): + variables = {'params': params, 'batch_stats': state.batch_stats} + logits, new_model_state = state.apply_fn( + variables, + batch_data, + training=True, + mutable=['batch_stats'], + rngs={'dropout': dropout_rng} + ) + loss = cross_entropy_loss(logits, batch_labels) + return loss, (logits, new_model_state) + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, (logits, new_model_state)), grads = grad_fn(state.params) + state = state.apply_gradients( + grads=grads, + batch_stats=new_model_state['batch_stats'], + dropout_rng=new_dropout_rng + ) + metrics = compute_metrics(logits, batch_labels) + return state, metrics + + +@jax.jit +def eval_step( + state: TrainState, + batch_data: jnp.ndarray, + batch_labels: jnp.ndarray +) -> Dict[str, jnp.ndarray]: + """Perform a single evaluation step.""" + variables = {'params': state.params, 'batch_stats': state.batch_stats} + logits = state.apply_fn( + variables, + batch_data, + training=False, + rngs={'dropout': state.dropout_rng} + ) + return compute_metrics(logits, batch_labels) + + +def train_and_evaluate( + train_data_jax: jnp.ndarray, + train_labels_jax: jnp.ndarray, + test_data_jax: jnp.ndarray, + test_labels_jax: jnp.ndarray, + model: nn.Module, + rng: jnp.ndarray, + epochs: int = 10, + batch_size: int = 64, + learning_rate: float = 1e-3, + weight_decay: float = 1e-5, + patience: int = 5, + min_delta: float = 0.001 +) -> Dict[str, float]: + """Train and evaluate the model with early stopping.""" + state = create_train_state(rng, model, learning_rate, weight_decay) + num_train_batches = train_data_jax.shape[0] // batch_size + num_test_batches = test_data_jax.shape[0] // batch_size + + best_accuracy = 0 + patience_counter = 0 + best_state = None + + for epoch in range(epochs): + # Shuffle data + rng, input_rng = jax.random.split(rng) + perm = jax.random.permutation(input_rng, train_data_jax.shape[0]) + train_data_shuffled = train_data_jax[perm] + train_labels_shuffled = train_labels_jax[perm] + + # Training loop + for batch_idx in range(num_train_batches): + start_idx = batch_idx * batch_size + end_idx = (batch_idx + 1) * batch_size + batch_data = train_data_shuffled[start_idx:end_idx] + batch_labels = train_labels_shuffled[start_idx:end_idx] + state, metrics = train_step(state, batch_data, batch_labels) + + if epoch % 8 == 0: + print( + f"Epoch {epoch}, " + f"Loss: {metrics['loss']:.4f}, " + f"Accuracy: {metrics['accuracy']:.4f}" + ) + + # Evaluation loop + total_metrics = {'loss': 0.0, 'accuracy': 0.0} + for batch_idx in range(num_test_batches): + start_idx = batch_idx * batch_size + end_idx = (batch_idx + 1) * batch_size + batch_data = test_data_jax[start_idx:end_idx] + batch_labels = test_labels_jax[start_idx:end_idx] + metrics = eval_step(state, batch_data, batch_labels) + total_metrics['loss'] += metrics['loss'] + total_metrics['accuracy'] += metrics['accuracy'] + + avg_loss = total_metrics['loss'] / num_test_batches + avg_accuracy = total_metrics['accuracy'] / num_test_batches + if epoch % 8 == 0: + print( + f"Epoch {epoch}, " + f"Test Loss: {avg_loss:.4f}, " + f"Test Accuracy: {avg_accuracy:.4f}" + ) + + # Early stopping logic + if avg_accuracy > best_accuracy + min_delta: + best_accuracy = avg_accuracy + patience_counter = 0 + best_state = state + else: + patience_counter += 1 + + if patience_counter >= patience: + print(f"Early stopping triggered at epoch {epoch}") + break + + # If training completed without early stopping, use the final state + if best_state is None: + best_state = state + + # Final evaluation using the best state + total_metrics = {'loss': 0.0, 'accuracy': 0.0} + for batch_idx in range(num_test_batches): + start_idx = batch_idx * batch_size + end_idx = (batch_idx + 1) * batch_size + batch_data = test_data_jax[start_idx:end_idx] + batch_labels = test_labels_jax[start_idx:end_idx] + metrics = eval_step(best_state, batch_data, batch_labels) + total_metrics['loss'] += metrics['loss'] + total_metrics['accuracy'] += metrics['accuracy'] + + final_avg_loss = total_metrics['loss'] / num_test_batches + final_avg_accuracy = total_metrics['accuracy'] / num_test_batches + print( + f"Final Test Loss: {final_avg_loss:.4f}, " + f"Final Test Accuracy: {final_avg_accuracy:.4f}" + ) + + return { + 'final_test_loss': final_avg_loss, + 'final_test_accuracy': final_avg_accuracy + } + + +def pca(X, n_components): + # Center the data + X_centered = X - jnp.mean(X, axis=0) + + # Compute the covariance matrix + cov_matrix = jnp.cov(X_centered.T) + + # Compute eigenvalues and eigenvectors + eigenvalues, eigenvectors = jnp.linalg.eigh(cov_matrix) + + # Sort eigenvectors by descending eigenvalues + idx = jnp.argsort(eigenvalues)[::-1] + eigenvectors = eigenvectors[:, idx] + + # Select top n_components eigenvectors + components = eigenvectors[:, :n_components] + + # Project the data onto the new subspace + X_pca = jnp.dot(X_centered, components) + + return X_pca + +# Perform PCA +train_data_pca = pca(train_data_jax, n_components) +dataset = Data(train_data_pca) + +results = [] + +#Set up different solvers +# Set up kernel using median heuristic +num_samples_length_scale = min(300, 1000) +random_seed = 45 +generator = np.random.default_rng(random_seed) +idx = generator.choice(300, num_samples_length_scale, replace=False) +length_scale = median_heuristic(train_data_pca[idx]) +kernel = SquaredExponentialKernel(length_scale=length_scale) + +# Generate small dataset for ScoreMatching for SteinKernel +indices = jax.random.choice(key, train_data_pca.shape[0], shape=(1000,), replace=False) +small_dataset = train_data_pca[indices] + +def _get_herding_solver(coreset_size): + herding_solver = KernelHerding(coreset_size, kernel, block_size=64) + + return 'KernelHerding', MapReduce(herding_solver, leaf_size= 2 * coreset_size) + +def _get_stein_solver(coreset_size): + score_function = KernelDensityMatching(length_scale=length_scale).match(small_dataset) + stein_kernel = SteinKernel(kernel, score_function) + stein_solver = SteinThinning(coreset_size=coreset_size, kernel=stein_kernel, block_size=64) + return 'SteinThinning', MapReduce(stein_solver, leaf_size= 2 * coreset_size) + +def _get_random_solver(coreset_size): + random_solver = RandomSample(coreset_size, key) + return "RandomSample", random_solver + +def _get_rp_solver(coreset_size): + rp_solver = RPCholesky(coreset_size=coreset_size, kernel=kernel, random_key=key) + return "RPCholesky", rp_solver + +getters = [_get_stein_solver, _get_rp_solver, _get_random_solver, _get_herding_solver] +for getter in getters: + for size in [25, 26]: + name, solver = getter(size) + subset, _ = solver.reduce(Data(small_dataset)) + print(name, subset) + + indices = subset.nodes.data + + data = train_data_jax[indices] + targets = train_targets_jax[indices] + + if size <= 100: + batch_size = 8 + else: + batch_size = 64 + + model = MLP(hidden_size=64) + + result = train_and_evaluate(data, targets, test_data_jax, test_targets_jax, model, key, + epochs=100, batch_size=batch_size) + print(result) + results.append((name, size, result['final_test_accuracy'])) + +print('results', results) + + + diff --git a/coreax/solvers/composite.py b/coreax/solvers/composite.py index 90487e03..b77a3089 100644 --- a/coreax/solvers/composite.py +++ b/coreax/solvers/composite.py @@ -16,6 +16,7 @@ import math import warnings +from types import NoneType from typing import Generic, Optional, TypeVar, Union import equinox as eqx @@ -23,18 +24,20 @@ import jax.numpy as jnp import jax.tree_util as jtu import numpy as np +from jax import Array from sklearn.neighbors import BallTree, KDTree from typing_extensions import TypeAlias, override from coreax.coreset import Coreset, Coresubset from coreax.data import Data from coreax.solvers.base import ExplicitSizeSolver, PaddingInvariantSolver, Solver -from coreax.util import tree_zero_pad_leading_axis +from coreax.util import ArrayLike, tree_zero_pad_leading_axis BinaryTree: TypeAlias = Union[KDTree, BallTree] _Data = TypeVar("_Data", bound=Data) _Coreset = TypeVar("_Coreset", Coreset, Coresubset) _State = TypeVar("_State") +_Indices = TypeVar("_Indices", ArrayLike, NoneType) class CompositeSolver( @@ -126,39 +129,37 @@ def reduce( del solver_state def _reduce_coreset( - data: _Data, _indices=None - ) -> (tuple)[_Coreset, _State, _Data]: + data: _Data, _indices: Optional[_Indices] = None + ) -> tuple[_Coreset, _State, _Indices]: if len(data) <= self.leaf_size: coreset, state = self.base_solver.reduce(data) if _indices is not None: _indices = _indices[coreset.nodes.data] return coreset, state, _indices - def wrapper(row: _Data) -> tuple[_Data, _Data]: + def wrapper(partition: _Data) -> tuple[_Data, Array]: """ - Apply the reduce method of the base solver on a row. + Apply the `reduce` method of the base solver on a partition. - It is a wrapper to process a single partition (row) - of the result of _jit_tree that works with the vmap + This is a wrapper for `reduce()` for processing a single partition. + The data is partitioned with `_jit_tree()`. + The reduction is performed on each partition via `v`map()`. """ - x, _ = self.base_solver.reduce(row) + x, _ = self.base_solver.reduce(partition) return x.coreset, x.nodes.data def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: return a[b] - # First partition the data partitioned_dataset, partitioned_indices = _jit_tree( data, self.leaf_size, self.tree_type ) - # Then apply base solver to each partition and - # keep track of indices with respect to partitions + # Reduce each partition and get indices from each coreset_ensemble, ensemble_indices = jax.vmap(wrapper)(partitioned_dataset) - # Calculate the indices with respect to the data (_reduce_coreset) + # Calculate the indices with respect to the original data concatenated_indices = jax.vmap(get_indices)( partitioned_indices, ensemble_indices ) - # flatten the indices concatenated_indices = jnp.ravel(concatenated_indices) _coreset = jtu.tree_map(jnp.concatenate, coreset_ensemble) @@ -168,12 +169,9 @@ def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: final_indices = concatenated_indices return _reduce_coreset(_coreset, final_indices) - (coreset_wrong_pre_coreset_data, output_solver_state, _indices) = ( - _reduce_coreset(dataset) - ) - coreset = eqx.tree_at( - lambda x: x.pre_coreset_data, coreset_wrong_pre_coreset_data, dataset - ) + (coreset, output_solver_state, _indices) = _reduce_coreset(dataset) + # Replace the pre-coreset data by the original dataset + coreset = eqx.tree_at(lambda x: x.pre_coreset_data, coreset, dataset) if _indices is not None: if isinstance(coreset, Coresubset): coreset = eqx.tree_at(lambda x: x.nodes.data, coreset, _indices) @@ -224,10 +222,3 @@ def _binary_tree(_input_data: Data) -> np.ndarray: indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset) return dataset[indices], jnp.arange(len(dataset))[indices] - - -# If you are dividing a dataset of size 20 to three partitions of size 8, -# then your indices will exceed the len(dataset) - 1, -# at the moment the last entry in dataset is chosen for every index -# that is out of bounds, it might be worth considering -# if choosing a datapoint randomly makes more sense From e5f5122d92483e96cc343b56b94556c880a2bbb7 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:36:06 +0100 Subject: [PATCH 09/17] Removed changed NoneType to None (NoneType is not compatible with python 3.9) --- coreax/solvers/composite.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/coreax/solvers/composite.py b/coreax/solvers/composite.py index b77a3089..a51b1d69 100644 --- a/coreax/solvers/composite.py +++ b/coreax/solvers/composite.py @@ -16,7 +16,6 @@ import math import warnings -from types import NoneType from typing import Generic, Optional, TypeVar, Union import equinox as eqx @@ -37,7 +36,7 @@ _Data = TypeVar("_Data", bound=Data) _Coreset = TypeVar("_Coreset", Coreset, Coresubset) _State = TypeVar("_State") -_Indices = TypeVar("_Indices", ArrayLike, NoneType) +_Indices = TypeVar("_Indices", ArrayLike, None) class CompositeSolver( From da827a7ad0b2950d4629b36f662dcfe5fef92165 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:45:43 +0100 Subject: [PATCH 10/17] removed a folder that wasn't supposed to be added --- benchmark/mnist_benchmark.py | 385 ----------------------------------- 1 file changed, 385 deletions(-) delete mode 100644 benchmark/mnist_benchmark.py diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py deleted file mode 100644 index d6e892e7..00000000 --- a/benchmark/mnist_benchmark.py +++ /dev/null @@ -1,385 +0,0 @@ -import time -from typing import Any, Dict, Tuple - -import numpy as np -import torch -import torchvision -import torchvision.transforms as transforms -from torch.utils.data import Dataset, DataLoader - -import jax -import jax.numpy as jnp -import optax -from flax import linen as nn -from flax.training import train_state - -from coreax import Data - -from coreax.kernels import ( - SquaredExponentialKernel, - PCIMQKernel, - SteinKernel, - median_heuristic -) - -from coreax.solvers import ( - KernelHerding, - RandomSample, - RPCholesky, - SteinThinning, - GreedyKernelPoints, - MapReduce, - PaddingInvariantSolver -) - -from coreax.metrics import ( - MMD, - KSD -) - -from coreax.score_matching import KernelDensityMatching -from matplotlib import pyplot as plt - - -key = jax.random.PRNGKey(0) # for reproducibility -n_components = 16 #for PCA - -transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))]) -train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) -test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) -# Convert PyTorch dataset to JAX arrays -def convert_to_jax_arrays(dataset: Dataset) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Convert a PyTorch dataset to JAX arrays. - - This function takes a PyTorch dataset, loads all the data at once using a DataLoader, - and converts it to JAX arrays. It's designed to work with datasets that can fit into memory. - - Args: - dataset (Dataset): A PyTorch dataset to be converted. - - Returns: - Tuple[jnp.ndarray, jnp.ndarray]: A tuple containing two JAX arrays: - - The first array contains the data. - - The second array contains the targets (labels). - """ - data_loader = DataLoader(dataset, batch_size=len(dataset)) - data, targets = next(iter(data_loader)) # Load all data at once - data_jax = jnp.array(data.numpy()) # Convert to NumPy first, then JAX array - targets_jax = jnp.array(targets.numpy()) - return data_jax, targets_jax - -# Usage remains the same -train_data_jax, train_targets_jax = convert_to_jax_arrays(train_dataset) -test_data_jax, test_targets_jax = convert_to_jax_arrays(test_dataset) - -def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: - """Compute cross-entropy loss.""" - return jnp.mean(optax.softmax_cross_entropy( - logits, - jax.nn.one_hot(labels, num_classes=10) - )) - - -def compute_metrics( - logits: jnp.ndarray, - labels: jnp.ndarray -) -> Dict[str, jnp.ndarray]: - """Compute loss and accuracy metrics.""" - loss = cross_entropy_loss(logits, labels) - accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) - return {'loss': loss, 'accuracy': accuracy} - - -class MLP(nn.Module): - """Multi-layer perceptron with optional batch normalization and dropout.""" - - hidden_size: int - output_size: int = 10 - use_batchnorm: bool = True - dropout_rate: float = 0.5 - - @nn.compact - def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray: - """Forward pass of the MLP.""" - x = nn.Dense(self.hidden_size)(x) - if self.use_batchnorm: - x = nn.BatchNorm(use_running_average=not training)(x) - x = nn.relu(x) - x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) - x = nn.Dense(self.output_size)(x) - return x - - -class TrainState(train_state.TrainState): - """Custom train state with batch statistics and dropout RNG.""" - - batch_stats: Any = None - dropout_rng: jnp.ndarray = None - - -def create_train_state( - rng: jnp.ndarray, - model: nn.Module, - learning_rate: float, - weight_decay: float -) -> TrainState: - """Create and initialize the train state.""" - dropout_rng, params_rng = jax.random.split(rng) - params = model.init( - {'params': params_rng, 'dropout': dropout_rng}, - jnp.ones([1, 784]), - training=False - ) - tx = optax.adamw(learning_rate, weight_decay=weight_decay) - return TrainState.create( - apply_fn=model.apply, - params=params['params'], - tx=tx, - batch_stats=params['batch_stats'], - dropout_rng=dropout_rng - ) - - -@jax.jit -def train_step( - state: TrainState, - batch_data: jnp.ndarray, - batch_labels: jnp.ndarray -) -> Tuple[TrainState, Dict[str, jnp.ndarray]]: - """Perform a single training step.""" - dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) - - def loss_fn(params): - variables = {'params': params, 'batch_stats': state.batch_stats} - logits, new_model_state = state.apply_fn( - variables, - batch_data, - training=True, - mutable=['batch_stats'], - rngs={'dropout': dropout_rng} - ) - loss = cross_entropy_loss(logits, batch_labels) - return loss, (logits, new_model_state) - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (loss, (logits, new_model_state)), grads = grad_fn(state.params) - state = state.apply_gradients( - grads=grads, - batch_stats=new_model_state['batch_stats'], - dropout_rng=new_dropout_rng - ) - metrics = compute_metrics(logits, batch_labels) - return state, metrics - - -@jax.jit -def eval_step( - state: TrainState, - batch_data: jnp.ndarray, - batch_labels: jnp.ndarray -) -> Dict[str, jnp.ndarray]: - """Perform a single evaluation step.""" - variables = {'params': state.params, 'batch_stats': state.batch_stats} - logits = state.apply_fn( - variables, - batch_data, - training=False, - rngs={'dropout': state.dropout_rng} - ) - return compute_metrics(logits, batch_labels) - - -def train_and_evaluate( - train_data_jax: jnp.ndarray, - train_labels_jax: jnp.ndarray, - test_data_jax: jnp.ndarray, - test_labels_jax: jnp.ndarray, - model: nn.Module, - rng: jnp.ndarray, - epochs: int = 10, - batch_size: int = 64, - learning_rate: float = 1e-3, - weight_decay: float = 1e-5, - patience: int = 5, - min_delta: float = 0.001 -) -> Dict[str, float]: - """Train and evaluate the model with early stopping.""" - state = create_train_state(rng, model, learning_rate, weight_decay) - num_train_batches = train_data_jax.shape[0] // batch_size - num_test_batches = test_data_jax.shape[0] // batch_size - - best_accuracy = 0 - patience_counter = 0 - best_state = None - - for epoch in range(epochs): - # Shuffle data - rng, input_rng = jax.random.split(rng) - perm = jax.random.permutation(input_rng, train_data_jax.shape[0]) - train_data_shuffled = train_data_jax[perm] - train_labels_shuffled = train_labels_jax[perm] - - # Training loop - for batch_idx in range(num_train_batches): - start_idx = batch_idx * batch_size - end_idx = (batch_idx + 1) * batch_size - batch_data = train_data_shuffled[start_idx:end_idx] - batch_labels = train_labels_shuffled[start_idx:end_idx] - state, metrics = train_step(state, batch_data, batch_labels) - - if epoch % 8 == 0: - print( - f"Epoch {epoch}, " - f"Loss: {metrics['loss']:.4f}, " - f"Accuracy: {metrics['accuracy']:.4f}" - ) - - # Evaluation loop - total_metrics = {'loss': 0.0, 'accuracy': 0.0} - for batch_idx in range(num_test_batches): - start_idx = batch_idx * batch_size - end_idx = (batch_idx + 1) * batch_size - batch_data = test_data_jax[start_idx:end_idx] - batch_labels = test_labels_jax[start_idx:end_idx] - metrics = eval_step(state, batch_data, batch_labels) - total_metrics['loss'] += metrics['loss'] - total_metrics['accuracy'] += metrics['accuracy'] - - avg_loss = total_metrics['loss'] / num_test_batches - avg_accuracy = total_metrics['accuracy'] / num_test_batches - if epoch % 8 == 0: - print( - f"Epoch {epoch}, " - f"Test Loss: {avg_loss:.4f}, " - f"Test Accuracy: {avg_accuracy:.4f}" - ) - - # Early stopping logic - if avg_accuracy > best_accuracy + min_delta: - best_accuracy = avg_accuracy - patience_counter = 0 - best_state = state - else: - patience_counter += 1 - - if patience_counter >= patience: - print(f"Early stopping triggered at epoch {epoch}") - break - - # If training completed without early stopping, use the final state - if best_state is None: - best_state = state - - # Final evaluation using the best state - total_metrics = {'loss': 0.0, 'accuracy': 0.0} - for batch_idx in range(num_test_batches): - start_idx = batch_idx * batch_size - end_idx = (batch_idx + 1) * batch_size - batch_data = test_data_jax[start_idx:end_idx] - batch_labels = test_labels_jax[start_idx:end_idx] - metrics = eval_step(best_state, batch_data, batch_labels) - total_metrics['loss'] += metrics['loss'] - total_metrics['accuracy'] += metrics['accuracy'] - - final_avg_loss = total_metrics['loss'] / num_test_batches - final_avg_accuracy = total_metrics['accuracy'] / num_test_batches - print( - f"Final Test Loss: {final_avg_loss:.4f}, " - f"Final Test Accuracy: {final_avg_accuracy:.4f}" - ) - - return { - 'final_test_loss': final_avg_loss, - 'final_test_accuracy': final_avg_accuracy - } - - -def pca(X, n_components): - # Center the data - X_centered = X - jnp.mean(X, axis=0) - - # Compute the covariance matrix - cov_matrix = jnp.cov(X_centered.T) - - # Compute eigenvalues and eigenvectors - eigenvalues, eigenvectors = jnp.linalg.eigh(cov_matrix) - - # Sort eigenvectors by descending eigenvalues - idx = jnp.argsort(eigenvalues)[::-1] - eigenvectors = eigenvectors[:, idx] - - # Select top n_components eigenvectors - components = eigenvectors[:, :n_components] - - # Project the data onto the new subspace - X_pca = jnp.dot(X_centered, components) - - return X_pca - -# Perform PCA -train_data_pca = pca(train_data_jax, n_components) -dataset = Data(train_data_pca) - -results = [] - -#Set up different solvers -# Set up kernel using median heuristic -num_samples_length_scale = min(300, 1000) -random_seed = 45 -generator = np.random.default_rng(random_seed) -idx = generator.choice(300, num_samples_length_scale, replace=False) -length_scale = median_heuristic(train_data_pca[idx]) -kernel = SquaredExponentialKernel(length_scale=length_scale) - -# Generate small dataset for ScoreMatching for SteinKernel -indices = jax.random.choice(key, train_data_pca.shape[0], shape=(1000,), replace=False) -small_dataset = train_data_pca[indices] - -def _get_herding_solver(coreset_size): - herding_solver = KernelHerding(coreset_size, kernel, block_size=64) - - return 'KernelHerding', MapReduce(herding_solver, leaf_size= 2 * coreset_size) - -def _get_stein_solver(coreset_size): - score_function = KernelDensityMatching(length_scale=length_scale).match(small_dataset) - stein_kernel = SteinKernel(kernel, score_function) - stein_solver = SteinThinning(coreset_size=coreset_size, kernel=stein_kernel, block_size=64) - return 'SteinThinning', MapReduce(stein_solver, leaf_size= 2 * coreset_size) - -def _get_random_solver(coreset_size): - random_solver = RandomSample(coreset_size, key) - return "RandomSample", random_solver - -def _get_rp_solver(coreset_size): - rp_solver = RPCholesky(coreset_size=coreset_size, kernel=kernel, random_key=key) - return "RPCholesky", rp_solver - -getters = [_get_stein_solver, _get_rp_solver, _get_random_solver, _get_herding_solver] -for getter in getters: - for size in [25, 26]: - name, solver = getter(size) - subset, _ = solver.reduce(Data(small_dataset)) - print(name, subset) - - indices = subset.nodes.data - - data = train_data_jax[indices] - targets = train_targets_jax[indices] - - if size <= 100: - batch_size = 8 - else: - batch_size = 64 - - model = MLP(hidden_size=64) - - result = train_and_evaluate(data, targets, test_data_jax, test_targets_jax, model, key, - epochs=100, batch_size=batch_size) - print(result) - results.append((name, size, result['final_test_accuracy'])) - -print('results', results) - - - From c187818fd534e945d0642152814594740a2932b8 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Fri, 4 Oct 2024 10:21:25 +0100 Subject: [PATCH 11/17] On composite.py changed return statement of _jit_tree from dataset[index] to padded_dataset[index] removed unnessary comments and changed variable names to be more descriptive --- coreax/solvers/composite.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/coreax/solvers/composite.py b/coreax/solvers/composite.py index a51b1d69..0924dc03 100644 --- a/coreax/solvers/composite.py +++ b/coreax/solvers/composite.py @@ -147,16 +147,13 @@ def wrapper(partition: _Data) -> tuple[_Data, Array]: x, _ = self.base_solver.reduce(partition) return x.coreset, x.nodes.data - def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - return a[b] - partitioned_dataset, partitioned_indices = _jit_tree( data, self.leaf_size, self.tree_type ) # Reduce each partition and get indices from each coreset_ensemble, ensemble_indices = jax.vmap(wrapper)(partitioned_dataset) # Calculate the indices with respect to the original data - concatenated_indices = jax.vmap(get_indices)( + concatenated_indices = jax.vmap(lambda x, index: x[index])( partitioned_indices, ensemble_indices ) concatenated_indices = jnp.ravel(concatenated_indices) @@ -169,7 +166,7 @@ def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: return _reduce_coreset(_coreset, final_indices) (coreset, output_solver_state, _indices) = _reduce_coreset(dataset) - # Replace the pre-coreset data by the original dataset + # Correct the pre-coreset data and the indices coreset = eqx.tree_at(lambda x: x.pre_coreset_data, coreset, dataset) if _indices is not None: if isinstance(coreset, Coresubset): @@ -177,7 +174,9 @@ def get_indices(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: return coreset, output_solver_state -def _jit_tree(dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]) -> _Data: +def _jit_tree( + dataset: _Data, leaf_size: int, tree_type: type[BinaryTree] +) -> tuple[_Data, _Indices]: """ Return JIT compatible BinaryTree partitioning of 'dataset'. @@ -220,4 +219,4 @@ def _binary_tree(_input_data: Data) -> np.ndarray: return node_indices.reshape(n_leaves, -1).astype(np.int32) indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset) - return dataset[indices], jnp.arange(len(dataset))[indices] + return padded_dataset[indices], jnp.arange(len(dataset))[indices] From a4f762cb892ad06bb86ca8d51985962f1d8bacc0 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Tue, 8 Oct 2024 12:33:01 +0100 Subject: [PATCH 12/17] Added MapReduce bugfix to `CHANGELOG.md` --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d03a183b..6622239b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,7 @@ disable tqdm progress bar terminal output. Defaults to disabled (`False`). ### Fixed - +- `MapReduce` in `coreax.solvers.composite.py` now keeps track of the indices. ### Changed From bbae67deac9ab4410725b6421459a7f298a52709 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Tue, 8 Oct 2024 12:34:38 +0100 Subject: [PATCH 13/17] Added double backticks on comments in `composite.py` when referring to ``vmap()`` --- coreax/solvers/composite.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/coreax/solvers/composite.py b/coreax/solvers/composite.py index 0924dc03..931e0ee0 100644 --- a/coreax/solvers/composite.py +++ b/coreax/solvers/composite.py @@ -142,7 +142,7 @@ def wrapper(partition: _Data) -> tuple[_Data, Array]: This is a wrapper for `reduce()` for processing a single partition. The data is partitioned with `_jit_tree()`. - The reduction is performed on each partition via `v`map()`. + The reduction is performed on each partition via ``vmap()``. """ x, _ = self.base_solver.reduce(partition) return x.coreset, x.nodes.data @@ -165,9 +165,9 @@ def wrapper(partition: _Data) -> tuple[_Data, Array]: final_indices = concatenated_indices return _reduce_coreset(_coreset, final_indices) - (coreset, output_solver_state, _indices) = _reduce_coreset(dataset) + (pre_coreset, output_solver_state, _indices) = _reduce_coreset(dataset) # Correct the pre-coreset data and the indices - coreset = eqx.tree_at(lambda x: x.pre_coreset_data, coreset, dataset) + coreset = eqx.tree_at(lambda x: x.pre_coreset_data, pre_coreset, dataset) if _indices is not None: if isinstance(coreset, Coresubset): coreset = eqx.tree_at(lambda x: x.nodes.data, coreset, _indices) From cf537eddf1d562a95ff67a2ea2216a6f84a13eb8 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Tue, 8 Oct 2024 12:35:50 +0100 Subject: [PATCH 14/17] Removed a redundant comment on a test on `TestMapReduce` class --- tests/unit/test_solvers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index b0ef1894..e3408c72 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -867,7 +867,6 @@ def test_map_reduce_diverse_selection(self): coreset, _ = solver.reduce(Data(dataset)) selected_indices = coreset.nodes.data - # Check if there are indices beyond the first few assert jnp.any( selected_indices >= coreset_size ), "MapReduce should select points beyond the first few" From c894e07fcf51f2f97aaca342ca0c9b001d27529b Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:27:29 +0100 Subject: [PATCH 15/17] Added analytic test for `MapReduce` --- tests/unit/test_solvers.py | 132 +++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index e3408c72..8a85e619 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -877,6 +877,138 @@ def test_map_reduce_diverse_selection(self): len(partitions_represented) > 1 ), "MapReduce should select points from multiple partitions" + def test_map_reduce_analytic(self): + r""" + Test ``MapReduce`` on an analytical example, enforcing a unique coreset. + + In this example, We start with the original dataset + :math:`[10, 20, 30, 210, 40, 60, 180, 90, 150, 70, 120, + 200, 50, 140, 80, 170, 100, 190, 110, 160, 130]`. + + Suppose we want a subset size of 3, and we want maximum leaf size of 6. + + We can see that we have a dataset of size 21. The partitioning scheme + only allows for :math:`n` partitions where :math:`n` is a power of 2. + Therefore, we can partition into: + + 1. 1 partition of size 21 + 2. 2 partitions of size :math:`\lceil 10.5 \rceil = 11` each (with one padded 0) + 3. 4 partitions of size :math:`\lceil 5.25 \rceil = 6` each (with 3 padded 0's) + 4. 8 partitions of size :math:`\lceil 2.625 \rceil = 3` each (with 3 padded 0's) + + Since we set the maximum leaf size :math:`m = 6`, we choose the largest + partition size that is less than or equal to 6. Thus, we have 4 partitions + each of size 6. + + We first pad the dataset by adding 3 zeros, then we arrange the + data in ascending order (Make sure to ask if this is another bug, + I don't know how exactly binary_tree manages to reorder my dataset every time) + This results in + the following 4 partitions (see how data is in ascending order): + + 1. :math:`[0, 0, 0, 10, 20, 30]` + 2. :math:`[40, 50, 60, 70, 80, 90]` + 3. :math:`[100, 110, 120, 130, 140, 150]` + 4. :math:`[160, 170, 180, 190, 200, 210]` + + Now we want to reduce each partition with our ``interleaved_base_solver`` + which is designed to choose first, last, second, second-last, third, + third-last elements etc. until the coreset of correct size is formed. + Hence, we obtain: + + 1. :math:`[0, 30, 0]` + 2. :math:`[40, 90, 50]` + 3. :math:`[100, 150, 110]` + 4. :math:`[160, 210, 170]` + + Concatenating we obtain + :math:`[0, 30, 0, 40, 90, 50, 100, 150, 110, 160, 210, 170]`. + Now we repeat the same process, we check how many partitions + (has to be power of 2) we want to divide this new data of size 12 into, + our new options for partitioning are: + + 1. 1 partition of size 12 + 2. 2 partitions of size 6 + 3. 4 partitions of size 3 + 4. 8 partitions of size 1.5 (rounded up to 2) + + Given our maximum leaf size :math:`m = 6`, we choose the largest partition size + that is less than or equal to 6. Therefore, we select 2 partitions of size 6. + This time no padding is necessary. The two partitions resulting from this step + are (note that it is again in ascending order): + + 1. :math:`[0, 0, 30, 40, 50, 90]` + 2. :math:`[100, 110, 150, 160, 170, 210]` + + Applying our ``interleaved_base_solver`` with `coreset_size` 3 on + each partition, we obtain: + + 1. :math:`[0, 90, 0]` + 2. :math:`[100, 210, 110]` + + Now, we concatenate the two subsets and repeat the process to + obtain only one partition: + + 1. Concatenated subset: :math:`[0, 90, 0, 100, 210, 110]` + + Note that the size of the dataset is 6 so no partitioning is necessary! + + Applying the ``interleaved_base_solver`` one last time we obtain + the final coreset of: + :math:`[0, 110, 90]`. This is what we will test in this test + """ + interleaved_base_solver = MagicMock(_ExplicitPaddingInvariantSolver) + interleaved_base_solver.coreset_size = 3 + + def interleaved_mock_reduce( + dataset: Data, solver_state: None = None + ) -> tuple[Coreset[Data], None]: + half_size = interleaved_base_solver.coreset_size // 2 + indices = jnp.arange(interleaved_base_solver.coreset_size) + forward_indices = indices[:half_size] + backward_indices = -(indices[:half_size] + 1) + interleaved_indices = jnp.stack( + [forward_indices, backward_indices], axis=1 + ).ravel() + + if interleaved_base_solver.coreset_size % 2 != 0: + interleaved_indices = jnp.append(interleaved_indices, half_size) + return Coreset(dataset[interleaved_indices], dataset), solver_state + + interleaved_base_solver.reduce = interleaved_mock_reduce + + original_data = Data( + [ + 10, + 20, + 30, + 210, + 40, + 60, + 180, + 90, + 150, + 70, + 120, + 200, + 50, + 140, + 80, + 170, + 100, + 190, + 110, + 160, + 130, + ] + ) + expected_coreset_data = Data([0, 110, 90]) + + coreset, _ = MapReduce(base_solver=interleaved_base_solver, leaf_size=6).reduce( + original_data + ) + assert eqx.tree_equal(coreset.coreset.data == expected_coreset_data.data) + class TestCaratheodoryRecombination(RecombinationSolverTest): """Tests for :class:`coreax.solvers.recombination.CaratheodoryRecombination`.""" From a5fb1bed11036bf051c6541560d1af652709fcab Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Mon, 14 Oct 2024 11:47:17 +0100 Subject: [PATCH 16/17] docs: make suggested changes in the docstring Refs: #799 --- tests/unit/test_solvers.py | 73 +++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index 8a85e619..86453c19 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -881,7 +881,7 @@ def test_map_reduce_analytic(self): r""" Test ``MapReduce`` on an analytical example, enforcing a unique coreset. - In this example, We start with the original dataset + In this example, we start with the original dataset :math:`[10, 20, 30, 210, 40, 60, 180, 90, 150, 70, 120, 200, 50, 140, 80, 170, 100, 190, 110, 160, 130]`. @@ -900,11 +900,8 @@ def test_map_reduce_analytic(self): partition size that is less than or equal to 6. Thus, we have 4 partitions each of size 6. - We first pad the dataset by adding 3 zeros, then we arrange the - data in ascending order (Make sure to ask if this is another bug, - I don't know how exactly binary_tree manages to reorder my dataset every time) - This results in - the following 4 partitions (see how data is in ascending order): + This results in the following 4 partitions (see how + data is in ascending order): 1. :math:`[0, 0, 0, 10, 20, 30]` 2. :math:`[40, 50, 60, 70, 80, 90]` @@ -923,9 +920,10 @@ def test_map_reduce_analytic(self): Concatenating we obtain :math:`[0, 30, 0, 40, 90, 50, 100, 150, 110, 160, 210, 170]`. - Now we repeat the same process, we check how many partitions - (has to be power of 2) we want to divide this new data of size 12 into, - our new options for partitioning are: + We repeat the process, checking how many partitions we want + to divide this intermediate dataset (of size 12) into. + Recall, this number of partitions must be a power of 2. + Our options are:: 1. 1 partition of size 12 2. 2 partitions of size 6 @@ -951,11 +949,12 @@ def test_map_reduce_analytic(self): 1. Concatenated subset: :math:`[0, 90, 0, 100, 210, 110]` - Note that the size of the dataset is 6 so no partitioning is necessary! + Note that the size of the dataset is 6, + therefore, no more partitioning is necessary. Applying the ``interleaved_base_solver`` one last time we obtain - the final coreset of: - :math:`[0, 110, 90]`. This is what we will test in this test + the final coreset: + :math:`[0, 110, 90]`. """ interleaved_base_solver = MagicMock(_ExplicitPaddingInvariantSolver) interleaved_base_solver.coreset_size = 3 @@ -978,31 +977,33 @@ def interleaved_mock_reduce( interleaved_base_solver.reduce = interleaved_mock_reduce original_data = Data( - [ - 10, - 20, - 30, - 210, - 40, - 60, - 180, - 90, - 150, - 70, - 120, - 200, - 50, - 140, - 80, - 170, - 100, - 190, - 110, - 160, - 130, - ] + jnp.array( + [ + 10, + 20, + 30, + 210, + 40, + 60, + 180, + 90, + 150, + 70, + 120, + 200, + 50, + 140, + 80, + 170, + 100, + 190, + 110, + 160, + 130, + ] + ) ) - expected_coreset_data = Data([0, 110, 90]) + expected_coreset_data = Data(jnp.array([0, 110, 90])) coreset, _ = MapReduce(base_solver=interleaved_base_solver, leaf_size=6).reduce( original_data From b48bda45427acac695a96da5fd775dd2cd9c8476 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 17 Oct 2024 11:05:13 +0100 Subject: [PATCH 17/17] docs: make suggested changes in the docstring Refs: #799 --- tests/unit/test_solvers.py | 43 ++++++++------------------------------ 1 file changed, 9 insertions(+), 34 deletions(-) diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index 86453c19..78720280 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -920,10 +920,9 @@ def test_map_reduce_analytic(self): Concatenating we obtain :math:`[0, 30, 0, 40, 90, 50, 100, 150, 110, 160, 210, 170]`. - We repeat the process, checking how many partitions we want - to divide this intermediate dataset (of size 12) into. - Recall, this number of partitions must be a power of 2. - Our options are:: + We repeat the process, checking how many partitions we want to divide this + intermediate dataset (of size 12) into. Recall, this number of partitions must + be a power of 2. Our options are: 1. 1 partition of size 12 2. 2 partitions of size 6 @@ -952,9 +951,8 @@ def test_map_reduce_analytic(self): Note that the size of the dataset is 6, therefore, no more partitioning is necessary. - Applying the ``interleaved_base_solver`` one last time we obtain - the final coreset: - :math:`[0, 110, 90]`. + Applying ``interleaved_base_solver`` one last time we obtain the final coreset: + :math:`[0, 110, 90]`. """ interleaved_base_solver = MagicMock(_ExplicitPaddingInvariantSolver) interleaved_base_solver.coreset_size = 3 @@ -976,33 +974,10 @@ def interleaved_mock_reduce( interleaved_base_solver.reduce = interleaved_mock_reduce - original_data = Data( - jnp.array( - [ - 10, - 20, - 30, - 210, - 40, - 60, - 180, - 90, - 150, - 70, - 120, - 200, - 50, - 140, - 80, - 170, - 100, - 190, - 110, - 160, - 130, - ] - ) - ) + original_data = Data(jnp.array([ + 10, 20, 30, 210, 40, 60, 180, 90, 150, 70, 120, + 200, 50, 140, 80, 170, 100, 190, 110, 160, 130 + ])) expected_coreset_data = Data(jnp.array([0, 110, 90])) coreset, _ = MapReduce(base_solver=interleaved_base_solver, leaf_size=6).reduce(