-
Notifications
You must be signed in to change notification settings - Fork 0
/
InceptionNet.py
121 lines (106 loc) · 4.59 KB
/
InceptionNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPool2D,Dropout,Flatten,Dense
from tensorflow.keras import Model
np.set_printoptions (threshold=np.inf)
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train),(x_test, y_test)= cifar10.load_data()
x_train,x_test = x_train/255.0, x_test/255.0
class ConvBNRelu(Model):
def __init__(self, ch , kernelsz=3, strides=1, padding='same'): # 这里是定义了默认参数,如果不给就默认,strides是滑动步长
super(ConvBNRelu,self).__init__()
self.model = tf.keras.models.Sequential([
Conv2D(ch, kernelsz, strides=strides, padding=padding),
BatchNormalization(),
Activation('relu')
])
def call(self,x):
x = self.model(x)
return x
class InceptionBlk(Model):
def __init__(self, ch, strides=1):
super(InceptionBlk, self).__init__()
self.ch = ch
self.strides = strides
self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1)
self.p4_1 = MaxPool2D(3, strides=1, padding='same')
self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides)
def call(self, x):
x1 = self.c1(x)
x2_1 = self.c2_1(x)
x2_2 = self.c2_2(x2_1)
x3_1 = self.c3_1(x)
x3_2 = self.c3_2(x3_1)
x4_1 = self.p4_1(x)
x4_2 = self.c4_2(x4_1)
# concat along axis=channel
x = tf.concat([x1,x2_2,x3_2,x4_2],axis=3)
return x
# ---------------------------------------
# 上面两个class只是定义了一个结构单元
# 还要用这些结构单元堆叠形成完整的网络,
# 注意这里缺少使用Flatten?
# ---------------------------------------
class Inception10(Model):
def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs):
super(Inception10, self).__init__(**kwargs)
self.in_channels = init_ch
self.out_channels = init_ch
self.num_blocks = num_blocks
self.init_ch = init_ch
self.c1 = ConvBNRelu(init_ch)
self.blocks = tf.keras.models.Sequential()
for block_id in range(num_blocks):
for layer_id in range(2): # 这里是指1个block中有两个结构单元
if layer_id == 0:
block = InceptionBlk(self.out_channels, strides=2)
else:
block = InceptionBlk(self.out_channels, strides=1)
self.blocks.add(block)
# enlarger out_channels per block
self.out_channels *=2
self.p1 = tf.keras.layers.GlobalAveragePooling2D()
self.f1 = Dense(num_classes, activation='softmax')
def call(self,x):
x = self.c1(x)
x = self.blocks(x)
x = self.p1(x)
y = self.f1(x)
return y
# ---------------------------------------------------------------
model = Inception10(num_blocks=2,num_classes=10)
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/Baseline.ckpt"
# if os.path.exists (checkpoint_save_path + '.index'):
# print('-------------load the model-----------')
# model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint (filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)
history = model.fit(x_train, y_train,batch_size=32,epochs=2,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
model.summary()
file = open('./Lenet.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
# plt.subplot(1,2,1)
# plt.plot(acc,label='Training Accuracy')
# plt.plot(val_acc,label='Validation Accuracy')
# plt.title('Training and Validation Accuracy')
# plt.legend()
# plt.subplot(1,2,2)
# plt.plot(loss,label='Training loss')
# plt.plot(val_loss,label='Validation Loss')
# plt.title('Training and Validation Loss')
# plt.legend()
# plt.show()