Skip to content

Commit

Permalink
Option for multilayer mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Feb 9, 2024
1 parent dcd6439 commit e4f4a57
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 35 deletions.
63 changes: 43 additions & 20 deletions direct/nn/conv/modulated_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
dilation: IntOrTuple = 1,
bias: ModConv2dBias = ModConv2dBias.PARAM,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_bias: Optional[bool] = True,
fc_groups: int | None = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
Type of bias, by default ModConv2dBias.PARAM.
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`, by default None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulation MLP, by default None.
fc_bias : bool, optional
If True, enable bias in the modulation MLP, by default True.
Expand Down Expand Up @@ -141,6 +141,8 @@ def __init__(
raise ValueError(
f"Value for `fc_hidden_features` can't be None with `modulation` not set to ModConvType.NONE."
)
if isinstance(fc_hidden_features, int):
fc_hidden_features = (fc_hidden_features,)
if fc_groups is None:
raise ValueError(f"Value for `fc_groups` can't be None with `modulation` not set to ModConvType.NONE.")
if fc_groups < 1:
Expand All @@ -167,10 +169,14 @@ def __init__(
)
mod_out_features = num_weights

fc_hidden_features = fc_hidden_features + (mod_out_features,)

fc = [nn.Linear(aux_in_features, fc_hidden_features[0], bias=fc_bias), nn.PReLU()]
for i in range(0, len(fc_hidden_features) - 1):
fc.append(nn.Linear(fc_hidden_features[i], fc_hidden_features[i + 1]))
fc.append(nn.PReLU())
self.fc = nn.Sequential(
nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias),
nn.PReLU(),
nn.Linear(fc_hidden_features, mod_out_features, bias=fc_bias),
*fc,
*(
(nn.Sigmoid(),)
if fc_activation == ModConvActivation.SIGMOID
Expand Down Expand Up @@ -198,11 +204,16 @@ def __init__(
f"Bias can only be set to ModConv2dBias.LEARNED if `modulation` is not ModConvType.NONE, "
f"but modulation is set to ModConvType.NONE."
)
self.bias = nn.Sequential(
nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias),
nn.PReLU(),
nn.Linear(fc_hidden_features, out_channels, bias=fc_bias),
)
bias = [nn.Linear(aux_in_features, fc_hidden_features[0], bias=fc_bias)]
for i in range(0, len(fc_hidden_features) - 1):
bias.append(nn.PReLU())
bias.append(
nn.Linear(
fc_hidden_features[i],
fc_hidden_features[i + 1] if i != (len(fc_hidden_features) - 2) else out_channels,
)
)
self.bias = nn.Sequential(*bias)
else:
self.bias = None

Expand Down Expand Up @@ -331,7 +342,7 @@ def __init__(
dilation: IntOrTuple = 1,
bias: ModConv2dBias = ModConv2dBias.PARAM,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_bias: Optional[bool] = True,
fc_groups: int | None = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
Expand Down Expand Up @@ -361,7 +372,7 @@ def __init__(
Type of bias, by default ModConv2dBias.PARAM.
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`, by default None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulation MLP, by default None.
fc_bias : bool, optional
If True, enable bias in the modulation MLP, by default True.
Expand Down Expand Up @@ -401,6 +412,8 @@ def __init__(
raise ValueError(
f"Value for `fc_hidden_features` can't be None with `modulation` not set to ModConvType.NONE."
)
if isinstance(fc_hidden_features, int):
fc_hidden_features = (fc_hidden_features,)
if fc_groups is None:
raise ValueError(f"Value for `fc_groups` can't be None with `modulation` not set to ModConvType.NONE.")
if fc_groups < 1:
Expand All @@ -427,10 +440,15 @@ def __init__(
)
mod_out_features = num_weights

fc_hidden_features = fc_hidden_features + (mod_out_features,)

fc = [nn.Linear(aux_in_features, fc_hidden_features[0], bias=fc_bias), nn.PReLU()]

for i in range(0, len(fc_hidden_features) - 1):
fc.append(nn.Linear(fc_hidden_features[i], fc_hidden_features[i + 1]))
fc.append(nn.PReLU())
self.fc = nn.Sequential(
nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias),
nn.PReLU(),
nn.Linear(fc_hidden_features, mod_out_features, bias=fc_bias),
*fc,
*(
(nn.Sigmoid(),)
if fc_activation == ModConvActivation.SIGMOID
Expand Down Expand Up @@ -458,11 +476,16 @@ def __init__(
f"Bias can only be set to ModConv2dBias.LEARNED if `modulation` is not ModConvType.NONE, "
f"but modulation is set to ModConvType.NONE."
)
self.bias = nn.Sequential(
nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias),
nn.PReLU(),
nn.Linear(fc_hidden_features, out_channels, bias=fc_bias),
)
bias = [nn.Linear(aux_in_features, fc_hidden_features[0], bias=fc_bias)]
for i in range(0, len(fc_hidden_features) - 1):
bias.append(nn.PReLU())
bias.append(
nn.Linear(
fc_hidden_features[i],
fc_hidden_features[i + 1] if i != (len(fc_hidden_features) - 2) else out_channels,
)
)
self.bias = nn.Sequential(*bias)
else:
self.bias = None

Expand Down
23 changes: 13 additions & 10 deletions direct/nn/unet/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Copyright (c) DIRECT Contributors

# Code borrowed / edited from: https://github.com/facebookresearch/fastMRI/blob/

from __future__ import annotations

import math
from typing import Callable, List, Optional, Tuple

Expand All @@ -24,7 +27,7 @@ def __init__(
modulation: ModConvType = ModConvType.NONE,
bias: ModConv2dBias = ModConv2dBias.PARAM,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
Expand Down Expand Up @@ -52,7 +55,7 @@ def __init__(
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`. Ignored if `modulation` is ModConvType.NONE.
Default: None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulation MLP unit. Ignored if `modulation` is ModConvType.NONE.
Default: None.
fc_groups : int, optional
Expand Down Expand Up @@ -120,7 +123,7 @@ def __init__(
dropout_probability: float,
modulation: ModConvType = ModConvType.NONE,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
Expand All @@ -140,7 +143,7 @@ def __init__(
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`. Ignored if `modulation` is ModConvType.NONE.
Default: None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulation MLP units. Ignored if `modulation` is ModConvType.NONE.
Default: None.
fc_groups : int, optional
Expand Down Expand Up @@ -223,7 +226,7 @@ def __init__(
out_channels: int,
modulation: ModConvType = ModConvType.NONE,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
Expand All @@ -241,7 +244,7 @@ def __init__(
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`. Ignored if `modulation` is ModConvType.NONE.
Default: None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulation MLP unit. Ignored if `modulation` is ModConvType.NONE.
Default: None.
fc_groups : int, optional
Expand Down Expand Up @@ -330,7 +333,7 @@ def __init__(
dropout_probability: float,
modulation: ModConvType = ModConvType.NONE,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
Expand All @@ -355,7 +358,7 @@ def __init__(
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`. Ignored if `modulation` is ModConvType.None.
Default: None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulated convolutions. Ignored if `modulation` is ModConvType.None.
Default: None.
fc_groups : int, optional
Expand Down Expand Up @@ -558,7 +561,7 @@ def __init__(
norm_groups: int = 2,
modulation: ModConvType = ModConvType.NONE,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
Expand All @@ -585,7 +588,7 @@ def __init__(
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`. Ignored if `modulation` is ModConvType.None.
Default: None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulated convolutions. Ignored if `modulation` is ModConvType.None.
Default: None.
fc_groups : int, optional
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/vsharp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class VSharpNetConfig(ModelConfig):
initializer_activation: ActivationType = ActivationType.PRELU
conv_modulation: ModConvType = ModConvType.NONE
aux_in_features: int = 2
fc_hidden_features: Optional[int] = None
fc_hidden_features: Optional[tuple[int]] = None
fc_groups: int = 1
fc_activation: ModConvActivation = ModConvActivation.SIGMOID
num_weights: Optional[int] = None
Expand Down
8 changes: 4 additions & 4 deletions direct/nn/vsharp/vsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
activation: ActivationType = ActivationType.PRELU,
conv_modulation: ModConvType = ModConvType.NONE,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
Expand All @@ -60,7 +60,7 @@ def __init__(
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`. Ignored if `modulation` is ModConvType.None.
Default: None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulated convolutions. Ignored if `modulation` is ModConvType.None.
Default: None.
fc_activation : ModConvActivation
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(
auxiliary_steps: int = 0,
conv_modulation: ModConvType = ModConvType.NONE,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_hidden_features: Optional[tuple[int] | int] = None,
fc_groups: int = 1,
fc_activation: ModConvActivation = ModConvActivation.SIGMOID,
num_weights: Optional[int] = None,
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`. Ignored if `conv_modulation` is ModConvType.None.
Default: None.
fc_hidden_features : int, optional
fc_hidden_features : int or tuple of int, optional
Number of hidden features in the modulated convolutions. Ignored if `conv_modulation` is ModConvType.None.
Default: None.
fc_activation : ModConvActivation
Expand Down

0 comments on commit e4f4a57

Please sign in to comment.