diff --git a/kotsu/run.py b/kotsu/run.py index 8aebddf..1908845 100644 --- a/kotsu/run.py +++ b/kotsu/run.py @@ -56,8 +56,6 @@ def run( results_df["runtime_secs"] = results_df["runtime_secs"].astype(int) results_df = results_df.set_index(["validation_id", "model_id"], drop=False) - results_list = [] - for validation_spec in validation_registry.all(): if validation_spec.deprecated: logger.info(f"Skipping validation: {validation_spec.id} - as is deprecated.") @@ -66,7 +64,6 @@ def run( if model_spec.deprecated: logger.info(f"Skipping model: {model_spec.id} - as is deprecated.") continue - if ( not force_rerun == "all" and not (isinstance(force_rerun, list) and model_spec.id in force_rerun) @@ -91,15 +88,30 @@ def run( model = model_spec.make() results, elapsed_secs = _run_validation_model(validation, model, run_params) results = _add_meta_data_to_results(results, elapsed_secs, validation_spec, model_spec) - results_list.append(results) - additional_results_df = pd.DataFrame.from_records(results_list) + results_df = _save_results(results, results_df, results_path) + + +def _save_results(results: dict, results_df: pd.DataFrame, results_path: str): + """Prepare model-validation combination results for saving in store.write. + + Args: + results: dictionaries containing the key result scores for the current + model/validation combination. + results_df: The main DataFrame containing previous results, if any, + and to which new results contents are transferred. + results_path: The path where results_df is saved. + """ + additional_results_df = pd.DataFrame.from_records([results]) results_df = results_df.append(additional_results_df, ignore_index=True) results_df = results_df.drop_duplicates(subset=["validation_id", "model_id"], keep="last") results_df = results_df.sort_values(by=["validation_id", "model_id"]).reset_index(drop=True) store.write( results_df, results_path, to_front_cols=["validation_id", "model_id", "runtime_secs"] ) + results_df = results_df.set_index(["validation_id", "model_id"], drop=False) + + return results_df def _form_validation_partial_with_store_dirs( diff --git a/tests/test_run.py b/tests/test_run.py index 6cae1ee..26b5656 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,4 +1,5 @@ import logging +import sys from unittest import mock import pandas as pd @@ -256,3 +257,88 @@ def test_raise_if_valiation_returns_privilidged_key_name(result, mocker, tmpdir) validation_registry, results_path=results_path, ) + + +def _add_meta_data_mock_fail(results, elapsed_secs, validation_spec, model_spec): + if model_spec.id == "model_2": + sys.exit() + else: + return _add_meta_data_mock(results, elapsed_secs, validation_spec, model_spec) + + +def _add_meta_data_mock(results, elapsed_secs, validation_spec, model_spec): + results_meta_data = { + "validation_id": validation_spec.id, + "model_id": model_spec.id, + "runtime_secs": elapsed_secs, + } + + return {**results, **results_meta_data} + + +def test_interruption_results_saving(mocker, tmpdir): + patched_run_validation_model = mocker.patch( + "kotsu.run._run_validation_model", + side_effect=[ + ({"test_result": "result_1"}, 10), + ({"test_result": "result_2"}, 20), + ({"test_result": "result_3"}, 30), + ], + ) + + models = ["model_1", "model_2"] + model_registry = FakeRegistry(models) + validations = ["validation_1"] + validation_registry = FakeRegistry(validations) + + results_path = str(tmpdir) + "validation_results.csv" + + # 1 - check that if code exits before saving model 2, model 1 is still saved + _ = mocker.patch("kotsu.run._add_meta_data_to_results", side_effect=_add_meta_data_mock_fail) + + with pytest.raises(SystemExit): + kotsu.run.run( + model_registry, + validation_registry, + results_path=results_path, + ) + out_df = pd.read_csv(results_path) + + results_df_failed = pd.DataFrame( + [ + { + "validation_id": "validation_1", + "model_id": "model_1", + "runtime_secs": 10, + "test_result": "result_1", + }, + ] + ) + assert patched_run_validation_model.call_count == 2 + pd.testing.assert_frame_equal(out_df, results_df_failed) + + # 2 - if running again (no fail) then check that + # (a) we have only called the validation once more (not twice) + # (b) all the results are saved + _ = mocker.patch("kotsu.run._add_meta_data_to_results", side_effect=_add_meta_data_mock) + kotsu.run.run(model_registry, validation_registry, results_path=results_path) + out_df = pd.read_csv(results_path) + + results_df_all = pd.DataFrame( + [ + { + "validation_id": "validation_1", + "model_id": "model_1", + "runtime_secs": 10, + "test_result": "result_1", + }, + { + "validation_id": "validation_1", + "model_id": "model_2", + "runtime_secs": 30, + "test_result": "result_3", + }, + ] + ) + assert patched_run_validation_model.call_count == 3 + pd.testing.assert_frame_equal(out_df, results_df_all)