-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdensenet.py
146 lines (126 loc) · 5.24 KB
/
densenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.nn as nn
from positional_norm import PositionalNorm
class DenseBlock(nn.Module):
"""DenseBlock as described in https://arxiv.org/abs/1608.06993
The dense block consists of multiple convolutional layers, each having the
same number of output channels. However, each layer takes as input the
original input concatenated with the output of all previous layers from the
block.
"""
def __init__(self, in_chan, out_chan, n_layers):
"""Init a DenseBlock.
Args:
in_chan: int
Number of channels of the input tensor.
out_chan: int
Number of output channels for each of the conv layers.
n_layers: int
Number of conv layers in the dense block.
"""
super().__init__()
# The dense block consists of a `1x1` convolutional layer followed by a
# `3x3` convolutional layers. Each convolutional layer is preceded by a
# normalization layer and a non-linear activation, utilizing the
# so-called "pre-activation" structure.
# See: https://arxiv.org/abs/1603.05027
# Note that the number of input channels grows with every layer in the
# dense block, while the output channels remains fixed. We use the `1x1`
# convolution in order to reduce the number of channels before passing
# through the expensive `3x3` layer.
# In the paper the authors state that the `1x1` convolution does not
# reduce the channels directly to `out_chan` but rather uses a bottleneck
# size of `4 x out_chan`.
layers = []
for _ in range(n_layers):
layers.append(nn.Sequential(
PositionalNorm(in_chan),
nn.ReLU(),
nn.Conv2d(in_chan, 4*out_chan, kernel_size=1),
PositionalNorm(4*out_chan),
nn.ReLU(),
nn.Conv2d(4*out_chan, out_chan, kernel_size=3, padding="same"),
))
in_chan += out_chan
self.layers = nn.Sequential(*layers)
def forward(self, x):
out = x
for l in self.layers:
y = l(out)
out = torch.cat((out, y), dim=1)
return out
class DenseNet(nn.Module):
"""DenseNet as described in https://arxiv.org/abs/1608.06993"""
def __init__(self, in_shape, config):
"""Init a ResNet model.
Args:
in_shape: tuple(int)
The shape of the input tensors. Images should be reshaped
channels first, i.e. input_shape = (C, H ,W).
config: dict
Dictionary with configuration parameters, containing:
stem_chan: int
Number of feature channels in the stem of the model.
modules: list(tuple(int, int))
Each tuple contains two values:
The first value is the number of convolutional layers in the
dense block.
The second value is the number of filters for each of the
convolutional layers.
out_classes: int
Number of output classes.
"""
super().__init__()
C, H, W = in_shape
stem_chan = config["stem_chan"]
modules = config["modules"]
out_classes = config["out_classes"]
# The body of the DenseNet consists of stacking multiple Dense blocks.
# Since each dense block will increase the number of channels, a so-called
# transition layer is added between dense blocks to control the
# complexity. The transition layer reduces the number of channels by
# using a `1x1` convolution. It also halves the spatial dimensions using
# average pooling with stride 2.
body = []
in_chan = stem_chan
for i, (n_layers, out_chan) in enumerate(modules):
# Dense block.
body.append(DenseBlock(in_chan, out_chan, n_layers))
chan = in_chan + n_layers * out_chan
# Transition layer.
if i < len(modules)-1:
body.append(nn.Sequential(
PositionalNorm(chan),
nn.ReLU(),
nn.Conv2d(chan, chan//2, kernel_size=1), # C => C // 2
nn.AvgPool2d(kernel_size=2, stride=2) # (H, W) => (H//2, W//2)
))
in_chan = chan//2
body.append(PositionalNorm(chan))
body.append(nn.ReLU())
body = nn.Sequential(*body)
self.net = nn.Sequential(
# Stem.
nn.Conv2d(in_channels=C, out_channels=stem_chan, kernel_size=3, padding="same"),
PositionalNorm(stem_chan),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # (H, W) => (H//2, W//2)
# Body.
body,
# Head.
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(chan, out_classes),
)
def forward(self, x):
return self.net(x)
# Default DenseNet model for CIFAR-10.
DenseNet_CIFAR10 = DenseNet(
in_shape=(3, 32, 32),
config={
"stem_chan": 64,
"modules": [(4, 16), (4, 16), (4, 16), (4, 16)],
"out_classes": 10,
},
)
#