-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
81 lines (58 loc) · 3.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import time
import logging
import argparse
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from utils.config import get_config
from utils.util import get_logger, set_random_seed, load_state_dict
from utils.dataflow import get_transforms, get_dataloader, merge_train_dev_data
from utils.mango_dataset import MangoDataset
from utils.optim import get_optimizer, get_lr_scheduler, Loss
from utils.trainer import Trainer
from utils.model import Model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, help="path to the config file", required=True)
parser.add_argument("--load-pretrained", type=bool, help="load pretrained weight", default=True)
parser.add_argument("--title", type=str, help="experiment title", required=True)
args = parser.parse_args()
CONFIG = get_config(args.cfg)
if CONFIG.cuda:
device = torch.device("cuda" if (torch.cuda.is_available() and CONFIG.ngpu > 0) else "cpu")
else:
device = torch.device("cpu")
set_random_seed(CONFIG.seed)
get_logger(CONFIG.log_dir)
logging.info("=================================== Experiment title : {} Start ===========================".format(args.title))
train_data_info = merge_train_dev_data(CONFIG.dataset_dir)
folds = StratifiedKFold(n_splits=5, shuffle=True).split(np.arange(train_data_info.shape[0]), train_data_info.iloc[:, 5])
train_transform, val_transform, test_transform = get_transforms(CONFIG)
defect_classes = CONFIG.defect_classes if CONFIG.defect_classification else []
for fold, (trn_idx, val_idx) in enumerate(folds):
logging.info("Fold : {}".format(fold))
fold_train_data = train_data_info.iloc[trn_idx]
fold_val_data = train_data_info.iloc[val_idx]
#train_data = MangoDataset(fold_train_data, train_root_path, CONFIG.labels_name, transforms=train_transform, defect_classes=defect_classes)
train_data = MangoDataset(fold_train_data, CONFIG.dataset_dir, CONFIG.labels_name, transforms=train_transform, defect_classes=defect_classes)
val_data = MangoDataset(fold_val_data, CONFIG.dataset_dir, CONFIG.labels_name, transforms=val_transform)
train_loader, val_loader, test_loader = get_dataloader(train_data, val_data, val_data, CONFIG)
model = Model(input_size=CONFIG.input_size, classes=CONFIG.classes, se=True, activation="hswish", l_cfgs_name=CONFIG.model, seg_state=CONFIG.seg_state)
if args.load_pretrained:
pretrained_dict = load_state_dict(CONFIG.model_pretrained, use_ema=CONFIG.ema)
model.load_state_dict(pretrained_dict, strict=False)
logging.info("Load pretrained from {} to {}".format(CONFIG.model_pretrained, CONFIG.model))
if (device.type == "cuda" and CONFIG.ngpu >= 1):
model = model.to(device)
model = nn.DataParallel(model, list(range(CONFIG.ngpu)))
optimizer = get_optimizer(model.parameters(), CONFIG.optim_state)
criterion = Loss(device, CONFIG)
scheduler = get_lr_scheduler(optimizer, len(train_loader), CONFIG)
start_time = time.time()
trainer = Trainer(criterion, optimizer, scheduler, device, CONFIG)
trainer.train_loop(train_loader, test_loader, model, fold)
logging.info("Total training time : {:.2f}".format(time.time() - start_time))
logging.info("=================================== Experiment title : {} End ===========================".format(args.title))