-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
file_loading.py
61 lines (55 loc) · 2.41 KB
/
file_loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright 2019 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Helper functions for loading files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import pandas as pd
def filename(env_name, noops, dev_measure, dev_fun, baseline, beta,
value_discount, seed, path='', suffix=''):
"""Generate filename for the given set of parameters."""
noop_str = 'noops' if noops else 'nonoops'
seed_str = '_' + str(seed) if seed else ''
filename_template = ('{env_name}_{noop_str}_{dev_measure}_{dev_fun}' +
'_{baseline}_beta_{beta}_vd_{value_discount}' +
'{suffix}{seed_str}.csv')
full_path = os.path.join(path, filename_template.format(
env_name=env_name, noop_str=noop_str, dev_measure=dev_measure,
dev_fun=dev_fun, baseline=baseline, beta=beta,
value_discount=value_discount, suffix=suffix, seed_str=seed_str))
return full_path
def load_files(baseline, dev_measure, dev_fun, value_discount, beta, env_name,
noops, path, suffix, seed_list, final=True):
"""Load result files generated by run_experiment with the given parameters."""
def try_loading(f, final):
if os.path.isfile(f):
df = pd.read_csv(f, index_col=0)
if final:
last_episode = max(df['episode'])
return df[df.episode == last_episode]
else:
return df
else:
return pd.DataFrame()
dataframes = []
for seed in seed_list:
f = filename(baseline=baseline, dev_measure=dev_measure, dev_fun=dev_fun,
value_discount=value_discount, beta=beta, env_name=env_name,
noops=noops, path=path, suffix=suffix, seed=int(seed))
df_part = try_loading(f, final)
dataframes.append(df_part)
df = pd.concat(dataframes)
return df