-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfig.yml
109 lines (80 loc) · 2.88 KB
/
config.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# version of the model
version: 1
# description of the training task
task:
# number of classes for the model to classify, 0 is the background class
num_classes: 7
# id of the class to their corresponding class name
id_to_name:
1: "apo-ferritin"
2: "beta-amylase"
3: "beta-galactosidase"
4: "ribosome"
5: "thyroglobulin"
6: "virus-like-particle"
# training seed for reproducible training
seed: 42
# "batch size": number of samples to draw from each tomogram, e.g. 24 means each tomogram will gives 24 crops for the model to train on
# num_samples: 24
# size for each side of the cropped cube
patch_size: 128
# experiments to use in the training process
experiments: ["TS_6_4", "TS_6_6", "TS_69_2", "TS_73_6", "TS_86_3", "TS_99_9", "TS_5_4"]
# this section is only used during the pretraining step
pretrain:
# validation sets of the pretraining step
val_ids: ['TS_5_4', 'TS_6_4', 'TS_69_2', 'TS_6_6', 'TS_73_6', 'TS_86_3', 'TS_99_9']
val_data_dir: ./data/train
# synthetic dataset ids for the pretraining step
synthetic_ids: ["TS_0", "TS_1", "TS_10", "TS_11", "TS_12", "TS_13"]
synthetic_data_dir: ./data/synthetic-data
# model definitions, this follows https://docs.monai.io/en/stable/networks.html#unet
model:
spatial_dims: 3
in_channels: 1
out_channels: 7
# channels: [32, 64, 128, 128]
# strides: [2, 2, 1]
# num_res_units: 1 # Increased residual units
# dropout: 0.3
# training hyperparameters
train:
# max number of training epochs
# num_epochs: 128
# number of tomograms to sample per batch, the final number of crops for each model's gradient step will be batch_size * num_samples
batch_size: 1
# number of data loader workers
num_workers: 0
# validation sets for the training process
# val_ids: ['TS_5_4']
# data directory to read for the training process
data_dir: ./data/train
# if provided, training will resume from that checkpoint
pretrain_ckpt: null
validation:
# validation dataloader batch size and number of workers
# batch_size: 24
num_workers: 0
optimizer:
# AdamW learning rate
lr: 1e-3
# patience for early stopping and the ReduceLROnPlateau scheduler
# patience: 30
# where to write logs to
logging:
dir: ./logs
# where to write checkpoints to
checkpoints:
dir: ./checkpoints
# inference config
inference:
# path to the checkpoint of tiny-unet soup
tiny_unet: ./checkpoints/tiny-unet/tiny_unet_soup.pth
# path to the checkpoint of medium-unet soup
medium_unet: ./checkpoints/medium-unet/medium_unet_soup.pth
# path to the checkpoint of large_unet soup
large_unet: ./sumo/checkpoint_finetune_denoised/checkpoint_finetune_denoised_soup.pth
# path to the checkpoint of other tomo large_unet soup
large_unet_other_tomo: ./sumo/checkpoint_finetune_all/checkpoint_finetune_all_soup.pth
# device id on which to load the models and perform the inference
device_id: 0