-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnew_baseling.py
71 lines (53 loc) · 2.16 KB
/
new_baseling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
'''
Descripttion:
version:
Author: QIU Yaowen
Date: 2022-03-20 20:45:58
LastEditors: Andy
LastEditTime: 2022-03-21 10:03:31
'''
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import numpy as np
import config
from tensorflow.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import preprocess_input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint,LearningRateScheduler,CSVLogger
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.metrics import TopKCategoricalAccuracy
if __name__ == '__main__':
########Define Fixed Parameters###############
epochs = 5
val_batch_size = 64
input_shape = (224,224)
nums_of_classes = 1000
optimizer = Adam(lr=1e-6)
model = ResNet50()
model.compile(loss='categorical_crossentropy', metrics=['acc',TopKCategoricalAccuracy(k=5)],optimizer=optimizer)
########Define Model Callback#################
filepath="Best_model_{epoch:02d}_{acc:.4f}.hdf5"
checkpoint = ModelCheckpoint(filepath = 'model/'+filepath, monitor='acc',verbose=1,save_best_only=False)
def step_decay(epoch):
initial_lrate = 1e-6
drop = 0.95
epochs_drop = 1
lrate = initial_lrate * np.power(drop,
np.floor((1+epoch)/epochs_drop))
return lrate
lr_scheduler = LearningRateScheduler(step_decay)
csv_logger = CSVLogger('logs/baseline.log')
#######Training############
val_datagen = ImageDataGenerator(preprocessing_function = preprocess_input)
val_generator = val_datagen.flow_from_directory(
config.val,
target_size=input_shape,
batch_size=val_batch_size,
interpolation='bilinear')
history = model.fit(x = val_generator,
validation_data = val_generator,
epochs=epochs, verbose=1,max_queue_size=40,
workers=20,
callbacks=[checkpoint,lr_scheduler, csv_logger])
########Save the final Model#############
model.save('model/Baseline_Model.h5')