forked from Project-MONAI/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add classification template (Project-MONAI#533)
Part of tutorial#1456 ### Description Add a classification template ### Status **Ready** ### Please ensure all the checkboxes: <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Codeformat tests passed locally by running `./runtests.sh --codeformat`. - [ ] In-line docstrings updated. - [ ] Update `version` and `changelog` in `metadata.json` if changing an existing bundle. - [ ] Please ensure the naming rules in config files meet our requirements (please refer to: `CONTRIBUTING.md`). - [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy` are correct in `metadata.json`. - [ ] Descriptions should be consistent with the content, such as `eval_metrics` of the provided weights and TorchScript modules. - [ ] Files larger than 25MB are excluded and replaced by providing download links in `large_file.yml`. - [ ] Avoid using path that contains personal information within config files (such as use `/home/your_name/` for `"bundle_root"`). --------- Signed-off-by: KumoLiu <[email protected]>
- Loading branch information
Showing
12 changed files
with
827 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 MONAI Consortium | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics. | ||
|
||
# these transforms are used for inference to load and regularise inputs | ||
transforms: | ||
- _target_: AsDiscreted | ||
keys: ['@pred', '@label'] | ||
argmax: [true, false] | ||
to_onehot: '@num_classes' | ||
- _target_: ToTensord | ||
keys: ['@pred', '@label'] | ||
device: '@device' | ||
|
||
postprocessing: | ||
_target_: Compose | ||
transforms: $@transforms | ||
|
||
# inference handlers to load checkpoint, gather statistics | ||
val_handlers: | ||
- _target_: CheckpointLoader | ||
_disabled_: $not os.path.exists(@ckpt_path) | ||
load_path: '@ckpt_path' | ||
load_dict: | ||
model: '@network' | ||
- _target_: StatsHandler | ||
name: null # use engine.logger as the Logger object to log to | ||
output_transform: '$lambda x: None' | ||
- _target_: MetricsSaver | ||
save_dir: '@output_dir' | ||
metrics: ['val_accuracy'] | ||
metric_details: ['val_accuracy'] | ||
batch_transform: "$lambda x: [xx['image'].meta for xx in x]" | ||
summary_ops: "*" | ||
|
||
initialize: | ||
- "$monai.utils.set_determinism(seed=123)" | ||
- "$setattr(torch.backends.cudnn, 'benchmark', True)" | ||
run: | ||
- [email protected]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics. | ||
|
||
imports: | ||
- $import os | ||
- $import json | ||
- $import torch | ||
- $import glob | ||
|
||
# pull out some constants from MONAI | ||
image: $monai.utils.CommonKeys.IMAGE | ||
label: $monai.utils.CommonKeys.LABEL | ||
pred: $monai.utils.CommonKeys.PRED | ||
|
||
# hyperparameters for you to modify on the command line | ||
batch_size: 1 # number of images per batch | ||
num_workers: 0 # number of workers to generate batches with | ||
num_classes: 4 # number of classes in training data which network should predict | ||
device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
|
||
# define various paths | ||
bundle_root: . # root directory of the bundle | ||
ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting | ||
dataset_dir: $@bundle_root + '/data/test_data' # where data is coming from | ||
|
||
# network definition, this could be parameterised by pre-defined values or on the command line | ||
network_def: | ||
_target_: DenseNet121 | ||
spatial_dims: 2 | ||
in_channels: 1 | ||
out_channels: '@num_classes' | ||
network: $@network_def.to(@device) | ||
|
||
# list all niftis in the input directory | ||
test_json: "$@bundle_root+'/data/test_samples.json'" | ||
test_fp: "$open(@test_json,'r', encoding='utf8')" | ||
# load json file | ||
test_dict: "$json.load(@test_fp)" | ||
|
||
# these transforms are used for inference to load and regularise inputs | ||
transforms: | ||
- _target_: LoadImaged | ||
keys: '@image' | ||
- _target_: EnsureChannelFirstd | ||
keys: '@image' | ||
- _target_: ScaleIntensityd | ||
keys: '@image' | ||
|
||
preprocessing: | ||
_target_: Compose | ||
transforms: $@transforms | ||
|
||
dataset: | ||
_target_: Dataset | ||
data: '@test_dict' | ||
transform: '@preprocessing' | ||
|
||
dataloader: | ||
_target_: ThreadDataLoader # generate data ansynchronously from inference | ||
dataset: '@dataset' | ||
batch_size: '@batch_size' | ||
num_workers: '@num_workers' | ||
|
||
# should be replaced with other inferer types if training process is different for your network | ||
inferer: | ||
_target_: SimpleInferer | ||
|
||
# transform to apply to data from network to be suitable for validation | ||
postprocessing: | ||
_target_: Compose | ||
transforms: | ||
- _target_: Activationsd | ||
keys: '@pred' | ||
softmax: true | ||
- _target_: AsDiscreted | ||
keys: ['@pred', '@label'] | ||
argmax: [true, false] | ||
to_onehot: '@num_classes' | ||
- _target_: ToTensord | ||
keys: ['@pred', '@label'] | ||
device: '@device' | ||
|
||
# inference handlers to load checkpoint, gather statistics | ||
val_handlers: | ||
- _target_: CheckpointLoader | ||
_disabled_: $not os.path.exists(@ckpt_path) | ||
load_path: '@ckpt_path' | ||
load_dict: | ||
model: '@network' | ||
- _target_: StatsHandler | ||
name: null # use engine.logger as the Logger object to log to | ||
output_transform: '$lambda x: None' | ||
|
||
# engine for running inference, ties together objects defined above and has metric definitions | ||
evaluator: | ||
_target_: SupervisedEvaluator | ||
device: '@device' | ||
val_data_loader: '@dataloader' | ||
network: '@network' | ||
inferer: '@inferer' | ||
postprocessing: '@postprocessing' | ||
key_val_metric: | ||
val_accuracy: | ||
_target_: ignite.metrics.Accuracy | ||
output_transform: $monai.handlers.from_engine([@pred, @label]) | ||
additional_metrics: | ||
val_f1: # can have other metrics | ||
_target_: ConfusionMatrix | ||
metric_name: 'f1 score' | ||
output_transform: $monai.handlers.from_engine([@pred, @label]) | ||
val_handlers: '@val_handlers' | ||
|
||
initialize: | ||
- "$setattr(torch.backends.cudnn, 'benchmark', True)" | ||
run: | ||
- "[email protected]()" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[loggers] | ||
keys=root | ||
|
||
[handlers] | ||
keys=consoleHandler | ||
|
||
[formatters] | ||
keys=fullFormatter | ||
|
||
[logger_root] | ||
level=INFO | ||
handlers=consoleHandler | ||
|
||
[handler_consoleHandler] | ||
class=StreamHandler | ||
level=INFO | ||
formatter=fullFormatter | ||
args=(sys.stdout,) | ||
|
||
[formatter_fullFormatter] | ||
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
{ | ||
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json", | ||
"version": "0.0.1", | ||
"changelog": { | ||
"0.0.1": "Initial version" | ||
}, | ||
"monai_version": "1.3.0", | ||
"pytorch_version": "2.0.1", | ||
"numpy_version": "1.24.4", | ||
"optional_packages_version": { | ||
"pytorch-ignite": "0.4.12" | ||
}, | ||
"name": "Classification Template", | ||
"task": "Classification Template in 2D images", | ||
"description": "This is a template bundle for classifying in 2D, take this as a basis for your own bundles.", | ||
"authors": "Yun Liu", | ||
"copyright": "Copyright (c) 2023 MONAI Consortium", | ||
"network_data_format": { | ||
"inputs": { | ||
"image": { | ||
"type": "image", | ||
"format": "magnitude", | ||
"modality": "none", | ||
"num_channels": 1, | ||
"spatial_shape": [ | ||
128, | ||
128 | ||
], | ||
"dtype": "float32", | ||
"value_range": [], | ||
"is_patch_data": false, | ||
"channel_def": { | ||
"0": "image" | ||
} | ||
} | ||
}, | ||
"outputs": { | ||
"pred": { | ||
"type": "probabilities", | ||
"format": "classes", | ||
"num_channels": 4, | ||
"spatial_shape": [ | ||
1, | ||
4 | ||
], | ||
"dtype": "float32", | ||
"value_range": [ | ||
0, | ||
1, | ||
2, | ||
3 | ||
], | ||
"is_patch_data": false, | ||
"channel_def": { | ||
"0": "background", | ||
"1": "circle", | ||
"2": "triangle", | ||
"3": "rectangle" | ||
} | ||
} | ||
} | ||
} | ||
} |
37 changes: 37 additions & 0 deletions
37
models/classification_template/configs/multi_gpu_train.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# This file contains the changes to implement DDP training with the train.yaml config. | ||
|
||
device: "$torch.device('cuda:' + os.environ['LOCAL_RANK'])" # assumes GPU # matches rank # | ||
|
||
# wrap the network in a DistributedDataParallel instance, moving it to the chosen device for this process | ||
network: | ||
_target_: torch.nn.parallel.DistributedDataParallel | ||
module: $@network_def.to(@device) | ||
device_ids: ['@device'] | ||
find_unused_parameters: true | ||
|
||
train_sampler: | ||
_target_: DistributedSampler | ||
dataset: '@train_dataset' | ||
even_divisible: true | ||
shuffle: true | ||
|
||
train_dataloader#sampler: '@train_sampler' | ||
train_dataloader#shuffle: false | ||
|
||
val_sampler: | ||
_target_: DistributedSampler | ||
dataset: '@val_dataset' | ||
even_divisible: false | ||
shuffle: false | ||
|
||
val_dataloader#sampler: '@val_sampler' | ||
|
||
initialize: | ||
- $import torch.distributed as dist | ||
- $dist.init_process_group(backend='nccl') | ||
- $torch.cuda.set_device(@device) | ||
- $monai.utils.set_determinism(seed=123) # may want to choose a different seed or not do this here | ||
run: | ||
- '[email protected]()' | ||
finalize: | ||
- '$dist.is_initialized() and dist.destroy_process_group()' |
Oops, something went wrong.