Skip to content

Use .sizes instead of .dims for xr.Dataset/xr.DataArray compatibility #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 19, 2022
16 changes: 10 additions & 6 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
@xr.register_dataset_accessor('batch')
class BatchAccessor:
def __init__(self, xarray_obj):
'''
"""
Batch accessor returning a BatchGenerator object via the `generator method`
'''
"""
self._obj = xarray_obj

def generator(self, *args, **kwargs):
'''
"""
Return a BatchGenerator via the batch accessor

Parameters
Expand All @@ -22,7 +22,7 @@ def generator(self, *args, **kwargs):
Positional arguments to pass to the `BatchGenerator` constructor.
**kwargs : dict
Keyword arguments to pass to the `BatchGenerator` constructor.
'''
"""
return BatchGenerator(self._obj, *args, **kwargs)


Expand All @@ -38,7 +38,11 @@ def to_tensor(self):
return torch.tensor(self._obj.data)

def to_named_tensor(self):
"""Convert this DataArray to a torch.Tensor with named dimensions"""
"""
Convert this DataArray to a torch.Tensor with named dimensions.

See https://pytorch.org/docs/stable/named_tensor.html
"""
import torch

return torch.tensor(self._obj.data, names=self._obj.dims)
return torch.tensor(self._obj.data, names=tuple(self._obj.sizes))
14 changes: 3 additions & 11 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@
import xarray as xr


def _as_xarray_dataset(ds):
# maybe coerce to xarray dataset
if isinstance(ds, xr.Dataset):
return ds
else:
return ds.to_dataset()


def _slices(dimsize, size, overlap=0):
# return a list of slices to chop up a single dimension
if overlap >= size:
Expand All @@ -34,7 +26,7 @@ def _slices(dimsize, size, overlap=0):
def _iterate_through_dataset(ds, dims, overlap={}):
dim_slices = []
for dim in dims:
dimsize = ds.dims[dim]
dimsize = ds.sizes[dim]
size = dims[dim]
olap = overlap.get(dim, 0)
if size > dimsize:
Expand Down Expand Up @@ -66,7 +58,7 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):


def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name='sample'):
batch_dims = [d for d in ds.dims if d not in input_dims]
batch_dims = [d for d in ds.sizes if d not in input_dims]
if len(batch_dims) < 2:
return ds
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
Expand Down Expand Up @@ -121,7 +113,7 @@ def __init__(
preload_batch: bool = True,
):

self.ds = _as_xarray_dataset(ds)
self.ds = ds
# should be a dict
self.input_dims = OrderedDict(input_dims)
self.input_overlap = input_overlap
Expand Down
30 changes: 4 additions & 26 deletions xbatcher/loaders/keras.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Callable, Optional, Tuple

import tensorflow as tf
import xarray as xr

# Notes:
# This module includes one Keras dataset, which can be provided to model.fit().
Expand All @@ -18,9 +17,8 @@ def __init__(
*,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
dim: str = 'new_dim',
) -> None:
'''
"""
Keras Dataset adapter for Xbatcher

Parameters
Expand All @@ -31,38 +29,18 @@ def __init__(
A function/transform that takes in an array and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
dim : str, 'new_dim'
Name of dim to pass to :func:`xarray.concat` as the dimension
to concatenate all variables along.
'''
"""
self.X_generator = X_generator
self.y_generator = y_generator
self.transform = transform
self.target_transform = target_transform
self.concat_dim = dim

def __len__(self) -> int:
return len(self.X_generator)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
X_batch = tf.convert_to_tensor(
xr.concat(
(
self.X_generator[idx][key]
for key in list(self.X_generator[idx].keys())
),
self.concat_dim,
).data
)
y_batch = tf.convert_to_tensor(
xr.concat(
(
self.y_generator[idx][key]
for key in list(self.y_generator[idx].keys())
),
self.concat_dim,
).data
)
X_batch = tf.convert_to_tensor(self.X_generator[idx].data)
y_batch = tf.convert_to_tensor(self.y_generator[idx].data)

# TODO: Should the transformations be applied before tensor conversion?
if self.transform:
Expand Down
6 changes: 3 additions & 3 deletions xbatcher/loaders/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __getitem__(self, idx) -> Tuple[Any, Any]:

# TODO: figure out the dataset -> array workflow
# currently hardcoding a variable name
X_batch = self.X_generator[idx]['x'].torch.to_tensor()
y_batch = self.y_generator[idx]['y'].torch.to_tensor()
X_batch = self.X_generator[idx].torch.to_tensor()
y_batch = self.y_generator[idx].torch.to_tensor()
Comment on lines +57 to +58
Copy link
Member Author

@weiji14 weiji14 Aug 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change works because the unit tests in test_torch_loaders.py are actually testing xarray.DataArray inputs only, and not xarray.Dataset. Ideally there would be unit tests for both xr.DataArray and xr.Dataset inputs, but this might expand the scope of the Pull Request a bit too much 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR, the tests used xarray.DataArray inputs to the batch generator, but xarray.Dataset inputs to the dataloaders since the batches were coerced into datasets. So I think this would be an additional breaking change in expecting xarray.DataArray inputs.

For the unit tests, I opened #83 to keep track of improvements for subsequent PRs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are your thoughts on backwards compatibility here @weiji14? My impression is that the hardcoded xarray.Dataset variable names severely restricts the utility of the data loader. So, I think this is a worthwhile change since we'd be forced to break backwards compatibility anyways eventually for flexible variable names and that working from an xarray.DataArray implementation is better. But we could add an if; else block for dataset vs. dataarray if it's necessary to maintain the past behavior.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree that we should try to be backward compatible and support both xr.DataArray and xr.Dataset inputs. If you want, I can either:

  1. Add the if-then block, but write the comprehensive unit tests covering xr.Dataset/xr.DataArray cases in a separate PR
  2. Do the if-then block and unit-tests in a follow up PR, in order to keep this PR small.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for option 2


if self.transform:
X_batch = self.transform(X_batch)
Expand Down Expand Up @@ -85,4 +85,4 @@ def __init__(

def __iter__(self):
for xb, yb in zip(self.X_generator, self.y_generator):
yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor())
yield (xb.torch.to_tensor(), yb.torch.to_tensor())
6 changes: 3 additions & 3 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def sample_ds_3d():
return ds


def test_constructor_coerces_to_dataset():
def test_constructor_dataarray():
da = xr.DataArray(np.random.rand(10), dims='x', name='foo')
bg = BatchGenerator(da, input_dims={'x': 2})
assert isinstance(bg.ds, xr.Dataset)
assert bg.ds.equals(da.to_dataset())
assert isinstance(bg.ds, xr.DataArray)
assert bg.ds.equals(da)


@pytest.mark.parametrize('bsize', [5, 6])
Expand Down
10 changes: 6 additions & 4 deletions xbatcher/tests/test_keras_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def ds_xy():
return ds


def test_custom_dataset(ds_xy):
def test_custom_dataarray(ds_xy):

x = ds_xy['x']
y = ds_xy['y']
Expand All @@ -36,15 +36,16 @@ def test_custom_dataset(ds_xy):

# test __getitem__
x_batch, y_batch = dataset[0]
assert len(x_batch) == len(y_batch)
assert x_batch.shape == (10, 5)
assert y_batch.shape == (10,)
assert tf.is_tensor(x_batch)
assert tf.is_tensor(y_batch)

# test __len__
assert len(dataset) == len(x_gen)


def test_custom_dataset_with_transform(ds_xy):
def test_custom_dataarray_with_transform(ds_xy):

x = ds_xy['x']
y = ds_xy['y']
Expand All @@ -62,7 +63,8 @@ def y_transform(batch):
x_gen, y_gen, transform=x_transform, target_transform=y_transform
)
x_batch, y_batch = dataset[0]
assert len(x_batch) == len(y_batch)
assert x_batch.shape == (10, 5)
assert y_batch.shape == (10,)
assert tf.is_tensor(x_batch)
assert tf.is_tensor(y_batch)
assert tf.experimental.numpy.all(x_batch == 1)
Expand Down
27 changes: 16 additions & 11 deletions xbatcher/tests/test_torch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ def test_map_dataset(ds_xy):

# test __getitem__
x_batch, y_batch = dataset[0]
assert len(x_batch) == len(y_batch)
assert x_batch.shape == (10, 5)
assert y_batch.shape == (10,)
assert isinstance(x_batch, torch.Tensor)

idx = torch.tensor([0])
x_batch, y_batch = dataset[idx]
assert len(x_batch) == len(y_batch)
assert x_batch.shape == (10, 5)
assert y_batch.shape == (10,)
assert isinstance(x_batch, torch.Tensor)

with pytest.raises(NotImplementedError):
Expand All @@ -55,13 +57,14 @@ def test_map_dataset(ds_xy):
loader = torch.utils.data.DataLoader(dataset)

for x_batch, y_batch in loader:
assert len(x_batch) == len(y_batch)
assert x_batch.shape == (1, 10, 5)
assert y_batch.shape == (1, 10)
assert isinstance(x_batch, torch.Tensor)

# TODO: why does pytorch add an extra dimension (length 1) to x_batch
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
# TODO: also need to revisit the variable extraction bits here
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
assert x_gen[-1].shape == x_batch.shape[1:]
# TODO: add test for xarray.Dataset
assert np.array_equal(x_gen[-1], x_batch[0, :, :])


def test_map_dataset_with_transform(ds_xy):
Expand All @@ -82,7 +85,8 @@ def y_transform(batch):
x_gen, y_gen, transform=x_transform, target_transform=y_transform
)
x_batch, y_batch = dataset[0]
assert len(x_batch) == len(y_batch)
assert x_batch.shape == (10, 5)
assert y_batch.shape == (10,)
assert isinstance(x_batch, torch.Tensor)
assert (x_batch == 1).all()
assert (y_batch == -1).all()
Expand All @@ -102,10 +106,11 @@ def test_iterable_dataset(ds_xy):
loader = torch.utils.data.DataLoader(dataset)

for x_batch, y_batch in loader:
assert len(x_batch) == len(y_batch)
assert x_batch.shape == (1, 10, 5)
assert y_batch.shape == (1, 10)
assert isinstance(x_batch, torch.Tensor)

# TODO: why does pytorch add an extra dimension (length 1) to x_batch
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
# TODO: also need to revisit the variable extraction bits here
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
assert x_gen[-1].shape == x_batch.shape[1:]
# TODO: add test for xarray.Dataset
assert np.array_equal(x_gen[-1], x_batch[0, :, :])