Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/benchmarking #802

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1dffbc6
Initial commit for benchmarking
qh681248 Oct 9, 2024
54f0daa
Replaced `small_dataset` with `train_data_jax`
qh681248 Oct 10, 2024
055206b
Added copyright and documentation
qh681248 Oct 10, 2024
78cd2e0
added docstring to `__init__.py`
qh681248 Oct 10, 2024
9f799bf
Added documentation on benchmark/__init__.py
qh681248 Oct 10, 2024
5269d1c
Added `main` function on mnist_benchmark.py
qh681248 Oct 10, 2024
b84ff03
Added main() function on mnist_benchmark.py
qh681248 Oct 11, 2024
650fcdc
Added linestyle to custom_misc.txt
qh681248 Oct 14, 2024
41f4f09
Wrapped mnist_benchmark_visualiser.py on the main function
qh681248 Oct 14, 2024
5d0a8e2
feat(benchmark): add main function to all new benchmark scripts
qh681248 Oct 17, 2024
4180f52
feat(benchmark): add main function to all new benchmark scripts
qh681248 Oct 17, 2024
9e7891a
Merge remote-tracking branch 'remotes/origin/main' into feature/bench…
qh681248 Oct 17, 2024
e6baad2
feat: add blobs_benchmark_visualiser.py for benchmark result visualis…
qh681248 Oct 17, 2024
52e7869
feat: Add benchmark/blobs_benchmark_visualiser.py
qh681248 Oct 17, 2024
db3d3ee
feat: renamed json files
qh681248 Oct 17, 2024
b912461
feat: Add docstrings on blobs_benchmark.py and mnist_benchmark.py exp…
qh681248 Oct 17, 2024
e94aed0
Merge remote-tracking branch 'remotes/origin/main' into feature/bench…
qh681248 Oct 18, 2024
1df2237
:build: add pytorch, torchvision to dependencies
qh681248 Oct 23, 2024
09b6077
:docs: fixed docstrings
qh681248 Oct 24, 2024
4cef166
chore: Updated uv.lock after adding additional dependencies.
qh681248 Oct 24, 2024
a86b09f
feat: Added unit/test_benchmark.py
qh681248 Oct 25, 2024
4860981
fix: Fix merge conflict due to uv.lock
qh681248 Oct 25, 2024
4196727
fix: Fix merge conflict due to uv.lock
qh681248 Oct 25, 2024
4b36684
Merge remote-tracking branch 'remotes/origin/main' into feature/bench…
qh681248 Oct 25, 2024
f679a57
ci: remove macos-13 test from continuous integration test
qh681248 Oct 25, 2024
00f069e
refactor: use base directory for json path
qh681248 Oct 25, 2024
6102918
refactor: use base directory for json path
qh681248 Oct 25, 2024
606c2b9
refactor: use base directory for json path
qh681248 Oct 25, 2024
a2f45d3
refactor: use base directory for json path
qh681248 Oct 25, 2024
28fe378
chore: Added `torch` to `library_terms.txt`
qh681248 Oct 30, 2024
880711e
Merge remote-tracking branch 'remotes/origin/main' into feature/bench…
qh681248 Oct 30, 2024
0da26bd
chore: merge main to branch to resolve merge conflict
qh681248 Oct 30, 2024
058e816
feat: increase block size to 128 to run on more powerful machine
qh681248 Oct 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .cspell/custom_misc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ diag
docstrings
eigendecomposition
elementwise
fontsize
forall
GCHQ
Gramian
Expand All @@ -25,10 +26,13 @@ kdtree
kernelised
kernelized
KSD
linestyle
linewidth
mapsto
Matern
maxs
ml.p3.8xlarge
MNIST
ndmin
parsable
PCIMQ
Expand All @@ -42,10 +46,13 @@ recomb
refs
regulariser
RKHS
rngs
RPCHOLESKY
sigmas
subseteq
supp
TLDR
typecheck
WMMD
xticks
yerr
3 changes: 3 additions & 0 deletions .cspell/library_terms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pyqt
pyright
pyroma
pytest
pytorch
pytree
pytrees
quickstart
Expand All @@ -127,6 +128,8 @@ tensordot
texttt
toctree
tomli
torch
torchvision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding torch to this file

tqdm
triu
ttest
Expand Down
6 changes: 0 additions & 6 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ jobs:
- ubuntu-latest
- windows-latest
- macos-latest
# macos-latest does not support python 3.9
# https://github.com/actions/setup-python/issues/696#issuecomment-1637587760
exclude:
- { python-version: "3.9", os: "macos-latest" }
include:
- { python-version: "3.9", os: "macos-13" }
runs-on: ${{ matrix.os }}
env:
# Set the Python version that `uv` will use for its virtual environment.
Expand Down
15 changes: 15 additions & 0 deletions benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Different benchmarking tests to compare performance of different algorithms."""
262 changes: 262 additions & 0 deletions benchmark/blobs_benchmark.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there are any sensible unit tests to write for this module.

Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# © Crown Copyright GCHQ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Benchmark performance of different coreset algorithms on a synthetic dataset.

The benchmarking process follows these steps:
1. Generate a synthetic dataset of 1000 two-dimensional points using
:func:`sklearn.datasets.make_blobs`.
2. Generate coresets of varying sizes: 10, 50, 100, and 200 points using different
coreset algorithms.
3. Compute two metrics to evaluate the coresets' quality:
- Maximum Mean Discrepancy (MMD)
- Kernel Stein Discrepancy (KSD)
4. Optimize weights for the coresets to minimize the MMD score and recompute both
the MMD and KSD metrics.
5. Measure and report the time taken for each step of the benchmarking process.
"""

import json
import os
import time
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
from sklearn.datasets import make_blobs

from coreax import Data, SlicedScoreMatching
from coreax.kernels import (
SquaredExponentialKernel,
SteinKernel,
median_heuristic,
)
from coreax.metrics import KSD, MMD
from coreax.solvers import (
KernelHerding,
RandomSample,
RPCholesky,
SteinThinning,
)
from coreax.weights import MMDWeightsOptimiser


def setup_kernel(x: np.ndarray) -> SquaredExponentialKernel:
"""
Set up a squared exponential kernel using the median heuristic.

:param x: Input data array used to compute the kernel length scale.
:return: A SquaredExponentialKernel with the computed length scale.
"""
num_samples_length_scale = min(300, 1000)
random_seed = 45
generator = np.random.default_rng(random_seed)
idx = generator.choice(300, num_samples_length_scale, replace=False)
length_scale = median_heuristic(x[idx])
return SquaredExponentialKernel(length_scale=length_scale)


def setup_stein_kernel(
sq_exp_kernel: SquaredExponentialKernel, dataset: Data
) -> SteinKernel:
"""
Set up a Stein Kernel for Stein Thinning.

:param sq_exp_kernel: A SquaredExponential base kernel for the Stein Kernel.
:param dataset: Dataset for score matching.
:return: A SteinKernel object.
"""
sliced_score_matcher = SlicedScoreMatching(
jax.random.PRNGKey(45),
jax.random.rademacher,
use_analytic=True,
num_random_vectors=100,
learning_rate=0.001,
num_epochs=50,
)
return SteinKernel(
sq_exp_kernel,
sliced_score_matcher.match(jnp.asarray(dataset.data)),
)


def setup_solvers(
coreset_size: int,
sq_exp_kernel: SquaredExponentialKernel,
stein_kernel: SteinKernel,
) -> list[tuple[str, Any]]:
"""
Set up and return a list of solver configurations for reducing a dataset.

:param coreset_size: The size of the coresets to be generated by the solvers.
:param sq_exp_kernel: A Squared Exponential kernel for KernelHerding and RPCholesky.
:param stein_kernel: A Stein kernel object used for the SteinThinning solver.

:return: A list of tuples, where each tuple contains the name of the solver
and the corresponding solver object.
"""
random_key = jax.random.PRNGKey(42)
return [
(
"KernelHerding",
KernelHerding(coreset_size=coreset_size, kernel=sq_exp_kernel),
),
(
"RandomSample",
RandomSample(coreset_size=coreset_size, random_key=random_key),
),
(
"RPCholesky",
RPCholesky(
coreset_size=coreset_size,
kernel=sq_exp_kernel,
random_key=random_key,
),
),
(
"SteinThinning",
SteinThinning(
coreset_size=coreset_size,
kernel=stein_kernel,
regularise=False,
),
),
]


def compute_solver_metrics(
solver: Any,
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
weights_optimiser: MMDWeightsOptimiser,
) -> dict[str, float]:
"""
Compute weighted and unweighted MMD and KSD metrics for a given solver.

:param name: Name of the solver being evaluated.
:param solver: Solver object used to reduce the dataset.
:param dataset: The dataset.
:param mmd_metric: MMD metric object to compute MMD.
:param ksd_metric: KSD metric object to compute KSD.
:param weights_optimiser: Optimizer to compute weights for the coresubset.

:return: A dictionary with unweighted and weighted metrics (MMD, KSD) and
the time taken for the computation.
"""
start_time = time.perf_counter() # Using perf_counter for higher precision timing
coresubset, _ = solver.reduce(dataset)

# Unweighted metrics
unweighted_mmd = float(mmd_metric.compute(dataset, coresubset.coreset))
unweighted_ksd = float(ksd_metric.compute(dataset, coresubset.coreset))

# Weighted metrics
weighted_coresubset = coresubset.solve_weights(weights_optimiser)
weighted_mmd = float(weighted_coresubset.compute_metric(mmd_metric))
weighted_ksd = float(weighted_coresubset.compute_metric(ksd_metric))

end_time = time.perf_counter()
elapsed_time = end_time - start_time

return {
"unweighted_mmd": unweighted_mmd,
"unweighted_ksd": unweighted_ksd,
"weighted_mmd": weighted_mmd,
"weighted_ksd": weighted_ksd,
"time": elapsed_time,
}


def compute_metrics(
solvers: list[tuple[str, Any]],
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
weights_optimiser: MMDWeightsOptimiser,
) -> dict[str, dict[str, float]]:
"""
Compute the coresubsets and corresponding metrics for each solver in a given list.

:param solvers: A list of tuples containing solver names and their
respective solver objects.
:param dataset: The dataset.
:param mmd_metric: The MMD metric object for computing Maximum Mean Discrepancy.
:param ksd_metric: The KSD metric object for computing Kernel Stein Discrepancy.
:param weights_optimiser: The optimizer object for weights for the coresubset.

:return: A dictionary where the keys are the solver names, and the values are
dictionaries of computed metrics (unweighted/weighted MMD and KSD, and
computation time).
"""
return {
name: compute_solver_metrics(
solver, dataset, mmd_metric, ksd_metric, weights_optimiser
)
for name, solver in solvers
}


def main() -> None:
"""
Benchmark different algorithms against on a synthetic dataset.

Compare the performance of different coreset algorithms using a synthetic dataset,
generated using :func:`sklearn.datasets.make_blobs`. We set up various solvers,
generate coresets of multiple sizes, and compute performance metrics (MMD and KSD)
for each solver at each coreset size. Results are saved to a JSON file.
"""
# Generate data
x, *_ = make_blobs(n_samples=1000, n_features=2, centers=10, random_state=45)
dataset = Data(jnp.array(x))

# Set up kernel
sq_exp_kernel = setup_kernel(x)

# Set up Stein Kernel
stein_kernel = setup_stein_kernel(sq_exp_kernel, dataset)

# Set up metrics
mmd_metric = MMD(kernel=sq_exp_kernel)
ksd_metric = KSD(kernel=sq_exp_kernel)

# Set up weights optimizer
weights_optimiser = MMDWeightsOptimiser(kernel=sq_exp_kernel)

# Define coreset sizes
coreset_sizes = [10, 50, 100, 200]

all_results = {}

for size in coreset_sizes:
solvers = setup_solvers(size, sq_exp_kernel, stein_kernel)

# Compute metrics
results = compute_metrics(
solvers, dataset, mmd_metric, ksd_metric, weights_optimiser
)
all_results[size] = results

# Save results to JSON file
base_dir = os.path.dirname(os.path.abspath(__file__))
with open(
os.path.join(base_dir, "blobs_benchmark_results.json"), "w", encoding="utf-8"
) as f:
json.dump(all_results, f, indent=2)


if __name__ == "__main__":
main()
Loading