@@ -40,7 +40,7 @@ def _iterate_through_dataset(ds, dims, overlap={}):
40
40
41
41
for slices in itertools .product (* dim_slices ):
42
42
selector = {key : slice for key , slice in zip (dims , slices )}
43
- yield ds . isel ( ** selector )
43
+ yield selector
44
44
45
45
46
46
def _drop_input_dims (ds , input_dims , suffix = "_input" ):
@@ -120,13 +120,11 @@ def __init__(
120
120
self .batch_dims = OrderedDict (batch_dims )
121
121
self .concat_input_dims = concat_input_dims
122
122
self .preload_batch = preload_batch
123
-
124
123
self ._batches : Dict [int , Any ] = self ._gen_batches () # dict cache for batches
125
- # in the future, we can make this a lru cache or similar thing (cachey?)
126
124
127
125
def __iter__ (self ) -> Iterator [xr .Dataset ]:
128
- for batch in self ._batches . values () :
129
- yield batch
126
+ for idx in self ._batches :
127
+ yield self [ idx ]
130
128
131
129
def __len__ (self ) -> int :
132
130
return len (self ._batches )
@@ -142,7 +140,25 @@ def __getitem__(self, idx: int) -> xr.Dataset:
142
140
idx = list (self ._batches )[idx ]
143
141
144
142
if idx in self ._batches :
145
- return self ._batches [idx ]
143
+
144
+ if self .concat_input_dims :
145
+ new_dim_suffix = "_input"
146
+ all_dsets = [
147
+ _drop_input_dims (
148
+ self .ds .isel (** ds_input_select ),
149
+ list (self .input_dims ),
150
+ suffix = new_dim_suffix ,
151
+ )
152
+ for ds_input_select in self ._batches [idx ]
153
+ ]
154
+ dsc = xr .concat (all_dsets , dim = "input_batch" )
155
+ new_input_dims = [str (dim ) + new_dim_suffix for dim in self .input_dims ]
156
+ return _maybe_stack_batch_dims (dsc , new_input_dims )
157
+ else :
158
+
159
+ return _maybe_stack_batch_dims (
160
+ self .ds .isel (** self ._batches [idx ]), list (self .input_dims )
161
+ )
146
162
else :
147
163
raise IndexError ("list index out of range" )
148
164
@@ -151,26 +167,17 @@ def _gen_batches(self) -> dict:
151
167
# going the eager route for now is allowing me to fill out the loader api
152
168
# but it is likely to perform poorly.
153
169
batches = []
154
- for ds_batch in self ._iterate_batch_dims (self .ds ):
170
+ for ds_batch_selector in self ._iterate_batch_dims (self .ds ):
171
+ ds_batch = self .ds .isel (** ds_batch_selector )
155
172
if self .preload_batch :
156
173
ds_batch .load ()
174
+
157
175
input_generator = self ._iterate_input_dims (ds_batch )
176
+
158
177
if self .concat_input_dims :
159
- new_dim_suffix = "_input"
160
- all_dsets = [
161
- _drop_input_dims (
162
- ds_input , list (self .input_dims ), suffix = new_dim_suffix
163
- )
164
- for ds_input in input_generator
165
- ]
166
- dsc = xr .concat (all_dsets , dim = "input_batch" )
167
- new_input_dims = [str (dim ) + new_dim_suffix for dim in self .input_dims ]
168
- batches .append (_maybe_stack_batch_dims (dsc , new_input_dims ))
178
+ batches .append (list (input_generator ))
169
179
else :
170
- for ds_input in input_generator :
171
- batches .append (
172
- _maybe_stack_batch_dims (ds_input , list (self .input_dims ))
173
- )
180
+ batches += list (input_generator )
174
181
175
182
return dict (zip (range (len (batches )), batches ))
176
183
0 commit comments