From 8e5f64606cbf233ddd448d61950942f1b25042dd Mon Sep 17 00:00:00 2001 From: Michael Ding Date: Fri, 7 Jun 2024 15:04:35 +0800 Subject: [PATCH] fix dataset for pairwise task --- dooc/datasets.py | 2 +- tests/test_datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dooc/datasets.py b/dooc/datasets.py index 37789bc..59517da 100644 --- a/dooc/datasets.py +++ b/dooc/datasets.py @@ -128,4 +128,4 @@ def __call__( mut_x, smi_src, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len) out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device) out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1 - return mut_x, smi_src, smi_tgt, out.unsqueeze(-1) + return mut_x, smi_src, smi_tgt, out diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 32afd63..f24088f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -60,4 +60,4 @@ def test_MutSmisPairwise(smi_tkz): assert smi_src.shape == (2, 2, 200) assert smi_tgt.shape == (2, 2, 200) assert mut_x.shape == (2, 52) - assert out.shape == (2, 1) + assert out.shape == (2,)