Skip to content

Commit

Permalink
fix DKVMN loss labels-predictions mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
badranX committed Dec 6, 2023
1 parent ba3cbb6 commit d0f771c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions edustudio/model/KT/dkvmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(self, exer_seq, label_seq, **kwargs):
@torch.no_grad()
def predict(self, **kwargs):
y_pd = self(**kwargs)
y_pd = y_pd[:, :-1]
y_pd = y_pd[:, 1:]
y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
y_gt = None
if kwargs.get('label_seq', None) is not None:
Expand All @@ -97,7 +97,7 @@ def predict(self, **kwargs):

def get_main_loss(self, **kwargs):
y_pd = self(**kwargs)
y_pd = y_pd[:, :-1]
y_pd = y_pd[:, 1:]
y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
y_gt = kwargs['label_seq'][:, 1:]
y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1]
Expand Down

0 comments on commit d0f771c

Please sign in to comment.