Skip to content

Commit

Permalink
Refactor models + their Hydra configs (#142)
Browse files Browse the repository at this point in the history
* Bump torch and torchvision version

To get the latest version of LayerNorm

* Refactor U-Net related models, layers, blocks

Separate U-Net into encoder and decoder for more dynamic utilization and cleaner codes

* Update LightningModule after refactoring U-Net

Fix some bugs related to calling of arguments of U-Net

* Add Hydra config groups for encoder and decoder

* Update net Hydra configs to include encoder and decoder config groups

* Fix config interpolation bug related to model.net

* Update experiment_planner after refactoring model

Also fix bug when producing datamodule 3D config

* Allow partial instantiation of decoder in UNet

* Revert "Update LightningModule after refactoring U-Net"

This reverts commit 6119493.

* Store some attributes in UNet to be compatible with nnUNetLitModule

* Add resolver to get `in_channels` robustly from model/net

* Update PDUNet with reworked UNet

* Add resolver to automatically derive `dim` from `patch_size`

* Add resolver to automatically derive `num_stages` from kernels

* Update planner to avoid printing `dim` & `num_stages` in model config

* Use resolvers to automatically derive `dim` & `num_stages`

* Fix spyrit_net config

* Fix typing and docstrings

* Fix interpolation error in unwrap_2d data config

* Add resolver to automatically determine whether to do batch dice

* Update model configs

* Fix typing + Remove overriding of `soft_dice_kwargs` in 3D model configs

`batch_dice` is now directly derived from the patch size using resolver

* Remove `dim` and `num_stages` from some forgotten model configs

* Set `print_width` in prettier hook to 99

* Reformat yaml
  • Loading branch information
HangJung97 authored Nov 14, 2023
1 parent bd54c50 commit 4903d9a
Show file tree
Hide file tree
Showing 44 changed files with 2,735 additions and 1,331 deletions.
7 changes: 2 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ repos:
hooks:
- id: prettier
types: [yaml]
args: ["--print-width=99"]

# shell scripts linter
- repo: https://github.com/shellcheck-py/shellcheck-py
Expand Down Expand Up @@ -116,8 +117,4 @@ repos:
- id: nbqa-isort
args: ["--profile=black"]
- id: nbqa-flake8
args:
[
"--extend-ignore=E203,E402,E501,F401,F841",
"--exclude=logs/*,data/*",
]
args: ["--extend-ignore=E203,E402,E501,F401,F841", "--exclude=logs/*,data/*"]
2 changes: 1 addition & 1 deletion ascent/configs/datamodule/nnunet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ data_dir: ${paths.data_dir}
batch_size: ???
dataset_name: ???
patch_size: ${model.net.patch_size}
in_channels: ${model.net.in_channels}
in_channels: ${get_in_channels_from_model_net:${model.net}}
# Whether to do dummy 2D data augmentation during the training of 3D-UNet
do_dummy_2D_data_aug: ???
fold: ${fold}
Expand Down
2 changes: 1 addition & 1 deletion ascent/configs/datamodule/unwrap_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ defaults:

dataset_name: UNWRAPV2
patch_size: ${model.net.denoiser.patch_size}
in_channels: ${model.net.denoiser.in_channels}
in_channels: ${get_in_channels_from_model_net:${model.net.denoiser}}
separate_transform: True
exclude_Dpower: False
10 changes: 6 additions & 4 deletions ascent/configs/model/camus_challenge_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 4
patch_size: [640, 1024]
kernels: [[3, 1], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [1, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]]
encoder:
in_channels: 1
kernels: [[3, 1], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [1, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]]
decoder:
num_classes: 4
13 changes: 6 additions & 7 deletions ascent/configs/model/cardinal+ted_3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 3
patch_size: [320, 256, 24]
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1], [2, 2, 1]]

loss:
soft_dice_kwargs: { "batch_dice": False, "smooth": 1e-05, "do_bg": False }
encoder:
in_channels: 1
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1], [2, 2, 1]]
decoder:
num_classes: 3
10 changes: 6 additions & 4 deletions ascent/configs/model/cardinal_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 3
patch_size: [640, 512]
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]]
encoder:
in_channels: 1
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]]
decoder:
num_classes: 3
13 changes: 6 additions & 7 deletions ascent/configs/model/cardinal_3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 3
patch_size: [320, 256, 24]
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1], [2, 2, 1]]

loss:
soft_dice_kwargs: { "batch_dice": False, "smooth": 1e-05, "do_bg": False }
encoder:
in_channels: 1
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1], [2, 2, 1]]
decoder:
num_classes: 3
18 changes: 10 additions & 8 deletions ascent/configs/model/cardinal_convnext_3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ defaults:
- override scheduler: coslr

net:
in_channels: 1
num_classes: 3
patch_size: [320, 256, 24]
convnext_kernels: 7
decoder_kernels: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[4, 4, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

loss:
soft_dice_kwargs: { "batch_dice": False, "smooth": 1e-05, "do_bg": False }
encoder:
in_channels: 1
stem_kernel: 7
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1], [2, 2, 1]]
num_conv_per_stage: 2
num_features_per_stage: [32, 64, 128, 256, 380, 380]
expansion_rate: 2
decoder:
num_classes: 3

name: ConvNeXt
10 changes: 6 additions & 4 deletions ascent/configs/model/dealias_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 3
patch_size: [40, 192]
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [1, 2]]
encoder:
in_channels: 1
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [1, 2]]
decoder:
num_classes: 3
10 changes: 6 additions & 4 deletions ascent/configs/model/dealiasc_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ defaults:
- nnunet

net:
in_channels: 2
num_classes: 3
patch_size: [40, 192]
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [1, 2]]
encoder:
in_channels: 2
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [1, 2]]
decoder:
num_classes: 3
10 changes: 6 additions & 4 deletions ascent/configs/model/dealiasm_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 3
patch_size: [40, 192]
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [1, 2]]
encoder:
in_channels: 1
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [1, 2]]
decoder:
num_classes: 3
13 changes: 7 additions & 6 deletions ascent/configs/model/dealiasreg_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ defaults:
_target_: ascent.models.nnunet_reg_module.nnUNetRegLitModule

net:
in_channels: 2
num_classes: 1
patch_size: [40, 192]
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [1, 2]]
deep_supervision: False
out_seg_bias: True
encoder:
in_channels: 2
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [1, 2]]
decoder:
num_classes: 1
deep_supervision: False

name: nnUNetReg
5 changes: 4 additions & 1 deletion ascent/configs/model/loss/dice_ce.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
_target_: ascent.utils.loss_functions.dice_loss.DC_and_CE_loss

soft_dice_kwargs: { "batch_dice": True, "smooth": 1e-5, "do_bg": False }
soft_dice_kwargs:
batch_dice: ${do_batch_dice:${model.net.patch_size}}
smooth: 1e-5
do_bg: False
ce_kwargs: {}
weight_ce: 1
weight_dice: 1
10 changes: 6 additions & 4 deletions ascent/configs/model/myosaiq+emidec_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 3
patch_size: [96, 80]
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2]]
encoder:
in_channels: 1
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2]]
decoder:
num_classes: 3
13 changes: 6 additions & 7 deletions ascent/configs/model/myosaiq+emidec_3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 3
patch_size: [96, 80, 20]
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1]]

loss:
soft_dice_kwargs: { "batch_dice": False, "smooth": 1e-05, "do_bg": False }
encoder:
in_channels: 1
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1]]
decoder:
num_classes: 3
10 changes: 6 additions & 4 deletions ascent/configs/model/myosaiq_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 5
patch_size: [96, 80]
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2]]
encoder:
in_channels: 1
kernels: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
strides: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2]]
decoder:
num_classes: 5
13 changes: 6 additions & 7 deletions ascent/configs/model/myosaiq_3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ defaults:
- nnunet

net:
in_channels: 1
num_classes: 5
patch_size: [96, 80, 20]
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1]]

loss:
soft_dice_kwargs: { "batch_dice": False, "smooth": 1e-05, "do_bg": False }
encoder:
in_channels: 1
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1]]
decoder:
num_classes: 5
18 changes: 10 additions & 8 deletions ascent/configs/model/myosaiq_convnext_3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ defaults:
- override scheduler: coslr

net:
in_channels: 1
num_classes: 5
patch_size: [96, 80, 20]
convnext_kernels: 7
decoder_kernels: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[4, 4, 1], [2, 1, 1], [2, 2, 2], [2, 2, 2]]

loss:
soft_dice_kwargs: { "batch_dice": False, "smooth": 1e-05, "do_bg": False }
encoder:
in_channels: 1
stem_kernel: 7
kernels: [[3, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides: [[1, 1, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2], [2, 2, 1]]
num_conv_per_stage: 2
num_features_per_stage: [32, 64, 128, 256, 380]
expansion_rate: 2
decoder:
num_classes: 5

name: ConvNeXt
20 changes: 3 additions & 17 deletions ascent/configs/model/net/convnext.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
_target_: ascent.models.components.convnext.ConvNeXt

in_channels: ???
num_classes: ???
patch_size: ???
convnext_kernels: 7
decoder_kernels: ???
strides: [[4, 4, 4], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
depths: [2, 3, 3, 9, 3]
filters: [32, 96, 192, 384, 768]
drop_path_rate: 0
layer_scale_init_value: 1e-6
encoder_normalization_layer: "layer"
decoder_normalization_layer: "instance"
negative_slope: 1e-2
deep_supervision: True
out_seg_bias: False
defaults:
- unet
- override encoder: convnext
9 changes: 9 additions & 0 deletions ascent/configs/model/net/decoder/unet_decoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: ascent.models.components.decoders.unet_decoder.UNetDecoder

_partial_: true
num_classes: ???
num_conv_per_stage: 2
output_conv_bias: True
deep_supervision: True
attention: False
initialization: null
23 changes: 23 additions & 0 deletions ascent/configs/model/net/encoder/convnext.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
_target_: ascent.models.components.encoders.convnext.ConvNeXt

in_channels: ???
num_stages: ${get_num_stages_from_kernels:${model.net.encoder.kernels}}
dim: ???
stem_kernel: 7
kernels: ???
strides: ???
num_conv_per_stage: [3, 3, 9, 3]
num_features_per_stage: [96, 192, 384, 768]
conv_bias: True
expansion_rate: 4
stochastic_depth_p: 0
layer_scale_init_value: 1e-6
conv_kwargs: null
norm_layer: "group"
norm_kwargs: null
activation: "gelu"
activation_kwargs: null
drop_block: False
drop_kwargs: None
return_skip: True
initialization: "trunc_normal"
22 changes: 22 additions & 0 deletions ascent/configs/model/net/encoder/unet_encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_target_: ascent.models.components.encoders.unet_encoder.UNetEncoder

in_channels: ???
num_stages: ${get_num_stages_from_kernels:${model.net.encoder.kernels}}
dim: ???
kernels: ???
strides: ???
start_features: 32
num_conv_per_stage: 2
conv_bias: True
conv_kwargs: null
pooling: "stride"
adaptive_pooling: False
norm_layer: "instance"
norm_kwargs: null
activation: "leakyrelu"
activation_kwargs: { "inplace": True }
drop_block: False
drop_kwargs: None
residual: False
return_skip: True
initialization: "kaiming_normal"
Loading

0 comments on commit 4903d9a

Please sign in to comment.