-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_classification_model.py
89 lines (63 loc) · 2.97 KB
/
train_classification_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from fastai.basics import *
from fastai.callback.all import *
from fastai.vision.all import *
from fastai.data.all import *
from fastai.vision.all import *
from fastai.callback.core import *
from functools import partial
import os
import pandas as pd
from time import time
import torch
def label_func(fname):
return '1' if '_non_mi_abnormality_' in str(fname) else '2' if '_mi_' in str(fname) else '0'
session_id = int(time())
class CustomCallback(Callback):
def __init__(self):
super().__init__()
self.epoch_idx = 0
def after_epoch(self):
self.learn.save(f'model_session={session_id}_epoch={self.epoch_idx}')
self.epoch_idx += 1
recall_0 = RecallMulti(labels=[0])
precision_0 = PrecisionMulti(labels=[0])
f1_score_0 = F1ScoreMulti(labels=[0])
recall_1 = RecallMulti(labels=[1])
precision_1 = PrecisionMulti(labels=[1])
f1_score_1 = F1ScoreMulti(labels=[1])
recall_2 = RecallMulti(labels=[2])
precision_2 = PrecisionMulti(labels=[2])
f1_score_2 = F1ScoreMulti(labels=[2])
recall_1_2 = RecallMulti(labels=[1, 2])
precision_1_2 = PrecisionMulti(labels=[1, 2])
f1_score_1_2 = F1ScoreMulti(labels=[1, 2])
recall_total = RecallMulti(labels=[0, 1, 2])
precision_total = PrecisionMulti(labels=[0, 1, 2])
f1_score_total = F1ScoreMulti(labels=[0, 1, 2])
path_train = 'datasets/ptb_v_classification_merged_aug/training'
# removed the non-MI abnormality!
datablock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock(vocab=['0', '1', '2'])),
get_items=get_image_files,
get_y=label_func,
splitter=FuncSplitter(lambda path: 'validation' in str(path)))
batch_size = 12
dataloaders = datablock.dataloaders(path_train, bs=batch_size, verbose=True)
base_models = {'resnet152': models.resnet152}
base_model = 'resnet152'
loss_pre_eval = 'FocalLoss(weight=torch.tensor([1.0, 1.0, 1.5])' + ('.cuda()' if torch.cuda.is_available() else '') + ')'
learn = vision_learner(dataloaders, base_models[base_model], metrics=[recall_0, precision_0, f1_score_0,
recall_1, precision_1, f1_score_1,
recall_2, precision_2,
f1_score_2
],
cbs=[CSVLogger(fname=f'results_{session_id}.csv'), CustomCallback()],
loss_func=eval(loss_pre_eval))
lr = 0.0003 # based on learn.lr_find()
num_epochs = 15
info_str = '\n'.join([f'session: {session_id}', f'base_model: {base_model}', f'lr: {lr}', f'batch_size: {batch_size}',
f'num_epochs: {num_epochs}', f'loss_fn: {loss_pre_eval}'])
print(info_str)
with open(f'session_{session_id}_details.txt', 'w') as f:
f.write(info_str)
learn.fine_tune(num_epochs, base_lr=lr)
learn.export('model_152_final.pkl')