Skip to content

Commit 09b6077

Browse files
committed
:docs: fixed docstrings
1 parent 1df2237 commit 09b6077

4 files changed

+198
-27
lines changed

benchmark/blobs_benchmark.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
1818
The benchmarking process follows these steps:
1919
1. Generate a synthetic dataset of 1000 two-dimensional points using
20-
`sklearn.datasets.make_blobs`.
20+
:func:`sklearn.datasets.make_blobs`.
2121
2. Generate coresets of varying sizes: 10, 50, 100, and 200 points using different
2222
coreset algorithms.
2323
3. Compute two metrics to evaluate the coresets' quality:
@@ -30,7 +30,7 @@
3030

3131
import json
3232
import time
33-
from typing import Any, Tuple
33+
from typing import Any
3434

3535
import jax
3636
import jax.numpy as jnp
@@ -71,7 +71,13 @@ def setup_kernel(x: np.ndarray) -> SquaredExponentialKernel:
7171
def setup_stein_kernel(
7272
sq_exp_kernel: SquaredExponentialKernel, dataset: Data
7373
) -> SteinKernel:
74-
"""Set up Stein Kernel."""
74+
"""
75+
Set up a Stein Kernel for Stein Thinning.
76+
77+
:param sq_exp_kernel: A SquaredExponential base kernel for the Stein Kernel.
78+
:param dataset: Dataset for score matching.
79+
:return: A SteinKernel object.
80+
"""
7581
sliced_score_matcher = SlicedScoreMatching(
7682
jax.random.PRNGKey(45),
7783
jax.random.rademacher,
@@ -90,7 +96,7 @@ def setup_solvers(
9096
coreset_size: int,
9197
sq_exp_kernel: SquaredExponentialKernel,
9298
stein_kernel: SteinKernel,
93-
) -> list[Tuple[str, Any]]:
99+
) -> list[tuple[str, Any]]:
94100
"""
95101
Set up and return a list of solver configurations for reducing a dataset.
96102
@@ -175,7 +181,7 @@ def compute_solver_metrics(
175181

176182

177183
def compute_metrics(
178-
solvers: list[Tuple[str, Any]],
184+
solvers: list[tuple[str, Any]],
179185
dataset: Data,
180186
mmd_metric: MMD,
181187
ksd_metric: KSD,
@@ -205,11 +211,12 @@ def compute_metrics(
205211

206212
def main() -> None:
207213
"""
208-
Perform a benchmark comparing different coreset algorithms on a synthetic dataset.
214+
Benchmark different algorithms against on a synthetic dataset.
209215
210-
Generate a synthetic dataset using `sklearn.datasets.make_blobs`,set up various
211-
solvers, generate coreset of different sizes, and compute performance metrics
212-
(MMD and KSD) for each solver at different. Then, save the results into a JSON file.
216+
Compare the performance of different coreset algorithms using a synthetic dataset,
217+
generated using :func:`sklearn.datasets.make_blobs`. We set up various solvers,
218+
generate coresets of multiple sizes, and compute performance metrics (MMD and KSD)
219+
for each solver at each coreset size. Results are saved to a JSON file.
213220
"""
214221
# Generate data
215222
x, *_ = make_blobs(n_samples=1000, n_features=2, centers=10, random_state=45)
@@ -243,7 +250,7 @@ def main() -> None:
243250
all_results[size] = results
244251

245252
# Save results to JSON file
246-
with open("coreset_comparison_results.json", "w", encoding="utf-8") as f:
253+
with open("blobs_benchmark_results.json", "w", encoding="utf-8") as f:
247254
json.dump(all_results, f, indent=2)
248255

249256

benchmark/blobs_benchmark_visualiser.py

+24
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,30 @@ def print_metrics_table(data: dict, sample_size: str) -> None:
2222
"""
2323
Print a table for the given sample size with methods as rows and metrics as columns.
2424
25+
:param data: A dictionary where keys are sample sizes (as strings) and values are
26+
dictionaries containing the metrics for each algorithm. Each method's
27+
dictionary contains the following keys:
28+
- 'unweighted_mmd': Unweighted maximum mean discrepancy (MMD).
29+
- 'unweighted_ksd': Unweighted kernel Stein discrepancy (KSD).
30+
- 'weighted_mmd': Weighted maximum mean discrepancy (MMD).
31+
- 'weighted_ksd': Weighted kernel Stein discrepancy (KSD).
32+
- 'time': Time taken to compute the coreset and metrics (in seconds).
33+
Example format:
34+
{
35+
'100': {
36+
'KernelHerding': {
37+
'unweighted_mmd': 0.12345678,
38+
'unweighted_ksd': 0.23456789,
39+
'weighted_mmd': 0.34567890,
40+
'weighted_ksd': 0.45678901,
41+
'time': 0.123
42+
},
43+
'Algorithm B': { ... },
44+
...
45+
},
46+
'1000': { ... },
47+
...
48+
}
2549
:param sample_size: The sample size for which to print the table.
2650
"""
2751
# Define header

benchmark/mnist_benchmark.py

+156-16
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,26 @@ def convert_to_jax_arrays(pytorch_data: Dataset) -> tuple[jnp.ndarray, jnp.ndarr
7878

7979

8080
def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
81-
"""Compute cross-entropy loss."""
81+
"""
82+
Compute cross-entropy loss.
83+
84+
:param logits: Logits predicted by the model.
85+
:param labels: Ground truth labels as an array of integers.
86+
:return: The cross-entropy loss.
87+
"""
8288
return jnp.mean(
8389
optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10))
8490
)
8591

8692

8793
def compute_metrics(logits: jnp.ndarray, labels: jnp.ndarray) -> dict[str, jnp.ndarray]:
88-
"""Compute loss and accuracy metrics."""
94+
"""
95+
Compute loss and accuracy metrics.
96+
97+
:param logits: Logits predicted by the model.
98+
:param labels: Ground truth labels as an array of integers.
99+
:return: A dictionary containing 'loss' and 'accuracy' as keys.
100+
"""
89101
loss = cross_entropy_loss(logits, labels)
90102
_accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
91103
return {"loss": loss, "accuracy": _accuracy}
@@ -101,7 +113,13 @@ class MLP(nn.Module):
101113

102114
@nn.compact
103115
def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
104-
"""Forward pass of the MLP."""
116+
"""
117+
Forward pass of the MLP.
118+
119+
:param x: Input data.
120+
:param training: Whether the model is in training mode (default is True).
121+
:return: Output logits of the network.
122+
"""
105123
x = nn.Dense(self.hidden_size)(x)
106124
if self.use_batchnorm:
107125
x = nn.BatchNorm(use_running_average=not training)(x)
@@ -128,7 +146,15 @@ class Metrics(NamedTuple):
128146
def create_train_state(
129147
rng: jnp.ndarray, _model: nn.Module, learning_rate: float, weight_decay: float
130148
) -> TrainState:
131-
"""Create and initialise the train state."""
149+
"""
150+
Create and initialise the train state.
151+
152+
:param rng: Random number generator key.
153+
:param _model: The model to initialise.
154+
:param learning_rate: Learning rate for the optimiser.
155+
:param weight_decay: Weight decay for the optimiser.
156+
:return: The initialised TrainState.
157+
"""
132158
dropout_rng, params_rng = jax.random.split(rng)
133159
params = _model.init(
134160
{"params": params_rng, "dropout": dropout_rng},
@@ -149,10 +175,23 @@ def create_train_state(
149175
def train_step(
150176
state: TrainState, batch_data: jnp.ndarray, batch_labels: jnp.ndarray
151177
) -> tuple[TrainState, dict[str, jnp.ndarray]]:
152-
"""Perform a single training step."""
178+
"""
179+
Perform a single training step.
180+
181+
:param state: The current state of the model and optimiser.
182+
:param batch_data: Batch of input data.
183+
:param batch_labels: Batch of ground truth labels.
184+
:return: Updated TrainState and a dictionary of metrics (loss and accuracy).
185+
"""
153186
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
154187

155188
def loss_fn(params):
189+
"""
190+
Compute the cross-entropy loss for the given batch.
191+
192+
:param params: Model parameters.
193+
:return: Tuple containing the loss and a tuple of (logits, updated model state).
194+
"""
156195
variables = {"params": params, "batch_stats": state.batch_stats}
157196
logits, new_model_state = state.apply_fn(
158197
variables,
@@ -179,7 +218,14 @@ def loss_fn(params):
179218
def eval_step(
180219
state: TrainState, batch_data: jnp.ndarray, batch_labels: jnp.ndarray
181220
) -> dict[str, jnp.ndarray]:
182-
"""Perform a single evaluation step."""
221+
"""
222+
Perform a single evaluation step.
223+
224+
:param state: The current state of the model.
225+
:param batch_data: Batch of input data.
226+
:param batch_labels: Batch of ground truth labels.
227+
:return: A dictionary of evaluation metrics (loss and accuracy).
228+
"""
183229
variables = {"params": state.params, "batch_stats": state.batch_stats}
184230
logits = state.apply_fn(
185231
variables, batch_data, training=False, rngs={"dropout": state.dropout_rng}
@@ -193,7 +239,15 @@ def train_epoch(
193239
train_labels: jnp.ndarray,
194240
batch_size: int,
195241
) -> tuple[TrainState, dict[str, float]]:
196-
"""Train for one epoch and return updated state and metrics."""
242+
"""
243+
Train for one epoch and return updated state and metrics.
244+
245+
:param state: The current state of the model and optimiser.
246+
:param train_data: Training input data.
247+
:param train_labels: Training labels.
248+
:param batch_size: Size of each training batch.
249+
:return: Updated TrainState and a dictionary containing 'loss' and 'accuracy'.
250+
"""
197251
num_batches = train_data.shape[0] // batch_size
198252
total_loss, total_accuracy = 0.0, 0.0
199253

@@ -215,7 +269,15 @@ def train_epoch(
215269
def evaluate(
216270
state: TrainState, _data: jnp.ndarray, labels: jnp.ndarray, batch_size: int
217271
) -> dict[str, float]:
218-
"""Evaluate the model on given data and return metrics."""
272+
"""
273+
Evaluate the model on given data and return metrics.
274+
275+
:param state: The current state of the model.
276+
:param _data: Input data for evaluation.
277+
:param labels: Ground truth labels for evaluation.
278+
:param batch_size: Size of each evaluation batch.
279+
:return: A dictionary containing 'loss' and 'accuracy' metrics.
280+
"""
219281
num_batches = _data.shape[0] // batch_size
220282
total_loss, total_accuracy = 0.0, 0.0
221283

@@ -245,7 +307,22 @@ def train_and_evaluate(
245307
rng: jnp.ndarray,
246308
config: dict[str, Any],
247309
) -> dict[str, float]:
248-
"""Train and evaluate the model with early stopping."""
310+
"""
311+
Train and evaluate the model with early stopping.
312+
313+
:param train_set: The training dataset containing features and labels.
314+
:param test_set: The test dataset containing features and labels.
315+
:param _model: The model to be trained.
316+
:param rng: Random number generator key for parameter initialisation and dropout.
317+
:param config: A dictionary of training configuration parameters, including:
318+
- "learning_rate": Learning rate for the optimiser.
319+
- "weight_decay": Weight decay for the optimiser.
320+
- "batch_size": Number of samples per training batch.
321+
- "epochs": Total number of training epochs.
322+
- "patience": Early stopping patience.
323+
- "min_delta": Minimum change in accuracy to qualify as improvement.
324+
:return: A dictionary containing the final test loss and accuracy after training.
325+
"""
249326
state = create_train_state(
250327
rng, _model, config["learning_rate"], config["weight_decay"]
251328
)
@@ -303,9 +380,9 @@ def pca(x: jnp.ndarray, n_components: int = 16) -> jnp.ndarray:
303380
This function computes the principal components of the input data matrix
304381
and returns the projected data.
305382
306-
:param x: The input data matrix of shape (n_samples, n_features)
307-
:param n_components: The number of principal components to return
308-
:return: The projected data of shape (n_samples, n_components)
383+
:param x: The input data matrix of shape (n_samples, n_features).
384+
:param n_components: The number of principal components to return.
385+
:return: The projected data of shape (n_samples, n_components).
309386
"""
310387
# Center the data
311388
x_centred = x - jnp.mean(x, axis=0)
@@ -330,7 +407,12 @@ def pca(x: jnp.ndarray, n_components: int = 16) -> jnp.ndarray:
330407

331408

332409
def prepare_datasets() -> tuple:
333-
"""Prepare and return training and test datasets."""
410+
"""
411+
Prepare and return training and test datasets.
412+
413+
:return: A tuple containing training and test datasets in JAX arrays:
414+
(train_data_jax, train_targets_jax, test_data_jax, test_targets_jax).
415+
"""
334416
transform = transforms.Compose(
335417
[transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))]
336418
)
@@ -436,7 +518,24 @@ def _get_rp_solver(_size: int) -> tuple[str, MapReduce]:
436518

437519

438520
def train_model(data_bundle: dict, key, config) -> dict:
439-
"""Train the model and return the results."""
521+
"""
522+
Train the model and return the results.
523+
524+
:param data_bundle: A dictionary containing the following keys:
525+
- "data": Training input data.
526+
- "targets": Training labels.
527+
- "test_data": Test input data.
528+
- "test_targets": Test labels.
529+
:param key: Random number generator key for model initialisation and dropout.
530+
:param config: A dictionary of training configuration parameters, including:
531+
- "learning_rate": Learning rate for the optimiser.
532+
- "weight_decay": Weight decay for the optimiser.
533+
- "batch_size": Number of samples per training batch.
534+
- "epochs": Total number of training epochs.
535+
- "patience": Early stopping patience.
536+
- "min_delta": Minimum change in accuracy to qualify as improvement.
537+
:return: A dictionary containing the final test loss and accuracy after training.
538+
"""
440539
model = MLP(hidden_size=64)
441540

442541
# Access the values from the data_bundle dictionary
@@ -457,15 +556,56 @@ def train_model(data_bundle: dict, key, config) -> dict:
457556

458557

459558
def save_results(results: dict) -> None:
460-
"""Save results to JSON."""
559+
"""
560+
Save benchmark results to a JSON file for algorithm performance visualisation.
561+
562+
:param results: A dictionary of results structured as follows:
563+
{
564+
"algorithm_name": {
565+
"coreset_size_1": {
566+
"run_1": accuracy_value,
567+
"run_2": accuracy_value,
568+
...
569+
},
570+
"coreset_size_2": {
571+
"run_1": accuracy_value,
572+
"run_2": accuracy_value,
573+
...
574+
},
575+
...
576+
},
577+
"another_algorithm_name": {
578+
"coreset_size_1": {
579+
"run_1": accuracy_value,
580+
"run_2": accuracy_value,
581+
...
582+
},
583+
...
584+
},
585+
...
586+
}
587+
Each algorithm contains coreset sizes as keys, with values being
588+
dictionaries of accuracy results from different runs.
589+
"""
461590
with open("mnist_benchmark_results.json", "w", encoding="utf-8") as f:
462591
json.dump(results, f, indent=4)
463592

464593
print("Data has been saved to 'benchmark_results.json'")
465594

466595

467596
def main() -> None:
468-
"""Perform the benchmark."""
597+
"""
598+
Perform the benchmark for multiple solvers, coreset sizes, and random seeds.
599+
600+
The function follows these steps:
601+
1. Prepare and load the MNIST datasets (training and test).
602+
2. Perform dimensionality reduction on the training data using PCA.
603+
3. Initialise solvers for data reduction.
604+
4. For each solver and coreset size, reduce the dataset and train the model
605+
on the reduced set.
606+
5. Train the model and evaluate its performance on the test set.
607+
6. Save the results, which include test accuracy for each solver and coreset size.
608+
"""
469609
(train_data_jax, train_targets_jax, test_data_jax, test_targets_jax) = (
470610
prepare_datasets()
471611
)

benchmark/mnist_benchmark_visualiser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def main() -> None:
24-
"""Load benchmark results and visualize the algorithm performance."""
24+
"""Load benchmark results and visualise the algorithm performance."""
2525
with open("mnist_benchmark_results.json", "r", encoding="utf-8") as file:
2626
# Load the JSON data into a Python object
2727
data_by_solver = json.load(file)

0 commit comments

Comments
 (0)