forked from owruby/shake-shake_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
45 lines (33 loc) · 1.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding: utf-8 -*
import os
import math
import json
from datetime import datetime
def cosine_lr(opt, base_lr, e, epochs):
lr = 0.5 * base_lr * (math.cos(math.pi * e / epochs) + 1)
for param_group in opt.param_groups:
param_group["lr"] = lr
return lr
def accuracy(y, t):
pred = y.data.max(1, keepdim=True)[1]
acc = pred.eq(t.data.view_as(pred)).cpu().sum()
return acc
class Logger:
def __init__(self, log_dir, headers):
self.log_dir = log_dir
os.makedirs(log_dir, exist_ok=True)
self.f = open(os.path.join(log_dir, "log.txt"), "w")
header_str = "\t".join(headers + ["EndTime."])
self.print_str = "\t".join(["{}"] + ["{:.6f}"] * (len(headers) - 1) + ["{}"])
self.f.write(header_str + "\n")
self.f.flush()
print(header_str)
def write(self, *args):
now_time = datetime.now().strftime("%m/%d %H:%M:%S")
self.f.write(self.print_str.format(*args, now_time) + "\n")
self.f.flush()
print(self.print_str.format(*args, now_time))
def write_hp(self, hp):
json.dump(hp, open(os.path.join(self.log_dir, "hp.json"), "w"))
def close(self):
self.f.close()