Skip to content

Commit

Permalink
Merge pull request #13 from badranX/qikt
Browse files Browse the repository at this point in the history
fix QIKT response data leakage
  • Loading branch information
kervias authored Nov 25, 2023
2 parents e92b7d1 + c5460c3 commit ea1f688
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions edustudio/model/KT/qikt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,9 @@ def forward(self, exer_seq, label_seq, cpt_seq, cpt_seq_mask, **kwargs):
dim=2) / cpt_seq_mask.sum(dim=2).unsqueeze(2).repeat(1, 1, self.modeltpl_cfg['emb_size'])


e_tmp = torch.cat((emb_q, emb_c), dim=2)
mask_e = label_seq.unsqueeze(-1).repeat(1, 1, e_tmp.shape[-1]).to(torch.float)
emb_qca = torch.cat((mask_e*e_tmp, (1-mask_e)*e_tmp), dim=-1)

mask_c = label_seq.unsqueeze(-1).repeat(1, 1, emb_c.shape[-1]).to(torch.float)
emb_qc = torch.cat((mask_c * emb_c, (1 - mask_c) * emb_c), dim=-1)
emb_qc = torch.cat((emb_q, emb_c), dim=2)
mask_e = label_seq.unsqueeze(-1).repeat(1, 1, emb_qc.shape[-1]).to(torch.float)
emb_qca = torch.cat((mask_e*emb_qc, (1-mask_e)*emb_qc), dim=-1)

emb_qc_shift = emb_qc[:, 1:, :]
emb_qca_current = emb_qca[:, :-1, :]
Expand Down

0 comments on commit ea1f688

Please sign in to comment.