-
Notifications
You must be signed in to change notification settings - Fork 3
/
resnet_from_scratch.py
355 lines (280 loc) · 10.4 KB
/
resnet_from_scratch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import os
import numpy as np
import tensorflow
from tensorflow.keras import Model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import Add, GlobalAveragePooling2D,\
Dense, Flatten, Conv2D, Lambda, Input, BatchNormalization, Activation
from tensorflow.keras.optimizers import schedules, SGD
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
def model_configuration():
"""
Get configuration variables for the model.
"""
# Load dataset for computing dataset size
(input_train, _), (_, _) = load_dataset()
# Generic config
width, height, channels = 32, 32, 3
batch_size = 128
num_classes = 10
validation_split = 0.1 # 45/5 per the He et al. paper
verbose = 1
n = 3
init_fm_dim = 16
shortcut_type = "identity" # or: projection
# Dataset size
train_size = (1 - validation_split) * len(input_train)
val_size = (validation_split) * len(input_train)
# Number of steps per epoch is dependent on batch size
maximum_number_iterations = 64000 # per the He et al. paper
steps_per_epoch = tensorflow.math.floor(train_size / batch_size)
val_steps_per_epoch = tensorflow.math.floor(val_size / batch_size)
epochs = tensorflow.cast(tensorflow.math.floor(maximum_number_iterations / steps_per_epoch),\
dtype=tensorflow.int64)
# Define loss function
loss = tensorflow.keras.losses.CategoricalCrossentropy(from_logits=True)
# Learning rate config per the He et al. paper
boundaries = [32000, 48000]
values = [0.1, 0.01, 0.001]
lr_schedule = schedules.PiecewiseConstantDecay(boundaries, values)
# Set layer init
initializer = tensorflow.keras.initializers.HeNormal()
# Define optimizer
optimizer_momentum = 0.9
optimizer_additional_metrics = ["accuracy"]
optimizer = SGD(learning_rate=lr_schedule, momentum=optimizer_momentum)
# Load Tensorboard callback
tensorboard = TensorBoard(
log_dir=os.path.join(os.getcwd(), "logs"),
histogram_freq=1,
write_images=True
)
# Save a model checkpoint after every epoch
checkpoint = ModelCheckpoint(
os.path.join(os.getcwd(), "model_checkpoint"),
save_freq="epoch"
)
# Add callbacks to list
callbacks = [
tensorboard,
checkpoint
]
# Create config dictionary
config = {
"width": width,
"height": height,
"dim": channels,
"batch_size": batch_size,
"num_classes": num_classes,
"validation_split": validation_split,
"verbose": verbose,
"stack_n": n,
"initial_num_feature_maps": init_fm_dim,
"training_ds_size": train_size,
"steps_per_epoch": steps_per_epoch,
"val_steps_per_epoch": val_steps_per_epoch,
"num_epochs": epochs,
"loss": loss,
"optim": optimizer,
"optim_learning_rate_schedule": lr_schedule,
"optim_momentum": optimizer_momentum,
"optim_additional_metrics": optimizer_additional_metrics,
"initializer": initializer,
"callbacks": callbacks,
"shortcut_type": shortcut_type
}
return config
def load_dataset():
"""
Load the CIFAR-10 dataset
"""
return cifar10.load_data()
def random_crop(img, random_crop_size):
# Note: image_data_format is 'channel_last'
# SOURCE: https://jkjung-avt.github.io/keras-image-cropping/
assert img.shape[2] == 3
height, width = img.shape[0], img.shape[1]
dy, dx = random_crop_size
x = np.random.randint(0, width - dx + 1)
y = np.random.randint(0, height - dy + 1)
return img[y:(y+dy), x:(x+dx), :]
def crop_generator(batches, crop_length):
"""Take as input a Keras ImageGen (Iterator) and generate random
crops from the image batches generated by the original iterator.
SOURCE: https://jkjung-avt.github.io/keras-image-cropping/
"""
while True:
batch_x, batch_y = next(batches)
batch_crops = np.zeros((batch_x.shape[0], crop_length, crop_length, 3))
for i in range(batch_x.shape[0]):
batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
yield (batch_crops, batch_y)
def preprocessed_dataset():
"""
Load and preprocess the CIFAR-10 dataset.
"""
(input_train, target_train), (input_test, target_test) = load_dataset()
# Retrieve shape from model configuration and unpack into components
config = model_configuration()
width, height, dim = config.get("width"), config.get("height"),\
config.get("dim")
num_classes = config.get("num_classes")
# Data augmentation: perform zero padding on datasets
paddings = tensorflow.constant([[0, 0,], [4, 4], [4, 4], [0, 0]])
input_train = tensorflow.pad(input_train, paddings, mode="CONSTANT")
# Convert scalar targets to categorical ones
target_train = tensorflow.keras.utils.to_categorical(target_train, num_classes)
target_test = tensorflow.keras.utils.to_categorical(target_test, num_classes)
# Data generator for training data
train_generator = tensorflow.keras.preprocessing.image.ImageDataGenerator(
validation_split = config.get("validation_split"),
horizontal_flip = True,
rescale = 1./255,
preprocessing_function = tensorflow.keras.applications.resnet50.preprocess_input
)
# Generate training and validation batches
train_batches = train_generator.flow(input_train, target_train, batch_size=config.get("batch_size"), subset="training")
validation_batches = train_generator.flow(input_train, target_train, batch_size=config.get("batch_size"), subset="validation")
train_batches = crop_generator(train_batches, config.get("height"))
validation_batches = crop_generator(validation_batches, config.get("height"))
# Data generator for testing data
test_generator = tensorflow.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function = tensorflow.keras.applications.resnet50.preprocess_input,
rescale = 1./255)
# Generate test batches
test_batches = test_generator.flow(input_test, target_test, batch_size=config.get("batch_size"))
return train_batches, validation_batches, test_batches
def residual_block(x, number_of_filters, match_filter_size=False):
"""
Residual block with
"""
# Retrieve initializer
config = model_configuration()
initializer = config.get("initializer")
# Create skip connection
x_skip = x
# Perform the original mapping
if match_filter_size:
x = Conv2D(number_of_filters, kernel_size=(3, 3), strides=(2,2),\
kernel_initializer=initializer, padding="same")(x_skip)
else:
x = Conv2D(number_of_filters, kernel_size=(3, 3), strides=(1,1),\
kernel_initializer=initializer, padding="same")(x_skip)
x = BatchNormalization(axis=3)(x)
x = Activation("relu")(x)
x = Conv2D(number_of_filters, kernel_size=(3, 3),\
kernel_initializer=initializer, padding="same")(x)
x = BatchNormalization(axis=3)(x)
# Perform matching of filter numbers if necessary
if match_filter_size and config.get("shortcut_type") == "identity":
x_skip = Lambda(lambda x: tensorflow.pad(x[:, ::2, ::2, :], tensorflow.constant([[0, 0,], [0, 0], [0, 0], [number_of_filters//4, number_of_filters//4]]), mode="CONSTANT"))(x_skip)
elif match_filter_size and config.get("shortcut_type") == "projection":
x_skip = Conv2D(number_of_filters, kernel_size=(1,1),\
kernel_initializer=initializer, strides=(2,2))(x_skip)
# Add the skip connection to the regular mapping
x = Add()([x, x_skip])
# Nonlinearly activate the result
x = Activation("relu")(x)
# Return the result
return x
def ResidualBlocks(x):
"""
Set up the residual blocks.
"""
# Retrieve values
config = model_configuration()
# Set initial filter size
filter_size = config.get("initial_num_feature_maps")
# Paper: "Then we use a stack of 6n layers (...)
# with 2n layers for each feature map size."
# 6n/2n = 3, so there are always 3 groups.
for layer_group in range(3):
# Each block in our code has 2 weighted layers,
# and each group has 2n such blocks,
# so 2n/2 = n blocks per group.
for block in range(config.get("stack_n")):
# Perform filter size increase at every
# first layer in the 2nd block onwards.
# Apply Conv block for projecting the skip
# connection.
if layer_group > 0 and block == 0:
filter_size *= 2
x = residual_block(x, filter_size, match_filter_size=True)
else:
x = residual_block(x, filter_size)
# Return final layer
return x
def model_base(shp):
"""
Base structure of the model, with residual blocks
attached.
"""
# Get number of classes from model configuration
config = model_configuration()
initializer = model_configuration().get("initializer")
# Define model structure
# logits are returned because Softmax is pushed to loss function.
inputs = Input(shape=shp)
x = Conv2D(config.get("initial_num_feature_maps"), kernel_size=(3,3),\
strides=(1,1), kernel_initializer=initializer, padding="same")(inputs)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = ResidualBlocks(x)
x = GlobalAveragePooling2D()(x)
x = Flatten()(x)
outputs = Dense(config.get("num_classes"), kernel_initializer=initializer)(x)
return inputs, outputs
def init_model():
"""
Initialize a compiled ResNet model.
"""
# Get shape from model configuration
config = model_configuration()
# Get model base
inputs, outputs = model_base((config.get("width"), config.get("height"),\
config.get("dim")))
# Initialize and compile model
model = Model(inputs, outputs, name=config.get("name"))
model.compile(loss=config.get("loss"),\
optimizer=config.get("optim"),\
metrics=config.get("optim_additional_metrics"))
# Print model summary
model.summary()
return model
def train_model(model, train_batches, validation_batches):
"""
Train an initialized model.
"""
# Get model configuration
config = model_configuration()
# Fit data to model
model.fit(train_batches,
batch_size=config.get("batch_size"),
epochs=config.get("num_epochs"),
verbose=config.get("verbose"),
callbacks=config.get("callbacks"),
steps_per_epoch=config.get("steps_per_epoch"),
validation_data=validation_batches,
validation_steps=config.get("val_steps_per_epoch"))
return model
def evaluate_model(model, test_batches):
"""
Evaluate a trained model.
"""
# Evaluate model
score = model.evaluate(test_batches, verbose=0)
print(f'Test loss: {score[0]} / Test accuracy: {score[1]}')
def training_process():
"""
Run the training process for the ResNet model.
"""
# Get dataset
train_batches, validation_batches, test_batches = preprocessed_dataset()
# Initialize ResNet
resnet = init_model()
# Train ResNet model
trained_resnet = train_model(resnet, train_batches, validation_batches)
# Evalute trained ResNet model post training
evaluate_model(trained_resnet, test_batches)
if __name__ == "__main__":
training_process()