-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet.py
144 lines (123 loc) · 5.04 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn as nn
from positional_norm import PositionalNorm
class ResidualBlock(nn.Module):
"""Residual block as described in https://arxiv.org/abs/1512.03385"""
def __init__(self, in_chan, out_chan, stride=1):
"""Init a ResidualBlock.
Args:
in_chan: int
Number of channels of the input tensor.
out_chan: int
Number of channels of the output tensor.
stride: int, optional
Stride of the convolving kernel. The stride is applied only to
the first `3x3` convolutional layer, as well as to the `1x1`
convolutional layer. Default value is 1.
"""
super().__init__()
# The residual function is modelled using two `3x3` convolutional layers
# with the same number of filters. Each convolution is followed by
# normalization and nonlinearity.
self.res = nn.Sequential(
nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=stride, padding=1),
PositionalNorm(out_chan),
nn.ReLU(),
nn.Conv2d(out_chan, out_chan, kernel_size=3, padding="same"),
PositionalNorm(out_chan),
)
# If the residual block changes the number of channels, then the input
# is forwarded through a `1x1` convolutional layer to transform it into
# the desired shape for the addition operation.
if in_chan != out_chan or stride != 1:
self.id = nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride)
else:
self.id = lambda x: x
def forward(self, x):
return torch.relu(self.res(x) + self.id(x))
class ResNet(nn.Module):
"""ResNet-18 as described in https://arxiv.org/abs/1512.03385"""
def __init__(self, in_shape, config):
"""Init a ResNet model.
Args:
in_shape: tuple(int)
The shape of the input tensors. Images should be reshaped
channels first, i.e. input_shape = (C, H ,W).
config: dict
Dictionary with configuration parameters, containing:
stem_chan: int
Number of feature channels in the stem of the model.
modules: list(tuple(int, int))
Each tuple contains two values:
The first value is the number of residual blocks in the module.
The second value is the number of filters for each of the
residual blocks in the module. All blocks in the module have
the same number of filters.
out_classes: int
Number of output classes.
"""
super().__init__()
C, H, W = in_shape
stem_chan = config["stem_chan"]
modules = config["modules"]
out_classes = config["out_classes"]
# The body of the residual network consists of modules of residual blocks.
# In each module there is a specified number of residual blocks all
# having the same number of filters.
# At the end of every module of residual blocks we apply a max-pool layer.
body = []
in_chan = stem_chan
stride = 1
for num_blocks, out_chan in modules:
# The first residual block of the module is reducing the spatial
# dimensions in half and increasing the number of channels.
# Note that for the very first residual block we have `stride=1`,
# i.e., no down-sampling. The reason for this is because we already
# have a max-pool layer in the stem.
body.append(ResidualBlock(in_chan, out_chan, stride=stride)) # (H, W) => (H//2, W//2)
body.extend([ResidualBlock(out_chan, out_chan) for _ in range(num_blocks-1)])
in_chan = out_chan
stride = 2
body = nn.Sequential(*body)
self.net = nn.Sequential(
# Stem.
# nn.Conv2d(in_channels=C, out_channels=stem_chan, kernel_size=7, stride=2, padding=3),
nn.Conv2d(in_channels=C, out_channels=stem_chan, kernel_size=3, padding="same"),
PositionalNorm(stem_chan),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # (H, W) => (H//2, W//2)
# Body.
body,
# Head.
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(out_chan, out_classes),
)
def forward(self, x):
return self.net(x)
# Default ResNet models for CIFAR-10.
ResNet_CIFAR10 = ResNet(
in_shape=(3, 32, 32),
config= {
"stem_chan": 64,
"modules": [(3, 16), (3, 32), (3, 64)],
"out_classes": 10,
}
)
ResNet_18 = ResNet(
in_shape=(3, 32, 32),
config = {
"stem_chan": 64,
"modules": [(2, 64), (2, 128), (2, 256), (2, 512)],
"out_classes": 10,
},
)
ResNet_34 = ResNet(
in_shape=(3, 32, 32),
config = {
"stem_chan": 64,
"modules": [(3, 64), (4, 128), (6, 256), (3, 512)],
"out_classes": 10,
},
)
#