-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_ACL.py
74 lines (54 loc) · 2.54 KB
/
main_ACL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Created on December 11, 2021.
main_ACL.py
@author: Soroosh Tayebi Arasteh <[email protected]>
https://github.com/tayebiarasteh/
"""
import pdb
from torch.nn import CrossEntropyLoss
import torch
import os
from models.ACL_model import ACL_net
from config.serde import open_experiment, create_experiment, delete_experiment
from Train_Valid_ACL import Training
import warnings
warnings.filterwarnings('ignore')
def main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml", valid=False,
resume=False, augment=False, experiment_name='name'):
"""Main function for training + validation for directly 3d-wise
Parameters
----------
global_config_path: str
always global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml"
valid: bool
if we want to do validation
resume: bool
if we are resuming training on a model
experiment_name: str
name of the experiment, in case of resuming training.
name of new experiment, in case of new training.
"""
if resume == True:
params = open_experiment(experiment_name, global_config_path)
else:
params = create_experiment(experiment_name, global_config_path)
cfg_path = params["cfg_path"]
# Changeable network parameters
model = ACL_net()
loss_function = CrossEntropyLoss
optimizer = torch.optim.Adam(model.parameters(), lr=float(params['Network']['lr']),
weight_decay=float(params['Network']['weight_decay']), amsgrad=params['Network']['amsgrad'])
trainer = Training(cfg_path, num_iterations=params['num_iterations'], resume=resume)
if resume == True:
trainer.load_checkpoint(model=model, optimiser=optimizer, loss_function=loss_function)
else:
trainer.setup_model(model=model, optimiser=optimizer,
loss_function=loss_function)
# loading the data
train_loader = torch.ones((1, 2, 110, 281, 285))
valid_loader = torch.ones((1, 2, 110, 281, 285))
trainer.execute_training(train_loader=train_loader, valid_loader=valid_loader, augmentation=augment)
if __name__ == '__main__':
delete_experiment(experiment_name='testtest', global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml")
main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml",
valid=False, resume=False, augment=False, experiment_name='testtest')