From 3d74ee3ba9ce6698c4532ee114329c03a59f9ce5 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Fri, 19 Jan 2024 08:26:23 +0100 Subject: [PATCH] Remove cwn conv from vsharp and unet. Add mod conv --- direct/nn/conv/modulated_conv.py | 326 ++++++++++++++++++++++++++++ direct/nn/get_nn_model_config.py | 4 +- direct/nn/unet/config.py | 11 +- direct/nn/unet/unet_2d.py | 356 ++++++++++++++++++------------- direct/nn/unet/unet_3d.py | 78 +------ direct/nn/varsplitnet/config.py | 1 - direct/nn/vsharp/config.py | 5 +- direct/nn/vsharp/vsharp.py | 85 ++++++-- direct/types.py | 1 + 9 files changed, 620 insertions(+), 247 deletions(-) create mode 100644 direct/nn/conv/modulated_conv.py diff --git a/direct/nn/conv/modulated_conv.py b/direct/nn/conv/modulated_conv.py new file mode 100644 index 00000000..31ab712b --- /dev/null +++ b/direct/nn/conv/modulated_conv.py @@ -0,0 +1,326 @@ +# Copyright (c) DIRECT Contributors + +"""direct.nn.conv.modulated_conv module""" + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from direct.types import DirectEnum, IntOrTuple + +__all__ = ["ModConv2d", "ModConv2dBias", "ModConvTranspose2d"] + + +class ModConv2dBias(DirectEnum): + LEARNED = "learned" + PARAM = "param" + NONE = "none" + + +class ModConv2d(nn.Module): + """Modulated Conv2d module. + + If `modulation`=False and `bias`=ModConv2dBias.PARAM this is identical to nn.Conv2d: + + .. math :: + + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k=0}^{C_{\text{in}}-1} \text{weight}(C_{\text{out}_j}, k) * \text{input}(N_i, k) + + + + If `modulation`=True, this will compute: + + .. math :: + + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k=0}^{C_{\text{in}}-1} \text{MLP}(y(N_i))(C_{\text{out}_j}, k) \text{weight}(C_{\text{out}_j}, k) + * \text{input}(N_i, k). + + where :math`*` is a 2D cross-correlation. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: IntOrTuple, + modulation: bool = False, + stride: IntOrTuple = 1, + padding: IntOrTuple = 0, + dilation: IntOrTuple = 1, + bias: ModConv2dBias = ModConv2dBias.PARAM, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, + fc_bias: Optional[bool] = True, + ): + """Inits :class:`ModConv2d`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple of int + Size of the convolutional kernel. + modulation : bool, optional + If True, apply modulation using an MLP on the auxiliary variable `y`, by default False. + stride : int or tuple of int, optional + Stride of the convolution, by default 1. + padding : int or tuple of int, optional + Padding added to all sides of the input, by default 0. + dilation : int or tuple of int, optional + Spacing between kernel elements, by default 1. + bias : ModConv2dBias, optional + Type of bias, by default ModConv2dBias.PARAM. + aux_in_features : int, optional + Number of features in the auxiliary input variable `y`, by default None. + fc_hidden_features : int, optional + Number of hidden features in the modulation MLP, by default None. + fc_bias : bool, optional + If True, enable bias in the modulation MLP, by default True. + """ + super().__init__() + + self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + + k = math.sqrt(1 / (in_channels * self.kernel_size[0] * self.kernel_size[1])) + self.weight = nn.Parameter( + torch.FloatTensor(out_channels, in_channels, self.kernel_size[0], self.kernel_size[1]).uniform_(-k, k) + ) + + self.in_channels = in_channels + self.out_channels = out_channels + + self.modulation = modulation + self.aux_in_features = aux_in_features + self.fc_hidden_features = fc_hidden_features + self.fc_bias = fc_bias + + if modulation: + if aux_in_features is None: + raise ValueError(f"Value for `aux_in_features` can't be None with `modulation`=True.") + if fc_hidden_features is None: + raise ValueError(f"Value for `fc_hidden_features` can't be None with `modulation`=True.") + self.fc = nn.Sequential( + nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias), + nn.PReLU(), + nn.Linear(fc_hidden_features, in_channels * out_channels, bias=fc_bias), + ) + + if bias == ModConv2dBias.PARAM: + self.bias = nn.Parameter(torch.FloatTensor(out_channels).uniform_(-k, k)) + elif bias == ModConv2dBias.LEARNED: + if not modulation: + raise ValueError( + f"Bias can only be set to ModConv2dBias.LEARNED if `modulation`=True, " + f"but modulation is set to False." + ) + self.bias = nn.Sequential( + nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias), + nn.PReLU(), + nn.Linear(fc_hidden_features, out_channels, bias=fc_bias), + ) + else: + self.bias = None + + def __repr__(self): + """Representation of "class:`ModConv2d`.""" + return ( + f"ModConv2d(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"kernel_size={self.kernel_size}, modulation={self.modulation}, " + f"stride={self.stride}, padding={self.padding}, " + f"dilation={self.dilation}, bias={self.bias}, aux_in_features={self.aux_in_features}, " + f"fc_hidden_features={self.fc_hidden_features}, fc_bias={self.fc_bias})" + ) + + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of :class:`ModConv2d`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (N, `in_channels`, H, W). + y : torch.Tensor, optional + Auxiliary variable of shape (N, `aux_in_features`) to be used if `modulation` is set to True. Default: None + + Returns + ------- + torch.Tensor + Output tensor of shape (N, `out_channels`, H_out, W_out). + """ + if not self.modulation: + out = F.conv2d(x, self.weight, stride=self.stride, padding=self.padding, dilation=self.dilation) + else: + fc_out = self.fc(y).view(x.shape[0], self.out_channels, self.in_channels, 1, 1) + out = torch.cat( + [ + F.conv2d( + x[i : i + 1], + fc_out[i] * self.weight, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) + for i in range(x.shape[0]) + ], + 0, + ) + + if self.bias is not None: + if isinstance(self.bias, nn.parameter.Parameter): + bias = self.bias.view(1, -1, 1, 1) + else: + bias = self.bias(y).view(x.shape[0], -1, 1, 1) + out = out + bias + + return out + + +class ModConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: IntOrTuple, + modulation: bool = False, + stride: IntOrTuple = 1, + padding: IntOrTuple = 0, + dilation: IntOrTuple = 1, + bias: ModConv2dBias = ModConv2dBias.PARAM, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, + fc_bias: Optional[bool] = True, + ): + """Inits :class:`ModConvTranspose2d`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple of int + Size of the convolutional kernel. + modulation : bool, optional + If True, apply modulation using an MLP on the auxiliary variable `y`, by default False. + stride : int or tuple of int, optional + Stride of the convolution, by default 1. + padding : int or tuple of int, optional + Padding added to all sides of the input, by default 0. + dilation : int or tuple of int, optional + Spacing between kernel elements, by default 1. + bias : ModConv2dBias, optional + Type of bias, by default ModConv2dBias.PARAM. + aux_in_features : int, optional + Number of features in the auxiliary input variable `y`, by default None. + fc_hidden_features : int, optional + Number of hidden features in the modulation MLP, by default None. + fc_bias : bool, optional + If True, enable bias in the modulation MLP, by default True. + """ + super().__init__() + + self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + + k = math.sqrt(1 / (in_channels * self.kernel_size[0] * self.kernel_size[1])) + self.weight = nn.Parameter( + torch.FloatTensor(in_channels, out_channels, self.kernel_size[0], self.kernel_size[1]).uniform_(-k, k) + ) + + self.in_channels = in_channels + self.out_channels = out_channels + + self.modulation = modulation + self.aux_in_features = aux_in_features + self.fc_hidden_features = fc_hidden_features + self.fc_bias = fc_bias + + if modulation: + if aux_in_features is None: + raise ValueError(f"Value for `aux_in_features` can't be None with `modulation`=True.") + if fc_hidden_features is None: + raise ValueError(f"Value for `fc_hidden_features` can't be None with `modulation`=True.") + self.fc = nn.Sequential( + nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias), + nn.PReLU(), + nn.Linear(fc_hidden_features, in_channels * out_channels, bias=fc_bias), + ) + + if bias == ModConv2dBias.PARAM: + self.bias = nn.Parameter(torch.FloatTensor(out_channels).uniform_(-k, k)) + elif bias == ModConv2dBias.LEARNED: + if not modulation: + raise ValueError( + f"Bias can only be set to ModConv2dBias.LEARNED if `modulation`=True, " + f"but modulation is set to False." + ) + self.bias = nn.Sequential( + nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias), + nn.PReLU(), + nn.Linear(fc_hidden_features, out_channels, bias=fc_bias), + ) + else: + self.bias = None + + def __repr__(self): + """Representation of "class:`ModConvTranspose2d`.""" + return ( + f"ModConvTranspose2d(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"kernel_size={self.kernel_size}, modulation={self.modulation}, " + f"stride={self.stride}, padding={self.padding}, " + f"dilation={self.dilation}, bias={self.bias}, aux_in_features={self.aux_in_features}, " + f"fc_hidden_features={self.fc_hidden_features}, fc_bias={self.fc_bias})" + ) + + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of :class:`ModConvTranspose2d`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (N, `in_channels`, H, W). + y : torch.Tensor, optional + Auxiliary variable of shape (N, `aux_in_features`) to be used if `modulation` is set to True. Default: None + + Returns + ------- + torch.Tensor + Output tensor of shape (N, `out_channels`, H_out, W_out). + """ + if not self.modulation: + out = F.conv_transpose2d(x, self.weight, stride=self.stride, padding=self.padding, dilation=self.dilation) + else: + fc_out = self.fc(y).view(x.shape[0], self.in_channels, self.out_channels, 1, 1) + out = torch.cat( + [ + F.conv_transpose2d( + x[i : i + 1], + fc_out[i] * self.weight, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) + for i in range(x.shape[0]) + ], + 0, + ) + + if self.bias is not None: + if isinstance(self.bias, nn.parameter.Parameter): + bias = self.bias.view(1, -1, 1, 1) + else: + bias = self.bias(y).view(x.shape[0], -1, 1, 1) + out = out + bias + + return out diff --git a/direct/nn/get_nn_model_config.py b/direct/nn/get_nn_model_config.py index 69257699..f2b74977 100644 --- a/direct/nn/get_nn_model_config.py +++ b/direct/nn/get_nn_model_config.py @@ -42,7 +42,9 @@ def _get_model_config( "num_filters": kwargs.get("unet_num_filters", 32), "num_pool_layers": kwargs.get("unet_num_pool_layers", 4), "dropout_probability": kwargs.get("unet_dropout", 0.0), - "cwn_conv": kwargs.get("cwn_conv", False), + "modulation": kwargs.get("modulation", False), + "aux_in_features": kwargs.get("aux_in_features", None), + "fc_hidden_features": kwargs.get("fc_hidden_features", None), } ) elif model_architecture_name == "resnet": diff --git a/direct/nn/unet/config.py b/direct/nn/unet/config.py index e7309d84..9098f469 100644 --- a/direct/nn/unet/config.py +++ b/direct/nn/unet/config.py @@ -12,7 +12,9 @@ class UnetModel2dConfig(ModelConfig): num_filters: int = 16 num_pool_layers: int = 4 dropout_probability: float = 0.0 - cwn_conv: bool = False + modulation: bool = False + aux_in_features: Optional[int] = None + fc_hidden_features: Optional[int] = None class NormUnetModel2dConfig(ModelConfig): @@ -22,7 +24,9 @@ class NormUnetModel2dConfig(ModelConfig): num_pool_layers: int = 4 dropout_probability: float = 0.0 norm_groups: int = 2 - cwn_conv: bool = False + modulation: bool = False + aux_in_features: Optional[int] = None + fc_hidden_features: Optional[int] = None @dataclass @@ -32,7 +36,6 @@ class UnetModel3dConfig(ModelConfig): num_filters: int = 16 num_pool_layers: int = 4 dropout_probability: float = 0.0 - cwn_conv: bool = False class NormUnetModel3dConfig(ModelConfig): @@ -42,7 +45,6 @@ class NormUnetModel3dConfig(ModelConfig): num_pool_layers: int = 4 dropout_probability: float = 0.0 norm_groups: int = 2 - cwn_conv: bool = False @dataclass @@ -50,7 +52,6 @@ class Unet2dConfig(ModelConfig): num_filters: int = 16 num_pool_layers: int = 4 dropout_probability: float = 0.0 - cwn_conv: bool = False skip_connection: bool = False normalized: bool = False image_initialization: str = "zero_filled" diff --git a/direct/nn/unet/unet_2d.py b/direct/nn/unet/unet_2d.py index d5d78382..10a5faa7 100644 --- a/direct/nn/unet/unet_2d.py +++ b/direct/nn/unet/unet_2d.py @@ -10,118 +10,101 @@ from torch.nn import functional as F from direct.data import transforms as T -from direct.nn.conv.conv import CWNConv2d, CWNConvTranspose2d +from direct.nn.conv.modulated_conv import ModConv2d, ModConv2dBias, ModConvTranspose2d -class ConvBlock(nn.Module): - """U-Net convolutional block. - - It consists of two convolution layers each followed by instance normalization, LeakyReLU activation and dropout. - """ - - def __init__(self, in_channels: int, out_channels: int, dropout_probability: float): - """Inits ConvBlock. +class ConvModule(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int, + dropout_probability: float, + modulation: bool = False, + bias: ModConv2dBias = ModConv2dBias.PARAM, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, + ): + """Inits :class:`ConvModule`. Parameters ---------- - in_channels: int + in_channels : int Number of input channels. - out_channels: int + out_channels : int Number of output channels. + kernel_size : int + Size of the convolutional kernel. + padding : int + Padding added to all sides of the input. dropout_probability: float Dropout probability. + modulation : bool, optional + If True, apply modulation using an MLP on the auxiliary variable `y`, by default False. + bias : ModConv2dBias, optional + Type of bias, by default ModConv2dBias.PARAM. + aux_in_features : int, optional + Number of features in the auxiliary input variable `y`, by default None. + fc_hidden_features : int, optional + Number of hidden features in the modulation MLP, by default None. """ super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.dropout_probability = dropout_probability - - self.layers = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.InstanceNorm2d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Dropout2d(dropout_probability), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.InstanceNorm2d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Dropout2d(dropout_probability), - ) - - def forward(self, input_data: torch.Tensor) -> torch.Tensor: - """Performs the forward pass of :class:`ConvBlock`. - - Parameters - ---------- - input_data: torch.Tensor - - Returns - ------- - torch.Tensor - """ - return self.layers(input_data) + self.modulation = modulation - def __repr__(self): - """Representation of :class:`ConvBlock`.""" - return ( - f"ConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, " - f"dropout_probability={self.dropout_probability})" - ) - - -class TransposeConvBlock(nn.Module): - """U-Net Transpose Convolutional Block. - - It consists of one convolution transpose layers followed by instance normalization and LeakyReLU activation. - """ - - def __init__(self, in_channels: int, out_channels: int): - """Inits :class:`TransposeConvBlock`. - - Parameters - ---------- - in_channels: int - Number of input channels. - out_channels: int - Number of output channels. - """ - super().__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - - self.layers = nn.Sequential( - nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=False), - nn.InstanceNorm2d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), + self.conv = ModConv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + modulation=modulation, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, ) + self.instance_norm = nn.InstanceNorm2d(out_channels) + self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.dropout = nn.Dropout2d(dropout_probability) - def forward(self, input_data: torch.Tensor) -> torch.Tensor: - """Performs forward pass of :class:`TransposeConvBlock`. + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs the forward pass of :class:`ConvModule`. Parameters ---------- - input_data: torch.Tensor + x : torch.Tensor + y : torch.Tensor Returns ------- torch.Tensor """ - return self.layers(input_data) - - def __repr__(self): - """Representation of "class:`TransposeConvBlock`.""" - return f"ConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels})" + if self.modulation: + x = self.conv(x, y) + else: + x = self.conv(x) + x = self.instance_norm(x) + x = self.leaky_relu(x) + x = self.dropout(x) + return x -class CWNConvBlock(nn.Module): +class ConvBlock(nn.Module): """U-Net convolutional block. It consists of two convolution layers each followed by instance normalization, LeakyReLU activation and dropout. """ - def __init__(self, in_channels: int, out_channels: int, dropout_probability: float): - """Inits ConvBlock. + def __init__( + self, + in_channels: int, + out_channels: int, + dropout_probability: float, + modulation: bool = False, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, + ): + """Inits :class:`ConvBlock`. Parameters ---------- @@ -131,6 +114,8 @@ def __init__(self, in_channels: int, out_channels: int, dropout_probability: flo Number of output channels. dropout_probability: float Dropout probability. + modulation : bool + If True modulated convolutions will be used. Default: False. """ super().__init__() @@ -138,45 +123,66 @@ def __init__(self, in_channels: int, out_channels: int, dropout_probability: flo self.out_channels = out_channels self.dropout_probability = dropout_probability - self.layers = nn.Sequential( - CWNConv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.InstanceNorm2d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Dropout2d(dropout_probability), - CWNConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.InstanceNorm2d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Dropout2d(dropout_probability), - ) + self.modulation = modulation + self.aux_in_features = aux_in_features + self.fc_hidden_features = fc_hidden_features + + self.layer_1, self.layer_2 = [ + ConvModule( + in_channels if i == 0 else out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=ModConv2dBias.NONE if not self.modulation else ModConv2dBias.LEARNED, + dropout_probability=dropout_probability, + modulation=modulation, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, + ) + for i in range(2) + ] - def forward(self, input_data: torch.Tensor) -> torch.Tensor: + def forward(self, input_data: torch.Tensor, aux_data: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs the forward pass of :class:`ConvBlock`. Parameters ---------- - input_data: torch.Tensor + input_data : torch.Tensor + aux_data : torch.Tensor Returns ------- torch.Tensor """ - return self.layers(input_data) + if self.modulation: + out = self.layer_2(self.layer_1(input_data, aux_data), aux_data) + else: + out = self.layer_2(self.layer_1(input_data)) + return out def __repr__(self): """Representation of :class:`ConvBlock`.""" return ( - f"CWNConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, " - f"dropout_probability={self.dropout_probability})" + f"ConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"dropout_probability={self.dropout_probability}, modulation={self.modulation}, " + f"aux_in_features={self.aux_in_features}, fc_hidden_features={self.fc_hidden_features})" ) -class CWNTransposeConvBlock(nn.Module): +class TransposeConvBlock(nn.Module): """U-Net Transpose Convolutional Block. It consists of one convolution transpose layers followed by instance normalization and LeakyReLU activation. """ - def __init__(self, in_channels: int, out_channels: int): + def __init__( + self, + in_channels: int, + out_channels: int, + modulation: bool = False, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, + ): """Inits :class:`TransposeConvBlock`. Parameters @@ -191,28 +197,49 @@ def __init__(self, in_channels: int, out_channels: int): self.in_channels = in_channels self.out_channels = out_channels - self.layers = nn.Sequential( - CWNConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=False), - nn.InstanceNorm2d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), + self.modulation = modulation + self.aux_in_features = aux_in_features + self.fc_hidden_features = fc_hidden_features + + self.conv = ModConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2, + stride=2, + bias=ModConv2dBias.NONE if not self.modulation else ModConv2dBias.LEARNED, + modulation=modulation, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, ) - def forward(self, input_data: torch.Tensor) -> torch.Tensor: - """Performs forward pass of :class:`TransposeConvBlock`. + self.instance_norm = nn.InstanceNorm2d(out_channels) + self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, input_data: torch.Tensor, aux_data: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs the forward pass of :class:`TransposeConvBlock`. Parameters ---------- - input_data: torch.Tensor + input_data : torch.Tensor + aux_data : torch.Tensor Returns ------- torch.Tensor """ - return self.layers(input_data) + if self.modulation: + out = self.conv(input_data, aux_data) + else: + out = self.conv(input_data) + return self.leaky_relu(self.instance_norm(out)) def __repr__(self): """Representation of "class:`TransposeConvBlock`.""" - return f"CWNConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels})" + return ( + f"TransposeConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"modulation={self.modulation}, aux_in_features={self.aux_in_features}, " + f"fc_hidden_features={self.fc_hidden_features})" + ) class UnetModel2d(nn.Module): @@ -231,7 +258,9 @@ def __init__( num_filters: int, num_pool_layers: int, dropout_probability: float, - cwn_conv: bool = False, + modulation: bool = False, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, ): """Inits :class:`UnetModel2d`. @@ -247,8 +276,6 @@ def __init__( Number of down-sampling and up-sampling layers (depth). dropout_probability: float Dropout probability. - cwn_conv : bool - Apply centered weigh normalization to convolutions. Default: False. """ super().__init__() @@ -257,62 +284,82 @@ def __init__( self.num_filters = num_filters self.num_pool_layers = num_pool_layers self.dropout_probability = dropout_probability + self.modulation = modulation - if cwn_conv: - conv_block = CWNConvBlock - transpose_conv_block = CWNTransposeConvBlock - else: - conv_block = ConvBlock - transpose_conv_block = TransposeConvBlock - - self.down_sample_layers = nn.ModuleList([conv_block(in_channels, num_filters, dropout_probability)]) + self.down_sample_layers = nn.ModuleList( + [ConvBlock(in_channels, num_filters, dropout_probability, modulation, aux_in_features, fc_hidden_features)] + ) ch = num_filters for _ in range(num_pool_layers - 1): - self.down_sample_layers += [conv_block(ch, ch * 2, dropout_probability)] + self.down_sample_layers += [ + ConvBlock(ch, ch * 2, dropout_probability, modulation, aux_in_features, fc_hidden_features) + ] ch *= 2 - self.conv = conv_block(ch, ch * 2, dropout_probability) + self.conv = ConvBlock(ch, ch * 2, dropout_probability, modulation, aux_in_features, fc_hidden_features) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() for _ in range(num_pool_layers - 1): - self.up_transpose_conv += [transpose_conv_block(ch * 2, ch)] - self.up_conv += [conv_block(ch * 2, ch, dropout_probability)] + self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch, modulation, aux_in_features, fc_hidden_features)] + self.up_conv += [ + ConvBlock(ch * 2, ch, dropout_probability, modulation, aux_in_features, fc_hidden_features) + ] ch //= 2 - self.up_transpose_conv += [transpose_conv_block(ch * 2, ch)] - self.up_conv += [ - nn.Sequential( - conv_block(ch * 2, ch, dropout_probability), - (CWNConv2d if cwn_conv else nn.Conv2d)(ch, self.out_channels, kernel_size=1, stride=1), - ) - ] + self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch, modulation, aux_in_features, fc_hidden_features)] + self.up_conv += [ConvBlock(ch * 2, ch, dropout_probability, modulation, aux_in_features, fc_hidden_features)] + + self.conv_out = ModConv2d( + ch, + self.out_channels, + kernel_size=1, + stride=1, + modulation=modulation, + bias=ModConv2dBias.NONE if not self.modulation else ModConv2dBias.LEARNED, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, + ) - def forward(self, input_data: torch.Tensor) -> torch.Tensor: + def forward(self, input_data: torch.Tensor, aux_data: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs forward pass of :class:`UnetModel2d`. Parameters ---------- - input_data: torch.Tensor + input_data : torch.Tensor + Input data tensor of shape (N, `in_channels`, H, W). + aux_data : torch.Tensor, optional + Auxiliary data tensor of shape (N, `aux_in_features`) to be used if `modulation` is set to True. + Default: None Returns ------- torch.Tensor + Output data tensor of shape (N, `out_channels`, H, W). """ stack = [] output = input_data # Apply down-sampling layers for _, layer in enumerate(self.down_sample_layers): - output = layer(output) + if self.modulation: + output = layer(output, aux_data) + else: + output = layer(output) stack.append(output) output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) - output = self.conv(output) + if self.modulation: + output = self.conv(output, aux_data) + else: + output = self.conv(output) # Apply up-sampling layers for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): downsample_layer = stack.pop() - output = transpose_conv(output) + if self.modulation: + output = transpose_conv(output, aux_data) + else: + output = transpose_conv(output) # Reflect pad on the right/bottom if needed to handle odd input dimensions. padding = [0, 0, 0, 0] @@ -324,7 +371,15 @@ def forward(self, input_data: torch.Tensor) -> torch.Tensor: output = F.pad(output, padding, "reflect") output = torch.cat([output, downsample_layer], dim=1) - output = conv(output) + if self.modulation: + output = conv(output, aux_data) + else: + output = conv(output) + + if self.modulation: + output = self.conv_out(output, aux_data) + else: + output = self.conv_out(output) return output @@ -340,7 +395,9 @@ def __init__( num_pool_layers: int, dropout_probability: float, norm_groups: int = 2, - cwn_conv: bool = False, + modulation: bool = False, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, ): """Inits :class:`NormUnetModel2d`. @@ -356,8 +413,6 @@ def __init__( Number of down-sampling and up-sampling layers (depth). dropout_probability: float Dropout probability. - cwn_conv : bool - Apply centered weigh normalization to convolutions. Default: False. norm_groups: int, Number of normalization groups. """ @@ -369,9 +424,11 @@ def __init__( num_filters=num_filters, num_pool_layers=num_pool_layers, dropout_probability=dropout_probability, - cwn_conv=cwn_conv, + modulation=modulation, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, ) - + self.modulation = modulation self.norm_groups = norm_groups @staticmethod @@ -416,21 +473,27 @@ def unpad( ) -> torch.Tensor: return input_data[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] - def forward(self, input_data: torch.Tensor) -> torch.Tensor: + def forward(self, input_data: torch.Tensor, aux_data: Optional[torch.Tensor] = None) -> torch.Tensor: """Performs forward pass of :class:`NormUnetModel2d`. Parameters ---------- - input_data: torch.Tensor + input_data : torch.Tensor + Input data tensor of shape (N, `in_channels`, H, W). + aux_data : torch.Tensor, optional + Auxiliary data tensor of shape (N, `aux_in_features`) to be used if `modulation` is set to True. + Default: None Returns ------- torch.Tensor + Output data tensor of shape (N, `out_channels`, H, W). """ output, mean, std = self.norm(input_data, self.norm_groups) output, pad_sizes = self.pad(output) - output = self.unet2d(output) + + output = self.unet2d(output, aux_data) output = self.unpad(output, *pad_sizes) output = self.unnorm(output, mean, std, self.norm_groups) @@ -448,7 +511,6 @@ def __init__( num_filters: int, num_pool_layers: int, dropout_probability: float, - cwn_conv: bool = False, skip_connection: bool = False, normalized: bool = False, image_initialization: str = "zero_filled", @@ -468,8 +530,6 @@ def __init__( Number of pooling layers. dropout_probability: float Dropout probability. - cwn_conv : bool - Apply centered weigh normalization to convolutions. Default: False. skip_connection: bool If True, skip connection is used for the output. Default: False. normalized: bool @@ -494,7 +554,6 @@ def __init__( num_filters=num_filters, num_pool_layers=num_pool_layers, dropout_probability=dropout_probability, - cwn_conv=cwn_conv, ) else: self.unet = UnetModel2d( @@ -503,7 +562,6 @@ def __init__( num_filters=num_filters, num_pool_layers=num_pool_layers, dropout_probability=dropout_probability, - cwn_conv=cwn_conv, ) self.forward_operator = forward_operator self.backward_operator = backward_operator diff --git a/direct/nn/unet/unet_3d.py b/direct/nn/unet/unet_3d.py index 5578e71d..7061e352 100644 --- a/direct/nn/unet/unet_3d.py +++ b/direct/nn/unet/unet_3d.py @@ -5,8 +5,6 @@ from torch import nn from torch.nn import functional as F -from direct.nn.conv.conv import CWNConv3d, CWNConvTranspose3d - class ConvBlock3D(nn.Module): """3D U-Net convolutional block.""" @@ -52,56 +50,6 @@ def forward(self, input_data: torch.Tensor) -> torch.Tensor: return self.layers(input_data) -class CWNConvBlock3D(nn.Module): - """U-Net convolutional block for 3D data.""" - - def __init__(self, in_channels: int, out_channels: int, dropout_probability: float): - super().__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.dropout_probability = dropout_probability - - self.layers = nn.Sequential( - CWNConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.InstanceNorm3d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Dropout3d(dropout_probability), - CWNConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.InstanceNorm3d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Dropout3d(dropout_probability), - ) - - def forward(self, input_data: torch.Tensor) -> torch.Tensor: - return self.layers(input_data) - - def __repr__(self): - return ( - f"CWNConvBlock3D(in_channels={self.in_channels}, out_channels={self.out_channels}, " - f"dropout_probability={self.dropout_probability})" - ) - - -class CWNTransposeConvBlock3D(nn.Module): - """U-Net Transpose Convolutional Block for 3D data.""" - - def __init__(self, in_channels: int, out_channels: int): - super().__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - - self.layers = nn.Sequential( - CWNConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2, bias=False), - nn.InstanceNorm3d(out_channels), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - ) - - def forward(self, input_data: torch.Tensor) -> torch.Tensor: - return self.layers(input_data) - - class UnetModel3d(nn.Module): """PyTorch implementation of a 3D U-Net model.""" @@ -112,7 +60,6 @@ def __init__( num_filters: int, num_pool_layers: int, dropout_probability: float, - cwn_conv: bool = False, ): super().__init__() @@ -122,31 +69,24 @@ def __init__( self.num_pool_layers = num_pool_layers self.dropout_probability = dropout_probability - if cwn_conv: - conv_block = CWNConvBlock3D - transpose_conv_block = CWNTransposeConvBlock3D - else: - conv_block = ConvBlock3D - transpose_conv_block = TransposeConvBlock3D - - self.down_sample_layers = nn.ModuleList([conv_block(in_channels, num_filters, dropout_probability)]) + self.down_sample_layers = nn.ModuleList([ConvBlock3D(in_channels, num_filters, dropout_probability)]) ch = num_filters for _ in range(num_pool_layers - 1): - self.down_sample_layers += [conv_block(ch, ch * 2, dropout_probability)] + self.down_sample_layers += [ConvBlock3D(ch, ch * 2, dropout_probability)] ch *= 2 - self.conv = conv_block(ch, ch * 2, dropout_probability) + self.conv = ConvBlock3D(ch, ch * 2, dropout_probability) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() for _ in range(num_pool_layers - 1): - self.up_transpose_conv += [transpose_conv_block(ch * 2, ch)] - self.up_conv += [conv_block(ch * 2, ch, dropout_probability)] + self.up_transpose_conv += [TransposeConvBlock3D(ch * 2, ch)] + self.up_conv += [ConvBlock3D(ch * 2, ch, dropout_probability)] ch //= 2 - self.up_transpose_conv += [transpose_conv_block(ch * 2, ch)] + self.up_transpose_conv += [TransposeConvBlock3D(ch * 2, ch)] self.up_conv += [ nn.Sequential( - conv_block(ch * 2, ch, dropout_probability), + ConvBlock3D(ch * 2, ch, dropout_probability), nn.Conv3d(ch, out_channels, kernel_size=1, stride=1), ) ] @@ -204,7 +144,6 @@ def __init__( num_pool_layers: int, dropout_probability: float, norm_groups: int = 2, - cwn_conv: bool = False, ): """Inits :class:`NormUnetModel3D`. @@ -222,8 +161,6 @@ def __init__( Dropout probability. norm_groups: int, Number of normalization groups. - cwn_conv : bool - Apply centered weight normalization to convolutions. Default: False. """ super().__init__() @@ -233,7 +170,6 @@ def __init__( num_filters=num_filters, num_pool_layers=num_pool_layers, dropout_probability=dropout_probability, - cwn_conv=cwn_conv, ) self.norm_groups = norm_groups diff --git a/direct/nn/varsplitnet/config.py b/direct/nn/varsplitnet/config.py index 5eca4c9e..73b707af 100644 --- a/direct/nn/varsplitnet/config.py +++ b/direct/nn/varsplitnet/config.py @@ -23,7 +23,6 @@ class MRIVarSplitNetConfig(ModelConfig): image_unet_num_filters: int = 32 image_unet_num_pool_layers: int = 4 image_unet_dropout: float = 0.0 - image_unet_cwn_conv: bool = False image_didn_hidden_channels: int = 16 image_didn_num_dubs: int = 6 image_didn_num_convs_recon: int = 9 diff --git a/direct/nn/vsharp/config.py b/direct/nn/vsharp/config.py index b69f2a5d..15ffaa75 100644 --- a/direct/nn/vsharp/config.py +++ b/direct/nn/vsharp/config.py @@ -20,6 +20,9 @@ class VSharpNetConfig(ModelConfig): initializer_dilations: tuple[int, ...] = (1, 1, 2, 4) initializer_multiscale: int = 1 initializer_activation: ActivationType = ActivationType.PRELU + conv_modulation: bool = False + aux_in_features: int = 2 + fc_hidden_features: int = 32 image_resnet_hidden_channels: int = 128 image_resnet_num_blocks: int = 15 image_resnet_batchnorm: bool = True @@ -27,7 +30,6 @@ class VSharpNetConfig(ModelConfig): image_unet_num_filters: int = 32 image_unet_num_pool_layers: int = 4 image_unet_dropout: float = 0.0 - image_unet_cwn_conv: bool = False image_didn_hidden_channels: int = 16 image_didn_num_dubs: int = 6 image_didn_num_convs_recon: int = 9 @@ -51,5 +53,4 @@ class VSharpNet3DConfig(ModelConfig): unet_num_filters: int = 32 unet_num_pool_layers: int = 4 unet_dropout: float = 0.0 - unet_cwn_conv: bool = False unet_norm: bool = False diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py index ab6acdbe..e1f975d7 100644 --- a/direct/nn/vsharp/vsharp.py +++ b/direct/nn/vsharp/vsharp.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Optional import numpy as np import torch @@ -14,6 +14,7 @@ from direct.constants import COMPLEX_SIZE from direct.data.transforms import apply_mask, expand_operator, reduce_operator +from direct.nn.conv.modulated_conv import ModConv2d, ModConv2dBias from direct.nn.get_nn_model_config import ModelName, _get_model_config, _get_relu_activation from direct.nn.types import ActivationType, InitType from direct.nn.unet.unet_3d import NormUnetModel3d, UnetModel3d @@ -30,6 +31,9 @@ def __init__( dilations: tuple[int, ...], multiscale_depth: int = 1, activation: ActivationType = ActivationType.PRELU, + conv_modulation: bool = False, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, ): """Inits :class:`LagrangeMultipliersInitializer`. @@ -52,29 +56,53 @@ def __init__( self.conv_blocks = nn.ModuleList() tch = in_channels for curr_channels, curr_dilations in zip(channels, dilations): - block = nn.Sequential( - nn.ReplicationPad2d(curr_dilations), - nn.Conv2d(tch, curr_channels, 3, padding=0, dilation=curr_dilations), + block = nn.ModuleList( + [ + nn.ReplicationPad2d(curr_dilations), + ModConv2d( + tch, + curr_channels, + kernel_size=3, + padding=0, + dilation=curr_dilations, + modulation=conv_modulation, + bias=ModConv2dBias.LEARNED if conv_modulation else ModConv2dBias.PARAM, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, + ), + ] ) tch = curr_channels self.conv_blocks.append(block) # Define output block tch = np.sum(channels[-multiscale_depth:]) - block = nn.Conv2d(tch, out_channels, 1, padding=0) - self.out_block = nn.Sequential(block) + self.out_block = ModConv2d( + tch, + out_channels, + kernel_size=1, + padding=0, + modulation=conv_modulation, + bias=ModConv2dBias.LEARNED if conv_modulation else ModConv2dBias.PARAM, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, + ) self.multiscale_depth = multiscale_depth self.activation = _get_relu_activation(activation) - def forward(self, x: torch.Tensor) -> torch.Tensor: + self.conv_modulation = conv_modulation + + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward pass of :class:`LagrangeMultipliersInitializer`. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, in_channels, height, width). + y : torch.Tensor, optional + Auxiliary tensor of shape (batch_size, aux_in_features). Default: None. Returns ------- @@ -84,14 +112,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: features = [] for block in self.conv_blocks: - x = F.relu(block(x), inplace=True) + x = block[0](x) + if self.conv_modulation: + x = F.relu(block[1](x, y), inplace=True) + else: + x = F.relu(block[1](x), inplace=True) if self.multiscale_depth > 1: features.append(x) if self.multiscale_depth > 1: x = torch.cat(features[-self.multiscale_depth :], dim=1) - return self.activation(self.out_block(x)) + if self.conv_modulation: + return self.activation(self.out_block(x, y)) + else: + return self.activation(self.out_block(x)) class VSharpNet(nn.Module): @@ -150,6 +185,9 @@ def __init__( initializer_multiscale: int = 1, initializer_activation: ActivationType = ActivationType.PRELU, auxiliary_steps: int = 0, + conv_modulation: bool = False, + aux_in_features: Optional[int] = None, + fc_hidden_features: Optional[int] = None, **kwargs, ): """Inits :class:`VSharpNet`. @@ -202,6 +240,9 @@ def __init__( image_model_architecture, in_channels=COMPLEX_SIZE * 3, out_channels=COMPLEX_SIZE, + modulation=conv_modulation, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, **{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k}, ) @@ -216,6 +257,9 @@ def __init__( dilations=initializer_dilations, multiscale_depth=initializer_multiscale, activation=initializer_activation, + conv_modulation=conv_modulation, + aux_in_features=aux_in_features, + fc_hidden_features=fc_hidden_features, ) self.learning_rate_eta = nn.Parameter(torch.ones(num_steps_dc_gd, requires_grad=True)) @@ -246,11 +290,14 @@ def __init__( self._complex_dim = -1 self._spatial_dims = (2, 3) + self.conv_modulation = conv_modulation + def forward( self, masked_kspace: torch.Tensor, sensitivity_map: torch.Tensor, sampling_mask: torch.Tensor, + auxiliary_data: Optional[torch.Tensor] = None, ) -> list[torch.Tensor]: """Computes forward pass of :class:`VSharpNet`. @@ -279,15 +326,19 @@ def forward( z = x.clone() - u = self.initializer(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + if self.conv_modulation: + u = self.initializer(x.permute(0, 3, 1, 2), auxiliary_data).permute(0, 2, 3, 1) + else: + u = self.initializer(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) for admm_step in range(self.num_steps): - z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0]( - torch.cat( - [z, x, u / self.rho[admm_step]], - dim=self._complex_dim, - ).permute(0, 3, 1, 2) - ).permute(0, 2, 3, 1) + denoiser_input = [torch.cat([z, x, u / self.rho[admm_step]], dim=self._complex_dim).permute(0, 3, 1, 2)] + if self.conv_modulation: + denoiser_input.append(auxiliary_data) + + z = self.denoiser_blocks[admm_step if self.no_parameter_sharing else 0](*denoiser_input).permute( + 0, 2, 3, 1 + ) for dc_gd_step in range(self.num_steps_dc_gd): dc = apply_mask( @@ -404,7 +455,6 @@ def __init__( unet_num_filters: int = 32, unet_num_pool_layers: int = 4, unet_dropout: float = 0.0, - unet_cwn_conv: bool = False, unet_norm: bool = False, **kwargs, ): @@ -467,7 +517,6 @@ def __init__( num_filters=unet_num_filters, num_pool_layers=unet_num_pool_layers, dropout_probability=unet_dropout, - cwn_conv=unet_cwn_conv, ) ) diff --git a/direct/types.py b/direct/types.py index e3bda208..d92c69cf 100644 --- a/direct/types.py +++ b/direct/types.py @@ -15,6 +15,7 @@ DictOrDictConfig = Union[dict, DictConfig] Number = Union[float, int] +IntOrTuple = Union[int, tuple] PathOrString = Union[pathlib.Path, str] FileOrUrl = NewType("FileOrUrl", PathOrString) HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler]