Skip to content

Commit 89746a7

Browse files
authored
Iterate through batch generator in benchmarks (#140)
* Fix TorchLoader benchmarks * Iterate through batch generator in benchmarks
1 parent f93af88 commit 89746a7

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

asv_bench/benchmarks/benchmarks.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ def setup(self, *args, **kwargs):
2424
shape_4d = (10, 50, 100, 3)
2525
self.ds_4d = xr.Dataset(
2626
{
27-
"foo": (["time", "y", "x", "b"], np.random.rand(*shape_4d)),
27+
"foo": (["time", "y", "x", "z"], np.random.rand(*shape_4d)),
2828
},
2929
{
3030
"x": (["x"], np.arange(shape_4d[-2])),
3131
"y": (["y"], np.arange(shape_4d[-3])),
32-
"b": (["b"], np.arange(shape_4d[-1])),
32+
"z": (["z"], np.arange(shape_4d[-1])),
3333
},
3434
)
3535

36-
self.ds_xy = xr.Dataset(
36+
self.ds_2d = xr.Dataset(
3737
{
3838
"x": (
3939
["sample", "feature"],
@@ -51,8 +51,12 @@ def time_batch_preload(self, preload_batch):
5151
Construct a generator on a chunked DataSet with and without preloading
5252
batches.
5353
"""
54-
ds_dask = self.ds_xy.chunk({"sample": 2})
55-
BatchGenerator(ds_dask, input_dims={"sample": 2}, preload_batch=preload_batch)
54+
ds_dask = self.ds_2d.chunk({"sample": 2})
55+
bg = BatchGenerator(
56+
ds_dask, input_dims={"sample": 2}, preload_batch=preload_batch
57+
)
58+
for batch in bg:
59+
pass
5660

5761
@parameterized(
5862
["input_dims", "batch_dims", "input_overlap"],
@@ -66,12 +70,14 @@ def time_batch_input(self, input_dims, batch_dims, input_overlap):
6670
"""
6771
Benchmark simple batch generation case.
6872
"""
69-
BatchGenerator(
73+
bg = BatchGenerator(
7074
self.ds_3d,
7175
input_dims=input_dims,
7276
batch_dims=batch_dims,
7377
input_overlap=input_overlap,
7478
)
79+
for batch in bg:
80+
pass
7581

7682
@parameterized(
7783
["input_dims", "concat_input_dims"],
@@ -82,11 +88,13 @@ def time_batch_concat(self, input_dims, concat_input_dims):
8288
Construct a generator on a DataSet with and without concatenating
8389
chunks specified by ``input_dims`` into the batch dimension.
8490
"""
85-
BatchGenerator(
91+
bg = BatchGenerator(
8692
self.ds_3d,
8793
input_dims=input_dims,
8894
concat_input_dims=concat_input_dims,
8995
)
96+
for batch in bg:
97+
pass
9098

9199
@parameterized(
92100
["input_dims", "batch_dims", "concat_input_dims"],
@@ -101,12 +109,14 @@ def time_batch_concat_4d(self, input_dims, batch_dims, concat_input_dims):
101109
Construct a generator on a DataSet with and without concatenating
102110
chunks specified by ``input_dims`` into the batch dimension.
103111
"""
104-
BatchGenerator(
112+
bg = BatchGenerator(
105113
self.ds_4d,
106114
input_dims=input_dims,
107115
batch_dims=batch_dims,
108116
concat_input_dims=concat_input_dims,
109117
)
118+
for batch in bg:
119+
pass
110120

111121

112122
class Accessor(Base):
@@ -119,27 +129,29 @@ def time_accessor_input_dim(self, input_dims):
119129
Benchmark simple batch generation case using xarray accessor
120130
Equivalent to subset of ``time_batch_input()``.
121131
"""
122-
self.ds_3d.batch.generator(input_dims=input_dims)
132+
bg = self.ds_3d.batch.generator(input_dims=input_dims)
133+
for batch in bg:
134+
pass
123135

124136

125137
class TorchLoader(Base):
126138
def setup(self, *args, **kwargs):
127139
super().setup(**kwargs)
128-
self.x_gen = BatchGenerator(self.ds_xy["x"], {"sample": 10})
129-
self.y_gen = BatchGenerator(self.ds_xy["y"], {"sample": 10})
140+
self.x_gen = BatchGenerator(self.ds_2d["x"], {"sample": 10})
141+
self.y_gen = BatchGenerator(self.ds_2d["y"], {"sample": 10})
130142

131143
def time_map_dataset(self):
132144
"""
133145
Benchmark MapDataset integration with torch DataLoader.
134146
"""
135147
dataset = MapDataset(self.x_gen, self.y_gen)
136148
loader = torch.utils.data.DataLoader(dataset)
137-
iter(loader).next()
149+
next(iter(loader))
138150

139151
def time_iterable_dataset(self):
140152
"""
141153
Benchmark IterableDataset integration with torch DataLoader.
142154
"""
143155
dataset = IterableDataset(self.x_gen, self.y_gen)
144156
loader = torch.utils.data.DataLoader(dataset)
145-
iter(loader).next()
157+
next(iter(loader))

0 commit comments

Comments
 (0)