diff --git a/Classification/cnns/mobilenet_v2_model.py b/Classification/cnns/mobilenet_v2_model.py index 09736ab..029d201 100644 --- a/Classification/cnns/mobilenet_v2_model.py +++ b/Classification/cnns/mobilenet_v2_model.py @@ -58,7 +58,11 @@ def _relu6(data, prefix): def mobilenet_unit(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, data_format="NCHW", if_act=True, use_bias=False, prefix=''): - conv = flow.layers.conv2d(inputs=data, filters=num_filter, kernel_size=kernel, strides=stride, padding=pad, data_format=data_format, dilation_rate=1, groups=num_group, activation=None, use_bias=use_bias, kernel_initializer=_get_initializer("weight"), bias_initializer=_get_initializer("bias"), kernel_regularizer=_get_regularizer("weight"), bias_regularizer=_get_regularizer("bias"), name=prefix) + conv = flow.layers.conv2d(inputs=data, filters=num_filter, kernel_size=kernel, strides=stride, + padding=pad, data_format=data_format, dilation_rate=1, groups=num_group, activation=None, + use_bias=use_bias, kernel_initializer=_get_initializer("weight"), + bias_initializer=_get_initializer("bias"), kernel_regularizer=_get_regularizer("weight"), + bias_regularizer=_get_regularizer("bias"), name=prefix) bn = _batch_norm(conv, axis=1, momentum=0.9, epsilon=1e-5, name='%s-BatchNorm'%prefix) if if_act: act = _relu6(bn, prefix) @@ -156,11 +160,9 @@ def __init__(self, data_wh, multiplier, **kargs): else: self.config_map=MNETV2_CONFIGS_MAP[(224, 224)] - def build_network(self, input_data, need_transpose, data_format, class_num=1000, prefix="", **configs): + def build_network(self, input_data, data_format, class_num=1000, prefix="", **configs): self.config_map.update(configs) - if need_transpose: - input_data = flow.transpose(input_data, name="transpose", perm=[0, 3, 1, 2]) first_c = int(round(self.config_map['firstconv_filter_num']*self.multiplier)) first_layer = mobilenet_unit( data=input_data, @@ -233,11 +235,13 @@ def build_network(self, input_data, need_transpose, data_format, class_num=1000, ) return fc - def __call__(self, input_data, need_transpose, class_num=1000, prefix = "", **configs): - sym = self.build_network(input_data, need_transpose, class_num=class_num, prefix=prefix, **configs) + def __call__(self, input_data, class_num=1000, prefix = "", **configs): + sym = self.build_network(input_data, class_num=class_num, prefix=prefix, **configs) return sym -def Mobilenet(input_data, trainable=True, need_transpose=False, training=True, data_format="NCHW", num_classes=1000, multiplier=1.0, prefix = ""): +def Mobilenet(input_data, trainable=True, training=True, channel_last=False, num_classes=1000, multiplier=1.0, prefix = ""): + assert channel_last==False, "Mobilenet does not support channel_last mode, set channel_last=False will be right!" + data_format="NHWC" if channel_last else "NCHW" mobilenetgen = MobileNetV2((224,224), multiplier=multiplier) - out = mobilenetgen(input_data, need_transpose, data_format=data_format, class_num=num_classes, prefix = "MobilenetV2") + out = mobilenetgen(input_data, data_format=data_format, class_num=num_classes, prefix = "MobilenetV2") return out