Skip to content

Commit

Permalink
Remove cwn conv from vsharp and unet. Add mod conv
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 19, 2024
1 parent 174ed7d commit 3d74ee3
Show file tree
Hide file tree
Showing 9 changed files with 620 additions and 247 deletions.
326 changes: 326 additions & 0 deletions direct/nn/conv/modulated_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
# Copyright (c) DIRECT Contributors

"""direct.nn.conv.modulated_conv module"""

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from direct.types import DirectEnum, IntOrTuple

__all__ = ["ModConv2d", "ModConv2dBias", "ModConvTranspose2d"]


class ModConv2dBias(DirectEnum):
LEARNED = "learned"
PARAM = "param"
NONE = "none"


class ModConv2d(nn.Module):
"""Modulated Conv2d module.
If `modulation`=False and `bias`=ModConv2dBias.PARAM this is identical to nn.Conv2d:
.. math ::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k=0}^{C_{\text{in}}-1} \text{weight}(C_{\text{out}_j}, k) * \text{input}(N_i, k)
If `modulation`=True, this will compute:
.. math ::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k=0}^{C_{\text{in}}-1} \text{MLP}(y(N_i))(C_{\text{out}_j}, k) \text{weight}(C_{\text{out}_j}, k)
* \text{input}(N_i, k).
where :math`*` is a 2D cross-correlation.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: IntOrTuple,
modulation: bool = False,
stride: IntOrTuple = 1,
padding: IntOrTuple = 0,
dilation: IntOrTuple = 1,
bias: ModConv2dBias = ModConv2dBias.PARAM,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_bias: Optional[bool] = True,
):
"""Inits :class:`ModConv2d`.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple of int
Size of the convolutional kernel.
modulation : bool, optional
If True, apply modulation using an MLP on the auxiliary variable `y`, by default False.
stride : int or tuple of int, optional
Stride of the convolution, by default 1.
padding : int or tuple of int, optional
Padding added to all sides of the input, by default 0.
dilation : int or tuple of int, optional
Spacing between kernel elements, by default 1.
bias : ModConv2dBias, optional
Type of bias, by default ModConv2dBias.PARAM.
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`, by default None.
fc_hidden_features : int, optional
Number of hidden features in the modulation MLP, by default None.
fc_bias : bool, optional
If True, enable bias in the modulation MLP, by default True.
"""
super().__init__()

self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation

k = math.sqrt(1 / (in_channels * self.kernel_size[0] * self.kernel_size[1]))
self.weight = nn.Parameter(
torch.FloatTensor(out_channels, in_channels, self.kernel_size[0], self.kernel_size[1]).uniform_(-k, k)
)

self.in_channels = in_channels
self.out_channels = out_channels

self.modulation = modulation
self.aux_in_features = aux_in_features
self.fc_hidden_features = fc_hidden_features
self.fc_bias = fc_bias

if modulation:
if aux_in_features is None:
raise ValueError(f"Value for `aux_in_features` can't be None with `modulation`=True.")
if fc_hidden_features is None:
raise ValueError(f"Value for `fc_hidden_features` can't be None with `modulation`=True.")
self.fc = nn.Sequential(
nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias),
nn.PReLU(),
nn.Linear(fc_hidden_features, in_channels * out_channels, bias=fc_bias),
)

if bias == ModConv2dBias.PARAM:
self.bias = nn.Parameter(torch.FloatTensor(out_channels).uniform_(-k, k))
elif bias == ModConv2dBias.LEARNED:
if not modulation:
raise ValueError(
f"Bias can only be set to ModConv2dBias.LEARNED if `modulation`=True, "
f"but modulation is set to False."
)
self.bias = nn.Sequential(
nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias),
nn.PReLU(),
nn.Linear(fc_hidden_features, out_channels, bias=fc_bias),
)
else:
self.bias = None

def __repr__(self):
"""Representation of "class:`ModConv2d`."""
return (
f"ModConv2d(in_channels={self.in_channels}, out_channels={self.out_channels}, "
f"kernel_size={self.kernel_size}, modulation={self.modulation}, "
f"stride={self.stride}, padding={self.padding}, "
f"dilation={self.dilation}, bias={self.bias}, aux_in_features={self.aux_in_features}, "
f"fc_hidden_features={self.fc_hidden_features}, fc_bias={self.fc_bias})"
)

def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward pass of :class:`ModConv2d`.
Parameters
----------
x : torch.Tensor
Input tensor of shape (N, `in_channels`, H, W).
y : torch.Tensor, optional
Auxiliary variable of shape (N, `aux_in_features`) to be used if `modulation` is set to True. Default: None
Returns
-------
torch.Tensor
Output tensor of shape (N, `out_channels`, H_out, W_out).
"""
if not self.modulation:
out = F.conv2d(x, self.weight, stride=self.stride, padding=self.padding, dilation=self.dilation)
else:
fc_out = self.fc(y).view(x.shape[0], self.out_channels, self.in_channels, 1, 1)
out = torch.cat(
[
F.conv2d(
x[i : i + 1],
fc_out[i] * self.weight,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
)
for i in range(x.shape[0])
],
0,
)

if self.bias is not None:
if isinstance(self.bias, nn.parameter.Parameter):
bias = self.bias.view(1, -1, 1, 1)
else:
bias = self.bias(y).view(x.shape[0], -1, 1, 1)
out = out + bias

return out


class ModConvTranspose2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: IntOrTuple,
modulation: bool = False,
stride: IntOrTuple = 1,
padding: IntOrTuple = 0,
dilation: IntOrTuple = 1,
bias: ModConv2dBias = ModConv2dBias.PARAM,
aux_in_features: Optional[int] = None,
fc_hidden_features: Optional[int] = None,
fc_bias: Optional[bool] = True,
):
"""Inits :class:`ModConvTranspose2d`.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int or tuple of int
Size of the convolutional kernel.
modulation : bool, optional
If True, apply modulation using an MLP on the auxiliary variable `y`, by default False.
stride : int or tuple of int, optional
Stride of the convolution, by default 1.
padding : int or tuple of int, optional
Padding added to all sides of the input, by default 0.
dilation : int or tuple of int, optional
Spacing between kernel elements, by default 1.
bias : ModConv2dBias, optional
Type of bias, by default ModConv2dBias.PARAM.
aux_in_features : int, optional
Number of features in the auxiliary input variable `y`, by default None.
fc_hidden_features : int, optional
Number of hidden features in the modulation MLP, by default None.
fc_bias : bool, optional
If True, enable bias in the modulation MLP, by default True.
"""
super().__init__()

self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation

k = math.sqrt(1 / (in_channels * self.kernel_size[0] * self.kernel_size[1]))
self.weight = nn.Parameter(
torch.FloatTensor(in_channels, out_channels, self.kernel_size[0], self.kernel_size[1]).uniform_(-k, k)
)

self.in_channels = in_channels
self.out_channels = out_channels

self.modulation = modulation
self.aux_in_features = aux_in_features
self.fc_hidden_features = fc_hidden_features
self.fc_bias = fc_bias

if modulation:
if aux_in_features is None:
raise ValueError(f"Value for `aux_in_features` can't be None with `modulation`=True.")
if fc_hidden_features is None:
raise ValueError(f"Value for `fc_hidden_features` can't be None with `modulation`=True.")
self.fc = nn.Sequential(
nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias),
nn.PReLU(),
nn.Linear(fc_hidden_features, in_channels * out_channels, bias=fc_bias),
)

if bias == ModConv2dBias.PARAM:
self.bias = nn.Parameter(torch.FloatTensor(out_channels).uniform_(-k, k))
elif bias == ModConv2dBias.LEARNED:
if not modulation:
raise ValueError(
f"Bias can only be set to ModConv2dBias.LEARNED if `modulation`=True, "
f"but modulation is set to False."
)
self.bias = nn.Sequential(
nn.Linear(aux_in_features, fc_hidden_features, bias=fc_bias),
nn.PReLU(),
nn.Linear(fc_hidden_features, out_channels, bias=fc_bias),
)
else:
self.bias = None

def __repr__(self):
"""Representation of "class:`ModConvTranspose2d`."""
return (
f"ModConvTranspose2d(in_channels={self.in_channels}, out_channels={self.out_channels}, "
f"kernel_size={self.kernel_size}, modulation={self.modulation}, "
f"stride={self.stride}, padding={self.padding}, "
f"dilation={self.dilation}, bias={self.bias}, aux_in_features={self.aux_in_features}, "
f"fc_hidden_features={self.fc_hidden_features}, fc_bias={self.fc_bias})"
)

def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward pass of :class:`ModConvTranspose2d`.
Parameters
----------
x : torch.Tensor
Input tensor of shape (N, `in_channels`, H, W).
y : torch.Tensor, optional
Auxiliary variable of shape (N, `aux_in_features`) to be used if `modulation` is set to True. Default: None
Returns
-------
torch.Tensor
Output tensor of shape (N, `out_channels`, H_out, W_out).
"""
if not self.modulation:
out = F.conv_transpose2d(x, self.weight, stride=self.stride, padding=self.padding, dilation=self.dilation)
else:
fc_out = self.fc(y).view(x.shape[0], self.in_channels, self.out_channels, 1, 1)
out = torch.cat(
[
F.conv_transpose2d(
x[i : i + 1],
fc_out[i] * self.weight,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
)
for i in range(x.shape[0])
],
0,
)

if self.bias is not None:
if isinstance(self.bias, nn.parameter.Parameter):
bias = self.bias.view(1, -1, 1, 1)
else:
bias = self.bias(y).view(x.shape[0], -1, 1, 1)
out = out + bias

return out
4 changes: 3 additions & 1 deletion direct/nn/get_nn_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def _get_model_config(
"num_filters": kwargs.get("unet_num_filters", 32),
"num_pool_layers": kwargs.get("unet_num_pool_layers", 4),
"dropout_probability": kwargs.get("unet_dropout", 0.0),
"cwn_conv": kwargs.get("cwn_conv", False),
"modulation": kwargs.get("modulation", False),
"aux_in_features": kwargs.get("aux_in_features", None),
"fc_hidden_features": kwargs.get("fc_hidden_features", None),
}
)
elif model_architecture_name == "resnet":
Expand Down
11 changes: 6 additions & 5 deletions direct/nn/unet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class UnetModel2dConfig(ModelConfig):
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
cwn_conv: bool = False
modulation: bool = False
aux_in_features: Optional[int] = None
fc_hidden_features: Optional[int] = None


class NormUnetModel2dConfig(ModelConfig):
Expand All @@ -22,7 +24,9 @@ class NormUnetModel2dConfig(ModelConfig):
num_pool_layers: int = 4
dropout_probability: float = 0.0
norm_groups: int = 2
cwn_conv: bool = False
modulation: bool = False
aux_in_features: Optional[int] = None
fc_hidden_features: Optional[int] = None


@dataclass
Expand All @@ -32,7 +36,6 @@ class UnetModel3dConfig(ModelConfig):
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
cwn_conv: bool = False


class NormUnetModel3dConfig(ModelConfig):
Expand All @@ -42,15 +45,13 @@ class NormUnetModel3dConfig(ModelConfig):
num_pool_layers: int = 4
dropout_probability: float = 0.0
norm_groups: int = 2
cwn_conv: bool = False


@dataclass
class Unet2dConfig(ModelConfig):
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
cwn_conv: bool = False
skip_connection: bool = False
normalized: bool = False
image_initialization: str = "zero_filled"
Loading

0 comments on commit 3d74ee3

Please sign in to comment.