-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathconfig.py
69 lines (54 loc) · 2 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# -*- coding: utf-8 -*-
# @Time : 2020-02-26 17:54
# @Author : Zonas
# @Email : [email protected]
# @File : config.py
"""
"""
import os
class UNetConfig:
def __init__(self,
epochs = 100, # Number of epochs
batch_size = 2, # Batch size
validation = 10.0, # Percent of the data that is used as validation (0-100)
out_threshold = 0.5,
optimizer='SGD',
lr = 0.0001, # learning rate
lr_decay_milestones = [20, 50],
lr_decay_gamma = 0.9,
weight_decay=1e-8,
momentum=0.9,
nesterov=True,
n_channels = 3, # Number of channels in input images
n_classes = 3, # Number of classes in the segmentation
scale = 1, # Downscaling factor of the images
load = False, # Load model from a .pth file
save_cp = True,
model='NestedUNet',
bilinear = True,
deepsupervision = True,
):
super(UNetConfig, self).__init__()
self.images_dir = './data/images'
self.masks_dir = './data/masks'
self.checkpoints_dir = './data/checkpoints'
self.epochs = epochs
self.batch_size = batch_size
self.validation = validation
self.out_threshold = out_threshold
self.optimizer = optimizer
self.lr = lr
self.lr_decay_milestones = lr_decay_milestones
self.lr_decay_gamma = lr_decay_gamma
self.weight_decay = weight_decay
self.momentum = momentum
self.nesterov = nesterov
self.n_channels = n_channels
self.n_classes = n_classes
self.scale = scale
self.load = load
self.save_cp = save_cp
self.model = model
self.bilinear = bilinear
self.deepsupervision = deepsupervision
os.makedirs(self.checkpoints_dir, exist_ok=True)