-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
104 lines (79 loc) · 3.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Training the autoencoder.
#from pytorch_lightning.loggers import LightningLoggerBase
from typing import List, Optional, Tuple
import pytorch_lightning as pl
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
import hydra
from omegaconf import DictConfig
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from pcnn import DATA_DIR
import torch
from omegaconf import OmegaConf
import pandas as pd
import time
import wandb
def get_logger(exp_name, wandb_user):
logger = WandbLogger(
name=exp_name,
project='pcnn',
entity=wandb_user,
log_model=False
)
log_dir = logger.experiment.dir
return logger, log_dir
def train(cfg: DictConfig) -> Tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator which applies extra utilities
before and after the call.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
if not cfg.trainer.wandb:
wandb.init(mode="disabled")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)()
model: LightningModule = hydra.utils.instantiate(
cfg.model, input_dim=datamodule.input_dim, num_classes = datamodule.num_classes)
model.to(torch.float32)
logger, log_dir = get_logger(
exp_name=cfg.name, wandb_user=cfg.trainer.wandb_user)
# checkpointing
checkpoint_cb = ModelCheckpoint(
dirpath=log_dir,
monitor='val_loss',
mode='min',
verbose=True
)
early_stopping_cb = EarlyStopping(
monitor='val_loss', patience=cfg.trainer.early_stopping)
trainer = pl.Trainer(accelerator=cfg.trainer.accelerator, logger=logger, callbacks=[
checkpoint_cb, early_stopping_cb], max_epochs=cfg.trainer.max_epochs, devices = 1)
trainer.fit(model, datamodule=datamodule)
val_metrics = trainer.test(model, dataloaders = [datamodule.val_dataloader()],
ckpt_path = "best")
val_metrics = {k.replace("test","val"): v for k,v in val_metrics[0].items()}
test_metrics = trainer.test(model, dataloaders = [datamodule.test_dataloader()],
ckpt_path = "best")
df_val = pd.DataFrame(val_metrics, index = [0])
df_test = pd.DataFrame(test_metrics, index = [0])
df = df_val.join(df_test)
df.to_pickle(
f"final_{cfg.dataset_name}_{cfg.model_name}.pkl")
@hydra.main(version_base = None, config_path="config/", config_name="main.yaml")
def main(config: DictConfig):
# Imports can be nested inside @hydra.main to optimize tab completion
# https://github.com/facebookresearch/hydra/issues/934
# Train model
return train(config)
if __name__ == "__main__":
start_time = time.time()
main()
end_time = time.time()
print(f"Total time taken: {end_time - start_time} seconds")