Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
tests: new checkpoint mxnet test + fix utils (#273)
Browse files Browse the repository at this point in the history
* tests: new mxnet test + fix utils

new test added:
- test_restore_checkpoint[tensorflow, mxnet]

fix failed tests in CI
improve utils

* tests: fix comments for mxnet checkpoint test and utils
  • Loading branch information
anabwan authored Apr 7, 2019
1 parent e1e335a commit 881f78f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
28 changes: 24 additions & 4 deletions rl_coach/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,44 @@ def test_get_checkpoint_state():


@pytest.mark.functional_test
def test_restore_checkpoint(preset_args, clres, start_time=time.time(),
@pytest.mark.parametrize("framework", ["mxnet", "tensorflow"])
def test_restore_checkpoint(preset_args, clres, framework,
start_time=time.time(),
timeout=Def.TimeOuts.test_time_limit):
""" Create checkpoint and restore them in second run."""
"""
Create checkpoints and restore them in second run.
:param preset_args: all preset that can be tested for argument tests
:param clres: logs and csv files
:param framework: name of the test framework
:param start_time: test started time
:param timeout: max time for test
"""

def _create_cmd_and_run(flag):

"""
Create default command with given flag and run it
:param flag: name of the tested flag, this flag will be extended to the
running command line
:return: active process
"""
run_cmd = [
'python3', 'rl_coach/coach.py',
'-p', '{}'.format(preset_args),
'-e', '{}'.format("ExpName_" + preset_args),
'--seed', '{}'.format(42),
'-f', '{}'.format(framework),
]

test_flag = a_utils.add_one_flag_value(flag=flag)
run_cmd.extend(test_flag)

print(str(run_cmd))
p = subprocess.Popen(run_cmd, stdout=clres.stdout, stderr=clres.stdout)

return p

if framework == "mxnet":
preset_args = Def.Presets.mxnet_args_test

p_valid_params = p_utils.validation_params(preset_args)
create_cp_proc = _create_cmd_and_run(flag=['--checkpoint_save_secs', '5'])

Expand Down
15 changes: 8 additions & 7 deletions rl_coach/tests/utils/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
results.append(last_step[-1])
time.sleep(1)

assert results[-1] >= Def.Consts.num_hs, \
Def.Consts.ASSERT_MSG.format("bigger than " + Def.Consts.num_hs,
results[-1])
assert int(results[-1]) >= Def.Consts.num_hs, \
Def.Consts.ASSERT_MSG.format("bigger than " +
str(Def.Consts.num_hs), results[-1])

elif flag[0] == "-f" or flag[0] == "--framework":
"""
Expand Down Expand Up @@ -445,7 +445,8 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
"""
lst_csv = []
# wait until files created
csv_path = get_csv_path(clres=clres, extra_tries=10)
csv_path = get_csv_path(clres=clres, extra_tries=20,
num_expected_files=int(flag[1]))

assert len(csv_path) > 0, \
Def.Consts.ASSERT_MSG.format("paths are not found", csv_path)
Expand Down Expand Up @@ -491,8 +492,8 @@ def validate_arg_result(flag, p_valid_params, clres=None, process=None,
# wait until files created
csv_path = get_csv_path(clres=clres, extra_tries=20)

expected_files = int(flag[1])
assert len(csv_path) >= expected_files, \
Def.Consts.ASSERT_MSG.format(str(expected_files),
num_expected_files = int(flag[1])
assert len(csv_path) >= num_expected_files, \
Def.Consts.ASSERT_MSG.format(str(num_expected_files),
str(len(csv_path)))

13 changes: 10 additions & 3 deletions rl_coach/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ def print_progress(averaged_rewards, last_num_episodes, start_time, time_limit,


def read_csv_paths(test_path, filename_pattern, read_csv_tries=120,
extra_tries=0):
extra_tries=0, num_expected_files=None):
"""
Return file path once it found
:param test_path: test folder path
:param filename_pattern: csv file pattern
:param read_csv_tries: number of iterations until file found
:param extra_tries: add number of extra tries to check after getting all
the paths.
:param num_expected_files: find all expected file in experiment folder.
:return: |string| return csv file path
"""
csv_paths = []
Expand All @@ -68,6 +69,10 @@ def read_csv_paths(test_path, filename_pattern, read_csv_tries=120,
csv_paths = glob.glob(path.join(test_path, '*', filename_pattern))
if tries_counter > read_csv_tries:
break

if num_expected_files and num_expected_files == len(csv_paths):
break

time.sleep(1)
tries_counter += 1

Expand Down Expand Up @@ -131,17 +136,19 @@ def find_string_in_logs(log_path, str, timeout=Def.TimeOuts.wait_for_files,


def get_csv_path(clres, tries_for_csv=Def.TimeOuts.wait_for_csv,
extra_tries=0):
extra_tries=0, num_expected_files=None):
"""
Get the csv path with the results - reading csv paths will take some time
:param clres: object of files that test is creating
:param tries_for_csv: timeout of tires until getting all csv files
:param extra_tries: add number of extra tries to check after getting all
the paths.
:param num_expected_files: find all expected file in experiment folder.
:return: |list| csv path
"""
return read_csv_paths(test_path=clres.exp_path,
filename_pattern=clres.fn_pattern,
read_csv_tries=tries_for_csv,
extra_tries=extra_tries)
extra_tries=extra_tries,
num_expected_files=num_expected_files)

0 comments on commit 881f78f

Please sign in to comment.