Skip to content

Commit

Permalink
fix QIKT response data leakage
Browse files Browse the repository at this point in the history
  • Loading branch information
badranX committed Nov 24, 2023
1 parent e92b7d1 commit c5460c3
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions edustudio/model/KT/qikt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,9 @@ def forward(self, exer_seq, label_seq, cpt_seq, cpt_seq_mask, **kwargs):
dim=2) / cpt_seq_mask.sum(dim=2).unsqueeze(2).repeat(1, 1, self.modeltpl_cfg['emb_size'])


e_tmp = torch.cat((emb_q, emb_c), dim=2)
mask_e = label_seq.unsqueeze(-1).repeat(1, 1, e_tmp.shape[-1]).to(torch.float)
emb_qca = torch.cat((mask_e*e_tmp, (1-mask_e)*e_tmp), dim=-1)

mask_c = label_seq.unsqueeze(-1).repeat(1, 1, emb_c.shape[-1]).to(torch.float)
emb_qc = torch.cat((mask_c * emb_c, (1 - mask_c) * emb_c), dim=-1)
emb_qc = torch.cat((emb_q, emb_c), dim=2)
mask_e = label_seq.unsqueeze(-1).repeat(1, 1, emb_qc.shape[-1]).to(torch.float)
emb_qca = torch.cat((mask_e*emb_qc, (1-mask_e)*emb_qc), dim=-1)

emb_qc_shift = emb_qc[:, 1:, :]
emb_qca_current = emb_qca[:, :-1, :]
Expand Down

0 comments on commit c5460c3

Please sign in to comment.