diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 612be61..a257567 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -54,14 +54,22 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): return out -def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): +def _maybe_stack_batch_dims( + ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample' +): batch_dims = [d for d in ds.dims if d not in input_dims] - if len(batch_dims) < 2: + if len(batch_dims) == 0: + if squeeze_batch_dim: + return ds + else: + return ds.expand_dims(stacked_dim_name, 0) + elif len(batch_dims) == 1: return ds - ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) - # ensure correct order - dim_order = (stacked_dim_name,) + tuple(input_dims) - return ds_stack.transpose(*dim_order) + else: + ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) + # ensure correct order + dim_order = (stacked_dim_name,) + tuple(input_dims) + return ds_stack.transpose(*dim_order) class BatchGenerator: @@ -90,6 +98,10 @@ class BatchGenerator: preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. + squeeze_batch_dim : bool, optional + If ``False`` and all dims are input dims, each batch's dataset will have a + "batch" dimension of size 1 prepended to the array. This functionality is + useful for interoperability with Keras / Tensorflow. Yields ------ @@ -105,6 +117,7 @@ def __init__( batch_dims={}, concat_input_dims=False, preload_batch=True, + squeeze_batch_dim=True, ): self.ds = _as_xarray_dataset(ds) @@ -114,6 +127,7 @@ def __init__( self.batch_dims = OrderedDict(batch_dims) self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch + self.squeeze_batch_dim = squeeze_batch_dim def __iter__(self): for ds_batch in self._iterate_batch_dims(self.ds): @@ -132,11 +146,13 @@ def __iter__(self): new_input_dims = [ dim + new_dim_suffix for dim in self.input_dims ] - yield _maybe_stack_batch_dims(dsc, new_input_dims) + yield _maybe_stack_batch_dims( + dsc, new_input_dims, self.squeeze_batch_dim + ) else: for ds_input in input_generator: yield _maybe_stack_batch_dims( - ds_input, list(self.input_dims) + ds_input, list(self.input_dims), self.squeeze_batch_dim ) def _iterate_batch_dims(self, ds): diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 38acae9..1b99083 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -169,6 +169,63 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize): ) +@pytest.mark.parametrize('bsize', [10, 20]) +def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize): + bg = BatchGenerator( + sample_ds_1d, + input_dims={'x': bsize}, + squeeze_batch_dim=False, + ) + for ds_batch in bg: + assert list(ds_batch['foo'].shape) == [1, bsize] + + bg2 = BatchGenerator( + sample_ds_1d, + input_dims={'x': bsize}, + squeeze_batch_dim=True, + ) + for ds_batch in bg2: + assert list(ds_batch['foo'].shape) == [bsize] + + +@pytest.mark.parametrize('bsize', [5, 10]) +def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize): + bg = BatchGenerator( + sample_ds_3d, + input_dims={'y': bsize, 'x': bsize}, + squeeze_batch_dim=False, + ) + for ds_batch in bg: + assert list(ds_batch['foo'].shape) == [10, bsize, bsize] + + bg2 = BatchGenerator( + sample_ds_3d, + input_dims={'y': bsize, 'x': bsize}, + squeeze_batch_dim=True, + ) + for ds_batch in bg2: + assert list(ds_batch['foo'].shape) == [10, bsize, bsize] + + +@pytest.mark.parametrize('bsize', [5, 10]) +def test_batch_3d_squeeze_batch_dim2(sample_ds_3d, bsize): + bg = BatchGenerator( + sample_ds_3d, + input_dims={'x': bsize}, + squeeze_batch_dim=False, + ) + for ds_batch in bg: + assert list(ds_batch['foo'].shape) == [500, bsize] + + bg2 = BatchGenerator( + sample_ds_3d, + input_dims={'x': bsize}, + squeeze_batch_dim=True, + ) + for ds_batch in bg2: + assert list(ds_batch['foo'].shape) == [500, bsize] + + def test_preload_batch_false(sample_ds_1d): sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2}) bg = BatchGenerator(