-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_wsm.py
120 lines (106 loc) · 4.12 KB
/
train_wsm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# MIT License
# Copyright (c) 2024 Batista Lab (Yale University)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from pathlib import Path
import lightning as L
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
import DirectMultiStep.helpers as helpers
from DirectMultiStep.Models.Architecture import VanillaTransformerConfig
from DirectMultiStep.Models.Configure import determine_device, prepare_model
from DirectMultiStep.Models.Training import PLTraining
data_path = Path(__file__).resolve().parent / "Data" / "Processed"
train_path = Path(__file__).resolve().parent / "Data" / "Training"
run_name = "sm_run_name"
batch_size = 32
lr = 3e-4
steps_per_epoch = 30299
max_epochs = 12 + 12
L.seed_everything(42)
n_devices = 4
torch.set_float32_matmul_precision("high")
model_10m = dict(n_layers=6, ff_mult=3, hid_dim=256)
model_60m = dict(n_layers=8, ff_mult=4, hid_dim=512)
model_config = VanillaTransformerConfig(
input_dim=53,
output_dim=53,
input_max_length=145 + 135,
output_max_length=1074 + 1,
pad_index=52,
attn_bias=False,
ff_activation="gelu",
**model_10m,
# **model_60m,
)
# enc and dec configs may be different
model = prepare_model(enc_config=model_config, dec_config=model_config)
if __name__ == "__main__":
# Training hyperparameters
mask_idx, pad_idx = 51, 52
ds_train, ds_val = helpers.prepare_datasets(
train_data_path=data_path / "all_dataset_nperms=3.pkl",
val_data_path=data_path / "n1_dataset_nperms=1.pkl",
metadata_path=data_path / "character_dictionary.yaml",
)
dl_train = torch.utils.data.DataLoader(
dataset=ds_train,
batch_size=batch_size,
shuffle=True,
num_workers=120,
pin_memory=True,
)
dl_val = torch.utils.data.DataLoader(
dataset=ds_val,
batch_size=batch_size,
shuffle=False,
num_workers=120,
pin_memory=True,
)
criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, reduction="mean")
lightning_model = PLTraining(
model=model,
pad_idx=pad_idx,
mask_idx=mask_idx,
criterion=criterion,
lr=lr,
batch_size=batch_size,
warmup_steps=int(steps_per_epoch * 0.1),
decay_steps=steps_per_epoch * 20,
decay_factor=0.1,
)
device = determine_device()
checkpoint_callback = ModelCheckpoint(
monitor="val_loss", dirpath=train_path / run_name, save_last=True, save_top_k=1
)
model_summary = RichModelSummary(max_depth=2)
trainer = L.Trainer(
default_root_dir=train_path / run_name,
max_epochs=max_epochs,
accelerator=device,
devices=n_devices,
num_nodes=1,
strategy="fsdp", # if using CUDA
callbacks=[checkpoint_callback, model_summary],
gradient_clip_val=1.0,
gradient_clip_algorithm="value",
)
latest_ckpt = helpers.find_checkpoint(train_path, run_name)
if latest_ckpt is not None:
print(f"Loading model from {latest_ckpt}")
trainer.fit(lightning_model, dl_train, dl_val, ckpt_path=latest_ckpt)
else:
trainer.fit(lightning_model, dl_train, dl_val)