-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmain.py
42 lines (34 loc) · 1.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import json
import importlib
import torch
from option import get_option
from solver import Solver
from tester import Tester
from utils import LogWritter
import glob
from tqdm import tqdm
def main():
opt = get_option()
torch.manual_seed(opt.seed)
module = importlib.import_module("model.{}".format(opt.model.lower()))
logger = LogWritter(opt)
if not opt.test_only:
msg = json.dumps(vars(opt), indent=4)
print(msg)
logger.update_txt(msg + '\n', mode='w')
if opt.test_only:
tester = Tester(module, opt)
ckpt = glob.glob('{}/{}'.format(opt.ckpt_root, opt.pretrain))
assert len(ckpt)!=0, "cannot find checkpoint {} in {}".format(opt.pretrain, opt.ckpt_root)
print("Evaluate {} (loaded from {}) on {} dataset".format(opt.model,
ckpt[0],
opt.test_dataset.split('_')[1]))
result = tester.evaluate(path=ckpt[0])
msg = "ckpt:{} MAE: {:.4f}\n".format(ckpt[0], result)
print(msg)
print('done testing')
else:
solver = Solver(module, opt)
solver.fit()
if __name__ == "__main__":
main()