diff --git a/examples/seedwandb/cdkt.yaml b/examples/seedwandb/atdkt.yaml similarity index 87% rename from examples/seedwandb/cdkt.yaml rename to examples/seedwandb/atdkt.yaml index f77c2414..b5688aec 100644 --- a/examples/seedwandb/cdkt.yaml +++ b/examples/seedwandb/atdkt.yaml @@ -1,17 +1,17 @@ -program: wandb_cdkt_train.py +program: wandb_atdkt_train.py method: bayes metric: goal: maximize name: validauc parameters: model_name: - values: ["cdkt"] + values: ["atdkt"] dataset_name: values: ["xes"] emb_type: values: ["qiddelxembhistranscembpredcurc"] save_dir: - values: ["models/cdkt_tiaocan"] + values: ["models/atdkt_tiaocan"] emb_size: values: [64, 256] num_attn_heads: diff --git a/examples/seedwandb/bakt.yaml b/examples/seedwandb/simplekt.yaml similarity index 84% rename from examples/seedwandb/bakt.yaml rename to examples/seedwandb/simplekt.yaml index 79ddbe15..12f8c1fe 100644 --- a/examples/seedwandb/bakt.yaml +++ b/examples/seedwandb/simplekt.yaml @@ -1,17 +1,17 @@ -program: ./wandb_bakt_train.py +program: ./wandb_simplekt_train.py method: bayes metric: goal: maximize name: validauc parameters: model_name: - values: ["bakt"] + values: ["simplekt"] dataset_name: values: ["xes"] emb_type: values: ["qid"] save_dir: - values: ["models/bakt_tiaocan"] + values: ["models/simplekt_tiaocan"] d_model: values: [64, 256] d_ff: diff --git a/examples/wandb_cdkt_train.py b/examples/wandb_atdkt_train.py similarity index 94% rename from examples/wandb_cdkt_train.py rename to examples/wandb_atdkt_train.py index 08b5103b..242ff782 100644 --- a/examples/wandb_cdkt_train.py +++ b/examples/wandb_atdkt_train.py @@ -5,7 +5,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--dataset_name", type=str, default="algebra2005") - parser.add_argument("--model_name", type=str, default="cdkt") + parser.add_argument("--model_name", type=str, default="atdkt") parser.add_argument("--emb_type", type=str, default="qid") parser.add_argument("--save_dir", type=str, default="saved_model") parser.add_argument("--seed", type=int, default=3407) diff --git a/examples/wandb_eval.py b/examples/wandb_eval.py index 32e6545a..39814577 100644 --- a/examples/wandb_eval.py +++ b/examples/wandb_eval.py @@ -25,7 +25,7 @@ def main(params): trained_params = config["params"] model_name, dataset_name, emb_type = trained_params["model_name"], trained_params["dataset_name"], trained_params["emb_type"] seq_len = config["train_config"]["seq_len"] - if model_name in ["saint", "sakt", "cdkt"]: + if model_name in ["saint", "sakt", "atdkt"]: model_config["seq_len"] = seq_len data_config = config["data_config"] diff --git a/examples/wandb_predict.py b/examples/wandb_predict.py index 3345ce37..b9811339 100644 --- a/examples/wandb_predict.py +++ b/examples/wandb_predict.py @@ -29,7 +29,7 @@ def main(params): del model_config[remove_item] trained_params = config["params"] model_name, dataset_name, emb_type = trained_params["model_name"], trained_params["dataset_name"], trained_params["emb_type"] - if model_name in ["saint", "sakt", "cdkt"]: + if model_name in ["saint", "sakt", "atdkt"]: train_config = config["train_config"] seq_len = train_config["seq_len"] model_config["seq_len"] = seq_len diff --git a/examples/wandb_bakt_train.py b/examples/wandb_simplekt_train.py similarity index 95% rename from examples/wandb_bakt_train.py rename to examples/wandb_simplekt_train.py index b488b475..008ab754 100644 --- a/examples/wandb_bakt_train.py +++ b/examples/wandb_simplekt_train.py @@ -4,7 +4,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dataset_name", type=str, default="algebra2005") - parser.add_argument("--model_name", type=str, default="bakt") + parser.add_argument("--model_name", type=str, default="simplekt") parser.add_argument("--emb_type", type=str, default="qid") parser.add_argument("--save_dir", type=str, default="saved_model") # parser.add_argument("--learning_rate", type=float, default=1e-5) diff --git a/examples/wandb_train.py b/examples/wandb_train.py index ca2c5185..dc811d80 100644 --- a/examples/wandb_train.py +++ b/examples/wandb_train.py @@ -41,7 +41,7 @@ def main(params): train_config = config["train_config"] if model_name in ["dkvmn","deep_irt", "sakt", "saint","saint++", "akt", "atkt", "lpkt", "skvmn"]: train_config["batch_size"] = 64 ## because of OOM - if model_name in ["bakt", "bakt_time"]: + if model_name in ["simplekt", "bakt_time"]: train_config["batch_size"] = 64 ## because of OOM if model_name in ["gkt"]: train_config["batch_size"] = 16 @@ -88,7 +88,7 @@ def main(params): for remove_item in ['use_wandb','learning_rate','add_uuid','l2']: if remove_item in model_config: del model_config[remove_item] - if model_name in ["saint","saint++", "sakt", "cdkt", "bakt", "bakt_time"]: + if model_name in ["saint","saint++", "sakt", "atdkt", "simplekt", "bakt_time"]: model_config["seq_len"] = seq_len debug_print(text = "init_model",fuc_name="main") diff --git a/pykt/datasets/cdkt_dataloader.py b/pykt/datasets/atdkt_dataloader.py similarity index 96% rename from pykt/datasets/cdkt_dataloader.py rename to pykt/datasets/atdkt_dataloader.py index dbacca58..4c783074 100644 --- a/pykt/datasets/cdkt_dataloader.py +++ b/pykt/datasets/atdkt_dataloader.py @@ -11,7 +11,7 @@ from torch import FloatTensor, LongTensor import numpy as np -class CDKTDataset(Dataset): +class ATDKTDataset(Dataset): """Dataset for KT can use to init dataset for: (for models except dkt_forget) train data, valid data @@ -24,16 +24,16 @@ class CDKTDataset(Dataset): qtest (bool, optional): is question evaluation or not. Defaults to False. """ def __init__(self, file_path, input_type, folds, qtest=False): - super(CDKTDataset, self).__init__() + super(ATDKTDataset, self).__init__() sequence_path = file_path self.input_type = input_type self.qtest = qtest folds = sorted(list(folds)) folds_str = "_" + "_".join([str(_) for _ in folds]) if self.qtest: - processed_data = file_path + folds_str + "_cdkt_qtest.pkl" + processed_data = file_path + folds_str + "_atdkt_qtest.pkl" else: - processed_data = file_path + folds_str + "_cdkt.pkl" + processed_data = file_path + folds_str + "_atdkt.pkl" self.dpath = "/".join(file_path.split("/")[0:-1]) if not os.path.exists(processed_data): @@ -51,8 +51,6 @@ def __init__(self, file_path, input_type, folds, qtest=False): self.dori, self.dqtest = pd.read_pickle(processed_data) else: self.dori = pd.read_pickle(processed_data) - for key in self.dori: - self.dori[key] = self.dori[key]#[:100] print(f"file path: {file_path}, qlen: {len(self.dori['qseqs'])}, clen: {len(self.dori['cseqs'])}, rlen: {len(self.dori['rseqs'])}") def __len__(self): diff --git a/pykt/datasets/init_dataset.py b/pykt/datasets/init_dataset.py index f1dab6ea..a3c94a26 100644 --- a/pykt/datasets/init_dataset.py +++ b/pykt/datasets/init_dataset.py @@ -5,7 +5,7 @@ import numpy as np from .data_loader import KTDataset from .dkt_forget_dataloader import DktForgetDataset -from .cdkt_dataloader import CDKTDataset +from .atdkt_dataloader import ATDKTDataset from .lpkt_dataloader import LPKTDataset from .lpkt_utils import generate_time2idx from .que_data_loader import KTQueDataset @@ -39,12 +39,12 @@ def init_test_datasets(data_config, model_name, batch_size): concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) test_question_dataset = None test_question_window_dataset= None - elif model_name in ["cdkt"]: - test_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) - test_window_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) + elif model_name in ["atdkt"]: + test_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) + test_window_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) if "test_question_file" in data_config: - test_question_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True) - test_question_window_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True) + test_question_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True) + test_question_window_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True) else: test_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) test_window_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) @@ -96,9 +96,9 @@ def init_dataset4train(dataset_name, model_name, data_config, i, batch_size): curtrain = KTQueDataset(os.path.join(data_config["dpath"], data_config["train_valid_file_quelevel"]), input_type=data_config["input_type"], folds=all_folds - {i}, concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) - elif model_name in ["cdkt"]: - curvalid = CDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i}) - curtrain = CDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i}) + elif model_name in ["atdkt"]: + curvalid = ATDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i}) + curtrain = ATDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i}) else: curvalid = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i}) curtrain = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i}) diff --git a/pykt/models/cdkt.py b/pykt/models/atdkt.py similarity index 93% rename from pykt/models/cdkt.py rename to pykt/models/atdkt.py index a2f185a1..3e4e5de0 100644 --- a/pykt/models/cdkt.py +++ b/pykt/models/atdkt.py @@ -6,10 +6,11 @@ device = "cpu" if not torch.cuda.is_available() else "cuda" -class CDKT(Module): - def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid', num_layers=1, num_attn_heads=5, l1=0.5, l2=0.5, l3=0.5, start=50, emb_path="", pretrain_dim=768): +class ATDKT(Module): + def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid', + num_layers=1, num_attn_heads=5, l1=0.5, l2=0.5, l3=0.5, start=50, emb_path="", pretrain_dim=768): super().__init__() - self.model_name = "cdkt" + self.model_name = "atdkt" print(f"qnum: {num_q}, cnum: {num_c}") print(f"emb_type: {emb_type}") self.num_q = num_q @@ -34,7 +35,6 @@ def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid', if self.emb_type.find("qemb") != -1: self.question_emb = Embedding(self.num_q, self.emb_size) - # 加一个预测历史准确率的任务 self.start = start self.hisclasifier = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size//2), nn.ReLU(), nn.Dropout(dropout), @@ -62,7 +62,6 @@ def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid', self.concept_emb = Embedding(self.num_c, self.emb_size) # add concept emb self.closs = CrossEntropyLoss() - # 加一个预测历史准确率的任务 if self.emb_type.find("his") != -1: self.start = start self.hisclasifier = nn.Sequential( @@ -122,7 +121,7 @@ def predcurc(self, dcur, q, c, r, xemb, train): h = self.dropout_layer(h) y = self.out_layer(h) y = torch.sigmoid(y) - return y, y2, y3, rpreds, qh + return y, y2, y3 def forward(self, dcur, train=False): ## F * xemb # print(f"keys: {dcur.keys()}") @@ -162,10 +161,10 @@ def forward(self, dcur, train=False): ## F * xemb y = self.out_layer(h) y = torch.sigmoid(y) elif emb_type.endswith("predcurc"): # predict current question' current concept - y, y2, y3, rpreds, qh = self.predcurc(dcur, q, c, r, xemb, train) + y, y2, y3 = self.predcurc(dcur, q, c, r, xemb, train) if train: return y, y2, y3 else: - return y, rpreds, qh + return y diff --git a/pykt/models/evaluate_model.py b/pykt/models/evaluate_model.py index bd4b5977..b87fb40a 100644 --- a/pykt/models/evaluate_model.py +++ b/pykt/models/evaluate_model.py @@ -71,7 +71,7 @@ def evaluate(model, test_loader, model_name, save_path=""): cq = torch.cat((q[:,0:1], qshft), dim=1) cc = torch.cat((c[:,0:1], cshft), dim=1) cr = torch.cat((r[:,0:1], rshft), dim=1) - if model_name in ["cdkt"]: + if model_name in ["atdkt"]: ''' y = model(dcur) import pickle @@ -79,12 +79,12 @@ def evaluate(model, test_loader, model_name, save_path=""): data = {"y":y,"cshft":cshft,"num_c":model.num_c,"rshft":rshft,"qshft":qshft,"sm":sm} pickle.dump(data,f) ''' - y, rpreds, qh = model(dcur) + y = model(dcur) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) elif model_name in ["bakt_time"]: y = model(dcur, dgaps) y = y[:,1:] - elif model_name in ["bakt"]: + elif model_name in ["simplekt"]: y = model(dcur) y = y[:,1:] elif model_name in ["dkt", "dkt+"]: @@ -158,7 +158,7 @@ def early_fusion(curhs, model, model_name): que_diff = model.diff_layer(curhs[1])#equ 13 p = torch.sigmoid(3.0*stu_ability-que_diff)#equ 14 p = p.squeeze(-1) - elif model_name in ["akt", "bakt", "bakt_time", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: + elif model_name in ["akt", "simplekt", "bakt_time", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: output = model.out(curhs[0]).squeeze(-1) m = nn.Sigmoid() p = m(output) @@ -204,7 +204,7 @@ def effective_fusion(df, model, model_name, fusion_type): curhs, curr = [[], []], [] dcur = {"late_trues": [], "qidxs": [], "questions": [], "concepts": [], "row": [], "concept_preds": []} - hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "bakt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"] + hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "simplekt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"] for ui in df: # 一题一题处理 curdf = ui[1] @@ -252,7 +252,7 @@ def group_fusion(dmerge, model, model_name, fusion_type, fout): if cq.shape[1] == 0: cq = cc - hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "bakt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"] + hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "simplekt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"] alldfs, drest = [], dict() # not predict infos! # print(f"real bz in group fusion: {rs.shape[0]}") @@ -349,7 +349,7 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion # dkvmn / akt / saint: give cur -> predict cur # sakt: give past+cur -> predict cur # kqn: give past+cur -> predict cur - hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "bakt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"] + hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "simplekt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"] if save_path != "": fout = open(save_path, "w", encoding="utf8") if model_name in hasearly: @@ -398,7 +398,7 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion # start_hemb = torch.tensor([-1] * (h.shape[0] * h.shape[2])).reshape(h.shape[0], 1, h.shape[2]).to(device) # print(start_hemb.shape, h.shape) # h = torch.cat((start_hemb, h), dim=1) # add the first hidden emb - elif model_name in ["bakt"]: + elif model_name in ["simplekt"]: y, h = model(dcurori, qtest=True, train=False) y = y[:,1:] elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: @@ -418,8 +418,8 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion start_hemb = torch.tensor([-1] * (ek.shape[0] * ek.shape[2])).reshape(ek.shape[0], 1, ek.shape[2]).to(device) ek = torch.cat((start_hemb, ek), dim=1) # add the first hidden emb es = torch.cat((start_hemb, es), dim=1) # add the first hidden emb - elif model_name in ["cdkt"]: - y, _, _ = model(dcurori)#c.long(), r.long(), q.long()) + elif model_name in ["atdkt"]: + y = model(dcurori)#c.long(), r.long(), q.long()) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) elif model_name in ["dkt", "dkt+"]: y = model(c.long(), r.long()) @@ -774,10 +774,10 @@ def predict_each_group(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid, dgaps[key] = din[key] for key in dcur: dgaps["shft_"+key] = dcur[key] - if model_name in ["cdkt"]: ## need change! + if model_name in ["atdkt"]: ## need change! # create input dcurinfos = {"qseqs": qin, "cseqs": cin, "rseqs": rin} - y, _, _ = model(dcurinfos) + y = model(dcurinfos) pred = y[0][-1][cout.item()] elif model_name in ["dkt", "dkt+"]: y = model(cin.long(), rin.long()) @@ -850,7 +850,7 @@ def predict_each_group(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid, y = model(dcurinfos, dgaps) pred = y[0][-1] - elif model_name in ["bakt"]: + elif model_name in ["simplekt"]: if qout != None: curq = torch.tensor([[qout.item()]]).to(device) qinshft = torch.cat((qin[:,1:], curq), axis=1) @@ -1116,12 +1116,12 @@ def predict_each_group2(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid, dgaps[key] = curd[key] for key in curdshft: dgaps["shft_"+key] = curdshft[key] - if model_name in ["cdkt"]: + if model_name in ["atdkt"]: # y = model(curc.long(), curr.long(), curq.long()) # y = (y * one_hot(curcshft.long(), model.num_c)).sum(-1) # create input dcurinfos = {"qseqs": curq, "cseqs": curc, "rseqs": curr} - y, _, _ = model(dcurinfos) + y = model(dcurinfos) y = (y * one_hot(curcshft.long(), model.num_c)).sum(-1) elif model_name in ["dkt", "dkt+"]: y = model(curc.long(), curr.long()) @@ -1164,7 +1164,7 @@ def predict_each_group2(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid, # print(f"dgaps: {dgaps.keys()}") y = model(dcurinfos, dgaps) y = y[:,1:] - elif model_name in ["bakt"]: + elif model_name in ["simplekt"]: dcurinfos = {"qseqs": curq, "cseqs": curc, "rseqs": curr, "shft_qseqs":curqshft,"shft_cseqs":curcshft,"shft_rseqs":currshft} # print(f"finald: {finald.keys()}") diff --git a/pykt/models/init_model.py b/pykt/models/init_model.py index 2d801c1e..a575b3a8 100644 --- a/pykt/models/init_model.py +++ b/pykt/models/init_model.py @@ -19,8 +19,8 @@ from .skvmn import SKVMN from .hawkes import HawkesKT from .iekt import IEKT -from .cdkt import CDKT -from .bakt import BAKT +from .atdkt import ATDKT +from .simplekt import simpleKT from .bakt_time import BAKTTime from .qdkt import QDKT from .qikt import QIKT @@ -92,12 +92,12 @@ def init_model(model_name, model_config, data_config, emb_type): elif model_name == "qikt": model = QIKT(num_q=data_config['num_q'], num_c=data_config['num_c'], max_concepts=data_config['max_concepts'], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"],device=device).to(device) - elif model_name == "cdkt": - model = CDKT(data_config["num_q"], data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device) + elif model_name == "atdkt": + model = ATDKT(data_config["num_q"], data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device) elif model_name == "bakt_time": model = BAKTTime(data_config["num_c"], data_config["num_q"], data_config["num_rgap"], data_config["num_sgap"], data_config["num_pcount"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device) - elif model_name == "bakt": - model = BAKT(data_config["num_c"], data_config["num_q"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device) + elif model_name == "simplekt": + model = simpleKT(data_config["num_c"], data_config["num_q"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device) else: print("The wrong model name was used...") return None diff --git a/pykt/models/bakt.py b/pykt/models/simplekt.py similarity index 66% rename from pykt/models/bakt.py rename to pykt/models/simplekt.py index f394268a..3c2df150 100644 --- a/pykt/models/bakt.py +++ b/pykt/models/simplekt.py @@ -18,7 +18,7 @@ class Dim(IntEnum): seq = 1 feature = 2 -class BAKT(nn.Module): +class simpleKT(nn.Module): def __init__(self, n_question, n_pid, d_model, n_blocks, dropout, d_ff=256, loss1=0.5, loss2=0.5, loss3=0.5, start=50, num_layers=2, nheads=4, seq_len=200, @@ -32,7 +32,7 @@ def __init__(self, n_question, n_pid, d_ff : dimension for fully conntected net inside the basic block kq_same: if key query same, kq_same=1, else = 0 """ - self.model_name = "bakt" + self.model_name = "simplekt" print(f"model_name: {self.model_name}, emb_type: {emb_type}") self.n_question = n_question self.dropout = dropout @@ -70,38 +70,6 @@ def __init__(self, n_question, n_pid, ), nn.Dropout(self.dropout), nn.Linear(final_fc_dim2, 1) ) - - if self.emb_type.endswith("predcurc"): # predict cur question' cur concept - self.l1 = loss1 - self.l2 = loss2 - self.l3 = loss3 - num_layers = num_layers - self.emb_size, self.hidden_size = d_model, d_model - self.num_q, self.num_c = n_pid, n_question - - if self.num_q > 0: - self.question_emb = Embedding(self.num_q, self.emb_size) # 1.2 - if self.emb_type.find("trans") != -1: - self.nhead = nheads - # d_model = self.hidden_size# * 2 - encoder_layer = TransformerEncoderLayer(d_model, nhead=self.nhead) - encoder_norm = LayerNorm(d_model) - self.trans = TransformerEncoder(encoder_layer, num_layers=num_layers, norm=encoder_norm) - elif self.emb_type.find("lstm") != -1: - self.qlstm = LSTM(self.emb_size, self.hidden_size, batch_first=True) - # self.qdrop = Dropout(dropout) - self.qclasifier = Linear(self.hidden_size, self.num_c) - if self.emb_type.find("cemb") != -1: - self.concept_emb = Embedding(self.num_c, self.emb_size) # add concept emb - self.closs = CrossEntropyLoss() - # 加一个预测历史准确率的任务 - if self.emb_type.find("his") != -1: - self.start = start - self.hisclasifier = nn.Sequential( - # nn.Linear(self.hidden_size*2, self.hidden_size), nn.ELU(), nn.Dropout(dropout), - nn.Linear(self.hidden_size, self.hidden_size//2), nn.ELU(), nn.Dropout(dropout), - nn.Linear(self.hidden_size//2, 1)) - self.hisloss = nn.MSELoss() self.reset() @@ -126,128 +94,6 @@ def get_attn_pad_mask(self, sm): pad_attn_mask = pad_attn_mask.expand(batch_size, l, l) return pad_attn_mask.repeat(self.nhead, 1, 1) - def predcurc(self, qemb, cemb, xemb, dcur, train): - y2 = 0 - sm, c, cshft = dcur["smasks"], dcur["cseqs"], dcur["shft_cseqs"] - padsm = torch.ones(sm.shape[0], 1).to(device) - sm = torch.cat([padsm, sm], dim=-1) - c = torch.cat([c[:,0:1], cshft], dim=-1) - chistory = xemb - if self.num_q > 0: - catemb = qemb + chistory - else: - catemb = chistory - if self.separate_qa: - catemb += cemb - # if self.emb_type.find("cemb") != -1: akt本身就加了cemb - # catemb += cemb - - if self.emb_type.find("trans") != -1: - mask = ut_mask(seq_len = catemb.shape[1]) - qh = self.trans(catemb.transpose(0,1), mask).transpose(0,1) - else: - qh, _ = self.qlstm(catemb) - if train: - start = 0 - cpreds = self.qclasifier(qh[:,start:,:]) - flag = sm[:,start:]==1 - y2 = self.closs(cpreds[flag], c[:,start:][flag]) - - xemb = xemb + qh# + cemb - if self.separate_qa: - xemb = xemb + cemb - if self.emb_type.find("qemb") != -1: - xemb = xemb+qemb - - return y2, xemb - - def predcurc2(self, qemb, cemb, xemb, dcur, train): - y2 = 0 - sm, c, cshft = dcur["smasks"], dcur["cseqs"], dcur["shft_cseqs"] - padsm = torch.ones(sm.shape[0], 1).to(device) - sm = torch.cat([padsm, sm], dim=-1) - c = torch.cat([c[:,0:1], cshft], dim=-1) - chistory = cemb - if self.num_q > 0: - catemb = qemb + chistory - else: - catemb = chistory - - if self.emb_type.find("trans") != -1: - mask = ut_mask(seq_len = catemb.shape[1]) - qh = self.trans(catemb.transpose(0,1), mask).transpose(0,1) - else: - qh, _ = self.qlstm(catemb) - if train: - start = 0 - cpreds = self.qclasifier(qh[:,start:,:]) - flag = sm[:,start:]==1 - y2 = self.closs(cpreds[flag], c[:,start:][flag]) - - # xemb = xemb+qh - # if self.separate_qa: - # xemb = xemb+cemb - cemb = cemb + qh - xemb = xemb+qh - if self.emb_type.find("qemb") != -1: - cemb = cemb+qemb - xemb = xemb+qemb - - return y2, cemb, xemb - - def changecemb(self, qemb, cemb): - catemb = cemb - if self.emb_type.find("qemb") != -1: - catemb += qemb - if self.emb_type.find("trans") != -1: - mask = ut_mask(seq_len = catemb.shape[1]) - qh = self.trans(catemb.transpose(0,1), mask).transpose(0,1) - else: - qh, _ = self.qlstm(catemb) - - cemb = cemb + qh - if self.emb_type.find("qemb") != -1: - cemb = cemb+qemb - - return cemb - - def afterpredcurc(self, h, dcur): - y2 = 0 - sm, c, cshft = dcur["smasks"], dcur["cseqs"], dcur["shft_cseqs"] - padsm = torch.ones(sm.shape[0], 1).to(device) - sm = torch.cat([padsm, sm], dim=-1) - c = torch.cat([c[:,0:1], cshft], dim=-1) - - start = 1 - cpreds = self.qclasifier(h[:,start:,:]) - flag = sm[:,start:]==1 - y2 = self.closs(cpreds[flag], c[:,start:][flag]) - - return y2 - - def predhis(self, h, dcur): - sm = dcur["smasks"] - padsm = torch.ones(sm.shape[0], 1).to(device) - sm = torch.cat([padsm, sm], dim=-1) - - # predict history correctness rates - - start = self.start - rpreds = torch.sigmoid(self.hisclasifier(h)[:,start:,:]).squeeze(-1) - rsm = sm[:,start:] - rflag = rsm==1 - # rtrues = torch.cat([dcur["historycorrs"][:,0:1], dcur["shft_historycorrs"]], dim=-1)[:,start:] - padr = torch.zeros(h.shape[0], 1).to(device) - rtrues = torch.cat([padr, dcur["historycorrs"]], dim=-1)[:,start:] - # rtrues = dcur["historycorrs"][:,start:] - # rtrues = dcur["totalcorrs"][:,start:] - # print(f"rpreds: {rpreds.shape}, rtrues: {rtrues.shape}") - y3 = self.hisloss(rpreds[rflag], rtrues[rflag]) - - # h = self.dropout_layer(h) - # y = torch.sigmoid(self.out_layer(h)) - return y3 - def forward(self, dcur, qtest=False, train=False): q, c, r = dcur["qseqs"].long(), dcur["cseqs"].long(), dcur["rseqs"].long() qshft, cshft, rshft = dcur["shft_qseqs"].long(), dcur["shft_cseqs"].long(), dcur["shft_rseqs"].long() @@ -289,33 +135,6 @@ def forward(self, dcur, qtest=False, train=False): output = self.out(concat_q).squeeze(-1) m = nn.Sigmoid() preds = m(output) - elif emb_type.endswith("predcurc"): # predict current question' current concept - # predict concept - qemb = self.question_emb(pid_data) - - # predcurc(self, qemb, cemb, xemb, dcur, train): - cemb = q_embed_data - if emb_type.find("noxemb") != -1: - y2, q_embed_data, qa_embed_data = self.predcurc2(qemb, cemb, qa_embed_data, dcur, train) - else: - y2, qa_embed_data = self.predcurc(qemb, cemb, qa_embed_data, dcur, train) - - # q_embed_data = self.changecemb(qemb, cemb) - - # predict response - d_output = self.model(q_embed_data, qa_embed_data) - # if emb_type.find("after") != -1: - # curh = self.model(q_embed_data+qemb, qa_embed_data+qemb) - # y2 = self.afterpredcurc(curh, dcur) - if emb_type.find("his") != -1: - y3 = self.predhis(d_output, dcur) - - concat_q = torch.cat([d_output, q_embed_data], dim=-1) - # if emb_type.find("his") != -1: - # y3 = self.predhis(concat_q, dcur) - output = self.out(concat_q).squeeze(-1) - m = nn.Sigmoid() - preds = m(output) if train: return preds, y2, y3 @@ -338,7 +157,7 @@ def __init__(self, n_question, n_blocks, d_model, d_feature, self.d_model = d_model self.model_type = model_type - if model_type in {'bakt'}: + if model_type in {'simplekt'}: self.blocks_2 = nn.ModuleList([ TransformerLayer(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, dropout=dropout, n_heads=n_heads, kq_same=kq_same) @@ -547,25 +366,3 @@ def __init__(self, d_model, max_len=512): def forward(self, x): return self.weight[:, :x.size(Dim.seq), :] # ( 1,seq, Feature) - -class timeGap(nn.Module): - def __init__(self, num_rgap, num_sgap, num_pcount, emb_size) -> None: - super().__init__() - self.rgap_eye = torch.eye(num_rgap) - self.sgap_eye = torch.eye(num_sgap) - self.pcount_eye = torch.eye(num_pcount) - - input_size = num_rgap + num_sgap + num_pcount - - self.time_emb = nn.Linear(input_size, emb_size, bias=False) - - def forward(self, rgap, sgap, pcount): - rgap = self.rgap_eye[rgap].to(device) - sgap = self.sgap_eye[sgap].to(device) - pcount = self.pcount_eye[pcount].to(device) - - tg = torch.cat((rgap, sgap, pcount), -1) - tg_emb = self.time_emb(tg) - - return tg_emb - diff --git a/pykt/models/train_model.py b/pykt/models/train_model.py index 8a492ca3..383504d3 100644 --- a/pykt/models/train_model.py +++ b/pykt/models/train_model.py @@ -15,7 +15,7 @@ def cal_loss(model, ys, r, rshft, sm, preloss=[]): model_name = model.model_name - if model_name in ["cdkt", "bakt", "bakt_time"]: + if model_name in ["atdkt", "simplekt", "bakt_time"]: y = torch.masked_select(ys[0], sm) t = torch.masked_select(rshft, sm) # print(f"loss1: {y.shape}") @@ -81,14 +81,14 @@ def model_forward(model, data): cr = torch.cat((r[:,0:1], rshft), dim=1) if model_name in ["hawkes"]: ct = torch.cat((t[:,0:1], tshft), dim=1) - if model_name in ["cdkt"]: + if model_name in ["atdkt"]: # is_repeat = dcur["is_repeat"] y, y2, y3 = model(dcur, train=True) if model.emb_type.find("bkt") == -1 and model.emb_type.find("addcshft") == -1: y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) # y2 = (y2 * one_hot(cshft.long(), model.num_c)).sum(-1) ys = [y, y2, y3] # first: yshft - elif model_name in ["bakt"]: + elif model_name in ["simplekt"]: y, y2, y3 = model(dcur, train=True) ys = [y[:,1:], y2, y3] elif model_name in ["bakt_time"]: