-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
104 lines (71 loc) · 2.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import hydra
from omegaconf import DictConfig, OmegaConf
import os
import lightning as L
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
import torch
from src.model import B2T_Model
from src.data import B2T_DataModule
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
import wandb
import shutil
@hydra.main(version_base="1.3", config_path="config", config_name="config")
def main(config: DictConfig):
working_dir = os.getcwd()
original_cwd = hydra.utils.get_original_cwd()
shutil.copyfile(
os.path.join(original_cwd, "src/data.py"),
os.path.join(working_dir, ".hydra/data.py")
)
if config.get("seed"):
L.seed_everything(config.seed, workers=True)
torch.autograd.set_detect_anomaly(True)
if config.get("float32_matmul_precision"):
torch.set_float32_matmul_precision(config.float32_matmul_precision)
if config.get("from_ckpt") and config.from_ckpt != 0:
model = B2T_Model.load_from_checkpoint(config.from_ckpt, strict=False, **config.model)
else:
model = B2T_Model(**config.model)
# load datamodule
data_module = B2T_DataModule(**config.model)
# loggers
loggers = []
if config.get("wandb") and config.wandb:
wdb = WandbLogger(
project=config.experiment_name,
settings=wandb.Settings(code_dir=original_cwd)
)
loggers.append(wdb)
artifact = wandb.Artifact(name="configs", type="configs")
artifact.add_dir(local_path="./.hydra")
wdb.experiment.log_artifact(artifact)
print(f"The current working directory is {working_dir}")
tb = TensorBoardLogger(save_dir="./", name="", default_hp_metric=False)
loggers.append(tb)
lr_monitor = LearningRateMonitor(logging_interval="step")
checkpoint_callback = ModelCheckpoint(
monitor="wer",
mode="min",
save_top_k=3,
save_last=True,
save_weights_only=True,
dirpath="ckpts",
)
callbacks = [lr_monitor, checkpoint_callback]
trainer = L.Trainer(
**config.trainer,
logger=loggers,
callbacks=callbacks,
)
trainer.fit(model, datamodule=data_module)
# trainer.test(model, datamodule=data_module)
# trainer.test(ckpt_path="best", datamodule=data_module)
# if config.get("wandb") and config.wandb:
# artifact = wandb.Artifact(name="backup", type="configs")
# # artifact.add_file(local_path="./valid.txt")
# # artifact.add_file(local_path="./test.txt")
# artifact.add_dir(local_path="./.hydra")
# wdb.experiment.log_artifact(artifact)
if __name__ == "__main__":
main()