-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_img
executable file
·137 lines (109 loc) · 5.69 KB
/
train_img
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
import os
import numpy as np
import torchvision.models.segmentation
import torch
from torch.utils.data import DataLoader
from argparse import ArgumentParser
import datasets
from tqdm import tqdm
import segmentation_models_pytorch as smp
def create_model(architecture, n_inputs, n_outputs, pretrained=True):
assert architecture in ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large']
print('Creating model %s with %i inputs and %i outputs' % (architecture, n_inputs, n_outputs))
Architecture = eval('torchvision.models.segmentation.%s' % architecture)
model = Architecture(pretrained=pretrained)
arch = architecture.split('_')[0]
encoder = '_'.join(architecture.split('_')[1:])
# Change input layer to accept n_inputs
if encoder == 'mobilenet_v3_large':
model.backbone['0'][0] = torch.nn.Conv2d(n_inputs, 16,
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
else:
model.backbone['conv1'] = torch.nn.Conv2d(n_inputs, 64,
kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# Change final layer to output n classes
if arch == 'lraspp':
model.classifier.low_classifier = torch.nn.Conv2d(40, n_outputs, kernel_size=(1, 1), stride=(1, 1))
model.classifier.high_classifier = torch.nn.Conv2d(128, n_outputs, kernel_size=(1, 1), stride=(1, 1))
elif arch == 'fcn':
model.classifier[-1] = torch.nn.Conv2d(512, n_outputs, kernel_size=(1, 1), stride=(1, 1))
elif arch == 'deeplabv3':
model.classifier[-1] = torch.nn.Conv2d(256, n_outputs, kernel_size=(1, 1), stride=(1, 1))
return model
parser = ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-5)
# parser.add_argument('--dataset', type=str, default='TraversabilityImages51')
parser.add_argument('--dataset', type=str, default='Rellis3DImages')
parser.add_argument('--architecture', type=str, default='fcn_resnet50')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--img_size', nargs='+', default=(320, 512))
parser.add_argument('--n_epochs', type=int, default=10)
parser.add_argument('--n_workers', type=int, default=os.cpu_count())
args = parser.parse_args()
args.img_size = tuple(args.img_size)
Dataset = eval('datasets.%s' % args.dataset)
if args.dataset == 'TraversabilityImages':
dataset = Dataset(crop_size=args.img_size)
length = len(dataset)
train_dataset, valid_dataset = torch.utils.data.random_split(dataset,
[int(0.8 * length), int(0.2 * length)],
generator=torch.Generator().manual_seed(42))
else:
train_dataset = Dataset(crop_size=args.img_size, split='train')
valid_dataset = Dataset(crop_size=args.img_size, split='val')
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers // 2)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=args.n_workers // 2)
# --------------Load and set model and optimizer-------------------------------------
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if args.dataset == 'TraversabilityImages':
n_classes = 3
else:
n_classes = len(train_dataset.class_values)
n_inputs = train_dataset[0][0].shape[0]
model = create_model(args.architecture, n_inputs, n_classes, pretrained=False)
model = model.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) # Create adam optimizer
# ----------------Train--------------------------------------------------------------------------
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index
criterion_fn = smp.losses.DiceLoss(mode='multilabel', from_logits=True, ignore_index=255)
metric_fn = smp.utils.metrics.IoU(threshold=0.5, activation='softmax2d')
max_metric = -np.Inf
for e in tqdm(range(args.n_epochs)):
# train epoch
model = model.train()
for itr, sample in tqdm(enumerate(train_loader)):
images, labels = sample
images, labels = images.to(device), labels.to(device)
pred = model(images)['out'] # make prediction
optimizer.zero_grad()
loss = criterion_fn(pred, labels.long()) # Calculate loss
loss.backward() # Backpropagate loss
optimizer.step() # Apply gradient descent change to weight
# validation epoch
metrics = []
model = model.eval()
for itr, sample in tqdm(enumerate(valid_loader)):
images, labels = sample
images, labels = images.to(device), labels.to(device)
with torch.no_grad():
pred = model(images)['out'] # make prediction
metric1 = metric_fn(pred, labels)
metrics.append(metric1.cpu().numpy())
metric = np.mean(metrics)
# save better model
if max_metric < metric: # Save model weight
max_metric = metric
name = '%s_lr_%g_bs_%d_epoch_%d_%s_iou_%.2f.pth' % \
(args.architecture,
args.lr, args.batch_size, e, args.dataset, float(max_metric))
print("Saving Model:", name)
torch.save(model, os.path.join(os.path.dirname(__file__), name))
print("Epoch: %i" % e)
print('Train loss: %f' % loss.data.cpu().numpy())
print('Validation metric: %.3f' % metric)
if e == 60:
optimizer.param_groups[0]['lr'] /= 10.0
print('Decrease decoder learning rate to %f !' % optimizer.param_groups[0]['lr'])