diff --git a/edustudio/model/KT/dkvmn.py b/edustudio/model/KT/dkvmn.py index 0aad740..16505ea 100644 --- a/edustudio/model/KT/dkvmn.py +++ b/edustudio/model/KT/dkvmn.py @@ -84,7 +84,7 @@ def forward(self, exer_seq, label_seq, **kwargs): @torch.no_grad() def predict(self, **kwargs): y_pd = self(**kwargs) - y_pd = y_pd[:, :-1] + y_pd = y_pd[:, 1:] y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1] y_gt = None if kwargs.get('label_seq', None) is not None: @@ -97,7 +97,7 @@ def predict(self, **kwargs): def get_main_loss(self, **kwargs): y_pd = self(**kwargs) - y_pd = y_pd[:, :-1] + y_pd = y_pd[:, 1:] y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1] y_gt = kwargs['label_seq'][:, 1:] y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1]