@@ -78,14 +78,26 @@ def convert_to_jax_arrays(pytorch_data: Dataset) -> tuple[jnp.ndarray, jnp.ndarr
78
78
79
79
80
80
def cross_entropy_loss (logits : jnp .ndarray , labels : jnp .ndarray ) -> jnp .ndarray :
81
- """Compute cross-entropy loss."""
81
+ """
82
+ Compute cross-entropy loss.
83
+
84
+ :param logits: Logits predicted by the model.
85
+ :param labels: Ground truth labels as an array of integers.
86
+ :return: The cross-entropy loss.
87
+ """
82
88
return jnp .mean (
83
89
optax .softmax_cross_entropy (logits , jax .nn .one_hot (labels , num_classes = 10 ))
84
90
)
85
91
86
92
87
93
def compute_metrics (logits : jnp .ndarray , labels : jnp .ndarray ) -> dict [str , jnp .ndarray ]:
88
- """Compute loss and accuracy metrics."""
94
+ """
95
+ Compute loss and accuracy metrics.
96
+
97
+ :param logits: Logits predicted by the model.
98
+ :param labels: Ground truth labels as an array of integers.
99
+ :return: A dictionary containing 'loss' and 'accuracy' as keys.
100
+ """
89
101
loss = cross_entropy_loss (logits , labels )
90
102
_accuracy = jnp .mean (jnp .argmax (logits , - 1 ) == labels )
91
103
return {"loss" : loss , "accuracy" : _accuracy }
@@ -101,7 +113,13 @@ class MLP(nn.Module):
101
113
102
114
@nn .compact
103
115
def __call__ (self , x : jnp .ndarray , training : bool = True ) -> jnp .ndarray :
104
- """Forward pass of the MLP."""
116
+ """
117
+ Forward pass of the MLP.
118
+
119
+ :param x: Input data.
120
+ :param training: Whether the model is in training mode (default is True).
121
+ :return: Output logits of the network.
122
+ """
105
123
x = nn .Dense (self .hidden_size )(x )
106
124
if self .use_batchnorm :
107
125
x = nn .BatchNorm (use_running_average = not training )(x )
@@ -128,7 +146,15 @@ class Metrics(NamedTuple):
128
146
def create_train_state (
129
147
rng : jnp .ndarray , _model : nn .Module , learning_rate : float , weight_decay : float
130
148
) -> TrainState :
131
- """Create and initialise the train state."""
149
+ """
150
+ Create and initialise the train state.
151
+
152
+ :param rng: Random number generator key.
153
+ :param _model: The model to initialise.
154
+ :param learning_rate: Learning rate for the optimiser.
155
+ :param weight_decay: Weight decay for the optimiser.
156
+ :return: The initialised TrainState.
157
+ """
132
158
dropout_rng , params_rng = jax .random .split (rng )
133
159
params = _model .init (
134
160
{"params" : params_rng , "dropout" : dropout_rng },
@@ -149,10 +175,23 @@ def create_train_state(
149
175
def train_step (
150
176
state : TrainState , batch_data : jnp .ndarray , batch_labels : jnp .ndarray
151
177
) -> tuple [TrainState , dict [str , jnp .ndarray ]]:
152
- """Perform a single training step."""
178
+ """
179
+ Perform a single training step.
180
+
181
+ :param state: The current state of the model and optimiser.
182
+ :param batch_data: Batch of input data.
183
+ :param batch_labels: Batch of ground truth labels.
184
+ :return: Updated TrainState and a dictionary of metrics (loss and accuracy).
185
+ """
153
186
dropout_rng , new_dropout_rng = jax .random .split (state .dropout_rng )
154
187
155
188
def loss_fn (params ):
189
+ """
190
+ Compute the cross-entropy loss for the given batch.
191
+
192
+ :param params: Model parameters.
193
+ :return: Tuple containing the loss and a tuple of (logits, updated model state).
194
+ """
156
195
variables = {"params" : params , "batch_stats" : state .batch_stats }
157
196
logits , new_model_state = state .apply_fn (
158
197
variables ,
@@ -179,7 +218,14 @@ def loss_fn(params):
179
218
def eval_step (
180
219
state : TrainState , batch_data : jnp .ndarray , batch_labels : jnp .ndarray
181
220
) -> dict [str , jnp .ndarray ]:
182
- """Perform a single evaluation step."""
221
+ """
222
+ Perform a single evaluation step.
223
+
224
+ :param state: The current state of the model.
225
+ :param batch_data: Batch of input data.
226
+ :param batch_labels: Batch of ground truth labels.
227
+ :return: A dictionary of evaluation metrics (loss and accuracy).
228
+ """
183
229
variables = {"params" : state .params , "batch_stats" : state .batch_stats }
184
230
logits = state .apply_fn (
185
231
variables , batch_data , training = False , rngs = {"dropout" : state .dropout_rng }
@@ -193,7 +239,15 @@ def train_epoch(
193
239
train_labels : jnp .ndarray ,
194
240
batch_size : int ,
195
241
) -> tuple [TrainState , dict [str , float ]]:
196
- """Train for one epoch and return updated state and metrics."""
242
+ """
243
+ Train for one epoch and return updated state and metrics.
244
+
245
+ :param state: The current state of the model and optimiser.
246
+ :param train_data: Training input data.
247
+ :param train_labels: Training labels.
248
+ :param batch_size: Size of each training batch.
249
+ :return: Updated TrainState and a dictionary containing 'loss' and 'accuracy'.
250
+ """
197
251
num_batches = train_data .shape [0 ] // batch_size
198
252
total_loss , total_accuracy = 0.0 , 0.0
199
253
@@ -215,7 +269,15 @@ def train_epoch(
215
269
def evaluate (
216
270
state : TrainState , _data : jnp .ndarray , labels : jnp .ndarray , batch_size : int
217
271
) -> dict [str , float ]:
218
- """Evaluate the model on given data and return metrics."""
272
+ """
273
+ Evaluate the model on given data and return metrics.
274
+
275
+ :param state: The current state of the model.
276
+ :param _data: Input data for evaluation.
277
+ :param labels: Ground truth labels for evaluation.
278
+ :param batch_size: Size of each evaluation batch.
279
+ :return: A dictionary containing 'loss' and 'accuracy' metrics.
280
+ """
219
281
num_batches = _data .shape [0 ] // batch_size
220
282
total_loss , total_accuracy = 0.0 , 0.0
221
283
@@ -245,7 +307,22 @@ def train_and_evaluate(
245
307
rng : jnp .ndarray ,
246
308
config : dict [str , Any ],
247
309
) -> dict [str , float ]:
248
- """Train and evaluate the model with early stopping."""
310
+ """
311
+ Train and evaluate the model with early stopping.
312
+
313
+ :param train_set: The training dataset containing features and labels.
314
+ :param test_set: The test dataset containing features and labels.
315
+ :param _model: The model to be trained.
316
+ :param rng: Random number generator key for parameter initialisation and dropout.
317
+ :param config: A dictionary of training configuration parameters, including:
318
+ - "learning_rate": Learning rate for the optimiser.
319
+ - "weight_decay": Weight decay for the optimiser.
320
+ - "batch_size": Number of samples per training batch.
321
+ - "epochs": Total number of training epochs.
322
+ - "patience": Early stopping patience.
323
+ - "min_delta": Minimum change in accuracy to qualify as improvement.
324
+ :return: A dictionary containing the final test loss and accuracy after training.
325
+ """
249
326
state = create_train_state (
250
327
rng , _model , config ["learning_rate" ], config ["weight_decay" ]
251
328
)
@@ -303,9 +380,9 @@ def pca(x: jnp.ndarray, n_components: int = 16) -> jnp.ndarray:
303
380
This function computes the principal components of the input data matrix
304
381
and returns the projected data.
305
382
306
- :param x: The input data matrix of shape (n_samples, n_features)
307
- :param n_components: The number of principal components to return
308
- :return: The projected data of shape (n_samples, n_components)
383
+ :param x: The input data matrix of shape (n_samples, n_features).
384
+ :param n_components: The number of principal components to return.
385
+ :return: The projected data of shape (n_samples, n_components).
309
386
"""
310
387
# Center the data
311
388
x_centred = x - jnp .mean (x , axis = 0 )
@@ -330,7 +407,12 @@ def pca(x: jnp.ndarray, n_components: int = 16) -> jnp.ndarray:
330
407
331
408
332
409
def prepare_datasets () -> tuple :
333
- """Prepare and return training and test datasets."""
410
+ """
411
+ Prepare and return training and test datasets.
412
+
413
+ :return: A tuple containing training and test datasets in JAX arrays:
414
+ (train_data_jax, train_targets_jax, test_data_jax, test_targets_jax).
415
+ """
334
416
transform = transforms .Compose (
335
417
[transforms .ToTensor (), transforms .Lambda (lambda x : x .view (- 1 ))]
336
418
)
@@ -436,7 +518,24 @@ def _get_rp_solver(_size: int) -> tuple[str, MapReduce]:
436
518
437
519
438
520
def train_model (data_bundle : dict , key , config ) -> dict :
439
- """Train the model and return the results."""
521
+ """
522
+ Train the model and return the results.
523
+
524
+ :param data_bundle: A dictionary containing the following keys:
525
+ - "data": Training input data.
526
+ - "targets": Training labels.
527
+ - "test_data": Test input data.
528
+ - "test_targets": Test labels.
529
+ :param key: Random number generator key for model initialisation and dropout.
530
+ :param config: A dictionary of training configuration parameters, including:
531
+ - "learning_rate": Learning rate for the optimiser.
532
+ - "weight_decay": Weight decay for the optimiser.
533
+ - "batch_size": Number of samples per training batch.
534
+ - "epochs": Total number of training epochs.
535
+ - "patience": Early stopping patience.
536
+ - "min_delta": Minimum change in accuracy to qualify as improvement.
537
+ :return: A dictionary containing the final test loss and accuracy after training.
538
+ """
440
539
model = MLP (hidden_size = 64 )
441
540
442
541
# Access the values from the data_bundle dictionary
@@ -457,15 +556,56 @@ def train_model(data_bundle: dict, key, config) -> dict:
457
556
458
557
459
558
def save_results (results : dict ) -> None :
460
- """Save results to JSON."""
559
+ """
560
+ Save benchmark results to a JSON file for algorithm performance visualisation.
561
+
562
+ :param results: A dictionary of results structured as follows:
563
+ {
564
+ "algorithm_name": {
565
+ "coreset_size_1": {
566
+ "run_1": accuracy_value,
567
+ "run_2": accuracy_value,
568
+ ...
569
+ },
570
+ "coreset_size_2": {
571
+ "run_1": accuracy_value,
572
+ "run_2": accuracy_value,
573
+ ...
574
+ },
575
+ ...
576
+ },
577
+ "another_algorithm_name": {
578
+ "coreset_size_1": {
579
+ "run_1": accuracy_value,
580
+ "run_2": accuracy_value,
581
+ ...
582
+ },
583
+ ...
584
+ },
585
+ ...
586
+ }
587
+ Each algorithm contains coreset sizes as keys, with values being
588
+ dictionaries of accuracy results from different runs.
589
+ """
461
590
with open ("mnist_benchmark_results.json" , "w" , encoding = "utf-8" ) as f :
462
591
json .dump (results , f , indent = 4 )
463
592
464
593
print ("Data has been saved to 'benchmark_results.json'" )
465
594
466
595
467
596
def main () -> None :
468
- """Perform the benchmark."""
597
+ """
598
+ Perform the benchmark for multiple solvers, coreset sizes, and random seeds.
599
+
600
+ The function follows these steps:
601
+ 1. Prepare and load the MNIST datasets (training and test).
602
+ 2. Perform dimensionality reduction on the training data using PCA.
603
+ 3. Initialise solvers for data reduction.
604
+ 4. For each solver and coreset size, reduce the dataset and train the model
605
+ on the reduced set.
606
+ 5. Train the model and evaluate its performance on the test set.
607
+ 6. Save the results, which include test accuracy for each solver and coreset size.
608
+ """
469
609
(train_data_jax , train_targets_jax , test_data_jax , test_targets_jax ) = (
470
610
prepare_datasets ()
471
611
)
0 commit comments