-
Notifications
You must be signed in to change notification settings - Fork 2
/
trainFromExisting.py
49 lines (37 loc) · 1.3 KB
/
trainFromExisting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from fastai.vision.all import *
from pathlib import Path
def label_func(x): return x.parent.name
arch = resnet18
epochs = 1
archname = str(arch)
name = archname[10:int(len(archname))-23]
def train():
path = r'.\data_all'
namesList = get_image_files(path)
print(f"Total Images:{len(namesList)}")
# # checks if each image has resolution of 800x200, if not, resizes it
# for item in namesList:
# img = Image.open(item)
# wid, hgt = img.size
# if wid != 800 or hgt != 200:
# img_res = img.resize((800, 200))
# img_res.save(item)
dls = ImageDataLoaders.from_path_func(path, namesList, label_func, bs=16) # batchsize setting
learn = cnn_learner(dls, arch, pretrained=True, metrics=accuracy)
print("Loaded")
return(learn)
if __name__ == '__main__':
learner = train()
learner.fit(epochs)
modelpath = Path(__file__).parent.resolve() / f"jobData/model_{name}-{str(epochs)}.pkl"
learner.export(modelpath)
learner.show_results()
interp = ClassificationInterpretation.from_learner(learner)
interp.plot_confusion_matrix(figsize=(6,6))
interp.plot_top_losses(20, figsize=(10,10))
try:
interp.print_classification_report()
except:
pass
#learner.lr_find()
#learner.fine_tune(2, 3e-3)