-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetuning.py
67 lines (52 loc) · 2.24 KB
/
finetuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from utils.modelFactory import createFinetuningModel
from utils.data import get_data_pretraining, load_config, get_data_finetuning, save_config
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
import os
def main():
# Load config file
config = load_config('config.yaml')
# create model
#model = createFinetuningModel(config)
finetuned = config['finetuning']
assert len(finetuned['trainsplit']) == len(finetuned['epochs']) == len(finetuned['name'])
for train_split, epochs, name in zip(finetuned['trainsplit'], finetuned['epochs'], finetuned['name']):
config['finetuning']['trainsplit'] = train_split
config['finetuning']['epochs'] = epochs
config['finetuning']['name'] = name
# create model
model = createFinetuningModel(config)
train, val, test = get_data_finetuning(config)
save_path = os.path.join(config['savedmodel']['path'], config['finetuning']['name'])
if not os.path.exists(save_path):
os.makedirs(save_path)
save_config(config, save_path)
checkpoint_callback = ModelCheckpoint(
dirpath=save_path,
save_top_k=5, # Set the number of models to save
mode='min', # 'min' or 'max' depending on the metric being tracked
monitor='val_loss',)
batch_size = config['finetuning']['batch_size']
dataloader_training = torch.utils.data.DataLoader(
train,
batch_size,
shuffle = True,
num_workers=8)
dataloader_val= torch.utils.data.DataLoader(
val,
batch_size,
shuffle = False,
num_workers=8)
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
trainer = pl.Trainer(
max_epochs = config['finetuning']['epochs'],
devices='auto',
accelerator=accelerator,
callbacks=[checkpoint_callback],
log_every_n_steps=15,
)
trainer.fit(model= model, train_dataloaders=dataloader_training, val_dataloaders=dataloader_val)
if __name__ == '__main__':
main()