diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index 44c2e84..38d247c 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -7,13 +7,13 @@ @xr.register_dataset_accessor('batch') class BatchAccessor: def __init__(self, xarray_obj): - ''' + """ Batch accessor returning a BatchGenerator object via the `generator method` - ''' + """ self._obj = xarray_obj def generator(self, *args, **kwargs): - ''' + """ Return a BatchGenerator via the batch accessor Parameters @@ -22,7 +22,7 @@ def generator(self, *args, **kwargs): Positional arguments to pass to the `BatchGenerator` constructor. **kwargs : dict Keyword arguments to pass to the `BatchGenerator` constructor. - ''' + """ return BatchGenerator(self._obj, *args, **kwargs) @@ -38,7 +38,11 @@ def to_tensor(self): return torch.tensor(self._obj.data) def to_named_tensor(self): - """Convert this DataArray to a torch.Tensor with named dimensions""" + """ + Convert this DataArray to a torch.Tensor with named dimensions. + + See https://pytorch.org/docs/stable/named_tensor.html + """ import torch - return torch.tensor(self._obj.data, names=self._obj.dims) + return torch.tensor(self._obj.data, names=tuple(self._obj.sizes)) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index e9bcf65..144aff3 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -7,14 +7,6 @@ import xarray as xr -def _as_xarray_dataset(ds): - # maybe coerce to xarray dataset - if isinstance(ds, xr.Dataset): - return ds - else: - return ds.to_dataset() - - def _slices(dimsize, size, overlap=0): # return a list of slices to chop up a single dimension if overlap >= size: @@ -34,7 +26,7 @@ def _slices(dimsize, size, overlap=0): def _iterate_through_dataset(ds, dims, overlap={}): dim_slices = [] for dim in dims: - dimsize = ds.dims[dim] + dimsize = ds.sizes[dim] size = dims[dim] olap = overlap.get(dim, 0) if size > dimsize: @@ -66,7 +58,7 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'): - batch_dims = [d for d in ds.dims if d not in input_dims] + batch_dims = [d for d in ds.sizes if d not in input_dims] if len(batch_dims) < 2: return ds ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) @@ -121,7 +113,7 @@ def __init__( preload_batch: bool = True, ): - self.ds = _as_xarray_dataset(ds) + self.ds = ds # should be a dict self.input_dims = OrderedDict(input_dims) self.input_overlap = input_overlap diff --git a/xbatcher/loaders/keras.py b/xbatcher/loaders/keras.py index 2c7345c..b9f2816 100644 --- a/xbatcher/loaders/keras.py +++ b/xbatcher/loaders/keras.py @@ -1,7 +1,6 @@ from typing import Any, Callable, Optional, Tuple import tensorflow as tf -import xarray as xr # Notes: # This module includes one Keras dataset, which can be provided to model.fit(). @@ -18,9 +17,8 @@ def __init__( *, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - dim: str = 'new_dim', ) -> None: - ''' + """ Keras Dataset adapter for Xbatcher Parameters @@ -31,38 +29,18 @@ def __init__( A function/transform that takes in an array and returns a transformed version. target_transform : callable, optional A function/transform that takes in the target and transforms it. - dim : str, 'new_dim' - Name of dim to pass to :func:`xarray.concat` as the dimension - to concatenate all variables along. - ''' + """ self.X_generator = X_generator self.y_generator = y_generator self.transform = transform self.target_transform = target_transform - self.concat_dim = dim def __len__(self) -> int: return len(self.X_generator) def __getitem__(self, idx: int) -> Tuple[Any, Any]: - X_batch = tf.convert_to_tensor( - xr.concat( - ( - self.X_generator[idx][key] - for key in list(self.X_generator[idx].keys()) - ), - self.concat_dim, - ).data - ) - y_batch = tf.convert_to_tensor( - xr.concat( - ( - self.y_generator[idx][key] - for key in list(self.y_generator[idx].keys()) - ), - self.concat_dim, - ).data - ) + X_batch = tf.convert_to_tensor(self.X_generator[idx].data) + y_batch = tf.convert_to_tensor(self.y_generator[idx].data) # TODO: Should the transformations be applied before tensor conversion? if self.transform: diff --git a/xbatcher/loaders/torch.py b/xbatcher/loaders/torch.py index 6fa427d..f54c93c 100644 --- a/xbatcher/loaders/torch.py +++ b/xbatcher/loaders/torch.py @@ -54,8 +54,8 @@ def __getitem__(self, idx) -> Tuple[Any, Any]: # TODO: figure out the dataset -> array workflow # currently hardcoding a variable name - X_batch = self.X_generator[idx]['x'].torch.to_tensor() - y_batch = self.y_generator[idx]['y'].torch.to_tensor() + X_batch = self.X_generator[idx].torch.to_tensor() + y_batch = self.y_generator[idx].torch.to_tensor() if self.transform: X_batch = self.transform(X_batch) @@ -85,4 +85,4 @@ def __init__( def __iter__(self): for xb, yb in zip(self.X_generator, self.y_generator): - yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor()) + yield (xb.torch.to_tensor(), yb.torch.to_tensor()) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index e22c7be..84a0e26 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -34,11 +34,11 @@ def sample_ds_3d(): return ds -def test_constructor_coerces_to_dataset(): +def test_constructor_dataarray(): da = xr.DataArray(np.random.rand(10), dims='x', name='foo') bg = BatchGenerator(da, input_dims={'x': 2}) - assert isinstance(bg.ds, xr.Dataset) - assert bg.ds.equals(da.to_dataset()) + assert isinstance(bg.ds, xr.DataArray) + assert bg.ds.equals(da) @pytest.mark.parametrize('bsize', [5, 6]) diff --git a/xbatcher/tests/test_keras_loaders.py b/xbatcher/tests/test_keras_loaders.py index 6c2a9d1..579ba2f 100644 --- a/xbatcher/tests/test_keras_loaders.py +++ b/xbatcher/tests/test_keras_loaders.py @@ -24,7 +24,7 @@ def ds_xy(): return ds -def test_custom_dataset(ds_xy): +def test_custom_dataarray(ds_xy): x = ds_xy['x'] y = ds_xy['y'] @@ -36,7 +36,8 @@ def test_custom_dataset(ds_xy): # test __getitem__ x_batch, y_batch = dataset[0] - assert len(x_batch) == len(y_batch) + assert x_batch.shape == (10, 5) + assert y_batch.shape == (10,) assert tf.is_tensor(x_batch) assert tf.is_tensor(y_batch) @@ -44,7 +45,7 @@ def test_custom_dataset(ds_xy): assert len(dataset) == len(x_gen) -def test_custom_dataset_with_transform(ds_xy): +def test_custom_dataarray_with_transform(ds_xy): x = ds_xy['x'] y = ds_xy['y'] @@ -62,7 +63,8 @@ def y_transform(batch): x_gen, y_gen, transform=x_transform, target_transform=y_transform ) x_batch, y_batch = dataset[0] - assert len(x_batch) == len(y_batch) + assert x_batch.shape == (10, 5) + assert y_batch.shape == (10,) assert tf.is_tensor(x_batch) assert tf.is_tensor(y_batch) assert tf.experimental.numpy.all(x_batch == 1) diff --git a/xbatcher/tests/test_torch_loaders.py b/xbatcher/tests/test_torch_loaders.py index 4e4a412..bbde292 100644 --- a/xbatcher/tests/test_torch_loaders.py +++ b/xbatcher/tests/test_torch_loaders.py @@ -36,12 +36,14 @@ def test_map_dataset(ds_xy): # test __getitem__ x_batch, y_batch = dataset[0] - assert len(x_batch) == len(y_batch) + assert x_batch.shape == (10, 5) + assert y_batch.shape == (10,) assert isinstance(x_batch, torch.Tensor) idx = torch.tensor([0]) x_batch, y_batch = dataset[idx] - assert len(x_batch) == len(y_batch) + assert x_batch.shape == (10, 5) + assert y_batch.shape == (10,) assert isinstance(x_batch, torch.Tensor) with pytest.raises(NotImplementedError): @@ -55,13 +57,14 @@ def test_map_dataset(ds_xy): loader = torch.utils.data.DataLoader(dataset) for x_batch, y_batch in loader: - assert len(x_batch) == len(y_batch) + assert x_batch.shape == (1, 10, 5) + assert y_batch.shape == (1, 10) assert isinstance(x_batch, torch.Tensor) # TODO: why does pytorch add an extra dimension (length 1) to x_batch - assert x_gen[-1]['x'].shape == x_batch.shape[1:] - # TODO: also need to revisit the variable extraction bits here - assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :]) + assert x_gen[-1].shape == x_batch.shape[1:] + # TODO: add test for xarray.Dataset + assert np.array_equal(x_gen[-1], x_batch[0, :, :]) def test_map_dataset_with_transform(ds_xy): @@ -82,7 +85,8 @@ def y_transform(batch): x_gen, y_gen, transform=x_transform, target_transform=y_transform ) x_batch, y_batch = dataset[0] - assert len(x_batch) == len(y_batch) + assert x_batch.shape == (10, 5) + assert y_batch.shape == (10,) assert isinstance(x_batch, torch.Tensor) assert (x_batch == 1).all() assert (y_batch == -1).all() @@ -102,10 +106,11 @@ def test_iterable_dataset(ds_xy): loader = torch.utils.data.DataLoader(dataset) for x_batch, y_batch in loader: - assert len(x_batch) == len(y_batch) + assert x_batch.shape == (1, 10, 5) + assert y_batch.shape == (1, 10) assert isinstance(x_batch, torch.Tensor) # TODO: why does pytorch add an extra dimension (length 1) to x_batch - assert x_gen[-1]['x'].shape == x_batch.shape[1:] - # TODO: also need to revisit the variable extraction bits here - assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :]) + assert x_gen[-1].shape == x_batch.shape[1:] + # TODO: add test for xarray.Dataset + assert np.array_equal(x_gen[-1], x_batch[0, :, :])