Skip to content

Commit

Permalink
fix deepIRT loss labels-predictions mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
badranX committed Dec 6, 2023
1 parent d0f771c commit 7c11254
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions edustudio/model/KT/deep_irt.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def predict(self, **kwargs):
dict: The predictions of the model and the real situation
"""
y_pd = self(**kwargs)
y_pd = y_pd[:, :-1]
y_pd = y_pd[:, 1:]
y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
y_gt = None
if kwargs.get('label_seq', None) is not None:
Expand All @@ -141,7 +141,7 @@ def get_main_loss(self, **kwargs):
dict: {'loss_main': loss_value}
"""
y_pd = self(**kwargs)
y_pd = y_pd[:, :-1]
y_pd = y_pd[:, 1:]
y_pd = y_pd[kwargs['mask_seq'][:, 1:] == 1]
y_gt = kwargs['label_seq'][:, 1:]
y_gt = y_gt[kwargs['mask_seq'][:, 1:] == 1]
Expand Down

0 comments on commit 7c11254

Please sign in to comment.