@@ -39,6 +39,15 @@ def test_map_dataset(ds_xy):
39
39
assert len (x_batch ) == len (y_batch )
40
40
assert isinstance (x_batch , torch .Tensor )
41
41
42
+ idx = torch .tensor ([0 ])
43
+ x_batch , y_batch = dataset [idx ]
44
+ assert len (x_batch ) == len (y_batch )
45
+ assert isinstance (x_batch , torch .Tensor )
46
+
47
+ with pytest .raises (NotImplementedError ):
48
+ idx = torch .tensor ([0 , 1 ])
49
+ x_batch , y_batch = dataset [idx ]
50
+
42
51
# test __len__
43
52
assert len (dataset ) == len (x_gen )
44
53
@@ -55,6 +64,30 @@ def test_map_dataset(ds_xy):
55
64
assert np .array_equal (x_gen [- 1 ]['x' ], x_batch [0 , :, :])
56
65
57
66
67
+ def test_map_dataset_with_transform (ds_xy ):
68
+
69
+ x = ds_xy ['x' ]
70
+ y = ds_xy ['y' ]
71
+
72
+ x_gen = BatchGenerator (x , {'sample' : 10 })
73
+ y_gen = BatchGenerator (y , {'sample' : 10 })
74
+
75
+ def x_transform (batch ):
76
+ return batch * 0 + 1
77
+
78
+ def y_transform (batch ):
79
+ return batch * 0 - 1
80
+
81
+ dataset = MapDataset (
82
+ x_gen , y_gen , transform = x_transform , target_transform = y_transform
83
+ )
84
+ x_batch , y_batch = dataset [0 ]
85
+ assert len (x_batch ) == len (y_batch )
86
+ assert isinstance (x_batch , torch .Tensor )
87
+ assert (x_batch == 1 ).all ()
88
+ assert (y_batch == - 1 ).all ()
89
+
90
+
58
91
def test_iterable_dataset (ds_xy ):
59
92
60
93
x = ds_xy ['x' ]
0 commit comments