diff --git a/xbatcher/generators.py b/xbatcher/generators.py index b5edff0..fd077d3 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -69,6 +69,7 @@ def __init__( batch_dims: Optional[Dict[Hashable, int]] = None, concat_input_bins: bool = True, preload_batch: bool = True, + return_partial: bool = False, ): if input_overlap is None: input_overlap = {} @@ -79,6 +80,7 @@ def __init__( self.batch_dims = dict(batch_dims) self.concat_input_dims = concat_input_bins self.preload_batch = preload_batch + self.return_partial = return_partial # Store helpful information based on arguments self._duplicate_batch_dims: Dict[Hashable, int] = { dim: length @@ -131,6 +133,7 @@ def _gen_patch_selectors( ds, dims=self._all_sliced_dims, overlap=self.input_overlap, + return_partial=self.return_partial ) return all_slices @@ -272,7 +275,7 @@ def _get_batch_in_range_per_batch(self, batch_multi_index): return batch_in_range_per_patch -def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: +def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0, return_partial: bool = False) -> List[slice]: # return a list of slices to chop up a single dimension if overlap >= slice_size: raise ValueError( @@ -285,6 +288,8 @@ def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[sli end = start + slice_size if end <= dim_size: slices.append(slice(start, end)) + elif return_partial: + slices.append(slice(start, dim_size)) return slices @@ -293,6 +298,7 @@ def _iterate_through_dimensions( *, dims: Dict[Hashable, int], overlap: Dict[Hashable, int] = {}, + return_partial: bool = False, ) -> Iterator[Dict[Hashable, slice]]: dim_slices = [] for dim in dims: @@ -307,7 +313,7 @@ def _iterate_through_dimensions( f"for {dim}" ) dim_slices.append( - _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap) + _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap, return_partial=return_partial) ) for slices in itertools.product(*dim_slices): selector = dict(zip(dims, slices)) @@ -374,6 +380,9 @@ class BatchGenerator: preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. + return_partial: bool, optional + If ``True``, produce batches from edges when dims are not evenly divisible + by the input dim shapes cache : dict, optional Dict-like object to cache batches in (e.g., Zarr DirectoryStore). Note: The caching API is experimental and subject to change. @@ -395,6 +404,7 @@ def __init__( batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, + return_partial: bool = False, cache: Optional[Dict[str, Any]] = None, cache_preprocess: Optional[Callable] = None, ): @@ -409,6 +419,7 @@ def __init__( batch_dims=batch_dims, concat_input_bins=concat_input_dims, preload_batch=preload_batch, + return_partial=return_partial, ) @property diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 72953fd..cd56287 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -211,9 +211,12 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: s : int Number of batches expected given ``input_dims`` and ``input_overlap``. """ + # Add 0.5 if the generator is returning partial batches to account for + # the final batch that will be smaller than the rest. + final_batch_counts = 0.5 if generator._batch_selectors.return_partial else 0 nbatches_from_input_dims = np.prod( [ - generator.ds.sizes[dim] // length + int(generator.ds.sizes[dim] / length + final_batch_counts) for dim, length in generator.input_dims.items() if generator.input_overlap.get(dim) is None and generator.batch_dims.get(dim) is None @@ -222,8 +225,11 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: if generator.input_overlap: nbatches_from_input_overlap = np.prod( [ - (generator.ds.sizes[dim] - overlap) - // (generator.input_dims[dim] - overlap) + int( + (generator.ds.sizes[dim] - overlap) + / (generator.input_dims[dim] - overlap) + + final_batch_counts + ) for dim, overlap in generator.input_overlap.items() ] ) @@ -242,17 +248,22 @@ def validate_generator_length(generator: BatchGenerator) -> None: generator : xbatcher.BatchGenerator The batch generator object. """ + non_input_batch_dims = _get_non_input_batch_dims(generator) duplicate_batch_dims = _get_duplicate_batch_dims(generator) + + # Add 0.5 if the generator is returning partial batches to account for + # the final batch that will be smaller than the rest. + final_batch_counts = 0.5 if generator._batch_selectors.return_partial else 0 nbatches_from_unique_batch_dims = np.prod( [ - generator.ds.sizes[dim] // length + int(generator.ds.sizes[dim] / length + final_batch_counts) for dim, length in non_input_batch_dims.items() ] ) nbatches_from_duplicate_batch_dims = np.prod( [ - generator.ds.sizes[dim] // length + int(generator.ds.sizes[dim] / length + final_batch_counts) for dim, length in duplicate_batch_dims.items() ] ) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 248dd03..b7cf59a 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -58,11 +58,16 @@ def test_constructor_dataarray(): @pytest.mark.parametrize("input_size", [5, 6]) -def test_generator_length(sample_ds_1d, input_size): +@pytest.mark.parametrize("return_partial", [True, False]) +def test_generator_length(sample_ds_1d, input_size, return_partial): """ " Test the length of the batch generator. """ - bg = BatchGenerator(sample_ds_1d, input_dims={"x": input_size}) + bg = BatchGenerator( + sample_ds_1d, + input_dims={"x": input_size}, + return_partial=return_partial + ) validate_generator_length(bg)