diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py index 427572ba60..767eb65e0b 100644 --- a/monai/networks/nets/mednext.py +++ b/monai/networks/nets/mednext.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from torch.nn.functional import interpolate from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock @@ -57,7 +58,16 @@ class MedNeXt(nn.Module): decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2. bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2. kernel_size: kernel size for convolutions. Defaults to 7. - deep_supervision: whether to use deep supervision. Defaults to False. + deep_supervision: whether to use deep supervision. Defaults to ``False``. + If ``True``, in training mode, the forward function will output not only the final feature map + (from the `out_0` block), but also the feature maps that come from the intermediate up sample layers. + In order to unify the return type, all intermediate feature maps are interpolated into the same size + as the final feature map and stacked together (with a new dimension in the first axis) into one single tensor. + For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and + (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps + will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24). + When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss + one by one with the ground truth, then do a weighted average for all losses to achieve the final loss. use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False. blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2]. blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2. @@ -260,7 +270,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]: # Return output(s) if self.do_ds and self.training: - return (x, *ds_outputs[::-1]) + out_all = [x] + for feature_map in ds_outputs[::-1]: + out_all.append(interpolate(feature_map, x.shape[2:])) + return torch.stack(out_all, dim=1) else: return x diff --git a/tests/test_mednext.py b/tests/test_mednext.py index b4ba4f9939..4c715d9282 100644 --- a/tests/test_mednext.py +++ b/tests/test_mednext.py @@ -75,8 +75,10 @@ def test_shape(self, input_param, input_shape, expected_shape): with eval_mode(net): result = net(torch.randn(input_shape).to(device)) if input_param["deep_supervision"] and net.training: - assert isinstance(result, tuple) - self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + assert isinstance(result, torch.Tensor) + result = torch.unbind(result, dim=1) + for r in result: + self.assertEqual(r.shape, expected_shape, msg=str(input_param)) else: self.assertEqual(result.shape, expected_shape, msg=str(input_param)) @@ -87,8 +89,10 @@ def test_shape2(self, input_param, input_shape, expected_shape): net.train() result = net(torch.randn(input_shape).to(device)) if input_param["deep_supervision"]: - assert isinstance(result, tuple) - self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + assert isinstance(result, torch.Tensor) + result = torch.unbind(result, dim=1) + for r in result: + self.assertEqual(r.shape, expected_shape, msg=str(input_param)) else: assert isinstance(result, torch.Tensor) self.assertEqual(result.shape, expected_shape, msg=str(input_param))