Skip to content

Commit

Permalink
:docs: fixed docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
qh681248 committed Oct 24, 2024
1 parent 1df2237 commit 09b6077
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 27 deletions.
27 changes: 17 additions & 10 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
The benchmarking process follows these steps:
1. Generate a synthetic dataset of 1000 two-dimensional points using
`sklearn.datasets.make_blobs`.
:func:`sklearn.datasets.make_blobs`.
2. Generate coresets of varying sizes: 10, 50, 100, and 200 points using different
coreset algorithms.
3. Compute two metrics to evaluate the coresets' quality:
Expand All @@ -30,7 +30,7 @@

import json
import time
from typing import Any, Tuple
from typing import Any

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -71,7 +71,13 @@ def setup_kernel(x: np.ndarray) -> SquaredExponentialKernel:
def setup_stein_kernel(
sq_exp_kernel: SquaredExponentialKernel, dataset: Data
) -> SteinKernel:
"""Set up Stein Kernel."""
"""
Set up a Stein Kernel for Stein Thinning.
:param sq_exp_kernel: A SquaredExponential base kernel for the Stein Kernel.
:param dataset: Dataset for score matching.
:return: A SteinKernel object.
"""
sliced_score_matcher = SlicedScoreMatching(
jax.random.PRNGKey(45),
jax.random.rademacher,
Expand All @@ -90,7 +96,7 @@ def setup_solvers(
coreset_size: int,
sq_exp_kernel: SquaredExponentialKernel,
stein_kernel: SteinKernel,
) -> list[Tuple[str, Any]]:
) -> list[tuple[str, Any]]:
"""
Set up and return a list of solver configurations for reducing a dataset.
Expand Down Expand Up @@ -175,7 +181,7 @@ def compute_solver_metrics(


def compute_metrics(
solvers: list[Tuple[str, Any]],
solvers: list[tuple[str, Any]],
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
Expand Down Expand Up @@ -205,11 +211,12 @@ def compute_metrics(

def main() -> None:
"""
Perform a benchmark comparing different coreset algorithms on a synthetic dataset.
Benchmark different algorithms against on a synthetic dataset.
Generate a synthetic dataset using `sklearn.datasets.make_blobs`,set up various
solvers, generate coreset of different sizes, and compute performance metrics
(MMD and KSD) for each solver at different. Then, save the results into a JSON file.
Compare the performance of different coreset algorithms using a synthetic dataset,
generated using :func:`sklearn.datasets.make_blobs`. We set up various solvers,
generate coresets of multiple sizes, and compute performance metrics (MMD and KSD)
for each solver at each coreset size. Results are saved to a JSON file.
"""
# Generate data
x, *_ = make_blobs(n_samples=1000, n_features=2, centers=10, random_state=45)
Expand Down Expand Up @@ -243,7 +250,7 @@ def main() -> None:
all_results[size] = results

# Save results to JSON file
with open("coreset_comparison_results.json", "w", encoding="utf-8") as f:
with open("blobs_benchmark_results.json", "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=2)


Expand Down
24 changes: 24 additions & 0 deletions benchmark/blobs_benchmark_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@ def print_metrics_table(data: dict, sample_size: str) -> None:
"""
Print a table for the given sample size with methods as rows and metrics as columns.
:param data: A dictionary where keys are sample sizes (as strings) and values are
dictionaries containing the metrics for each algorithm. Each method's
dictionary contains the following keys:
- 'unweighted_mmd': Unweighted maximum mean discrepancy (MMD).
- 'unweighted_ksd': Unweighted kernel Stein discrepancy (KSD).
- 'weighted_mmd': Weighted maximum mean discrepancy (MMD).
- 'weighted_ksd': Weighted kernel Stein discrepancy (KSD).
- 'time': Time taken to compute the coreset and metrics (in seconds).
Example format:
{
'100': {
'KernelHerding': {
'unweighted_mmd': 0.12345678,
'unweighted_ksd': 0.23456789,
'weighted_mmd': 0.34567890,
'weighted_ksd': 0.45678901,
'time': 0.123
},
'Algorithm B': { ... },
...
},
'1000': { ... },
...
}
:param sample_size: The sample size for which to print the table.
"""
# Define header
Expand Down
172 changes: 156 additions & 16 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,26 @@ def convert_to_jax_arrays(pytorch_data: Dataset) -> tuple[jnp.ndarray, jnp.ndarr


def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
"""Compute cross-entropy loss."""
"""
Compute cross-entropy loss.
:param logits: Logits predicted by the model.
:param labels: Ground truth labels as an array of integers.
:return: The cross-entropy loss.
"""
return jnp.mean(
optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10))
)


def compute_metrics(logits: jnp.ndarray, labels: jnp.ndarray) -> dict[str, jnp.ndarray]:
"""Compute loss and accuracy metrics."""
"""
Compute loss and accuracy metrics.
:param logits: Logits predicted by the model.
:param labels: Ground truth labels as an array of integers.
:return: A dictionary containing 'loss' and 'accuracy' as keys.
"""
loss = cross_entropy_loss(logits, labels)
_accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return {"loss": loss, "accuracy": _accuracy}
Expand All @@ -101,7 +113,13 @@ class MLP(nn.Module):

@nn.compact
def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
"""Forward pass of the MLP."""
"""
Forward pass of the MLP.
:param x: Input data.
:param training: Whether the model is in training mode (default is True).
:return: Output logits of the network.
"""
x = nn.Dense(self.hidden_size)(x)
if self.use_batchnorm:
x = nn.BatchNorm(use_running_average=not training)(x)
Expand All @@ -128,7 +146,15 @@ class Metrics(NamedTuple):
def create_train_state(
rng: jnp.ndarray, _model: nn.Module, learning_rate: float, weight_decay: float
) -> TrainState:
"""Create and initialise the train state."""
"""
Create and initialise the train state.
:param rng: Random number generator key.
:param _model: The model to initialise.
:param learning_rate: Learning rate for the optimiser.
:param weight_decay: Weight decay for the optimiser.
:return: The initialised TrainState.
"""
dropout_rng, params_rng = jax.random.split(rng)
params = _model.init(
{"params": params_rng, "dropout": dropout_rng},
Expand All @@ -149,10 +175,23 @@ def create_train_state(
def train_step(
state: TrainState, batch_data: jnp.ndarray, batch_labels: jnp.ndarray
) -> tuple[TrainState, dict[str, jnp.ndarray]]:
"""Perform a single training step."""
"""
Perform a single training step.
:param state: The current state of the model and optimiser.
:param batch_data: Batch of input data.
:param batch_labels: Batch of ground truth labels.
:return: Updated TrainState and a dictionary of metrics (loss and accuracy).
"""
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

def loss_fn(params):
"""
Compute the cross-entropy loss for the given batch.
:param params: Model parameters.
:return: Tuple containing the loss and a tuple of (logits, updated model state).
"""
variables = {"params": params, "batch_stats": state.batch_stats}
logits, new_model_state = state.apply_fn(
variables,
Expand All @@ -179,7 +218,14 @@ def loss_fn(params):
def eval_step(
state: TrainState, batch_data: jnp.ndarray, batch_labels: jnp.ndarray
) -> dict[str, jnp.ndarray]:
"""Perform a single evaluation step."""
"""
Perform a single evaluation step.
:param state: The current state of the model.
:param batch_data: Batch of input data.
:param batch_labels: Batch of ground truth labels.
:return: A dictionary of evaluation metrics (loss and accuracy).
"""
variables = {"params": state.params, "batch_stats": state.batch_stats}
logits = state.apply_fn(
variables, batch_data, training=False, rngs={"dropout": state.dropout_rng}
Expand All @@ -193,7 +239,15 @@ def train_epoch(
train_labels: jnp.ndarray,
batch_size: int,
) -> tuple[TrainState, dict[str, float]]:
"""Train for one epoch and return updated state and metrics."""
"""
Train for one epoch and return updated state and metrics.
:param state: The current state of the model and optimiser.
:param train_data: Training input data.
:param train_labels: Training labels.
:param batch_size: Size of each training batch.
:return: Updated TrainState and a dictionary containing 'loss' and 'accuracy'.
"""
num_batches = train_data.shape[0] // batch_size
total_loss, total_accuracy = 0.0, 0.0

Expand All @@ -215,7 +269,15 @@ def train_epoch(
def evaluate(
state: TrainState, _data: jnp.ndarray, labels: jnp.ndarray, batch_size: int
) -> dict[str, float]:
"""Evaluate the model on given data and return metrics."""
"""
Evaluate the model on given data and return metrics.
:param state: The current state of the model.
:param _data: Input data for evaluation.
:param labels: Ground truth labels for evaluation.
:param batch_size: Size of each evaluation batch.
:return: A dictionary containing 'loss' and 'accuracy' metrics.
"""
num_batches = _data.shape[0] // batch_size
total_loss, total_accuracy = 0.0, 0.0

Expand Down Expand Up @@ -245,7 +307,22 @@ def train_and_evaluate(
rng: jnp.ndarray,
config: dict[str, Any],
) -> dict[str, float]:
"""Train and evaluate the model with early stopping."""
"""
Train and evaluate the model with early stopping.
:param train_set: The training dataset containing features and labels.
:param test_set: The test dataset containing features and labels.
:param _model: The model to be trained.
:param rng: Random number generator key for parameter initialisation and dropout.
:param config: A dictionary of training configuration parameters, including:
- "learning_rate": Learning rate for the optimiser.
- "weight_decay": Weight decay for the optimiser.
- "batch_size": Number of samples per training batch.
- "epochs": Total number of training epochs.
- "patience": Early stopping patience.
- "min_delta": Minimum change in accuracy to qualify as improvement.
:return: A dictionary containing the final test loss and accuracy after training.
"""
state = create_train_state(
rng, _model, config["learning_rate"], config["weight_decay"]
)
Expand Down Expand Up @@ -303,9 +380,9 @@ def pca(x: jnp.ndarray, n_components: int = 16) -> jnp.ndarray:
This function computes the principal components of the input data matrix
and returns the projected data.
:param x: The input data matrix of shape (n_samples, n_features)
:param n_components: The number of principal components to return
:return: The projected data of shape (n_samples, n_components)
:param x: The input data matrix of shape (n_samples, n_features).
:param n_components: The number of principal components to return.
:return: The projected data of shape (n_samples, n_components).
"""
# Center the data
x_centred = x - jnp.mean(x, axis=0)
Expand All @@ -330,7 +407,12 @@ def pca(x: jnp.ndarray, n_components: int = 16) -> jnp.ndarray:


def prepare_datasets() -> tuple:
"""Prepare and return training and test datasets."""
"""
Prepare and return training and test datasets.
:return: A tuple containing training and test datasets in JAX arrays:
(train_data_jax, train_targets_jax, test_data_jax, test_targets_jax).
"""
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))]
)
Expand Down Expand Up @@ -436,7 +518,24 @@ def _get_rp_solver(_size: int) -> tuple[str, MapReduce]:


def train_model(data_bundle: dict, key, config) -> dict:
"""Train the model and return the results."""
"""
Train the model and return the results.
:param data_bundle: A dictionary containing the following keys:
- "data": Training input data.
- "targets": Training labels.
- "test_data": Test input data.
- "test_targets": Test labels.
:param key: Random number generator key for model initialisation and dropout.
:param config: A dictionary of training configuration parameters, including:
- "learning_rate": Learning rate for the optimiser.
- "weight_decay": Weight decay for the optimiser.
- "batch_size": Number of samples per training batch.
- "epochs": Total number of training epochs.
- "patience": Early stopping patience.
- "min_delta": Minimum change in accuracy to qualify as improvement.
:return: A dictionary containing the final test loss and accuracy after training.
"""
model = MLP(hidden_size=64)

# Access the values from the data_bundle dictionary
Expand All @@ -457,15 +556,56 @@ def train_model(data_bundle: dict, key, config) -> dict:


def save_results(results: dict) -> None:
"""Save results to JSON."""
"""
Save benchmark results to a JSON file for algorithm performance visualisation.
:param results: A dictionary of results structured as follows:
{
"algorithm_name": {
"coreset_size_1": {
"run_1": accuracy_value,
"run_2": accuracy_value,
...
},
"coreset_size_2": {
"run_1": accuracy_value,
"run_2": accuracy_value,
...
},
...
},
"another_algorithm_name": {
"coreset_size_1": {
"run_1": accuracy_value,
"run_2": accuracy_value,
...
},
...
},
...
}
Each algorithm contains coreset sizes as keys, with values being
dictionaries of accuracy results from different runs.
"""
with open("mnist_benchmark_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, indent=4)

print("Data has been saved to 'benchmark_results.json'")


def main() -> None:
"""Perform the benchmark."""
"""
Perform the benchmark for multiple solvers, coreset sizes, and random seeds.
The function follows these steps:
1. Prepare and load the MNIST datasets (training and test).
2. Perform dimensionality reduction on the training data using PCA.
3. Initialise solvers for data reduction.
4. For each solver and coreset size, reduce the dataset and train the model
on the reduced set.
5. Train the model and evaluate its performance on the test set.
6. Save the results, which include test accuracy for each solver and coreset size.
"""
(train_data_jax, train_targets_jax, test_data_jax, test_targets_jax) = (
prepare_datasets()
)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/mnist_benchmark_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def main() -> None:
"""Load benchmark results and visualize the algorithm performance."""
"""Load benchmark results and visualise the algorithm performance."""
with open("mnist_benchmark_results.json", "r", encoding="utf-8") as file:
# Load the JSON data into a Python object
data_by_solver = json.load(file)
Expand Down

0 comments on commit 09b6077

Please sign in to comment.