Skip to content

Commit ed45a99

Browse files
authored
Let torch accessor and dataloader handle either xarray.DataArray or xarray.Dataset inputs (#85)
* Let torch accessor support xarray.Dataset objects Convert xarray.Dataset to xarray.DataArray first, so that the `.data` method work to get the underlying array which can be converted to a torch.Tensor. * Add parametrized tests for xarray.DataArray and xarray.Dataset Need to squeeze the extra first dimension in order to preserve the same output shape for xarray.DataArray and xarray.Dataset. * Set batch_size in torch DataLoader to None instead of 1 to fix extra dim Resolve the strange extra dimension of 1, which is because torch.utils.data.DataLoader adds a batch dimension by default. Setting to `batch_size=None` means no extra batch dimension is prepended. * Add DataArray/Dataset parametrized tests for torch accessor
1 parent 714b624 commit ed45a99

File tree

3 files changed

+105
-37
lines changed

3 files changed

+105
-37
lines changed

xbatcher/accessors.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,32 @@ def generator(self, *args, **kwargs):
2727

2828

2929
@xr.register_dataarray_accessor('torch')
30+
@xr.register_dataset_accessor('torch')
3031
class TorchAccessor:
3132
def __init__(self, xarray_obj):
3233
self._obj = xarray_obj
3334

35+
def _as_xarray_dataarray(self, xr_obj):
36+
"""
37+
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
38+
be converted into a torch.Tensor object.
39+
"""
40+
try:
41+
# Convert xr.Dataset to xr.DataArray
42+
dataarray = xr_obj.to_array().squeeze(dim='variable')
43+
except AttributeError: # 'DataArray' object has no attribute 'to_array'
44+
# If object is already an xr.DataArray
45+
dataarray = xr_obj
46+
47+
return dataarray
48+
3449
def to_tensor(self):
3550
"""Convert this DataArray to a torch.Tensor"""
3651
import torch
3752

38-
return torch.tensor(self._obj.data)
53+
dataarray = self._as_xarray_dataarray(xr_obj=self._obj)
54+
55+
return torch.tensor(data=dataarray.data)
3956

4057
def to_named_tensor(self):
4158
"""
@@ -45,4 +62,6 @@ def to_named_tensor(self):
4562
"""
4663
import torch
4764

48-
return torch.tensor(self._obj.data, names=tuple(self._obj.sizes))
65+
dataarray = self._as_xarray_dataarray(xr_obj=self._obj)
66+
67+
return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes))

xbatcher/tests/test_accessors.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,41 @@ def test_batch_accessor_da(sample_ds_3d):
4040
assert batch_class.equals(batch_acc)
4141

4242

43-
def test_torch_to_tensor(sample_ds_3d):
43+
@pytest.mark.parametrize(
44+
'foo_var',
45+
[
46+
'foo', # xr.DataArray
47+
['foo'], # xr.Dataset
48+
],
49+
)
50+
def test_torch_to_tensor(sample_ds_3d, foo_var):
4451
torch = pytest.importorskip('torch')
4552

46-
da = sample_ds_3d['foo']
47-
t = da.torch.to_tensor()
53+
foo = sample_ds_3d[foo_var]
54+
t = foo.torch.to_tensor()
4855
assert isinstance(t, torch.Tensor)
4956
assert t.names == (None, None, None)
50-
assert t.shape == da.shape
51-
np.testing.assert_array_equal(t, da.values)
57+
assert t.shape == tuple(foo.sizes.values())
5258

59+
foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo
60+
np.testing.assert_array_equal(t, foo_array.values)
5361

54-
def test_torch_to_named_tensor(sample_ds_3d):
62+
63+
@pytest.mark.parametrize(
64+
'foo_var',
65+
[
66+
'foo', # xr.DataArray
67+
['foo'], # xr.Dataset
68+
],
69+
)
70+
def test_torch_to_named_tensor(sample_ds_3d, foo_var):
5571
torch = pytest.importorskip('torch')
5672

57-
da = sample_ds_3d['foo']
58-
t = da.torch.to_named_tensor()
73+
foo = sample_ds_3d[foo_var]
74+
t = foo.torch.to_named_tensor()
5975
assert isinstance(t, torch.Tensor)
60-
assert t.names == da.dims
61-
assert t.shape == da.shape
62-
np.testing.assert_array_equal(t, da.values)
76+
assert t.names == tuple(foo.dims)
77+
assert t.shape == tuple(foo.sizes.values())
78+
79+
foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo
80+
np.testing.assert_array_equal(t, foo_array.values)

xbatcher/tests/test_torch_loaders.py

+55-24
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,17 @@ def ds_xy():
2424
return ds
2525

2626

27-
def test_map_dataset(ds_xy):
28-
29-
x = ds_xy['x']
30-
y = ds_xy['y']
27+
@pytest.mark.parametrize(
28+
('x_var', 'y_var'),
29+
[
30+
('x', 'y'), # xr.DataArray
31+
(['x'], ['y']), # xr.Dataset
32+
],
33+
)
34+
def test_map_dataset(ds_xy, x_var, y_var):
35+
36+
x = ds_xy[x_var]
37+
y = ds_xy[y_var]
3138

3239
x_gen = BatchGenerator(x, {'sample': 10})
3340
y_gen = BatchGenerator(y, {'sample': 10})
@@ -54,23 +61,35 @@ def test_map_dataset(ds_xy):
5461
assert len(dataset) == len(x_gen)
5562

5663
# test integration with torch DataLoader
57-
loader = torch.utils.data.DataLoader(dataset)
64+
loader = torch.utils.data.DataLoader(dataset, batch_size=None)
5865

5966
for x_batch, y_batch in loader:
60-
assert x_batch.shape == (1, 10, 5)
61-
assert y_batch.shape == (1, 10)
67+
assert x_batch.shape == (10, 5)
68+
assert y_batch.shape == (10,)
6269
assert isinstance(x_batch, torch.Tensor)
6370

64-
# TODO: why does pytorch add an extra dimension (length 1) to x_batch
65-
assert x_gen[-1].shape == x_batch.shape[1:]
66-
# TODO: add test for xarray.Dataset
67-
assert np.array_equal(x_gen[-1], x_batch[0, :, :])
71+
# Check that array shape of last item in generator is same as the batch image
72+
assert tuple(x_gen[-1].sizes.values()) == x_batch.shape
73+
# Check that array values from last item in generator and batch are the same
74+
gen_array = (
75+
x_gen[-1].to_array().squeeze()
76+
if hasattr(x_gen[-1], 'to_array')
77+
else x_gen[-1]
78+
)
79+
np.testing.assert_array_equal(gen_array, x_batch)
6880

6981

70-
def test_map_dataset_with_transform(ds_xy):
82+
@pytest.mark.parametrize(
83+
('x_var', 'y_var'),
84+
[
85+
('x', 'y'), # xr.DataArray
86+
(['x'], ['y']), # xr.Dataset
87+
],
88+
)
89+
def test_map_dataset_with_transform(ds_xy, x_var, y_var):
7190

72-
x = ds_xy['x']
73-
y = ds_xy['y']
91+
x = ds_xy[x_var]
92+
y = ds_xy[y_var]
7493

7594
x_gen = BatchGenerator(x, {'sample': 10})
7695
y_gen = BatchGenerator(y, {'sample': 10})
@@ -92,25 +111,37 @@ def y_transform(batch):
92111
assert (y_batch == -1).all()
93112

94113

95-
def test_iterable_dataset(ds_xy):
114+
@pytest.mark.parametrize(
115+
('x_var', 'y_var'),
116+
[
117+
('x', 'y'), # xr.DataArray
118+
(['x'], ['y']), # xr.Dataset
119+
],
120+
)
121+
def test_iterable_dataset(ds_xy, x_var, y_var):
96122

97-
x = ds_xy['x']
98-
y = ds_xy['y']
123+
x = ds_xy[x_var]
124+
y = ds_xy[y_var]
99125

100126
x_gen = BatchGenerator(x, {'sample': 10})
101127
y_gen = BatchGenerator(y, {'sample': 10})
102128

103129
dataset = IterableDataset(x_gen, y_gen)
104130

105131
# test integration with torch DataLoader
106-
loader = torch.utils.data.DataLoader(dataset)
132+
loader = torch.utils.data.DataLoader(dataset, batch_size=None)
107133

108134
for x_batch, y_batch in loader:
109-
assert x_batch.shape == (1, 10, 5)
110-
assert y_batch.shape == (1, 10)
135+
assert x_batch.shape == (10, 5)
136+
assert y_batch.shape == (10,)
111137
assert isinstance(x_batch, torch.Tensor)
112138

113-
# TODO: why does pytorch add an extra dimension (length 1) to x_batch
114-
assert x_gen[-1].shape == x_batch.shape[1:]
115-
# TODO: add test for xarray.Dataset
116-
assert np.array_equal(x_gen[-1], x_batch[0, :, :])
139+
# Check that array shape of last item in generator is same as the batch image
140+
assert tuple(x_gen[-1].sizes.values()) == x_batch.shape
141+
# Check that array values from last item in generator and batch are the same
142+
gen_array = (
143+
x_gen[-1].to_array().squeeze()
144+
if hasattr(x_gen[-1], 'to_array')
145+
else x_gen[-1]
146+
)
147+
np.testing.assert_array_equal(gen_array, x_batch)

0 commit comments

Comments
 (0)