-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_HRNet.py
85 lines (68 loc) · 3.03 KB
/
train_HRNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from dataset import FileLoader
import glob, os
from torch.utils.data import DataLoader
import torch
from model.HRNet import HighResolutionNet
from model import SemanticSegmentationNet
from lib.config import config, update_config
import torch.backends.cudnn as cudnn
import model
update_config(config, "hrnet_config.yaml")
def worker_init_fn(worker_id):
# ! to make the seed chain reproducible, must use the torch random, not numpy
# the torch rng from main thread will regenerate a base seed, which is then
# copied into the dataloader each time it created (i.e start of each epoch)
# then dataloader with this seed will spawn worker, now we reseed the worker
worker_info = torch.utils.data.get_worker_info()
# to make it more random, simply switch torch.randint to np.randint
worker_seed = torch.randint(0, 2 ** 32, (1,))[0].cpu().item() + worker_id
# print('Loader Worker %d Uses RNG Seed: %d' % (worker_id, worker_seed))
# retrieve the dataset copied into this worker process
# then set the random seed for each augmentation
worker_info.dataset.setup_augmentor(worker_id, worker_seed)
return
batch_size = {"train": 1, "valid": 1}
tr_file_list = glob.glob("../hover_net/dataset/training_data/consep/consep/train/540x540_164x164/*.npy")
ts_file_list = glob.glob("../hover_net/dataset/training_data/consep/consep/valid/540x540_164x164/*.npy")
tr_file_list.sort()
ts_file_list.sort()
train_data = FileLoader(tr_file_list,
input_shape=(256, 256),
mask_shape=(256, 256),
mode="train",
)
test_data = FileLoader(ts_file_list,
input_shape=(256, 256),
mask_shape=(256, 256),
mode="valid")
tr_loader = DataLoader(train_data,
num_workers=1,
batch_size=batch_size["train"],
shuffle=True,
drop_last=True,
worker_init_fn=worker_init_fn,
)
ts_loader = DataLoader(test_data,
num_workers=1,
batch_size=batch_size["valid"],
shuffle=False,
drop_last=False,
worker_init_fn=worker_init_fn,
)
# if True:
# torch.cuda.set_device(-1)
# torch.distributed.init_process_group(
# backend="nccl", init_method="env://",
# )
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
model_base = eval('model.'+"HRNet" +'.get_seg_model')(config)
model = SemanticSegmentationNet(model_base=model_base)
best_score = 0.0
for epoch in range(1000):
model.train_on_loader(tr_loader)
test_dict = model.test_on_loader(ts_loader)
if test_dict['test_iou'] >= best_score:
model_to_save = model.model_base.module if hasattr(model.model_base, "module") else model.model_base
torch.save(model_to_save.state_dict(), "HRNet.pth")