-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvgg_pretrained.py
90 lines (76 loc) · 2.27 KB
/
vgg_pretrained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam
from pathlib import Path
import matplotlib.pyplot as plt
# Plant selection and dataset paths
plant = 'Apple'
dataset_train_path = Path('data', 'images', plant)
dataset_valid_path = Path('data', 'valid', plant)
# Load dataset
traindata = ImageDataGenerator().flow_from_directory(
directory=dataset_train_path, target_size=(200, 200))
testdata = ImageDataGenerator().flow_from_directory(
directory=dataset_valid_path, target_size=(200, 200))
# Model definition (pretrained VGG16)
vgg = VGG16(
input_shape=[200, 200] + [3],
weights='imagenet',
include_top=False
)
for layer in vgg.layers:
layer.trainable = False
x = Flatten()(vgg.output)
outputs = Dense(8, activation='softmax')(x)
model = Model(inputs=vgg.input, outputs=outputs)
opt = Adam(learning_rate=1e-4)
model.compile(
optimizer=opt,
loss=tf.keras.losses.categorical_crossentropy,
metrics=['accuracy']
)
model.summary()
# Train
checkpoint = ModelCheckpoint(
f'{plant}_pretrained_vgg16.h5',
monitor='val_accuracy',
verbose=1,
save_best_only=True,
mode='max', save_freq=1
)
early_stopping = EarlyStopping(monitor='val_accuracy',
patience=20,
verbose=1,
mode='max'
)
callbacks = [checkpoint, early_stopping]
hist = model.fit(
traindata,
validation_data=testdata,
batch_size=2,
steps_per_epoch=10,
epochs=100,
verbose=1,
callbacks=callbacks
)
# Plot
plt.plot(hist.history['accuracy'])
plt.plot(hist.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.savefig('accuracy.png')
plt.show()
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.savefig('loss.png')
plt.show()