diff --git a/models/resnet/resnet.py b/models/resnet/resnet.py index 5c6c4fc..4c57bc3 100644 --- a/models/resnet/resnet.py +++ b/models/resnet/resnet.py @@ -1,4 +1,3 @@ -import torch import torch.nn as nn import torch.nn.functional as F @@ -18,8 +17,7 @@ class ResNet(nn.Module): """ def __init__(self, in_channels, mid_channels, out_channels, num_blocks, kernel_size, padding): super(ResNet, self).__init__() - self.in_norm = nn.BatchNorm2d(in_channels, affine=False) - self.in_conv = WNConv2d(2 * in_channels, mid_channels, kernel_size, padding, bias=True) + self.in_conv = WNConv2d(in_channels, mid_channels, kernel_size, padding, bias=True) self.in_skip = WNConv2d(mid_channels, mid_channels, kernel_size=1, padding=0, bias=True) self.blocks = nn.ModuleList([ResidualBlock(mid_channels, mid_channels) @@ -31,8 +29,6 @@ def __init__(self, in_channels, mid_channels, out_channels, num_blocks, kernel_s self.out_conv = WNConv2d(mid_channels, out_channels, kernel_size=1, padding=0, bias=True) def forward(self, x): - x = self.in_norm(x) - x = F.relu(torch.cat((x, -x), dim=1)) x = self.in_conv(x) x_skip = self.in_skip(x)