This repository was archived by the owner on Dec 9, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 633
/
Copy pathmodel.py
339 lines (284 loc) · 12.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base model configuration for CNN benchmarks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import tensorflow as tf
import convnet_builder
import mlperf
# BuildNetworkResult encapsulate the result (e.g. logits) of a
# Model.build_network() call.
BuildNetworkResult = namedtuple(
'BuildNetworkResult',
[
'logits', # logits of the network
'extra_info', # Model specific extra information
])
class Model(object):
"""Base model config for DNN benchmarks."""
def __init__(self,
model_name,
batch_size,
learning_rate,
fp16_loss_scale,
params=None):
self.model_name = model_name
self.batch_size = batch_size
self.default_batch_size = batch_size
self.learning_rate = learning_rate
# TODO(reedwm) Set custom loss scales for each model instead of using the
# default of 128.
self.fp16_loss_scale = fp16_loss_scale
# use_tf_layers specifies whether to build the model using tf.layers.
# fp16_vars specifies whether to create the variables in float16.
if params:
self.use_tf_layers = params.use_tf_layers
self.fp16_vars = params.fp16_vars
self.data_type = tf.float16 if params.use_fp16 else tf.float32
else:
self.use_tf_layers = True
self.fp16_vars = False
self.data_type = tf.float32
def get_model_name(self):
return self.model_name
def get_batch_size(self):
return self.batch_size
def set_batch_size(self, batch_size):
self.batch_size = batch_size
def get_default_batch_size(self):
return self.default_batch_size
def get_fp16_loss_scale(self):
return self.fp16_loss_scale
def filter_l2_loss_vars(self, variables):
"""Filters out variables that the L2 loss should not be computed for.
By default, this filters out batch normalization variables and keeps all
other variables. This behavior can be overridden by subclasses.
Args:
variables: A list of the trainable variables.
Returns:
A list of variables that the L2 loss should be computed for.
"""
mlperf.logger.log(key=mlperf.tags.MODEL_EXCLUDE_BN_FROM_L2,
value=True)
return [v for v in variables if 'batchnorm' not in v.name]
def get_learning_rate(self, global_step, batch_size):
del global_step
del batch_size
return self.learning_rate
def get_input_shapes(self, subset):
"""Returns the list of expected shapes of all the inputs to this model."""
del subset
raise NotImplementedError('Must be implemented in derived classes')
def get_input_data_types(self, subset):
"""Returns the list of data types of all the inputs to this model."""
del subset
raise NotImplementedError('Must be implemented in derived classes')
def get_synthetic_inputs(self, input_name, nclass):
"""Returns the ops to generate synthetic inputs."""
raise NotImplementedError('Must be implemented in derived classes')
def build_network(self, inputs, phase_train, nclass):
"""Builds the forward pass of the model.
Args:
inputs: The list of inputs, including labels
phase_train: True during training. False during evaluation.
nclass: Number of classes that the inputs can belong to.
Returns:
A BuildNetworkResult which contains the logits and model-specific extra
information.
"""
raise NotImplementedError('Must be implemented in derived classes')
def loss_function(self, inputs, build_network_result):
"""Returns the op to measure the loss of the model.
Args:
inputs: the input list of the model.
build_network_result: a BuildNetworkResult returned by build_network().
Returns:
The loss tensor of the model.
"""
raise NotImplementedError('Must be implemented in derived classes')
# TODO(laigd): have accuracy_function() take build_network_result instead.
def accuracy_function(self, inputs, logits):
"""Returns the ops to measure the accuracy of the model."""
raise NotImplementedError('Must be implemented in derived classes')
def postprocess(self, results):
"""Postprocess results returned from model in Python."""
return results
def reached_target(self):
"""Define custom methods to stop training when model's target is reached."""
return False
class CNNModel(Model):
"""Base model configuration for CNN benchmarks."""
# TODO(laigd): reduce the number of parameters and read everything from
# params.
def __init__(self,
model,
image_size,
batch_size,
learning_rate,
layer_counts=None,
fp16_loss_scale=128,
params=None):
super(CNNModel, self).__init__(
model, batch_size, learning_rate, fp16_loss_scale,
params=params)
self.image_size = image_size
self.layer_counts = layer_counts
self.depth = 3
self.params = params
self.data_format = params.data_format if params else 'NCHW'
self.input_data_format = params.input_data_format
def get_layer_counts(self):
return self.layer_counts
def skip_final_affine_layer(self):
"""Returns if the caller of this class should skip the final affine layer.
Normally, this class adds a final affine layer to the model after calling
self.add_inference(), to generate the logits. If a subclass override this
method to return True, the caller should not add the final affine layer.
This is useful for tests.
"""
return False
def add_backbone_saver(self):
"""Creates a tf.train.Saver as self.backbone_saver for loading backbone.
A tf.train.Saver must be created and saved in self.backbone_saver before
calling load_backbone_model, with correct variable name mapping to load
variables from checkpoint correctly into the current model.
"""
raise NotImplementedError(self.getName() + ' does not have backbone model.')
def load_backbone_model(self, sess, backbone_model_path):
"""Loads variable values from a pre-trained backbone model.
This should be used at the beginning of the training process for transfer
learning models using checkpoints of base models.
Args:
sess: session to train the model.
backbone_model_path: path to backbone model checkpoint file.
"""
del sess, backbone_model_path
raise NotImplementedError(self.getName() + ' does not have backbone model.')
def add_inference(self, cnn):
"""Adds the core layers of the CNN's forward pass.
This should build the forward pass layers, except for the initial transpose
of the images and the final Dense layer producing the logits. The layers
should be build with the ConvNetBuilder `cnn`, so that when this function
returns, `cnn.top_layer` and `cnn.top_size` refer to the last layer and the
number of units of the layer layer, respectively.
Args:
cnn: A ConvNetBuilder to build the forward pass layers with.
"""
del cnn
raise NotImplementedError('Must be implemented in derived classes')
def get_input_data_types(self, subset):
"""Return data types of inputs for the specified subset."""
del subset # Same types for both 'train' and 'validation' subsets.
return [self.data_type, tf.int32]
def get_input_shapes(self, subset):
"""Return data shapes of inputs for the specified subset."""
del subset # Same shapes for both 'train' and 'validation' subsets.
# Each input is of shape [batch_size, height, width, depth]
# Each label is of shape [batch_size]
return [[self.batch_size, self.image_size, self.image_size, self.depth],
[self.batch_size]]
def get_synthetic_inputs(self, input_name, nclass):
# Synthetic input should be within [0, 255].
image_shape, label_shape = self.get_input_shapes('train')
inputs = tf.truncated_normal(
image_shape,
dtype=self.data_type,
mean=127,
stddev=60,
name=self.model_name + '_synthetic_inputs')
inputs = tf.contrib.framework.local_variable(inputs, name=input_name)
labels = tf.random_uniform(
label_shape,
minval=0,
maxval=nclass - 1,
dtype=tf.int32,
name=self.model_name + '_synthetic_labels')
return (inputs, labels)
def build_network(self,
inputs,
phase_train=True,
nclass=1001):
"""Returns logits from input images.
Args:
inputs: The input images and labels
phase_train: True during training. False during evaluation.
nclass: Number of classes that the images can belong to.
Returns:
A BuildNetworkResult which contains the logits and model-specific extra
information.
"""
images = inputs[0]
if self.data_format == 'NCHW' and self.input_data_format == 'NHWC':
images = tf.transpose(images, [0, 3, 1, 2])
elif self.data_format == 'NHWC' and self.input_data_format == 'NCHW':
images = tf.transpose(images, [0, 2, 3, 1])
else:
# No need to transpose since self.data_format == self.input_data_format
pass
var_type = tf.float32
if self.data_type == tf.float16 and self.fp16_vars:
var_type = tf.float16
network = convnet_builder.ConvNetBuilder(
images, self.depth, phase_train, self.use_tf_layers, self.data_format,
self.data_type, var_type)
with tf.variable_scope('cg', custom_getter=network.get_custom_getter()):
self.add_inference(network)
# Add the final fully-connected class layer
logits = (
network.affine(nclass, activation='linear')
if not self.skip_final_affine_layer() else network.top_layer)
mlperf.logger.log(key=mlperf.tags.MODEL_HP_FINAL_SHAPE,
value=logits.shape.as_list()[1:])
aux_logits = None
if network.aux_top_layer is not None:
with network.switch_to_aux_top_layer():
aux_logits = network.affine(nclass, activation='linear', stddev=0.001)
if self.data_type == tf.float16:
# TODO(reedwm): Determine if we should do this cast here.
logits = tf.cast(logits, tf.float32)
if aux_logits is not None:
aux_logits = tf.cast(aux_logits, tf.float32)
return BuildNetworkResult(
logits=logits, extra_info=None if aux_logits is None else aux_logits)
def loss_function(self, inputs, build_network_result):
"""Returns the op to measure the loss of the model."""
logits = build_network_result.logits
_, labels = inputs
# TODO(laigd): consider putting the aux logit in the Inception model,
# which could call super.loss_function twice, once with the normal logits
# and once with the aux logits.
aux_logits = build_network_result.extra_info
with tf.name_scope('xentropy'):
mlperf.logger.log(key=mlperf.tags.MODEL_HP_LOSS_FN, value=mlperf.tags.CCE)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
if aux_logits is not None:
with tf.name_scope('aux_xentropy'):
aux_cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=aux_logits, labels=labels)
aux_loss = 0.4 * tf.reduce_mean(aux_cross_entropy, name='aux_loss')
loss = tf.add_n([loss, aux_loss])
return loss
def accuracy_function(self, inputs, logits):
"""Returns the ops to measure the accuracy of the model."""
_, labels = inputs
top_1_op = tf.reduce_sum(
tf.cast(tf.nn.in_top_k(logits, labels, 1), self.data_type))
top_5_op = tf.reduce_sum(
tf.cast(tf.nn.in_top_k(logits, labels, 5), self.data_type))
return {'top_1_accuracy': top_1_op, 'top_5_accuracy': top_5_op}