@@ -24,16 +24,16 @@ def setup(self, *args, **kwargs):
24
24
shape_4d = (10 , 50 , 100 , 3 )
25
25
self .ds_4d = xr .Dataset (
26
26
{
27
- "foo" : (["time" , "y" , "x" , "b " ], np .random .rand (* shape_4d )),
27
+ "foo" : (["time" , "y" , "x" , "z " ], np .random .rand (* shape_4d )),
28
28
},
29
29
{
30
30
"x" : (["x" ], np .arange (shape_4d [- 2 ])),
31
31
"y" : (["y" ], np .arange (shape_4d [- 3 ])),
32
- "b " : (["b " ], np .arange (shape_4d [- 1 ])),
32
+ "z " : (["z " ], np .arange (shape_4d [- 1 ])),
33
33
},
34
34
)
35
35
36
- self .ds_xy = xr .Dataset (
36
+ self .ds_2d = xr .Dataset (
37
37
{
38
38
"x" : (
39
39
["sample" , "feature" ],
@@ -51,8 +51,12 @@ def time_batch_preload(self, preload_batch):
51
51
Construct a generator on a chunked DataSet with and without preloading
52
52
batches.
53
53
"""
54
- ds_dask = self .ds_xy .chunk ({"sample" : 2 })
55
- BatchGenerator (ds_dask , input_dims = {"sample" : 2 }, preload_batch = preload_batch )
54
+ ds_dask = self .ds_2d .chunk ({"sample" : 2 })
55
+ bg = BatchGenerator (
56
+ ds_dask , input_dims = {"sample" : 2 }, preload_batch = preload_batch
57
+ )
58
+ for batch in bg :
59
+ pass
56
60
57
61
@parameterized (
58
62
["input_dims" , "batch_dims" , "input_overlap" ],
@@ -66,12 +70,14 @@ def time_batch_input(self, input_dims, batch_dims, input_overlap):
66
70
"""
67
71
Benchmark simple batch generation case.
68
72
"""
69
- BatchGenerator (
73
+ bg = BatchGenerator (
70
74
self .ds_3d ,
71
75
input_dims = input_dims ,
72
76
batch_dims = batch_dims ,
73
77
input_overlap = input_overlap ,
74
78
)
79
+ for batch in bg :
80
+ pass
75
81
76
82
@parameterized (
77
83
["input_dims" , "concat_input_dims" ],
@@ -82,11 +88,13 @@ def time_batch_concat(self, input_dims, concat_input_dims):
82
88
Construct a generator on a DataSet with and without concatenating
83
89
chunks specified by ``input_dims`` into the batch dimension.
84
90
"""
85
- BatchGenerator (
91
+ bg = BatchGenerator (
86
92
self .ds_3d ,
87
93
input_dims = input_dims ,
88
94
concat_input_dims = concat_input_dims ,
89
95
)
96
+ for batch in bg :
97
+ pass
90
98
91
99
@parameterized (
92
100
["input_dims" , "batch_dims" , "concat_input_dims" ],
@@ -101,12 +109,14 @@ def time_batch_concat_4d(self, input_dims, batch_dims, concat_input_dims):
101
109
Construct a generator on a DataSet with and without concatenating
102
110
chunks specified by ``input_dims`` into the batch dimension.
103
111
"""
104
- BatchGenerator (
112
+ bg = BatchGenerator (
105
113
self .ds_4d ,
106
114
input_dims = input_dims ,
107
115
batch_dims = batch_dims ,
108
116
concat_input_dims = concat_input_dims ,
109
117
)
118
+ for batch in bg :
119
+ pass
110
120
111
121
112
122
class Accessor (Base ):
@@ -119,27 +129,29 @@ def time_accessor_input_dim(self, input_dims):
119
129
Benchmark simple batch generation case using xarray accessor
120
130
Equivalent to subset of ``time_batch_input()``.
121
131
"""
122
- self .ds_3d .batch .generator (input_dims = input_dims )
132
+ bg = self .ds_3d .batch .generator (input_dims = input_dims )
133
+ for batch in bg :
134
+ pass
123
135
124
136
125
137
class TorchLoader (Base ):
126
138
def setup (self , * args , ** kwargs ):
127
139
super ().setup (** kwargs )
128
- self .x_gen = BatchGenerator (self .ds_xy ["x" ], {"sample" : 10 })
129
- self .y_gen = BatchGenerator (self .ds_xy ["y" ], {"sample" : 10 })
140
+ self .x_gen = BatchGenerator (self .ds_2d ["x" ], {"sample" : 10 })
141
+ self .y_gen = BatchGenerator (self .ds_2d ["y" ], {"sample" : 10 })
130
142
131
143
def time_map_dataset (self ):
132
144
"""
133
145
Benchmark MapDataset integration with torch DataLoader.
134
146
"""
135
147
dataset = MapDataset (self .x_gen , self .y_gen )
136
148
loader = torch .utils .data .DataLoader (dataset )
137
- iter (loader ). next ( )
149
+ next ( iter (loader ))
138
150
139
151
def time_iterable_dataset (self ):
140
152
"""
141
153
Benchmark IterableDataset integration with torch DataLoader.
142
154
"""
143
155
dataset = IterableDataset (self .x_gen , self .y_gen )
144
156
loader = torch .utils .data .DataLoader (dataset )
145
- iter (loader ). next ( )
157
+ next ( iter (loader ))
0 commit comments