Skip to content

Commit

Permalink
change atdkt and simplekt name (#90)
Browse files Browse the repository at this point in the history
Co-authored-by: Liu-lqq <[email protected]>
  • Loading branch information
pykt-team and Liu-lqq authored Feb 10, 2023
1 parent 173f807 commit 4db3fbe
Show file tree
Hide file tree
Showing 14 changed files with 60 additions and 266 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
program: wandb_cdkt_train.py
program: wandb_atdkt_train.py
method: bayes
metric:
goal: maximize
name: validauc
parameters:
model_name:
values: ["cdkt"]
values: ["atdkt"]
dataset_name:
values: ["xes"]
emb_type:
values: ["qiddelxembhistranscembpredcurc"]
save_dir:
values: ["models/cdkt_tiaocan"]
values: ["models/atdkt_tiaocan"]
emb_size:
values: [64, 256]
num_attn_heads:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
program: ./wandb_bakt_train.py
program: ./wandb_simplekt_train.py
method: bayes
metric:
goal: maximize
name: validauc
parameters:
model_name:
values: ["bakt"]
values: ["simplekt"]
dataset_name:
values: ["xes"]
emb_type:
values: ["qid"]
save_dir:
values: ["models/bakt_tiaocan"]
values: ["models/simplekt_tiaocan"]
d_model:
values: [64, 256]
d_ff:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
parser = argparse.ArgumentParser()

parser.add_argument("--dataset_name", type=str, default="algebra2005")
parser.add_argument("--model_name", type=str, default="cdkt")
parser.add_argument("--model_name", type=str, default="atdkt")
parser.add_argument("--emb_type", type=str, default="qid")
parser.add_argument("--save_dir", type=str, default="saved_model")
parser.add_argument("--seed", type=int, default=3407)
Expand Down
2 changes: 1 addition & 1 deletion examples/wandb_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main(params):
trained_params = config["params"]
model_name, dataset_name, emb_type = trained_params["model_name"], trained_params["dataset_name"], trained_params["emb_type"]
seq_len = config["train_config"]["seq_len"]
if model_name in ["saint", "sakt", "cdkt"]:
if model_name in ["saint", "sakt", "atdkt"]:
model_config["seq_len"] = seq_len
data_config = config["data_config"]

Expand Down
2 changes: 1 addition & 1 deletion examples/wandb_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(params):
del model_config[remove_item]
trained_params = config["params"]
model_name, dataset_name, emb_type = trained_params["model_name"], trained_params["dataset_name"], trained_params["emb_type"]
if model_name in ["saint", "sakt", "cdkt"]:
if model_name in ["saint", "sakt", "atdkt"]:
train_config = config["train_config"]
seq_len = train_config["seq_len"]
model_config["seq_len"] = seq_len
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="algebra2005")
parser.add_argument("--model_name", type=str, default="bakt")
parser.add_argument("--model_name", type=str, default="simplekt")
parser.add_argument("--emb_type", type=str, default="qid")
parser.add_argument("--save_dir", type=str, default="saved_model")
# parser.add_argument("--learning_rate", type=float, default=1e-5)
Expand Down
4 changes: 2 additions & 2 deletions examples/wandb_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main(params):
train_config = config["train_config"]
if model_name in ["dkvmn","deep_irt", "sakt", "saint","saint++", "akt", "atkt", "lpkt", "skvmn"]:
train_config["batch_size"] = 64 ## because of OOM
if model_name in ["bakt", "bakt_time"]:
if model_name in ["simplekt", "bakt_time"]:
train_config["batch_size"] = 64 ## because of OOM
if model_name in ["gkt"]:
train_config["batch_size"] = 16
Expand Down Expand Up @@ -88,7 +88,7 @@ def main(params):
for remove_item in ['use_wandb','learning_rate','add_uuid','l2']:
if remove_item in model_config:
del model_config[remove_item]
if model_name in ["saint","saint++", "sakt", "cdkt", "bakt", "bakt_time"]:
if model_name in ["saint","saint++", "sakt", "atdkt", "simplekt", "bakt_time"]:
model_config["seq_len"] = seq_len

debug_print(text = "init_model",fuc_name="main")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import FloatTensor, LongTensor
import numpy as np

class CDKTDataset(Dataset):
class ATDKTDataset(Dataset):
"""Dataset for KT
can use to init dataset for: (for models except dkt_forget)
train data, valid data
Expand All @@ -24,16 +24,16 @@ class CDKTDataset(Dataset):
qtest (bool, optional): is question evaluation or not. Defaults to False.
"""
def __init__(self, file_path, input_type, folds, qtest=False):
super(CDKTDataset, self).__init__()
super(ATDKTDataset, self).__init__()
sequence_path = file_path
self.input_type = input_type
self.qtest = qtest
folds = sorted(list(folds))
folds_str = "_" + "_".join([str(_) for _ in folds])
if self.qtest:
processed_data = file_path + folds_str + "_cdkt_qtest.pkl"
processed_data = file_path + folds_str + "_atdkt_qtest.pkl"
else:
processed_data = file_path + folds_str + "_cdkt.pkl"
processed_data = file_path + folds_str + "_atdkt.pkl"
self.dpath = "/".join(file_path.split("/")[0:-1])

if not os.path.exists(processed_data):
Expand All @@ -51,8 +51,6 @@ def __init__(self, file_path, input_type, folds, qtest=False):
self.dori, self.dqtest = pd.read_pickle(processed_data)
else:
self.dori = pd.read_pickle(processed_data)
for key in self.dori:
self.dori[key] = self.dori[key]#[:100]
print(f"file path: {file_path}, qlen: {len(self.dori['qseqs'])}, clen: {len(self.dori['cseqs'])}, rlen: {len(self.dori['rseqs'])}")

def __len__(self):
Expand Down
18 changes: 9 additions & 9 deletions pykt/datasets/init_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from .data_loader import KTDataset
from .dkt_forget_dataloader import DktForgetDataset
from .cdkt_dataloader import CDKTDataset
from .atdkt_dataloader import ATDKTDataset
from .lpkt_dataloader import LPKTDataset
from .lpkt_utils import generate_time2idx
from .que_data_loader import KTQueDataset
Expand Down Expand Up @@ -39,12 +39,12 @@ def init_test_datasets(data_config, model_name, batch_size):
concept_num=data_config['num_c'], max_concepts=data_config['max_concepts'])
test_question_dataset = None
test_question_window_dataset= None
elif model_name in ["cdkt"]:
test_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1})
test_window_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1})
elif model_name in ["atdkt"]:
test_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1})
test_window_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1})
if "test_question_file" in data_config:
test_question_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True)
test_question_window_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True)
test_question_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True)
test_question_window_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True)
else:
test_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1})
test_window_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1})
Expand Down Expand Up @@ -96,9 +96,9 @@ def init_dataset4train(dataset_name, model_name, data_config, i, batch_size):
curtrain = KTQueDataset(os.path.join(data_config["dpath"], data_config["train_valid_file_quelevel"]),
input_type=data_config["input_type"], folds=all_folds - {i},
concept_num=data_config['num_c'], max_concepts=data_config['max_concepts'])
elif model_name in ["cdkt"]:
curvalid = CDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i})
curtrain = CDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i})
elif model_name in ["atdkt"]:
curvalid = ATDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i})
curtrain = ATDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i})
else:
curvalid = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i})
curtrain = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i})
Expand Down
15 changes: 7 additions & 8 deletions pykt/models/cdkt.py → pykt/models/atdkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

device = "cpu" if not torch.cuda.is_available() else "cuda"

class CDKT(Module):
def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid', num_layers=1, num_attn_heads=5, l1=0.5, l2=0.5, l3=0.5, start=50, emb_path="", pretrain_dim=768):
class ATDKT(Module):
def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid',
num_layers=1, num_attn_heads=5, l1=0.5, l2=0.5, l3=0.5, start=50, emb_path="", pretrain_dim=768):
super().__init__()
self.model_name = "cdkt"
self.model_name = "atdkt"
print(f"qnum: {num_q}, cnum: {num_c}")
print(f"emb_type: {emb_type}")
self.num_q = num_q
Expand All @@ -34,7 +35,6 @@ def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid',
if self.emb_type.find("qemb") != -1:
self.question_emb = Embedding(self.num_q, self.emb_size)

# 加一个预测历史准确率的任务
self.start = start
self.hisclasifier = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size//2), nn.ReLU(), nn.Dropout(dropout),
Expand Down Expand Up @@ -62,7 +62,6 @@ def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid',
self.concept_emb = Embedding(self.num_c, self.emb_size) # add concept emb

self.closs = CrossEntropyLoss()
# 加一个预测历史准确率的任务
if self.emb_type.find("his") != -1:
self.start = start
self.hisclasifier = nn.Sequential(
Expand Down Expand Up @@ -122,7 +121,7 @@ def predcurc(self, dcur, q, c, r, xemb, train):
h = self.dropout_layer(h)
y = self.out_layer(h)
y = torch.sigmoid(y)
return y, y2, y3, rpreds, qh
return y, y2, y3

def forward(self, dcur, train=False): ## F * xemb
# print(f"keys: {dcur.keys()}")
Expand Down Expand Up @@ -162,10 +161,10 @@ def forward(self, dcur, train=False): ## F * xemb
y = self.out_layer(h)
y = torch.sigmoid(y)
elif emb_type.endswith("predcurc"): # predict current question' current concept
y, y2, y3, rpreds, qh = self.predcurc(dcur, q, c, r, xemb, train)
y, y2, y3 = self.predcurc(dcur, q, c, r, xemb, train)

if train:
return y, y2, y3
else:
return y, rpreds, qh
return y

32 changes: 16 additions & 16 deletions pykt/models/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,20 @@ def evaluate(model, test_loader, model_name, save_path=""):
cq = torch.cat((q[:,0:1], qshft), dim=1)
cc = torch.cat((c[:,0:1], cshft), dim=1)
cr = torch.cat((r[:,0:1], rshft), dim=1)
if model_name in ["cdkt"]:
if model_name in ["atdkt"]:
'''
y = model(dcur)
import pickle
with open(f"{test_mini_index}_result.pkl",'wb') as f:
data = {"y":y,"cshft":cshft,"num_c":model.num_c,"rshft":rshft,"qshft":qshft,"sm":sm}
pickle.dump(data,f)
'''
y, rpreds, qh = model(dcur)
y = model(dcur)
y = (y * one_hot(cshft.long(), model.num_c)).sum(-1)
elif model_name in ["bakt_time"]:
y = model(dcur, dgaps)
y = y[:,1:]
elif model_name in ["bakt"]:
elif model_name in ["simplekt"]:
y = model(dcur)
y = y[:,1:]
elif model_name in ["dkt", "dkt+"]:
Expand Down Expand Up @@ -158,7 +158,7 @@ def early_fusion(curhs, model, model_name):
que_diff = model.diff_layer(curhs[1])#equ 13
p = torch.sigmoid(3.0*stu_ability-que_diff)#equ 14
p = p.squeeze(-1)
elif model_name in ["akt", "bakt", "bakt_time", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
elif model_name in ["akt", "simplekt", "bakt_time", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
output = model.out(curhs[0]).squeeze(-1)
m = nn.Sigmoid()
p = m(output)
Expand Down Expand Up @@ -204,7 +204,7 @@ def effective_fusion(df, model, model_name, fusion_type):

curhs, curr = [[], []], []
dcur = {"late_trues": [], "qidxs": [], "questions": [], "concepts": [], "row": [], "concept_preds": []}
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "bakt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "simplekt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
for ui in df:
# 一题一题处理
curdf = ui[1]
Expand Down Expand Up @@ -252,7 +252,7 @@ def group_fusion(dmerge, model, model_name, fusion_type, fout):
if cq.shape[1] == 0:
cq = cc

hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "bakt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "simplekt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]

alldfs, drest = [], dict() # not predict infos!
# print(f"real bz in group fusion: {rs.shape[0]}")
Expand Down Expand Up @@ -349,7 +349,7 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion
# dkvmn / akt / saint: give cur -> predict cur
# sakt: give past+cur -> predict cur
# kqn: give past+cur -> predict cur
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "bakt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
hasearly = ["dkvmn","deep_irt", "skvmn", "kqn", "akt", "simplekt", "bakt_time", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"]
if save_path != "":
fout = open(save_path, "w", encoding="utf8")
if model_name in hasearly:
Expand Down Expand Up @@ -398,7 +398,7 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion
# start_hemb = torch.tensor([-1] * (h.shape[0] * h.shape[2])).reshape(h.shape[0], 1, h.shape[2]).to(device)
# print(start_hemb.shape, h.shape)
# h = torch.cat((start_hemb, h), dim=1) # add the first hidden emb
elif model_name in ["bakt"]:
elif model_name in ["simplekt"]:
y, h = model(dcurori, qtest=True, train=False)
y = y[:,1:]
elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]:
Expand All @@ -418,8 +418,8 @@ def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion
start_hemb = torch.tensor([-1] * (ek.shape[0] * ek.shape[2])).reshape(ek.shape[0], 1, ek.shape[2]).to(device)
ek = torch.cat((start_hemb, ek), dim=1) # add the first hidden emb
es = torch.cat((start_hemb, es), dim=1) # add the first hidden emb
elif model_name in ["cdkt"]:
y, _, _ = model(dcurori)#c.long(), r.long(), q.long())
elif model_name in ["atdkt"]:
y = model(dcurori)#c.long(), r.long(), q.long())
y = (y * one_hot(cshft.long(), model.num_c)).sum(-1)
elif model_name in ["dkt", "dkt+"]:
y = model(c.long(), r.long())
Expand Down Expand Up @@ -774,10 +774,10 @@ def predict_each_group(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid,
dgaps[key] = din[key]
for key in dcur:
dgaps["shft_"+key] = dcur[key]
if model_name in ["cdkt"]: ## need change!
if model_name in ["atdkt"]: ## need change!
# create input
dcurinfos = {"qseqs": qin, "cseqs": cin, "rseqs": rin}
y, _, _ = model(dcurinfos)
y = model(dcurinfos)
pred = y[0][-1][cout.item()]
elif model_name in ["dkt", "dkt+"]:
y = model(cin.long(), rin.long())
Expand Down Expand Up @@ -850,7 +850,7 @@ def predict_each_group(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid,

y = model(dcurinfos, dgaps)
pred = y[0][-1]
elif model_name in ["bakt"]:
elif model_name in ["simplekt"]:
if qout != None:
curq = torch.tensor([[qout.item()]]).to(device)
qinshft = torch.cat((qin[:,1:], curq), axis=1)
Expand Down Expand Up @@ -1116,12 +1116,12 @@ def predict_each_group2(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid,
dgaps[key] = curd[key]
for key in curdshft:
dgaps["shft_"+key] = curdshft[key]
if model_name in ["cdkt"]:
if model_name in ["atdkt"]:
# y = model(curc.long(), curr.long(), curq.long())
# y = (y * one_hot(curcshft.long(), model.num_c)).sum(-1)
# create input
dcurinfos = {"qseqs": curq, "cseqs": curc, "rseqs": curr}
y, _, _ = model(dcurinfos)
y = model(dcurinfos)
y = (y * one_hot(curcshft.long(), model.num_c)).sum(-1)
elif model_name in ["dkt", "dkt+"]:
y = model(curc.long(), curr.long())
Expand Down Expand Up @@ -1164,7 +1164,7 @@ def predict_each_group2(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid,
# print(f"dgaps: {dgaps.keys()}")
y = model(dcurinfos, dgaps)
y = y[:,1:]
elif model_name in ["bakt"]:
elif model_name in ["simplekt"]:
dcurinfos = {"qseqs": curq, "cseqs": curc, "rseqs": curr,
"shft_qseqs":curqshft,"shft_cseqs":curcshft,"shft_rseqs":currshft}
# print(f"finald: {finald.keys()}")
Expand Down
Loading

0 comments on commit 4db3fbe

Please sign in to comment.