Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API compatibility of DynUNet and MedNeXt in deep supervision mode #8315

Open
elitap opened this issue Jan 28, 2025 · 4 comments · May be fixed by #8316
Open

API compatibility of DynUNet and MedNeXt in deep supervision mode #8315

elitap opened this issue Jan 28, 2025 · 4 comments · May be fixed by #8316

Comments

@elitap
Copy link
Contributor

elitap commented Jan 28, 2025

Is your feature request related to a problem? Please describe.
If deep supervision is set to true the DynUNet model returns a tensor with the second dimension containing the upscaled intermediate outputs of the model, whereas the new MedNeXt implementation returns a tuple with the tensors of the intermediate outputs in their original low res.

Describe the solution you'd like
For compatibility purposes and ease of use, I would suggest using the DynUNet behavior also in MedNeXt. I think having the intermediate outputs all in the same output resolution makes them easier to use for loss function calculations. Plus, the DynUNet is already stable for some time whereas the MedNeXt is not.

Describe alternatives you've considered
Obviously, the DynUNet behavior could be changed to act identical to the MedNeXt, which might be slightly more memory efficient, but that would certainly break more existing code and it is less convenient if references need to be downscaled for loss calculations.

@elitap elitap linked a pull request Jan 28, 2025 that will close this issue
7 tasks
@surajpaib
Copy link
Contributor

The tuple case also exists for the SegResNetDS where deep supervision is available. The MedNeXt implementation contains this as it was developed in consistency with SegResNetDS.

Only the tuple option seems to make sense to me here. While upscaling is an option, I don't think it should be the default. What if the developer requires the intermediate maps to be operated on in the original dims? And as you pointed, this is a more mem. efficient option too.

I do agree that there should be consistency - given that, should DynUNet be ported to SegResNetDS and MedNeXt behaviour?

@elitap
Copy link
Contributor Author

elitap commented Jan 28, 2025

Hm didn't know about the SegResNetDs. I only run into this issue as I was exchanging a DynUnet by the MedNext (btw. thanks for the impelmentation :) ).

I agree the tuple version makse more sense, and the user should decide how to handle the different scales. However I found the DynUnet version is quite handy to use in loss calls, and didnt want to brake to much existing code. I think there is at least one tutorial using the DynUnet somewhere around. But I will look into adapting the DynUnet tomorrow.

The authors of the DynUnet argued in their args description with "a restriction of TorchScript" for their Tensor version with an additional dim. Do you know what they are referring to?
And out of interest, how does your code for the loss calc looks like, when using the tuple version? Do you upsacle the output or downsacle the refernce?

@surajpaib
Copy link
Contributor

I use the loss here, which does upsample as you mentioned

class DeepSupervisionLoss(_Loss):
"""
Wrapper class around the main loss function to accept a list of tensors returned from a deeply
supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels.
"""
def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: list[float] | None = None) -> None:
"""
Args:
loss: main loss instance, e.g DiceLoss().
weight_mode: {``"same"``, ``"exp"``, ``"two"``}
Specifies the weights calculation for each image level. Defaults to ``"exp"``.
- ``"same"``: all weights are equal to 1.
- ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc .
- ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
regardless of the weight_mode
"""
super().__init__()
self.loss = loss
self.weight_mode = weight_mode
self.weights = weights
self.interp_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest"
def get_weights(self, levels: int = 1) -> list[float]:
"""
Calculates weights for a given number of scale levels
"""
levels = max(1, levels)
if self.weights is not None and len(self.weights) >= levels:
weights = self.weights[:levels]
elif self.weight_mode == "same":
weights = [1.0] * levels
elif self.weight_mode == "exp":
weights = [max(0.5**l, 0.0625) for l in range(levels)]
elif self.weight_mode == "two":
weights = [1.0 if l == 0 else 0.5 for l in range(levels)]
else:
weights = [1.0] * levels
return weights
def get_loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Calculates a loss output accounting for differences in shapes,
and downsizing targets if necessary (using nearest neighbor interpolation)
Generally downsizing occurs for all level, except for the first (level==0)
"""
if input.shape[2:] != target.shape[2:]:
target = F.interpolate(target, size=input.shape[2:], mode=self.interp_mode)
return self.loss(input, target) # type: ignore[no-any-return]
def forward(self, input: Union[None, torch.Tensor, list[torch.Tensor]], target: torch.Tensor) -> torch.Tensor:
if isinstance(input, (list, tuple)):
weights = self.get_weights(levels=len(input))
loss = torch.tensor(0, dtype=torch.float, device=target.device)
for l in range(len(input)):
loss += weights[l] * self.get_loss(input[l].float(), target)
return loss
if input is None:
raise ValueError("input shouldn't be None.")
return self.loss(input.float(), target) # type: ignore[no-any-return]

While the end result is upsampling of the network intermediates - I do think this should be performed within the loss calc. module as opposed to the network itself.

For the torchscript Q, I am not very familiar with the DynUnet structure. Could you point me to where this is mentioned?

@elitap
Copy link
Contributor Author

elitap commented Jan 29, 2025

Its mentioned here:

In order to unify the return type (the restriction of TorchScript), all intermediate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants