-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtests.py
187 lines (159 loc) · 6.55 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import importlib
import os
import os.path as osp
import queue
import tempfile
import time
import unittest
from multiprocessing import Queue, Process
import numpy as np
import tensorflow as tf
import easy_tf_log
if tf.__version__ >= '2':
import tensorflow.compat.v1.train as tf_train
# Needed for creation of a TensorFlow 1 `summary` op (which behave
# differently from a TensorFlow 2 `summary` op), and a TensorFlow 1
# `FileWriter` (TensorFlow 2 does has `tf.summary.create_file_writer, but
# the object it returns seems to be slightly different - it doesn't have the
# `add_summary` method.)
import tensorflow.compat.v1.summary as tf1_summary
# FileWriter is not compatible with eager execution.
tf.compat.v1.disable_eager_execution()
else:
import tensorflow.train as tf_train
import tensorflow.summary as tf1_summary
class TestEasyTFLog(unittest.TestCase):
def setUp(self):
importlib.reload(easy_tf_log)
print(self._testMethodName)
def test_no_setup(self):
"""
Test that if tflog() is used without any extra setup, a directory
'logs' is created in the current directory containing the event file.
"""
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
easy_tf_log.tflog('var', 0)
self.assertEqual(os.listdir(), ['logs'])
self.assertIn('events.out.tfevents', os.listdir('logs')[0])
def test_set_dir(self):
"""
Confirm that set_dir works.
"""
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
easy_tf_log.set_dir('logs2')
easy_tf_log.tflog('var', 0)
self.assertEqual(os.listdir(), ['logs2'])
self.assertIn('events.out.tfevents', os.listdir('logs2')[0])
def test_set_writer(self):
"""
Check that when using an EventFileWriter from a FileWriter,
the resulting events file contains events from both the FileWriter
and easy_tf_log.
"""
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
writer = tf1_summary.FileWriter('logs')
var = tf.Variable(0.0)
summary_op = tf1_summary.scalar('tf_var', var)
if tf.__version__ >= '2':
sess = tf.compat.v1.Session()
else:
sess = tf.Session()
sess.run(var.initializer)
summary = sess.run(summary_op)
writer.add_summary(summary)
easy_tf_log.set_writer(writer.event_writer)
easy_tf_log.tflog('easy-tf-log_var', 0)
self.assertEqual(os.listdir(), ['logs'])
event_filename = osp.join('logs', os.listdir('logs')[0])
self.assertIn('events.out.tfevents', event_filename)
tags = set()
for event in tf_train.summary_iterator(event_filename):
for value in event.summary.value:
tags.add(value.tag)
self.assertIn('tf_var', tags)
self.assertIn('easy-tf-log_var', tags)
def test_full(self):
"""
Log a few values and check that the event file contain the expected
values.
"""
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
for i in range(10):
easy_tf_log.tflog('foo', i)
for i in range(10):
easy_tf_log.tflog('bar', i)
event_filename = osp.join('logs', os.listdir('logs')[0])
event_n = 0
for event in tf_train.summary_iterator(event_filename):
if event_n == 0: # metadata
event_n += 1
continue
if event_n <= 10:
self.assertEqual(event.step, event_n - 1)
self.assertEqual(event.summary.value[0].tag, "foo")
self.assertEqual(event.summary.value[0].simple_value,
float(event_n - 1))
if event_n > 10 and event_n <= 20:
self.assertEqual(event.step, event_n - 10 - 1)
self.assertEqual(event.summary.value[0].tag, "bar")
self.assertEqual(event.summary.value[0].simple_value,
float(event_n - 10 - 1))
event_n += 1
def test_explicit_step(self):
"""
Log a few values explicitly setting the step number.
"""
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
for i in range(5):
easy_tf_log.tflog('foo', i, step=(10 * i))
# These ones should continue from where the previous ones left off
for i in range(5):
easy_tf_log.tflog('foo', i)
event_filename = osp.join('logs', os.listdir('logs')[0])
event_n = 0
for event in tf_train.summary_iterator(event_filename):
if event_n == 0: # metadata
event_n += 1
continue
if event_n <= 5:
self.assertEqual(event.step, 10 * (event_n - 1))
if event_n > 5 and event_n <= 10:
self.assertEqual(event.step, 40 + (event_n - 5))
event_n += 1
def test_fork(self):
with tempfile.TemporaryDirectory() as temp_dir:
easy_tf_log.set_dir(temp_dir)
def f(queue):
easy_tf_log.tflog('foo', 0)
queue.put(True)
q = Queue()
Process(target=f, args=[q], daemon=True).start()
try:
q.get(timeout=1.0)
except queue.Empty:
self.fail("Process did not return")
def test_measure_rate(self):
with tempfile.TemporaryDirectory() as temp_dir:
logger = easy_tf_log.Logger(log_dir=temp_dir)
logger.measure_rate('foo', 0)
time.sleep(1)
logger.measure_rate('foo', 10)
time.sleep(1)
logger.measure_rate('foo', 25)
event_filename = list(os.scandir(temp_dir))[0].path
event_n = 0
rates = []
for event in tf_train.summary_iterator(event_filename):
if event_n == 0: # metadata
event_n += 1
continue
rates.append(event.summary.value[0].simple_value)
event_n += 1
np.testing.assert_array_almost_equal(rates, [10., 15.], decimal=1)
if __name__ == '__main__':
unittest.main()