Skip to content

Latest commit

 

History

History
77 lines (60 loc) · 2.17 KB

re-training.md

File metadata and controls

77 lines (60 loc) · 2.17 KB

Re-train the pre-trained base model on own data

In order to get the base model, you can either load the model through torch hub:

from fastai.vision.all import *
trainedModel = torch.hub.load("chfc-cmi", "cmr_seg_base")

or download the model file from the releases page and load it with fastai:

from fastai.vision.all import *
trainedModel = load_learner("resnet34_5percent_size256_extremeTfms_ceLoss_fastai2.pkl")

Next, you have to define your own dataloaders. This can be done through the DataBlock API of fastai. Here is an example, assuming, that you have the data in a structure like this:

data
├── images
│   ├── train
│   │   ├── img-1.png
│   │   └── …
│   └── val
│       ├── img-1.png
│       └── …
└── masks
    ├── train
    │   ├── img-1.png
    │   └── …
    └── val
        ├── img-1.png
        └── …

We can get appropriate data loaders with this code:

def label_func(x):
    return str(x).replace("image","mask")

def get_parent(x):
    return Path(x).parent.name == 'val'

dbl = DataBlock(blocks=(ImageBlock, MaskBlock(codes = np.array(["background","left_ventricle","myocardium"]))),
        get_items = get_image_files,
        get_y = label_func,
        splitter = FuncSplitter(get_parent),
        item_tfms=Resize(512, method='crop'),
        batch_tfms=aug_transforms(do_flip=True,max_rotate=90,max_lighting=.4,max_zoom=1.2,size=256))

dls = dbl.dataloaders("data/images", bs=16)

Now, you can replace the existing dataloaders with the new ones on the trainedModel:

trainedModel.dls = dls

Now, the trained model can be used to make predictions on the whole validation set (without re-training), and to re-train it:

# look at results, without re-training
trainedModel.show_results()
val_preds = trainedModel.get_preds()

# continue training
trainedModel.freeze()
trainedModel.lr_find()
trainedModel.fit_one_cycle(10, lr_max=1e-4)
traindeModel.unfreeze()
trainedModel.lr_find()
trainedModel.fit_one_cycle(10, lr_max=1e-5)