-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathtrain.py
155 lines (132 loc) · 4.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import argparse
import json
import pickle
import torch
from torch.utils.data import DataLoader, random_split
from torch.optim import SGD, Adam
from data_loaders.assist2009 import ASSIST2009
from data_loaders.assist2015 import ASSIST2015
from data_loaders.algebra2005 import Algebra2005
from data_loaders.statics2011 import Statics2011
from models.dkt import DKT
from models.dkt_plus import DKTPlus
from models.dkvmn import DKVMN
from models.sakt import SAKT
from models.gkt import PAM, MHA
from models.utils import collate_fn
def main(model_name, dataset_name):
if not os.path.isdir("ckpts"):
os.mkdir("ckpts")
ckpt_path = os.path.join("ckpts", model_name)
if not os.path.isdir(ckpt_path):
os.mkdir(ckpt_path)
ckpt_path = os.path.join(ckpt_path, dataset_name)
if not os.path.isdir(ckpt_path):
os.mkdir(ckpt_path)
with open("config.json") as f:
config = json.load(f)
model_config = config[model_name]
train_config = config["train_config"]
batch_size = train_config["batch_size"]
num_epochs = train_config["num_epochs"]
train_ratio = train_config["train_ratio"]
learning_rate = train_config["learning_rate"]
optimizer = train_config["optimizer"] # can be [sgd, adam]
seq_len = train_config["seq_len"]
if dataset_name == "ASSIST2009":
dataset = ASSIST2009(seq_len)
elif dataset_name == "ASSIST2015":
dataset = ASSIST2015(seq_len)
elif dataset_name == "Algebra2005":
dataset = Algebra2005(seq_len)
elif dataset_name == "Statics2011":
dataset = Statics2011(seq_len)
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
with open(os.path.join(ckpt_path, "model_config.json"), "w") as f:
json.dump(model_config, f, indent=4)
with open(os.path.join(ckpt_path, "train_config.json"), "w") as f:
json.dump(train_config, f, indent=4)
if model_name == "dkt":
model = DKT(dataset.num_q, **model_config).to(device)
elif model_name == "dkt+":
model = DKTPlus(dataset.num_q, **model_config).to(device)
elif model_name == "dkvmn":
model = DKVMN(dataset.num_q, **model_config).to(device)
elif model_name == "sakt":
model = SAKT(dataset.num_q, **model_config).to(device)
elif model_name == "gkt":
if model_config["method"] == "PAM":
model = PAM(dataset.num_q, **model_config).to(device)
elif model_config["method"] == "MHA":
model = MHA(dataset.num_q, **model_config).to(device)
else:
print("The wrong model name was used...")
return
train_size = int(len(dataset) * train_ratio)
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size]
)
if os.path.exists(os.path.join(dataset.dataset_dir, "train_indices.pkl")):
with open(
os.path.join(dataset.dataset_dir, "train_indices.pkl"), "rb"
) as f:
train_dataset.indices = pickle.load(f)
with open(
os.path.join(dataset.dataset_dir, "test_indices.pkl"), "rb"
) as f:
test_dataset.indices = pickle.load(f)
else:
with open(
os.path.join(dataset.dataset_dir, "train_indices.pkl"), "wb"
) as f:
pickle.dump(train_dataset.indices, f)
with open(
os.path.join(dataset.dataset_dir, "test_indices.pkl"), "wb"
) as f:
pickle.dump(test_dataset.indices, f)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
collate_fn=collate_fn
)
test_loader = DataLoader(
test_dataset, batch_size=test_size, shuffle=True,
collate_fn=collate_fn
)
if optimizer == "sgd":
opt = SGD(model.parameters(), learning_rate, momentum=0.9)
elif optimizer == "adam":
opt = Adam(model.parameters(), learning_rate)
aucs, loss_means = \
model.train_model(
train_loader, test_loader, num_epochs, opt, ckpt_path
)
with open(os.path.join(ckpt_path, "aucs.pkl"), "wb") as f:
pickle.dump(aucs, f)
with open(os.path.join(ckpt_path, "loss_means.pkl"), "wb") as f:
pickle.dump(loss_means, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="dkt",
help="The name of the model to train. \
The possible models are in [dkt, dkt+, dkvmn, sakt, gkt]. \
The default model is dkt."
)
parser.add_argument(
"--dataset_name",
type=str,
default="ASSIST2009",
help="The name of the dataset to use in training. \
The possible datasets are in \
[ASSIST2009, ASSIST2015, Algebra2005, Statics2011]. \
The default dataset is ASSIST2009."
)
args = parser.parse_args()
main(args.model_name, args.dataset_name)