forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshufflenet.py
executable file
·209 lines (172 loc) · 7.5 KB
/
shufflenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: shufflenet.py
import sys
import argparse
import numpy as np
import os
import cv2
import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import logger, QueueInput, InputDesc, PlaceholderInput, TowerContext
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.train import *
from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_utils import (
fbresnet_augmentor, get_imagenet_dataflow, ImageNetModel, GoogleNetResize)
TOTAL_BATCH_SIZE = 256
@layer_register(log_shape=True)
def DepthConv(x, out_channel, kernel_shape, padding='SAME', stride=1,
W_init=None, nl=tf.identity):
in_shape = x.get_shape().as_list()
in_channel = in_shape[1]
assert out_channel % in_channel == 0
channel_mult = out_channel // in_channel
if W_init is None:
W_init = tf.contrib.layers.variance_scaling_initializer()
kernel_shape = [kernel_shape, kernel_shape]
filter_shape = kernel_shape + [in_channel, channel_mult]
W = tf.get_variable('W', filter_shape, initializer=W_init)
conv = tf.nn.depthwise_conv2d(x, W, [1, 1, stride, stride], padding=padding, data_format='NCHW')
return nl(conv, name='output')
@under_name_scope()
def channel_shuffle(l, group):
in_shape = l.get_shape().as_list()
in_channel = in_shape[1]
l = tf.reshape(l, [-1, group, in_channel // group] + in_shape[-2:])
l = tf.transpose(l, [0, 2, 1, 3, 4])
l = tf.reshape(l, [-1, in_channel] + in_shape[-2:])
return l
def BN(x, name):
return BatchNorm('bn', x)
class Model(ImageNetModel):
weight_decay = 4e-5
def get_logits(self, image):
def shufflenet_unit(l, out_channel, group, stride):
in_shape = l.get_shape().as_list()
in_channel = in_shape[1]
shortcut = l
# We do not apply group convolution on the first pointwise layer
# because the number of input channels is relatively small.
first_split = group if in_channel != 16 else 1
l = Conv2D('conv1', l, out_channel // 4, 1, split=first_split, nl=BNReLU)
l = channel_shuffle(l, group)
l = DepthConv('dconv', l, out_channel // 4, 3, nl=BN, stride=stride)
l = Conv2D('conv2', l,
out_channel if stride == 1 else out_channel - in_channel,
1, split=group, nl=BN)
if stride == 1: # unit (b)
output = tf.nn.relu(shortcut + l)
else: # unit (c)
shortcut = AvgPooling('avgpool', shortcut, 3, 2, padding='SAME')
output = tf.concat([shortcut, tf.nn.relu(l)], axis=1)
return output
with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format), \
argscope(Conv2D, use_bias=False):
group = 8
channels = [224, 416, 832]
l = Conv2D('conv1', image, 16, 3, stride=2, nl=BNReLU)
l = MaxPooling('pool1', l, 3, 2, padding='SAME')
with tf.variable_scope('group1'):
for i in range(4):
with tf.variable_scope('block{}'.format(i)):
l = shufflenet_unit(l, channels[0], group, 2 if i == 0 else 1)
with tf.variable_scope('group2'):
for i in range(6):
with tf.variable_scope('block{}'.format(i)):
l = shufflenet_unit(l, channels[1], group, 2 if i == 0 else 1)
with tf.variable_scope('group3'):
for i in range(4):
with tf.variable_scope('block{}'.format(i)):
l = shufflenet_unit(l, channels[2], group, 2 if i == 0 else 1)
l = GlobalAvgPooling('gap', l)
logits = FullyConnected('linear', l, 1000)
return logits
def get_data(name, batch):
isTrain = name == 'train'
if isTrain:
augmentors = [
GoogleNetResize(crop_area_fraction=0.49),
imgaug.RandomOrderAug(
[imgaug.BrightnessScale((0.6, 1.4), clip=False),
imgaug.Contrast((0.6, 1.4), clip=False),
imgaug.Saturation(0.4, rgb=False),
# rgb-bgr conversion for the constants copied from fb.resnet.torch
imgaug.Lighting(0.1,
eigval=np.asarray(
[0.2175, 0.0188, 0.0045][::-1]) * 255.0,
eigvec=np.array(
[[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]],
dtype='float32')[::-1, ::-1]
)]),
imgaug.Flip(horiz=True),
]
else:
augmentors = [
imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
imgaug.CenterCrop((224, 224)),
]
return get_imagenet_dataflow(
args.data, name, batch, augmentors)
def get_config(model, nr_tower):
batch = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch)
callbacks = [
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(0, 3e-1), (30, 3e-2), (60, 3e-3), (90, 3e-4)]),
HumanHyperParamSetter('learning_rate'),
]
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1:
# single-GPU inference with queue prefetch
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
else:
# multi-GPU inference (with mandatory queue prefetch)
callbacks.append(DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))))
return TrainConfig(
model=model,
dataflow=dataset_train,
callbacks=callbacks,
steps_per_epoch=5000,
max_epoch=100,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--flops', action='store_true', help='print flops and exit')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
model = Model()
if args.flops:
# manually build the graph with batch=1
input_desc = [
InputDesc(tf.float32, [1, 224, 224, 3], 'input'),
InputDesc(tf.int32, [1], 'label')
]
input = PlaceholderInput()
input.setup(input_desc)
with TowerContext('', is_training=True):
model.build_graph(input)
tf.profiler.profile(
tf.get_default_graph(),
cmd='op',
options=tf.profiler.ProfileOptionBuilder.float_operation())
else:
logger.set_logger_dir(
os.path.join('train_log', 'shufflenet'))
nr_tower = max(get_nr_gpu(), 1)
config = get_config(model, nr_tower)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower))