Skip to content

Commit

Permalink
Merge pull request #16 from badranX/dkvmn_based
Browse files Browse the repository at this point in the history
Fix DKVMN based loss, labels-predictions mismatch
  • Loading branch information
kervias authored Dec 7, 2023
2 parents ba3cbb6 + 7c11254 commit 18d7312
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions edustudio/model/KT/deep_irt.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def predict(self, **kwargs):
dict: The predictions of the model and the real situation
"""
y_pd = self(**kwargs)
y_pd = y_pd[:, :-1]
y_pd = y_pd[:, 1:]
y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
y_gt = None
if kwargs.get('label_seq', None) is not None:
Expand All @@ -141,7 +141,7 @@ def get_main_loss(self, **kwargs):
dict: {'loss_main': loss_value}
"""
y_pd = self(**kwargs)
y_pd = y_pd[:, :-1]
y_pd = y_pd[:, 1:]
y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
y_gt = kwargs['label_seq'][:, 1:]
y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1]
Expand Down
4 changes: 2 additions & 2 deletions edustudio/model/KT/dkvmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(self, exer_seq, label_seq, **kwargs):
@torch.no_grad()
def predict(self, **kwargs):
y_pd = self(**kwargs)
y_pd = y_pd[:, :-1]
y_pd = y_pd[:, 1:]
y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
y_gt = None
if kwargs.get('label_seq', None) is not None:
Expand All @@ -97,7 +97,7 @@ def predict(self, **kwargs):

def get_main_loss(self, **kwargs):
y_pd = self(**kwargs)
y_pd = y_pd[:, :-1]
y_pd = y_pd[:, 1:]
y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
y_gt = kwargs['label_seq'][:, 1:]
y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1]
Expand Down

0 comments on commit 18d7312

Please sign in to comment.