Skip to content

Commit 5312998

Browse files
author
Max Jones
authored
Setup benchmarks (#64)
* Setup benchmarks using asv * Add asv to dev-requirements.txt * Add info about benchmarking to a contributing guide * Add benchmark for concatenating input dims * Add 4-d case to benchmarks
1 parent 3d16162 commit 5312998

9 files changed

+409
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ nosetests.xml
4343
coverage.xml
4444
*,cover
4545

46+
# asv environments
47+
.asv
48+
4649
# Translations
4750
*.mo
4851
*.pot

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ repos:
1010
- id: end-of-file-fixer
1111
- id: check-docstring-first
1212
- id: check-json
13+
exclude: "asv_bench/asv.conf.json"
1314
- id: check-yaml
1415
- id: double-quote-string-fixer
1516

CONTRIBUTING.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Xbatcher's contributor guidelines [can be found in the online documentation](https://xbatcher.readthedocs.io/en/latest/contributing.html).

asv_bench/asv.conf.json

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
{
2+
// The version of the config file format. Do not change, unless
3+
// you know what you are doing.
4+
"version": 1,
5+
6+
// The name of the project being benchmarked
7+
"project": "xbatcher",
8+
9+
// The project's homepage
10+
"project_url": "https://xbatcher.readthedocs.io/",
11+
12+
// The URL or local path of the source code repository for the
13+
// project being benchmarked
14+
"repo": "..",
15+
16+
// The Python project's subdirectory in your repo. If missing or
17+
// the empty string, the project is assumed to be located at the root
18+
// of the repository.
19+
// "repo_subdir": "",
20+
21+
// Customizable commands for building, installing, and
22+
// uninstalling the project. See asv.conf.json documentation.
23+
//
24+
// "install_command": ["in-dir={env_dir} python -mpip install {wheel_file}"],
25+
// "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"],
26+
// "build_command": [
27+
// "python setup.py build",
28+
// "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}"
29+
// ],
30+
31+
// List of branches to benchmark. If not provided, defaults to "master"
32+
// (for git) or "default" (for mercurial).
33+
"branches": ["main"], // for git
34+
35+
// The DVCS being used. If not set, it will be automatically
36+
// determined from "repo" by looking at the protocol in the URL
37+
// (if remote), or by looking for special directories, such as
38+
// ".git" (if local).
39+
"dvcs": "git",
40+
41+
// The tool to use to create environments. May be "conda",
42+
// "virtualenv" or other value depending on the plugins in use.
43+
// If missing or the empty string, the tool will be automatically
44+
// determined by looking for tools on the PATH environment
45+
// variable.
46+
"environment_type": "conda",
47+
48+
// timeout in seconds for installing any dependencies in environment
49+
// defaults to 10 min
50+
"install_timeout": 600,
51+
52+
// the base URL to show a commit for the project.
53+
// "show_commit_url": "http://github.com/pangeo-data/xbatcher/commit/",
54+
55+
// The Pythons you'd like to test against. If not provided, defaults
56+
// to the current version of Python used to run `asv`.
57+
// "pythons": ["3.8"],
58+
59+
// The list of conda channel names to be searched for benchmark
60+
// dependency packages in the specified order
61+
"conda_channels": ["conda-forge"],
62+
63+
// A conda environment file that is used for environment creation.
64+
// "conda_environment_file": "environment.yml",
65+
66+
// The matrix of dependencies to test. Each key of the "req"
67+
// requirements dictionary is the name of a package (in PyPI) and
68+
// the values are version numbers. An empty list or empty string
69+
// indicates to just test against the default (latest)
70+
// version. null indicates that the package is to not be
71+
// installed. If the package to be tested is only available from
72+
// PyPi, and the 'environment_type' is conda, then you can preface
73+
// the package name by 'pip+', and the package will be installed
74+
// via pip (with all the conda available packages installed first,
75+
// followed by the pip installed packages).
76+
//
77+
// The ``@env`` and ``@env_nobuild`` keys contain the matrix of
78+
// environment variables to pass to build and benchmark commands.
79+
// An environment will be created for every combination of the
80+
// cartesian product of the "@env" variables in this matrix.
81+
// Variables in "@env_nobuild" will be passed to every environment
82+
// during the benchmark phase, but will not trigger creation of
83+
// new environments. A value of ``null`` means that the variable
84+
// will not be set for the current combination.
85+
//
86+
// "matrix": {
87+
// "req": {
88+
// "numpy": ["1.6", "1.7"],
89+
// "six": ["", null], // test with and without six installed
90+
// "pip+emcee": [""] // emcee is only available for install with pip.
91+
// },
92+
// "env": {"ENV_VAR_1": ["val1", "val2"]},
93+
// "env_nobuild": {"ENV_VAR_2": ["val3", null]},
94+
// },
95+
// "matrix": {
96+
// "xarray": [""],
97+
// "numpy": [""],
98+
// "dask": [""],
99+
// },
100+
101+
// Combinations of libraries/python versions can be excluded/included
102+
// from the set to test. Each entry is a dictionary containing additional
103+
// key-value pairs to include/exclude.
104+
//
105+
// An exclude entry excludes entries where all values match. The
106+
// values are regexps that should match the whole string.
107+
//
108+
// An include entry adds an environment. Only the packages listed
109+
// are installed. The 'python' key is required. The exclude rules
110+
// do not apply to includes.
111+
//
112+
// In addition to package names, the following keys are available:
113+
//
114+
// - python
115+
// Python version, as in the *pythons* variable above.
116+
// - environment_type
117+
// Environment type, as above.
118+
// - sys_platform
119+
// Platform, as in sys.platform. Possible values for the common
120+
// cases: 'linux2', 'win32', 'cygwin', 'darwin'.
121+
// - req
122+
// Required packages
123+
// - env
124+
// Environment variables
125+
// - env_nobuild
126+
// Non-build environment variables
127+
//
128+
// "exclude": [
129+
// {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows
130+
// {"environment_type": "conda", "req": {"six": null}}, // don't run without six on conda
131+
// {"env": {"ENV_VAR_1": "val2"}}, // skip val2 for ENV_VAR_1
132+
// ],
133+
//
134+
// "include": [
135+
// // additional env for python2.7
136+
// {"python": "2.7", "req": {"numpy": "1.8"}, "env_nobuild": {"FOO": "123"}},
137+
// // additional env if run on windows+conda
138+
// {"platform": "win32", "environment_type": "conda", "python": "2.7", "req": {"libpython": ""}},
139+
// ],
140+
141+
// The directory (relative to the current directory) that benchmarks are
142+
// stored in. If not provided, defaults to "benchmarks"
143+
"benchmark_dir": "benchmarks",
144+
145+
// The directory (relative to the current directory) to cache the Python
146+
// environments in. If not provided, defaults to "env"
147+
"env_dir": ".asv/env",
148+
149+
// The directory (relative to the current directory) that raw benchmark
150+
// results are stored in. If not provided, defaults to "results".
151+
"results_dir": ".asv/results",
152+
153+
// The directory (relative to the current directory) that the html tree
154+
// should be written to. If not provided, defaults to "html".
155+
"html_dir": ".asv/html"
156+
157+
// The number of characters to retain in the commit hashes.
158+
// "hash_length": 8,
159+
160+
// `asv` will cache results of the recent builds in each
161+
// environment, making them faster to install next time. This is
162+
// the number of builds to keep, per environment.
163+
// "build_cache_size": 2,
164+
165+
// The commits after which the regression search in `asv publish`
166+
// should start looking for regressions. Dictionary whose keys are
167+
// regexps matching to benchmark names, and values corresponding to
168+
// the commit (exclusive) after which to start looking for
169+
// regressions. The default is to start from the first commit
170+
// with results. If the commit is `null`, regression detection is
171+
// skipped for the matching benchmark.
172+
//
173+
// "regressions_first_commits": {
174+
// "some_benchmark": "352cdf", // Consider regressions only after this commit
175+
// "another_benchmark": null, // Skip regression detection altogether
176+
// },
177+
178+
// The thresholds for relative change in results, after which `asv
179+
// publish` starts reporting regressions. Dictionary of the same
180+
// form as in ``regressions_first_commits``, with values
181+
// indicating the thresholds. If multiple entries match, the
182+
// maximum is taken. If no entry matches, the default is 5%.
183+
//
184+
// "regressions_thresholds": {
185+
// "some_benchmark": 0.01, // Threshold of 1%
186+
// "another_benchmark": 0.5, // Threshold of 50%
187+
// },
188+
}

asv_bench/benchmarks/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
def parameterized(names, params):
2+
"""
3+
Copied from xarray benchmarks:
4+
https://github.com/pydata/xarray/blob/main/asv_bench/benchmarks/__init__.py#L9-L15
5+
"""
6+
7+
def decorator(func):
8+
func.param_names = names
9+
func.params = params
10+
return func
11+
12+
return decorator

asv_bench/benchmarks/benchmarks.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import numpy as np
2+
import torch
3+
import xarray as xr
4+
5+
from xbatcher import BatchGenerator
6+
from xbatcher.loaders.torch import IterableDataset, MapDataset
7+
8+
from . import parameterized
9+
10+
11+
class Base:
12+
def setup(self, *args, **kwargs):
13+
shape = (10, 50, 100)
14+
self.ds_3d = xr.Dataset(
15+
{
16+
'foo': (['time', 'y', 'x'], np.random.rand(*shape)),
17+
},
18+
{
19+
'x': (['x'], np.arange(shape[-1])),
20+
'y': (['y'], np.arange(shape[-2])),
21+
},
22+
)
23+
24+
shape_4d = (10, 50, 100, 3)
25+
self.ds_4d = xr.Dataset(
26+
{
27+
'foo': (['time', 'y', 'x', 'b'], np.random.rand(*shape_4d)),
28+
},
29+
{
30+
'x': (['x'], np.arange(shape_4d[-2])),
31+
'y': (['y'], np.arange(shape_4d[-3])),
32+
'b': (['b'], np.arange(shape_4d[-1])),
33+
},
34+
)
35+
36+
self.ds_xy = xr.Dataset(
37+
{
38+
'x': (
39+
['sample', 'feature'],
40+
np.random.random((shape[-1], shape[0])),
41+
),
42+
'y': (['sample'], np.random.random(shape[-1])),
43+
},
44+
)
45+
46+
47+
class Generator(Base):
48+
@parameterized(['preload_batch'], ([True, False]))
49+
def time_batch_preload(self, preload_batch):
50+
"""
51+
Construct a generator on a chunked DataSet with and without preloading
52+
batches.
53+
"""
54+
ds_dask = self.ds_xy.chunk({'sample': 2})
55+
BatchGenerator(
56+
ds_dask, input_dims={'sample': 2}, preload_batch=preload_batch
57+
)
58+
59+
@parameterized(
60+
['input_dims', 'batch_dims', 'input_overlap'],
61+
(
62+
[{'x': 5}, {'x': 10}, {'x': 5, 'y': 5}, {'x': 10, 'y': 5}],
63+
[{}, {'x': 20}, {'x': 30}],
64+
[{}, {'x': 1}, {'x': 2}],
65+
),
66+
)
67+
def time_batch_input(self, input_dims, batch_dims, input_overlap):
68+
"""
69+
Benchmark simple batch generation case.
70+
"""
71+
BatchGenerator(
72+
self.ds_3d,
73+
input_dims=input_dims,
74+
batch_dims=batch_dims,
75+
input_overlap=input_overlap,
76+
)
77+
78+
@parameterized(
79+
['input_dims', 'concat_input_dims'],
80+
([{'x': 5}, {'x': 10}, {'x': 5, 'y': 5}], [True, False]),
81+
)
82+
def time_batch_concat(self, input_dims, concat_input_dims):
83+
"""
84+
Construct a generator on a DataSet with and without concatenating
85+
chunks specified by ``input_dims`` into the batch dimension.
86+
"""
87+
BatchGenerator(
88+
self.ds_3d,
89+
input_dims=input_dims,
90+
concat_input_dims=concat_input_dims,
91+
)
92+
93+
@parameterized(
94+
['input_dims', 'batch_dims', 'concat_input_dims'],
95+
(
96+
[{'x': 5}, {'x': 5, 'y': 5}],
97+
[{}, {'x': 10}, {'x': 10, 'y': 10}],
98+
[True, False],
99+
),
100+
)
101+
def time_batch_concat_4d(self, input_dims, batch_dims, concat_input_dims):
102+
"""
103+
Construct a generator on a DataSet with and without concatenating
104+
chunks specified by ``input_dims`` into the batch dimension.
105+
"""
106+
BatchGenerator(
107+
self.ds_4d,
108+
input_dims=input_dims,
109+
batch_dims=batch_dims,
110+
concat_input_dims=concat_input_dims,
111+
)
112+
113+
114+
class Accessor(Base):
115+
@parameterized(
116+
['input_dims'],
117+
([{'x': 2}, {'x': 4}, {'x': 2, 'y': 2}, {'x': 4, 'y': 2}]),
118+
)
119+
def time_accessor_input_dim(self, input_dims):
120+
"""
121+
Benchmark simple batch generation case using xarray accessor
122+
Equivalent to subset of ``time_batch_input()``.
123+
"""
124+
self.ds_3d.batch.generator(input_dims=input_dims)
125+
126+
127+
class TorchLoader(Base):
128+
def setup(self, *args, **kwargs):
129+
super().setup(**kwargs)
130+
self.x_gen = BatchGenerator(self.ds_xy['x'], {'sample': 10})
131+
self.y_gen = BatchGenerator(self.ds_xy['y'], {'sample': 10})
132+
133+
def time_map_dataset(self):
134+
"""
135+
Benchmark MapDataset integration with torch DataLoader.
136+
"""
137+
dataset = MapDataset(self.x_gen, self.y_gen)
138+
loader = torch.utils.data.DataLoader(dataset)
139+
iter(loader).next()
140+
141+
def time_iterable_dataset(self):
142+
"""
143+
Benchmark IterableDataset integration with torch DataLoader.
144+
"""
145+
dataset = IterableDataset(self.x_gen, self.y_gen)
146+
loader = torch.utils.data.DataLoader(dataset)
147+
iter(loader).next()

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ torch
33
coverage
44
pytest-cov
55
adlfs
6+
asv
67
-r requirements.txt

0 commit comments

Comments
 (0)