Skip to content

Commit

Permalink
revise DKT+ regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
badranX committed Dec 3, 2023
1 parent 04f2622 commit 5a4919f
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions edustudio/model/KT/dkt_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ class DKT_plus(DKT):
default_cfg = {
'lambda_r': 0.01,
'lambda_w1': 0.003,
'lambda_w2': 3.0
'lambda_w2': 3.0,
'reg_all_KCs': True
}

def __init__(self, cfg):
Expand All @@ -31,7 +32,10 @@ def get_main_loss(self, **kwargs):
loss_main = F.binary_cross_entropy(input=y_next, target=gt_next)
loss_r = self.modeltpl_cfg['lambda_r'] * F.binary_cross_entropy(input=y_curr, target=gt_curr)

diff = (pred_shft - pred)[kwargs['mask_seq'][:, 1:] == 1]
if self.modeltpl_cfg['reg_all_KCs']:
diff = y[:, 1:] - y[:, :-1]
else:
diff = (pred_shft - pred)[kwargs['mask_seq'][:, 1:] == 1]
loss_w1 = torch.norm(diff, 1) / len(diff)
loss_w1 = self.modeltpl_cfg['lambda_w1'] * loss_w1 / self.n_item
loss_w2 = torch.norm(diff, 2) / len(diff)
Expand Down

0 comments on commit 5a4919f

Please sign in to comment.