diff --git a/GANDLF/models/seg_modules/DownsamplingModule.py b/GANDLF/models/seg_modules/DownsamplingModule.py index ab4a54ff3..10eded967 100644 --- a/GANDLF/models/seg_modules/DownsamplingModule.py +++ b/GANDLF/models/seg_modules/DownsamplingModule.py @@ -39,7 +39,7 @@ def __init__( if act_kwargs is None: act_kwargs = {"negative_slope": 1e-2, "inplace": True} - self.in_0 = norm(output_channels, **norm_kwargs) + self.in_0 = norm(input_channels, **norm_kwargs) self.conv0 = conv(input_channels, output_channels, **conv_kwargs) @@ -49,7 +49,7 @@ def forward(self, x): """ Applies a downsampling operation to the input tensor. - [input -- > in --> lrelu --> ConvDS --> output] + [input --> in --> lrelu --> ConvDS --> output] Args: x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width) @@ -57,6 +57,6 @@ def forward(self, x): Returns: torch.Tensor: The output tensor, of shape (batch_size, output_channels, height // 2, width // 2). """ - x = self.act(self.in_0(self.conv0(x))) + x = self.conv0(self.act(self.in_0(x))) return x