-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
- Loading branch information
1 parent
36f5c4b
commit 51134d3
Showing
10 changed files
with
56 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,11 +47,11 @@ transforms: | |
keys: '@image' | ||
- _target_: ScaleIntensityd | ||
keys: '@image' | ||
preprocessing: | ||
|
||
preprocessing: | ||
_target_: Compose | ||
transforms: $@transforms | ||
|
||
dataset: | ||
_target_: Dataset | ||
data: '@data_dicts' | ||
|
@@ -62,10 +62,10 @@ dataloader: | |
dataset: '@dataset' | ||
batch_size: '@batch_size' | ||
num_workers: '@num_workers' | ||
|
||
# should be replaced with other inferer types if training process is different for your network | ||
inferer: | ||
_target_: SimpleInferer | ||
_target_: SimpleInferer | ||
|
||
# transform to apply to data from network to be suitable for loss function and validation | ||
postprocessing: | ||
|
@@ -86,8 +86,8 @@ postprocessing: | |
output_dtype: $None | ||
output_postfix: '' | ||
resample: false | ||
separate_folder: true | ||
separate_folder: true | ||
|
||
# inference handlers to load checkpoint, gather statistics | ||
handlers: | ||
- _target_: CheckpointLoader | ||
|
@@ -98,7 +98,7 @@ handlers: | |
- _target_: StatsHandler | ||
name: null # use engine.logger as the Logger object to log to | ||
output_transform: '$lambda x: None' | ||
|
||
# engine for running inference, ties together objects defined above and has metric definitions | ||
evaluator: | ||
_target_: SupervisedEvaluator | ||
|
@@ -109,5 +109,5 @@ evaluator: | |
postprocessing: '@postprocessing' | ||
val_handlers: '@handlers' | ||
|
||
run: | ||
run: | ||
- [email protected]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,4 +56,4 @@ | |
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
# define various paths | ||
bundle_root: . # root directory of the bundle | ||
ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting | ||
dataset_dir: $@bundle_root + '/test_data' # where data is coming from | ||
dataset_dir: $@bundle_root + '/test_data' # where data is coming from | ||
output_dir: './outputs' # directory to store images to if save_pred is true | ||
|
||
# network definition, this could be parameterised by pre-defined values or on the command line | ||
|
@@ -53,11 +53,11 @@ transforms: | |
keys: '@both_keys' | ||
- _target_: ScaleIntensityd | ||
keys: '@image' | ||
preprocessing: | ||
|
||
preprocessing: | ||
_target_: Compose | ||
transforms: $@transforms | ||
|
||
dataset: | ||
_target_: Dataset | ||
data: '@data_dicts' | ||
|
@@ -68,10 +68,10 @@ dataloader: | |
dataset: '@dataset' | ||
batch_size: '@batch_size' | ||
num_workers: '@num_workers' | ||
|
||
# should be replaced with other inferer types if training process is different for your network | ||
inferer: | ||
_target_: SimpleInferer | ||
_target_: SimpleInferer | ||
|
||
# transform to apply to data from network to be suitable for loss function and validation | ||
postprocessing: | ||
|
@@ -93,8 +93,8 @@ postprocessing: | |
output_dtype: $None | ||
output_postfix: '' | ||
resample: false | ||
separate_folder: true | ||
separate_folder: true | ||
|
||
# inference handlers to load checkpoint, gather statistics | ||
handlers: | ||
- _target_: CheckpointLoader | ||
|
@@ -105,7 +105,7 @@ handlers: | |
- _target_: StatsHandler | ||
name: null # use engine.logger as the Logger object to log to | ||
output_transform: '$lambda x: None' | ||
|
||
# engine for running inference, ties together objects defined above and has metric definitions | ||
evaluator: | ||
_target_: SupervisedEvaluator | ||
|
@@ -119,7 +119,7 @@ evaluator: | |
include_background: false | ||
output_transform: $monai.handlers.from_engine([@pred, @label]) | ||
val_handlers: '@handlers' | ||
run: | ||
|
||
run: | ||
- [email protected]() | ||
- '$print(''Per-image Dice:\n'',@evaluator.state.metric_details[''val_mean_dice''].cpu().numpy())' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for | ||
# This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for | ||
# multi-GPU runs. Many definitions in this file are duplicated across other files for compatibility with MONAI | ||
# Label, eg. network_def, but ideally these would be in a common.yaml file used in conjunction with this one | ||
# or the other config files for testing or inference. | ||
# or the other config files for testing or inference. | ||
|
||
imports: | ||
- $import os | ||
|
@@ -34,7 +34,7 @@ device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
# define various paths | ||
bundle_root: . # root directory of the bundle | ||
ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting | ||
dataset_dir: $@bundle_root + '/train_data' # where data is coming from | ||
dataset_dir: $@bundle_root + '/train_data' # where data is coming from | ||
results_dir: $@bundle_root + '/results' # where results are being stored to | ||
# a new output directory is chosen using a timestamp for every invocation | ||
output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')' | ||
|
@@ -65,7 +65,7 @@ base_transforms: | |
image_only: true | ||
- _target_: EnsureChannelFirstd | ||
keys: '@both_keys' | ||
|
||
# these are the random and regularising transforms used only for training | ||
train_transforms: | ||
- _target_: RandAxisFlipd | ||
|
@@ -80,34 +80,34 @@ train_transforms: | |
std: 0.05 | ||
- _target_: ScaleIntensityd | ||
keys: '@image' | ||
|
||
# these are used for validation data so no randomness | ||
val_transforms: | ||
- _target_: ScaleIntensityd | ||
keys: '@image' | ||
|
||
# define the Compose objects for training and validation | ||
|
||
preprocessing: | ||
preprocessing: | ||
_target_: Compose | ||
transforms: $@base_transforms + @train_transforms | ||
val_preprocessing: | ||
|
||
val_preprocessing: | ||
_target_: Compose | ||
transforms: $@base_transforms + @val_transforms | ||
|
||
# define the datasets for training and validation | ||
|
||
train_dataset: | ||
_target_: Dataset | ||
data: '@train_sub' | ||
transform: '@preprocessing' | ||
|
||
val_dataset: | ||
_target_: Dataset | ||
data: '@val_sub' | ||
transform: '@val_preprocessing' | ||
|
||
# define the dataloaders for training and validation | ||
|
||
train_dataloader: | ||
|
@@ -116,30 +116,30 @@ train_dataloader: | |
batch_size: '@batch_size' | ||
repeats: '@num_substeps' | ||
num_workers: '@num_workers' | ||
|
||
val_dataloader: | ||
_target_: DataLoader # faster transforms probably won't benefit from threading | ||
dataset: '@val_dataset' | ||
batch_size: '@batch_size' | ||
num_workers: '@num_workers' | ||
|
||
# Simple Dice loss configured for multi-class segmentation, for binary segmentation | ||
# use include_background==True and sigmoid==True instead of these values | ||
lossfn: | ||
_target_: DiceLoss | ||
include_background: true # if your segmentations are relatively small it might help for this to be false | ||
to_onehot_y: true | ||
softmax: true | ||
|
||
# hyperparameters could be added for other arguments of this class | ||
optimizer: | ||
_target_: torch.optim.Adam | ||
params: [email protected]() | ||
lr: '@learning_rate' | ||
|
||
# should be replaced with other inferer types if training process is different for your network | ||
inferer: | ||
_target_: SimpleInferer | ||
_target_: SimpleInferer | ||
|
||
# transform to apply to data from network to be suitable for loss function and validation | ||
postprocessing: | ||
|
@@ -170,7 +170,7 @@ val_handlers: | |
epoch_level: false | ||
save_key_metric: true | ||
key_metric_name: val_mean_dice # save the checkpoint when this value improves | ||
|
||
# engine for running validation, ties together objects defined above and has metric definitions | ||
evaluator: | ||
_target_: SupervisedEvaluator | ||
|
@@ -192,12 +192,12 @@ evaluator: | |
_target_: MeanAbsoluteError | ||
output_transform: $monai.handlers.from_engine([@pred, @label]) | ||
val_handlers: '@val_handlers' | ||
|
||
# gathers the loss and validation values for each iteration, referred to by CheckpointSaver so defined separately | ||
metriclogger: | ||
metriclogger: | ||
_target_: MetricLogger | ||
evaluator: '@evaluator' | ||
evaluator: '@evaluator' | ||
|
||
handlers: | ||
- '@metriclogger' | ||
- _target_: CheckpointLoader | ||
|
@@ -224,7 +224,7 @@ handlers: | |
output_transform: $monai.handlers.from_engine(['loss'], first=True) # log loss value | ||
- _target_: LogfileHandler # log outputs from the training engine | ||
output_dir: '@output_dir' | ||
|
||
# engine for training, ties values defined above together into the main engine for the training process | ||
trainer: | ||
_target_: SupervisedTrainer | ||
|
@@ -238,6 +238,6 @@ trainer: | |
postprocessing: '@postprocessing' | ||
key_train_metric: null | ||
train_handlers: '@handlers' | ||
run: | ||
|
||
run: | ||
- [email protected]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,4 +16,3 @@ python -m monai.bundle run \ | |
--config_file "$BUNDLE/configs/inference.yaml" \ | ||
--bundle_root "$BUNDLE" \ | ||
$@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,4 +16,3 @@ python -m monai.bundle run \ | |
--config_file "$BUNDLE/configs/test.yaml" \ | ||
--bundle_root "$BUNDLE" \ | ||
$@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,4 +16,3 @@ python -m monai.bundle run \ | |
--config_file "$BUNDLE/configs/train.yaml" \ | ||
--bundle_root "$BUNDLE" \ | ||
$@ | ||
|