diff --git a/edustudio/model/KT/dkt_plus.py b/edustudio/model/KT/dkt_plus.py index dd38419..57f22f0 100644 --- a/edustudio/model/KT/dkt_plus.py +++ b/edustudio/model/KT/dkt_plus.py @@ -8,7 +8,8 @@ class DKT_plus(DKT): default_cfg = { 'lambda_r': 0.01, 'lambda_w1': 0.003, - 'lambda_w2': 3.0 + 'lambda_w2': 3.0, + 'reg_all_KCs': True } def __init__(self, cfg): @@ -31,7 +32,10 @@ def get_main_loss(self, **kwargs): loss_main = F.binary_cross_entropy(input=y_next, target=gt_next) loss_r = self.modeltpl_cfg['lambda_r'] * F.binary_cross_entropy(input=y_curr, target=gt_curr) - diff = (pred_shft - pred)[kwargs['mask_seq'][:, 1:] == 1] + if self.modeltpl_cfg['reg_all_KCs']: + diff = y[:, 1:] - y[:, :-1] + else: + diff = (pred_shft - pred)[kwargs['mask_seq'][:, 1:] == 1] loss_w1 = torch.norm(diff, 1) / len(diff) loss_w1 = self.modeltpl_cfg['lambda_w1'] * loss_w1 / self.n_item loss_w2 = torch.norm(diff, 2) / len(diff)