Skip to content

Commit

Permalink
Add choice for SUM modulation
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Feb 5, 2024
1 parent 2ec7e73 commit dcd6439
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 126 deletions.
276 changes: 167 additions & 109 deletions direct/nn/conv/modulated_conv.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions direct/nn/get_nn_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _get_model_config(
"fc_hidden_features": kwargs.get("fc_hidden_features", None),
"fc_groups": kwargs.get("fc_groups", 1),
"fc_activation": kwargs.get("fc_activation", None),
"num_weights": kwargs.get("num_weights", None),
"modulation_at_input": kwargs.get("modulation_at_input", False),
}
)
Expand Down
2 changes: 2 additions & 0 deletions direct/nn/unet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class UnetModel2dConfig(ModelConfig):
fc_hidden_features: Optional[int] = None
fc_groups: int = 1
fc_activation: ModConvActivation = ModConvActivation.SIGMOID
num_weights: Optional[int] = None


class NormUnetModel2dConfig(ModelConfig):
Expand All @@ -33,6 +34,7 @@ class NormUnetModel2dConfig(ModelConfig):
fc_hidden_features: Optional[int] = None
fc_groups: int = 1
fc_activation: ModConvActivation = ModConvActivation.SIGMOID
num_weights: Optional[int] = None


@dataclass
Expand Down
62 changes: 49 additions & 13 deletions direct/nn/unet/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
fc_hidden_features: Optional[int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
):
"""Inits :class:`ConvModule`.
Expand Down Expand Up @@ -55,11 +56,13 @@ def __init__(
Number of hidden features in the modulation MLP unit. Ignored if `modulation` is ModConvType.NONE.
Default: None.
fc_groups : int, optional
Number of MLP groups for the modulation MLP unit. Ignored if `modulation` is ModConvType.NONE.
Default: 1.
Number of MLP groups for the modulation MLP unit. Ignored if `modulation` is ModConvType.NONE
or ModConvType.SUM. Default: 1.
fc_activation : ModConvActivation
Activation function to be applied in the modulation MLP unit. Ignored if `modulation` is ModConvType.NONE.
Default: ModConvActivation.SIGMOID.
num_weights : int, optional
Number of weights to use in case modulation is ModConvType.SUM. Default: None.
"""
super().__init__()

Expand All @@ -76,6 +79,7 @@ def __init__(
fc_hidden_features=fc_hidden_features,
fc_groups=fc_groups,
fc_activation=fc_activation,
num_weights=num_weights,
)
self.instance_norm = nn.InstanceNorm2d(out_channels)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
Expand Down Expand Up @@ -119,6 +123,7 @@ def __init__(
fc_hidden_features: Optional[int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
):
"""Inits :class:`ConvBlock`.
Expand All @@ -139,11 +144,13 @@ def __init__(
Number of hidden features in the modulation MLP units. Ignored if `modulation` is ModConvType.NONE.
Default: None.
fc_groups : int, optional
Number of MLP groups for the modulation MLP unit. Ignored if `modulation` is ModConvType.NONE.
Default: 1.
Number of MLP groups for the modulation MLP unit. Ignored if `modulation` is ModConvType.NONE
or ModConvType.SUM. Default: 1.
fc_activation : ModConvActivation
Activation function to be applied in the MLP units. Ignored if `modulation` is ModConvType.NONE.
Default: ModConvActivation.SIGMOID.
num_weights : int, optional
Number of weights to use in case modulation is ModConvType.SUM. Default: None.
"""
super().__init__()

Expand All @@ -156,6 +163,7 @@ def __init__(
self.fc_hidden_features = fc_hidden_features
self.fc_groups = fc_groups
self.fc_activation = fc_activation
self.num_weights = num_weights

self.layer_1, self.layer_2 = [
ConvModule(
Expand All @@ -170,6 +178,7 @@ def __init__(
fc_hidden_features=fc_hidden_features,
fc_groups=fc_groups,
fc_activation=fc_activation,
num_weights=num_weights,
)
for i in range(2)
]
Expand Down Expand Up @@ -198,7 +207,7 @@ def __repr__(self):
f"ConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, "
f"dropout_probability={self.dropout_probability}, modulation={self.modulation}, "
f"aux_in_features={self.aux_in_features}, fc_hidden_features={self.fc_hidden_features},"
f"fc_groups={self.fc_groups}, fc_activation={self.fc_activation})"
f"fc_groups={self.fc_groups}, fc_activation={self.fc_activation}, num_weights={self.num_weights})"
)


Expand All @@ -217,6 +226,7 @@ def __init__(
fc_hidden_features: Optional[int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
):
"""Inits :class:`TransposeConvBlock`.
Expand All @@ -240,6 +250,8 @@ def __init__(
fc_activation : ModConvActivation
Activation function to be applied in the MLP units. Ignored if `modulation` is ModConvType.NONE.
Default: ModConvActivation.SIGMOID.
num_weights : int, optional
Number of weights to use in case modulation is ModConvType.SUM. Default: None.
"""
super().__init__()

Expand All @@ -251,6 +263,7 @@ def __init__(
self.fc_hidden_features = fc_hidden_features
self.fc_groups = fc_groups
self.fc_activation = fc_activation
self.num_weights = num_weights

self.conv = ModConvTranspose2d(
in_channels=in_channels,
Expand All @@ -263,6 +276,7 @@ def __init__(
fc_hidden_features=fc_hidden_features,
fc_groups=fc_groups,
fc_activation=fc_activation,
num_weights=num_weights,
)

self.instance_norm = nn.InstanceNorm2d(out_channels)
Expand Down Expand Up @@ -292,7 +306,7 @@ def __repr__(self):
f"TransposeConvBlock(in_channels={self.in_channels}, out_channels={self.out_channels}, "
f"modulation={self.modulation}, aux_in_features={self.aux_in_features}, "
f"fc_hidden_features={self.fc_hidden_features}, fc_groups={self.fc_groups}, "
f"fc_activation={self.fc_activation})"
f"fc_activation={self.fc_activation}, num_weights={self.num_weights})"
)


Expand All @@ -319,6 +333,7 @@ def __init__(
fc_hidden_features: Optional[int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
modulation_at_input: bool = False,
):
"""Inits :class:`UnetModel2d`.
Expand All @@ -344,11 +359,13 @@ def __init__(
Number of hidden features in the modulated convolutions. Ignored if `modulation` is ModConvType.None.
Default: None.
fc_groups : int, optional
Number of groups in the modulated convolutions. Ignored if `modulation` is ModConvType.NONE.
Default: 1.
Number of groups in the modulated convolutions. Ignored if `modulation` is ModConvType.NONE
or ModConvType.SUM. Default: 1.
fc_activation : ModConvActivation
Activation function to be applied in the MLP units for modulated convolutions.
Ignored if `modulation` is ModConvType.None. Default: ModConvActivation.SIGMOID.
num_weights : int, optional
Number of weights to use in case modulation is ModConvType.SUM. Default: None.
modulation_at_input : bool, optional
If True, apply modulation only at the initial convolutional layer. Default: False.
"""
Expand All @@ -372,6 +389,7 @@ def __init__(
fc_hidden_features,
fc_groups,
fc_activation,
num_weights,
)
]
)
Expand All @@ -391,19 +409,28 @@ def __init__(
fc_hidden_features,
fc_groups,
fc_activation,
num_weights,
)
]
ch *= 2
self.conv = ConvBlock(
ch, ch * 2, dropout_probability, modulation, aux_in_features, fc_hidden_features, fc_groups, fc_activation
ch,
ch * 2,
dropout_probability,
modulation,
aux_in_features,
fc_hidden_features,
fc_groups,
fc_activation,
num_weights,
)

self.up_conv = nn.ModuleList()
self.up_transpose_conv = nn.ModuleList()
for _ in range(num_pool_layers - 1):
self.up_transpose_conv += [
TransposeConvBlock(
ch * 2, ch, modulation, aux_in_features, fc_hidden_features, fc_groups, fc_activation
ch * 2, ch, modulation, aux_in_features, fc_hidden_features, fc_groups, fc_activation, num_weights
)
]
self.up_conv += [
Expand All @@ -416,12 +443,15 @@ def __init__(
fc_hidden_features,
fc_groups,
fc_activation,
num_weights,
)
]
ch //= 2

self.up_transpose_conv += [
TransposeConvBlock(ch * 2, ch, modulation, aux_in_features, fc_hidden_features, fc_groups, fc_activation)
TransposeConvBlock(
ch * 2, ch, modulation, aux_in_features, fc_hidden_features, fc_groups, fc_activation, num_weights
)
]
self.up_conv += [
ConvBlock(
Expand All @@ -433,6 +463,7 @@ def __init__(
fc_hidden_features,
fc_groups,
fc_activation,
num_weights,
)
]

Expand All @@ -447,6 +478,7 @@ def __init__(
fc_hidden_features=fc_hidden_features,
fc_groups=fc_groups,
fc_activation=fc_activation,
num_weights=num_weights,
)

def forward(self, input_data: torch.Tensor, aux_data: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down Expand Up @@ -529,6 +561,7 @@ def __init__(
fc_hidden_features: Optional[int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
modulation_at_input: bool = False,
):
"""Inits :class:`NormUnetModel2d`.
Expand Down Expand Up @@ -556,11 +589,13 @@ def __init__(
Number of hidden features in the modulated convolutions. Ignored if `modulation` is ModConvType.None.
Default: None.
fc_groups : int, optional
Number of groups in the modulated convolutions. Ignored if `modulation` is ModConvType.NONE.
Default: 1.
Number of groups in the modulated convolutions. Ignored if `modulation` is ModConvType.NONE
or ModConvType.SUM. Default: 1.
fc_activation : ModConvActivation
Activation function to be applied in the MLP units for modulated convolutions.
Ignored if `modulation` is ModConvType.None. Default: ModConvActivation.SIGMOID.
num_weights : int, optional
Number of weights to use in case modulation is ModConvType.SUM. Default: None.
modulation_at_input : bool, optional
If True, apply modulation only at the initial convolutional layer. Default: False.
"""
Expand All @@ -578,6 +613,7 @@ def __init__(
fc_groups=fc_groups,
fc_activation=fc_activation,
modulation_at_input=modulation_at_input,
num_weights=num_weights,
)
self.modulation = modulation
self.norm_groups = norm_groups
Expand Down
1 change: 1 addition & 0 deletions direct/nn/vsharp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class VSharpNetConfig(ModelConfig):
fc_hidden_features: Optional[int] = None
fc_groups: int = 1
fc_activation: ModConvActivation = ModConvActivation.SIGMOID
num_weights: Optional[int] = None
modulation_at_input: bool = False
image_resnet_hidden_channels: int = 128
image_resnet_num_blocks: int = 15
Expand Down
20 changes: 16 additions & 4 deletions direct/nn/vsharp/vsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
fc_hidden_features: Optional[int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
modulation_at_input: bool = False,
):
"""Inits :class:`LagrangeMultipliersInitializer`.
Expand Down Expand Up @@ -65,6 +66,8 @@ def __init__(
fc_activation : ModConvActivation
Activation function to be applied in the MLP units for modulated convolutions.
Ignored if `modulation` is ModConvType.None. Default: ModConvActivation.SIGMOID.
num_weights : int, optional
Number of weights to use in case `modulation` is ModConvType.SUM. Default: None.
modulation_at_input : bool, optional
If True, apply modulation only at the initial convolutional layer.
Ignored if `modulation` is ModConvType.None. Default: False.
Expand Down Expand Up @@ -95,6 +98,7 @@ def __init__(
fc_hidden_features=fc_hidden_features,
fc_groups=fc_groups,
fc_activation=fc_activation,
num_weights=num_weights,
),
]
)
Expand All @@ -117,6 +121,7 @@ def __init__(
fc_hidden_features=fc_hidden_features,
fc_groups=fc_groups,
fc_activation=fc_activation,
num_weights=num_weights,
)

self.multiscale_depth = multiscale_depth
Expand Down Expand Up @@ -221,6 +226,7 @@ def __init__(
fc_hidden_features: Optional[int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
modulation_at_input: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -258,15 +264,19 @@ def __init__(
conv_modulation : ModConvType
If not ModConvType.None, modulated convolutions will be used. Default: ModConvType.None.
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`. Ignored if `modulation` is ModConvType.None. Default: None.
Number of features in the auxiliary input variable `y`. Ignored if `conv_modulation` is ModConvType.None.
Default: None.
fc_hidden_features : int, optional
Number of hidden features in the modulated convolutions. Ignored if `modulation` is ModConvType.None. Default: None.
Number of hidden features in the modulated convolutions. Ignored if `conv_modulation` is ModConvType.None.
Default: None.
fc_activation : ModConvActivation
Activation function to be applied in the MLP units for modulated convolutions.
Ignored if `modulation` is ModConvType.None. Default: ModConvActivation.SIGMOID.
Ignored if `conv_modulation` is ModConvType.None. Default: ModConvActivation.SIGMOID.
num_weights : int, optional
Number of weights to use in case `conv_modulation` is ModConvType.SUM. Default: None.
modulation_at_input : bool, optional
If True, apply modulation only at the initial convolutional layers of learned initializer and denoisers.
Ignored if `modulation` is ModConvType.None. Default: False.
Ignored if `conv_modulation` is ModConvType.None. Default: False.
**kwargs: Additional keyword arguments.
"""
# pylint: disable=too-many-locals
Expand All @@ -291,6 +301,7 @@ def __init__(
fc_hidden_features=fc_hidden_features,
fc_groups=fc_groups,
fc_activation=fc_activation,
num_weights=num_weights,
modulation_at_input=modulation_at_input,
**{k.replace("image_", ""): v for (k, v) in kwargs.items() if "image_" in k},
)
Expand All @@ -311,6 +322,7 @@ def __init__(
fc_hidden_features=fc_hidden_features,
fc_groups=fc_groups,
fc_activation=fc_activation,
num_weights=num_weights,
modulation_at_input=modulation_at_input,
)

Expand Down

0 comments on commit dcd6439

Please sign in to comment.