Skip to content

Commit 3af1306

Browse files
author
Joe Hamman
authored
Merge pull request #25 from jhamman/loader/torch
Add pytorch dataloader
2 parents 802bbd5 + 8bcd870 commit 3af1306

14 files changed

+341
-19
lines changed

.pre-commit-config.yaml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
- id: double-quote-string-fixer
1515

1616
- repo: https://github.com/psf/black
17-
rev: 21.12b0
17+
rev: 22.1.0
1818
hooks:
1919
- id: black
2020
args: ["--line-length", "80", "--skip-string-normalization"]
@@ -37,3 +37,16 @@ repos:
3737
hooks:
3838
- id: prettier
3939
language_version: system
40+
41+
- repo: https://github.com/pre-commit/mirrors-mypy
42+
rev: v0.931
43+
hooks:
44+
- id: mypy
45+
additional_dependencies: [
46+
# Type stubs
47+
types-setuptools,
48+
types-pkg_resources,
49+
# Dependencies that are typed
50+
numpy,
51+
xarray,
52+
]

conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
import pytest
23

34

dev-requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
pytest
2+
torch
3+
coverage
24
pytest-cov
35
adlfs
46
-r requirements.txt

doc/api.rst

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@ API reference
55

66
This page provides an auto-generated summary of Xbatcher's API.
77

8-
Core
9-
====
10-
11-
.. autoclass:: xbatcher.BatchGenerator
12-
:members:
13-
148
Dataset.batch and DataArray.batch
159
=================================
1610

@@ -22,3 +16,17 @@ Dataset.batch and DataArray.batch
2216

2317
Dataset.batch.generator
2418
DataArray.batch.generator
19+
20+
Core
21+
====
22+
23+
.. autoclass:: xbatcher.BatchGenerator
24+
:members:
25+
26+
Dataloaders
27+
===========
28+
.. autoclass:: xbatcher.loaders.torch.MapDataset
29+
:members:
30+
31+
.. autoclass:: xbatcher.loaders.torch.IterableDataset
32+
:members:

doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# All configuration values have a default; values that are commented out
1313
# serve to show the default.
1414

15+
# type: ignore
16+
1517
import os
1618
import sys
1719

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9
77

88
[isort]
99
known_first_party=xbatcher
10-
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray
10+
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
1111
multi_line_output=3
1212
include_trailing_comma=True
1313
force_grid_wrap=0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python
2+
# type: ignore
23
import os
34

45
from setuptools import find_packages, setup

xbatcher/accessors.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,21 @@ def generator(self, *args, **kwargs):
2424
Keyword arguments to pass to the `BatchGenerator` constructor.
2525
'''
2626
return BatchGenerator(self._obj, *args, **kwargs)
27+
28+
29+
@xr.register_dataarray_accessor('torch')
30+
class TorchAccessor:
31+
def __init__(self, xarray_obj):
32+
self._obj = xarray_obj
33+
34+
def to_tensor(self):
35+
"""Convert this DataArray to a torch.Tensor"""
36+
import torch
37+
38+
return torch.tensor(self._obj.data)
39+
40+
def to_named_tensor(self):
41+
"""Convert this DataArray to a torch.Tensor with named dimensions"""
42+
import torch
43+
44+
return torch.tensor(self._obj.data, names=self._obj.dims)

xbatcher/generators.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import itertools
44
from collections import OrderedDict
5+
from typing import Any, Dict, Hashable, Iterator
56

67
import xarray as xr
78

@@ -99,12 +100,12 @@ class BatchGenerator:
99100

100101
def __init__(
101102
self,
102-
ds,
103-
input_dims,
104-
input_overlap={},
105-
batch_dims={},
106-
concat_input_dims=False,
107-
preload_batch=True,
103+
ds: xr.Dataset,
104+
input_dims: Dict[Hashable, int],
105+
input_overlap: Dict[Hashable, int] = {},
106+
batch_dims: Dict[Hashable, int] = {},
107+
concat_input_dims: bool = False,
108+
preload_batch: bool = True,
108109
):
109110

110111
self.ds = _as_xarray_dataset(ds)
@@ -115,7 +116,38 @@ def __init__(
115116
self.concat_input_dims = concat_input_dims
116117
self.preload_batch = preload_batch
117118

118-
def __iter__(self):
119+
self._batches: Dict[
120+
int, Any
121+
] = self._gen_batches() # dict cache for batches
122+
# in the future, we can make this a lru cache or similar thing (cachey?)
123+
124+
def __iter__(self) -> Iterator[xr.Dataset]:
125+
for batch in self._batches.values():
126+
yield batch
127+
128+
def __len__(self) -> int:
129+
return len(self._batches)
130+
131+
def __getitem__(self, idx: int) -> xr.Dataset:
132+
133+
if not isinstance(idx, int):
134+
raise NotImplementedError(
135+
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
136+
)
137+
138+
if idx < 0:
139+
idx = list(self._batches)[idx]
140+
141+
if idx in self._batches:
142+
return self._batches[idx]
143+
else:
144+
raise IndexError('list index out of range')
145+
146+
def _gen_batches(self) -> dict:
147+
# in the future, we will want to do the batch generation lazily
148+
# going the eager route for now is allowing me to fill out the loader api
149+
# but it is likely to perform poorly.
150+
batches = []
119151
for ds_batch in self._iterate_batch_dims(self.ds):
120152
if self.preload_batch:
121153
ds_batch.load()
@@ -130,15 +162,17 @@ def __iter__(self):
130162
]
131163
dsc = xr.concat(all_dsets, dim='input_batch')
132164
new_input_dims = [
133-
dim + new_dim_suffix for dim in self.input_dims
165+
str(dim) + new_dim_suffix for dim in self.input_dims
134166
]
135-
yield _maybe_stack_batch_dims(dsc, new_input_dims)
167+
batches.append(_maybe_stack_batch_dims(dsc, new_input_dims))
136168
else:
137169
for ds_input in input_generator:
138-
yield _maybe_stack_batch_dims(
139-
ds_input, list(self.input_dims)
170+
batches.append(
171+
_maybe_stack_batch_dims(ds_input, list(self.input_dims))
140172
)
141173

174+
return dict(zip(range(len(batches)), batches))
175+
142176
def _iterate_batch_dims(self, ds):
143177
return _iterate_through_dataset(ds, self.batch_dims)
144178

xbatcher/loaders/__init__.py

Whitespace-only changes.

xbatcher/loaders/torch.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import Any, Callable, Optional, Tuple
2+
3+
import torch
4+
5+
# Notes:
6+
# This module includes two PyTorch datasets.
7+
# - The MapDataset provides an indexable interface
8+
# - The IterableDataset provides a simple iterable interface
9+
# Both can be provided as arguments to the the Torch DataLoader
10+
# Assumptions made:
11+
# - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators ina dataset)
12+
# TODOs:
13+
# - sort out xarray -> numpy pattern. Currently there is a hardcoded variable name for x/y
14+
# - need to test with additional dataset parameters (e.g. transforms)
15+
16+
17+
class MapDataset(torch.utils.data.Dataset):
18+
def __init__(
19+
self,
20+
X_generator,
21+
y_generator,
22+
transform: Optional[Callable] = None,
23+
target_transform: Optional[Callable] = None,
24+
) -> None:
25+
'''
26+
PyTorch Dataset adapter for Xbatcher
27+
28+
Parameters
29+
----------
30+
X_generator : xbatcher.BatchGenerator
31+
y_generator : xbatcher.BatchGenerator
32+
transform : callable, optional
33+
A function/transform that takes in an array and returns a transformed version.
34+
target_transform : callable, optional
35+
A function/transform that takes in the target and transforms it.
36+
'''
37+
self.X_generator = X_generator
38+
self.y_generator = y_generator
39+
self.transform = transform
40+
self.target_transform = target_transform
41+
42+
def __len__(self) -> int:
43+
return len(self.X_generator)
44+
45+
def __getitem__(self, idx) -> Tuple[Any, Any]:
46+
if torch.is_tensor(idx):
47+
idx = idx.tolist()
48+
if len(idx) == 1:
49+
idx = idx[0]
50+
else:
51+
raise NotImplementedError(
52+
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
53+
)
54+
55+
# TODO: figure out the dataset -> array workflow
56+
# currently hardcoding a variable name
57+
X_batch = self.X_generator[idx]['x'].torch.to_tensor()
58+
y_batch = self.y_generator[idx]['y'].torch.to_tensor()
59+
60+
if self.transform:
61+
X_batch = self.transform(X_batch)
62+
63+
if self.target_transform:
64+
y_batch = self.target_transform(y_batch)
65+
return X_batch, y_batch
66+
67+
68+
class IterableDataset(torch.utils.data.IterableDataset):
69+
def __init__(
70+
self,
71+
X_generator,
72+
y_generator,
73+
) -> None:
74+
'''
75+
PyTorch Dataset adapter for Xbatcher
76+
77+
Parameters
78+
----------
79+
X_generator : xbatcher.BatchGenerator
80+
y_generator : xbatcher.BatchGenerator
81+
'''
82+
83+
self.X_generator = X_generator
84+
self.y_generator = y_generator
85+
86+
def __iter__(self):
87+
for xb, yb in zip(self.X_generator, self.y_generator):
88+
yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor())

xbatcher/tests/test_accessors.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,25 @@ def test_batch_accessor_da(sample_ds_3d):
3838
assert isinstance(bg_acc, BatchGenerator)
3939
for batch_class, batch_acc in zip(bg_class, bg_acc):
4040
assert batch_class.equals(batch_acc)
41+
42+
43+
def test_torch_to_tensor(sample_ds_3d):
44+
torch = pytest.importorskip('torch')
45+
46+
da = sample_ds_3d['foo']
47+
t = da.torch.to_tensor()
48+
assert isinstance(t, torch.Tensor)
49+
assert t.names == (None, None, None)
50+
assert t.shape == da.shape
51+
np.testing.assert_array_equal(t, da.values)
52+
53+
54+
def test_torch_to_named_tensor(sample_ds_3d):
55+
torch = pytest.importorskip('torch')
56+
57+
da = sample_ds_3d['foo']
58+
t = da.torch.to_named_tensor()
59+
assert isinstance(t, torch.Tensor)
60+
assert t.names == da.dims
61+
assert t.shape == da.shape
62+
np.testing.assert_array_equal(t, da.values)

xbatcher/tests/test_generators.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,28 @@ def test_constructor_coerces_to_dataset():
4141
assert bg.ds.equals(da.to_dataset())
4242

4343

44+
@pytest.mark.parametrize('bsize', [5, 6])
45+
def test_batcher_lenth(sample_ds_1d, bsize):
46+
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
47+
assert len(bg) == sample_ds_1d.dims['x'] // bsize
48+
49+
50+
def test_batcher_getitem(sample_ds_1d):
51+
bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10})
52+
53+
# first batch
54+
assert bg[0].dims['x'] == 10
55+
# last batch
56+
assert bg[-1].dims['x'] == 10
57+
# raises IndexError for out of range index
58+
with pytest.raises(IndexError, match=r'list index out of range'):
59+
bg[9999999]
60+
61+
# raises NotImplementedError for iterable index
62+
with pytest.raises(NotImplementedError):
63+
bg[[1, 2, 3]]
64+
65+
4466
# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
4567
# Should we enforce that each batch size always has to be the same
4668
@pytest.mark.parametrize('bsize', [5, 10])

0 commit comments

Comments
 (0)