-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogger.py
132 lines (109 loc) · 3.75 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# A simple torch style logger
# (C) Wei YANG 2017
from __future__ import absolute_import
# import matplotlib.pyplot as plt
__all__ = ['Logger', 'LoggerMonitor', 'savefig']
#
# def savefig(fname, dpi=None):
# dpi = 150 if dpi == None else dpi
# plt.savefig(fname, dpi=dpi)
#
#
# def plot_overlap(logger, names=None):
# names = logger.names if names == None else names
# numbers = logger.numbers
# for _, name in enumerate(names):
# x = np.arange(len(numbers[name]))
# plt.plot(x, np.asarray(numbers[name]))
# return [logger.title + '(' + name + ')' for name in names]
class Logger(object):
'''Save training process to log file with simple plot function.'''
def __init__(self, fpath, title=None, resume=False):
self.file = None
self.resume = resume
self.title = '' if title == None else title
if fpath is not None:
if resume:
self.file = open(fpath, 'r+')
name = self.file.readline()
self.names = name.rstrip().split('\t')
self.numbers = {}
for _, name in enumerate(self.names):
self.numbers[name] = []
for numbers in self.file:
numbers = numbers.rstrip().split('\t')
for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i])
self.file.close()
self.file = open(fpath, 'a+')
else:
self.file = open(fpath, 'w+')
def set_names(self, names):
if self.resume:
pass
# initialize numbers as empty list
self.numbers = {}
self.names = names
for _, name in enumerate(self.names):
self.file.write(name)
self.file.write('\t')
self.numbers[name] = []
self.file.write('\n')
self.file.flush()
def append(self, numbers):
assert len(self.names) == len(numbers), 'Numbers do not match names'
for index, num in enumerate(numbers):
self.file.write("{0:.6f}".format(num))
self.file.write('\t')
self.numbers[self.names[index]].append(num)
self.file.write('\n')
self.file.flush()
# def plot(self, names=None):
# names = self.names if names == None else names
# numbers = self.numbers
# for _, name in enumerate(names):
# x = np.arange(len(numbers[name]))
# plt.plot(x, np.asarray(numbers[name]))
# plt.legend([self.title + '(' + name + ')' for name in names])
# plt.grid(True)
def close(self):
if self.file is not None:
self.file.close()
# class LoggerMonitor(object):
# '''Load and visualize multiple logs.'''
#
# def __init__(self, paths):
# '''paths is a distionary with {name:filepath} pair'''
# self.loggers = []
# for title, path in paths.items():
# logger = Logger(path, title=title, resume=True)
# self.loggers.append(logger)
#
# def plot(self, names=None):
# plt.figure()
# plt.subplot(121)
# legend_text = []
# for logger in self.loggers:
# legend_text += plot_overlap(logger, names)
# plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
# plt.grid(True)
if __name__ == '__main__':
# # Example
# logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
# length = 100
# t = np.arange(length)
# train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# for i in range(0, length):
# logger.append([train_loss[i], valid_loss[i], test_loss[i]])
# logger.plot()
# Example: logger monitor
paths = {'resadvnet20': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
'resadvnet32': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
'resadvnet44': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', }
field = ['Valid Acc.']
monitor = LoggerMonitor(paths)
monitor.plot(names=field)
savefig('test.eps')