diff --git a/ascent/models/nnunet_module.py b/ascent/models/nnunet_module.py index 1d3fddf..1f1cbe4 100644 --- a/ascent/models/nnunet_module.py +++ b/ascent/models/nnunet_module.py @@ -494,10 +494,14 @@ def predict( if len(image.shape) == 5: if len(self.patch_size) == 3: # Pad the last dimension to avoid 3D segmentation border artifacts + extra_pad = 0 + while image.shape[-1] <= 6: + image = pad(image, (1, 1, 0, 0, 0, 0), mode="reflect") + extra_pad += 1 image = pad(image, (6, 6, 0, 0, 0, 0), mode="reflect") pred = self.predict_3D_3Dconv_tiled(image, apply_softmax) # Inverse the padding after prediction - return pred[..., 6:-6] + return pred[..., (6 + extra_pad) : (-6 - extra_pad)] elif len(self.patch_size) == 2: return self.predict_3D_2Dconv_tiled(image, apply_softmax) else: