-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcheckpoints.py
47 lines (37 loc) · 1.25 KB
/
checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# checkpoints.py
import os
import torch
class Checkpoints:
def __init__(self,args):
self.dir_save = args.save
self.dir_load = args.resume
# self.prevmodel = args.prevmodel
self.prevmodel = None
if os.path.isdir(self.dir_save) == False:
os.makedirs(self.dir_save)
def latest(self, name):
output = {}
if self.dir_load == None:
output['resume'] = None
else:
output['resume'] = self.dir_load
if (self.prevmodel != None):
output['prevmodel'] = self.prevmodel
else:
output['prevmodel'] = None
return output[name]
def save(self, epoch, model, best):
if best == True:
output = {}
num = len(model)
for key, value in model[0].items():
output[key] = value.state_dict()
torch.save(output, '%s/model_%d_epoch_%d.pth' %
(self.dir_save, num, epoch))
def load(self, filename):
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
model = torch.load(filename)
else:
print("=> no checkpoint found at '{}'".format(filename))
return model