Skip to content

Commit af7941f

Browse files
tjvandalThomas Vandaljhamman
authored
Lazily generate batches (#112)
* caching only selectors in batch generation * Remove comment Co-authored-by: Joe Hamman <[email protected]> * Remove comment Co-authored-by: Joe Hamman <[email protected]> * change loop Co-authored-by: Joe Hamman <[email protected]> Co-authored-by: Thomas Vandal <[email protected]> Co-authored-by: Joe Hamman <[email protected]>
1 parent 4d8e2c8 commit af7941f

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

xbatcher/generators.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _iterate_through_dataset(ds, dims, overlap={}):
4040

4141
for slices in itertools.product(*dim_slices):
4242
selector = {key: slice for key, slice in zip(dims, slices)}
43-
yield ds.isel(**selector)
43+
yield selector
4444

4545

4646
def _drop_input_dims(ds, input_dims, suffix="_input"):
@@ -120,13 +120,11 @@ def __init__(
120120
self.batch_dims = OrderedDict(batch_dims)
121121
self.concat_input_dims = concat_input_dims
122122
self.preload_batch = preload_batch
123-
124123
self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches
125-
# in the future, we can make this a lru cache or similar thing (cachey?)
126124

127125
def __iter__(self) -> Iterator[xr.Dataset]:
128-
for batch in self._batches.values():
129-
yield batch
126+
for idx in self._batches:
127+
yield self[idx]
130128

131129
def __len__(self) -> int:
132130
return len(self._batches)
@@ -142,7 +140,25 @@ def __getitem__(self, idx: int) -> xr.Dataset:
142140
idx = list(self._batches)[idx]
143141

144142
if idx in self._batches:
145-
return self._batches[idx]
143+
144+
if self.concat_input_dims:
145+
new_dim_suffix = "_input"
146+
all_dsets = [
147+
_drop_input_dims(
148+
self.ds.isel(**ds_input_select),
149+
list(self.input_dims),
150+
suffix=new_dim_suffix,
151+
)
152+
for ds_input_select in self._batches[idx]
153+
]
154+
dsc = xr.concat(all_dsets, dim="input_batch")
155+
new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims]
156+
return _maybe_stack_batch_dims(dsc, new_input_dims)
157+
else:
158+
159+
return _maybe_stack_batch_dims(
160+
self.ds.isel(**self._batches[idx]), list(self.input_dims)
161+
)
146162
else:
147163
raise IndexError("list index out of range")
148164

@@ -151,26 +167,17 @@ def _gen_batches(self) -> dict:
151167
# going the eager route for now is allowing me to fill out the loader api
152168
# but it is likely to perform poorly.
153169
batches = []
154-
for ds_batch in self._iterate_batch_dims(self.ds):
170+
for ds_batch_selector in self._iterate_batch_dims(self.ds):
171+
ds_batch = self.ds.isel(**ds_batch_selector)
155172
if self.preload_batch:
156173
ds_batch.load()
174+
157175
input_generator = self._iterate_input_dims(ds_batch)
176+
158177
if self.concat_input_dims:
159-
new_dim_suffix = "_input"
160-
all_dsets = [
161-
_drop_input_dims(
162-
ds_input, list(self.input_dims), suffix=new_dim_suffix
163-
)
164-
for ds_input in input_generator
165-
]
166-
dsc = xr.concat(all_dsets, dim="input_batch")
167-
new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims]
168-
batches.append(_maybe_stack_batch_dims(dsc, new_input_dims))
178+
batches.append(list(input_generator))
169179
else:
170-
for ds_input in input_generator:
171-
batches.append(
172-
_maybe_stack_batch_dims(ds_input, list(self.input_dims))
173-
)
180+
batches += list(input_generator)
174181

175182
return dict(zip(range(len(batches)), batches))
176183

0 commit comments

Comments
 (0)