diff --git a/dooc/datasets.py b/dooc/datasets.py index 37789bc..59517da 100644 --- a/dooc/datasets.py +++ b/dooc/datasets.py @@ -128,4 +128,4 @@ def __call__( mut_x, smi_src, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len) out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device) out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1 - return mut_x, smi_src, smi_tgt, out.unsqueeze(-1) + return mut_x, smi_src, smi_tgt, out diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 32afd63..f24088f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -60,4 +60,4 @@ def test_MutSmisPairwise(smi_tkz): assert smi_src.shape == (2, 2, 200) assert smi_tgt.shape == (2, 2, 200) assert mut_x.shape == (2, 52) - assert out.shape == (2, 1) + assert out.shape == (2,)