-
Notifications
You must be signed in to change notification settings - Fork 396
/
utils.py
34 lines (28 loc) · 835 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import time
import tensorflow as tf
from six.moves import range
from logging import getLogger
logger = getLogger(__name__)
def get_model_dir(config, exceptions=None):
keys = dir(config)
keys.sort()
keys.remove('env_name')
keys = ['env_name'] + keys
names = [config.env_name]
for key in keys:
# Only use useful flags
if key not in exceptions:
value = getattr(config, key)
names.append(
"%s=%s" % (key, ",".join([str(i) for i in value])
if type(value) == list else value))
return os.path.join('checkpoints', *names) + '/'
def timeit(f):
def timed(*args, **kwargs):
start_time = time.time()
result = f(*args, **kwargs)
end_time = time.time()
logger.info("%s : %2.2f sec" % (f.__name__, end_time - start_time))
return result
return timed