-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/benchmarking #802
base: main
Are you sure you want to change the base?
Feature/benchmarking #802
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't got my head around everything yet, but this is good progress and you're clear to continue the job on benchmarking. I may have chucked on quite a few comments, although nothing is major.
CHANGELOG.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add benchmarking to change log.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there are any sensible unit tests to write for this module.
benchmark/mnist_benchmark.py
Outdated
|
||
""" | ||
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data)) | ||
_data, _targets = next(iter(data_loader)) # Load all data at once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with DataLoader. next(iter(
would normally only grab the first item of something. Does this fetch all the data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty involved file. I would look at adding some unit tests for the medium-sized functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seconding @tp832944 's comment - please add some unit tests
Performance reviewCommit
|
Performance reviewCommit
|
…laining the benchmark process.
Performance reviewCommit
|
Performance reviewCommit
|
Note: Upon fix of RPCholesky (once #787 and related fixes are Done), we expect benchmarking metrics to change significantly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qh681248 Please find some further requested changes below.
This is an intermediate review. I'm having trouble running mnist_benchmark.py. Once I've resolved the issue, I will continue reviewing this PR
benchmark/blobs_benchmark.py
Outdated
|
||
def main() -> None: | ||
""" | ||
Perform a benchmark comparing different coreset algorithms on a synthetic dataset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding more information to this docstring, in response to a previous review comment.
It is still grammatically clumsy.
"Perform a benchmark..." is redundant, since 'benchmark' is a verb.
"...for each solver at different." isn't a correct ending to a sentence. Is the word 'sizes' missing?
Suggested rephrasing:
Benchmark different algorithms against a synthetic dataset.
Compare the performance of different coreset algorithms using a synthetic dataset,
generated using :func:`sklearn.datasets.make_blobs`. We set up various solvers,
generate coresets of multiple sizes, and compute performance metrics (MMD and KSD)
for each solver at each coreset size. Results are saved to a JSON file.
benchmark/blobs_benchmark.py
Outdated
|
||
The benchmarking process follows these steps: | ||
1. Generate a synthetic dataset of 1000 two-dimensional points using | ||
`sklearn.datasets.make_blobs`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add :func:
before sklearn reference, so the documentation externally links correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
`sklearn.datasets.make_blobs`. | |
:func:`sklearn.datasets.make_blobs`. |
benchmark/blobs_benchmark.py
Outdated
|
||
import json | ||
import time | ||
from typing import Any, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove Tuple
, and use type-annotations below with tuple
instead (also requested in #802 (comment))
benchmark/blobs_benchmark.py
Outdated
all_results[size] = results | ||
|
||
# Save results to JSON file | ||
with open("coreset_comparison_results.json", "w", encoding="utf-8") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this filename different from the one in blobs_benchmark_visualiser.py ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching that, I have changed it to the correct file name now.
benchmark/blobs_benchmark.py
Outdated
def setup_stein_kernel( | ||
sq_exp_kernel: SquaredExponentialKernel, dataset: Data | ||
) -> SteinKernel: | ||
"""Set up Stein Kernel.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document parameters and return for this function
:param sq_exp_kernel: ...
:param dataset: ...
:return: ...
benchmark/mnist_benchmark.py
Outdated
|
||
|
||
def save_results(results: dict) -> None: | ||
"""Save results to JSON.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document parameters!
import jax.numpy as jnp | ||
import numpy as np | ||
import optax | ||
import torchvision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qh681248 please add torch and torchvision. This file can't be run without these.
Please add them to pyproject.toml
.cspell/library_terms.txt
Outdated
@@ -109,6 +109,7 @@ rcond | |||
rect | |||
rerunfailures | |||
rightarrow | |||
rngs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is in custom_misc.txt too. It should only be needed in one of these .cspell files
|
||
|
||
def main() -> None: | ||
"""Load benchmark results and visualize the algorithm performance.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use British spelling: visualize -> visualise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seconding @tp832944 's comment - please add some unit tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qh681248 Some further questions/requested changes. Thanks in advance for responding
pyproject.toml
Outdated
@@ -34,6 +34,8 @@ dependencies = [ | |||
"scikit-learn", | |||
"tqdm", | |||
"typing-extensions", | |||
"pytorch", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"pytorch" isn't the package named for PyTorch. It's "torch"
You can learn this by trying to run uv sync
with this erroneous pyproject.toml file
@@ -123,6 +124,7 @@ tensordot | |||
texttt | |||
toctree | |||
tomli | |||
torchvision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider adding torch
to this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a particular reason we're committing these results files? I would have assumed they didn't need to be added to the repo, as they can be recreated locally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one takes a lot of runtime and computing power to generate, so we thought that it is worth committing (especially if we run it on a powerful computer to get the results)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a particular reason we're committing these results files? I would have assumed they didn't need to be added to the repo, as they can be recreated locally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this one needs to be added as it can be recreated locally.
benchmark/blobs_benchmark.py
Outdated
|
||
# Save results to JSON file | ||
with open("blobs_benchmark_results.json", "w", encoding="utf-8") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add more control over where this results file is saved. Suggested change:
base_dir = os.path.dirname(os.path.abspath(__file__)) # Gets the directory of current script
# Save results to JSON file
with open(os.path.join(base_dir,"blobs_benchmark_results.json"), "w", encoding="utf-8") as f:
|
||
def main() -> None: | ||
"""Load benchmark results and visualise the algorithm performance.""" | ||
with open("mnist_benchmark_results.json", "r", encoding="utf-8") as file: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly for blobs, enforce the location of the .json file. Suggested change:
base_dir = os.path.dirname(os.path.abspath(__file__)) # Gets the directory of current script
# Save results to JSON file
with open(os.path.join(base_dir,"mnist_benchmark_results.json"), "w", encoding="utf-8") as f:
Performance reviewCommit
|
PR Type
Description
Added two files mnist_benchmark.py and blobs_benchmark.py
How Has This Been Tested?
Test A: (Write your answer here.)
Test B: (Write your answer here.)
Test C: (Write your answer here.)
Does this PR introduce a breaking change?
(Write your answer here.)
Screenshots
(Write your answer here.)
Checklist before requesting a review