Skip to content

Commit 2bbf2df

Browse files
author
Joseph Hamman
committed
additional test coverage for torch loaders
1 parent 86c8560 commit 2bbf2df

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

xbatcher/loaders/torch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@ def __len__(self) -> int:
4545
def __getitem__(self, idx) -> Tuple[Any, Any]:
4646
if torch.is_tensor(idx):
4747
idx = idx.tolist()
48-
assert len(idx) == 1
48+
if len(idx) == 1:
49+
idx = idx[0]
50+
else:
51+
raise NotImplementedError(
52+
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
53+
)
4954

5055
# TODO: figure out the dataset -> array workflow
5156
# currently hardcoding a variable name

xbatcher/tests/test_torch_loaders.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ def test_map_dataset(ds_xy):
3939
assert len(x_batch) == len(y_batch)
4040
assert isinstance(x_batch, torch.Tensor)
4141

42+
idx = torch.tensor([0])
43+
x_batch, y_batch = dataset[idx]
44+
assert len(x_batch) == len(y_batch)
45+
assert isinstance(x_batch, torch.Tensor)
46+
47+
with pytest.raises(NotImplementedError):
48+
idx = torch.tensor([0, 1])
49+
x_batch, y_batch = dataset[idx]
50+
4251
# test __len__
4352
assert len(dataset) == len(x_gen)
4453

@@ -55,6 +64,30 @@ def test_map_dataset(ds_xy):
5564
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])
5665

5766

67+
def test_map_dataset_with_transform(ds_xy):
68+
69+
x = ds_xy['x']
70+
y = ds_xy['y']
71+
72+
x_gen = BatchGenerator(x, {'sample': 10})
73+
y_gen = BatchGenerator(y, {'sample': 10})
74+
75+
def x_transform(batch):
76+
return batch * 0 + 1
77+
78+
def y_transform(batch):
79+
return batch * 0 - 1
80+
81+
dataset = MapDataset(
82+
x_gen, y_gen, transform=x_transform, target_transform=y_transform
83+
)
84+
x_batch, y_batch = dataset[0]
85+
assert len(x_batch) == len(y_batch)
86+
assert isinstance(x_batch, torch.Tensor)
87+
assert (x_batch == 1).all()
88+
assert (y_batch == -1).all()
89+
90+
5891
def test_iterable_dataset(ds_xy):
5992

6093
x = ds_xy['x']

0 commit comments

Comments
 (0)