Skip to content

Commit eaef5d9

Browse files
committed
feat(dropout): 实现随机失活操作,修改ThreeNet和LeNet5实现随机失活操作
1 parent ded9602 commit eaef5d9

File tree

2 files changed

+87
-26
lines changed

2 files changed

+87
-26
lines changed

nn/functional.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# @Time : 19-6-7 下午2:46
4+
# @Author : zj
5+
6+
import numpy as np
7+
8+
9+
def dropout(shape, p):
10+
assert len(shape) == 2
11+
return (np.random.ranf(shape) < p) / p
12+
13+
14+
def dropout2d(shape, p):
15+
assert len(shape) == 4
16+
N, C, H, W = shape[:4]
17+
U = (np.random.rand(N, C) < p) / p
18+
res = np.ones(shape)
19+
for i in range(N):
20+
for j in range(C):
21+
res[i, j] *= U[i, j]
22+
23+
return res
24+
25+
26+
if __name__ == '__main__':
27+
res = dropout((3, 4), 0.5)
28+
# res = dropout2d((1, 4, 2, 2), 0.5)
29+
print(res)

nn/nets.py

+58-26
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,8 @@
33
# @Time : 19-5-27 下午1:34
44
# @Author : zj
55

6-
import numpy as np
7-
from abc import ABCMeta, abstractmethod
8-
from nn.im2row import *
9-
from nn.pool2row import *
10-
from nn.layer_utils import *
116
from nn.layers import *
7+
import nn.functional as F
128

139

1410
class Net(metaclass=ABCMeta):
@@ -82,14 +78,17 @@ class ThreeLayerNet(Net):
8278
实现3层神经网络
8379
"""
8480

85-
def __init__(self, num_in, num_h_one, num_h_two, num_out, momentum=0, nesterov=False, p_h=1.0, ):
81+
def __init__(self, num_in, num_h_one, num_h_two, num_out, momentum=0, nesterov=False, p_h=1.0):
8682
super(ThreeLayerNet, self).__init__()
8783
self.fc1 = FC(num_in, num_h_one, momentum=momentum, nesterov=nesterov)
8884
self.relu1 = ReLU()
8985
self.fc2 = FC(num_h_one, num_h_two, momentum=momentum, nesterov=nesterov)
9086
self.relu2 = ReLU()
9187
self.fc3 = FC(num_h_two, num_out, momentum=momentum, nesterov=nesterov)
88+
9289
self.p_h = p_h
90+
self.U1 = None
91+
self.U2 = None
9392

9493
def __call__(self, inputs):
9594
return self.forward(inputs)
@@ -98,21 +97,21 @@ def forward(self, inputs):
9897
# inputs.shape = [N, D_in]
9998
assert len(inputs.shape) == 2
10099
a1 = self.relu1(self.fc1(inputs))
101-
U1 = np.random.ranf(a1.shape) < self.p_h
102-
a1 *= U1
100+
self.U1 = F.dropout(a1.shape, self.p_h)
101+
a1 *= self.U1
103102

104103
a2 = self.relu2(self.fc2(a1))
105-
U2 = np.random.ranf(a2.shape) < self.p_h
106-
a2 *= U2
104+
self.U2 = F.dropout(a2.shape, self.p_h)
105+
a2 *= self.U2
107106

108107
z3 = self.fc3(a2)
109108

110109
return z3
111110

112111
def backward(self, grad_out):
113-
da2 = self.fc3.backward(grad_out)
112+
da2 = self.fc3.backward(grad_out) * self.U2
114113
dz2 = self.relu2.backward(da2)
115-
da1 = self.fc2.backward(dz2)
114+
da1 = self.fc2.backward(dz2) * self.U1
116115
dz1 = self.relu1.backward(da1)
117116
da0 = self.fc1.backward(dz1)
118117

@@ -125,11 +124,7 @@ def predict(self, inputs):
125124
# inputs.shape = [N, D_in]
126125
assert len(inputs.shape) == 2
127126
a1 = self.relu1(self.fc1(inputs))
128-
a1 *= self.p_h
129-
130127
a2 = self.relu2(self.fc2(a1))
131-
a2 *= self.p_h
132-
133128
z3 = self.fc3(a2)
134129

135130
return z3
@@ -142,68 +137,86 @@ def set_params(self, params):
142137
self.fc1.set_params(params['fc1'])
143138
self.fc2.set_params(params['fc2'])
144139
self.fc3.set_params(params['fc3'])
145-
self.p_h = params['p_h']
140+
self.p_h = params.get('p_h', 1.0)
146141

147142

148143
class LeNet5(Net):
149144
"""
150145
LeNet-5网络
151146
"""
152147

153-
def __init__(self, momentum=0):
148+
def __init__(self, momentum=0, nesterov=False, p_h=1.0):
154149
super(LeNet5, self).__init__()
155-
self.conv1 = Conv2d(1, 5, 5, 6, stride=1, padding=0, momentum=momentum)
156-
self.conv2 = Conv2d(6, 5, 5, 16, stride=1, padding=0, momentum=momentum)
157-
self.conv3 = Conv2d(16, 5, 5, 120, stride=1, padding=0, momentum=momentum)
150+
self.conv1 = Conv2d(1, 5, 5, 6, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
151+
self.conv2 = Conv2d(6, 5, 5, 16, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
152+
self.conv3 = Conv2d(16, 5, 5, 120, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
158153

159154
self.maxPool1 = MaxPool(2, 2, 6, stride=2)
160155
self.maxPool2 = MaxPool(2, 2, 16, stride=2)
161-
self.fc1 = FC(120, 84, momentum=momentum)
162-
self.fc2 = FC(84, 10, momentum=momentum)
156+
self.fc1 = FC(120, 84, momentum=momentum, nesterov=nesterov)
157+
self.fc2 = FC(84, 10, momentum=momentum, nesterov=nesterov)
163158

164159
self.relu1 = ReLU()
165160
self.relu2 = ReLU()
166161
self.relu3 = ReLU()
167162
self.relu4 = ReLU()
168163

164+
self.p_h = p_h
165+
self.U1 = None
166+
self.U2 = None
167+
self.U3 = None
168+
self.U4 = None
169+
169170
def __call__(self, inputs):
170171
return self.forward(inputs)
171172

172173
def forward(self, inputs):
173174
# inputs.shape = [N, C, H, W]
174175
assert len(inputs.shape) == 4
175176
x = self.relu1(self.conv1(inputs))
177+
self.U1 = F.dropout2d(x.shape, self.p_h)
178+
x *= self.U1
179+
176180
x = self.maxPool1(x)
177181
x = self.relu2(self.conv2(x))
182+
self.U2 = F.dropout2d(x.shape, self.p_h)
183+
x *= self.U2
184+
178185
x = self.maxPool2(x)
179186
x = self.relu3(self.conv3(x))
187+
self.U3 = F.dropout2d(x.shape, self.p_h)
188+
x *= self.U3
189+
180190
# (N, C, 1, 1) -> (N, C)
181191
x = x.reshape(x.shape[0], -1)
182192
x = self.relu4(self.fc1(x))
193+
self.U4 = F.dropout(x.shape, self.p_h)
194+
183195
x = self.fc2(x)
184196

185197
return x
186198

187199
def backward(self, grad_out):
188200
da6 = self.fc2.backward(grad_out)
201+
da6 *= self.U4
189202

190203
dz6 = self.relu4.backward(da6)
191204
da5 = self.fc1.backward(dz6)
192205
# [N, C] -> [N, C, 1, 1]
193206
N, C = da5.shape[:2]
194207
da5 = da5.reshape(N, C, 1, 1)
195-
208+
da5 *= self.U3
196209
dz5 = self.relu3.backward(da5)
197210
da4 = self.conv3.backward(dz5)
198211

199212
dz4 = self.maxPool2.backward(da4)
200-
213+
da4 *= self.U2
201214
dz3 = self.relu2.backward(dz4)
202215
da2 = self.conv2.backward(dz3)
203216

204217
da1 = self.maxPool1.backward(da2)
218+
da1 *= self.U1
205219
dz1 = self.relu1.backward(da1)
206-
207220
self.conv1.backward(dz1)
208221

209222
def update(self, lr=1e-3, reg=1e-3):
@@ -213,6 +226,21 @@ def update(self, lr=1e-3, reg=1e-3):
213226
self.conv2.update(learning_rate=lr, regularization_rate=reg)
214227
self.conv1.update(learning_rate=lr, regularization_rate=reg)
215228

229+
def predict(self, inputs):
230+
# inputs.shape = [N, C, H, W]
231+
assert len(inputs.shape) == 4
232+
x = self.relu1(self.conv1(inputs))
233+
x = self.maxPool1(x)
234+
x = self.relu2(self.conv2(x))
235+
x = self.maxPool2(x)
236+
x = self.relu3(self.conv3(x))
237+
# (N, C, 1, 1) -> (N, C)
238+
x = x.reshape(x.shape[0], -1)
239+
x = self.relu4(self.fc1(x))
240+
x = self.fc2(x)
241+
242+
return x
243+
216244
def get_params(self):
217245
out = dict()
218246
out['conv1'] = self.conv1.get_params()
@@ -222,6 +250,8 @@ def get_params(self):
222250
out['fc1'] = self.fc1.get_params()
223251
out['fc2'] = self.fc2.get_params()
224252

253+
out['p_h'] = self.p_h
254+
225255
return out
226256

227257
def set_params(self, params):
@@ -231,3 +261,5 @@ def set_params(self, params):
231261

232262
self.fc1.set_params(params['fc1'])
233263
self.fc2.set_params(params['fc2'])
264+
265+
self.p_h = params.get('p_h', 1.0)

0 commit comments

Comments
 (0)