diff --git a/src/ert/config/parsing/hook_runtime.py b/src/ert/config/parsing/hook_runtime.py index 9a5d854c66d..08d5621cd11 100644 --- a/src/ert/config/parsing/hook_runtime.py +++ b/src/ert/config/parsing/hook_runtime.py @@ -7,3 +7,5 @@ class HookRuntime(StrEnum): PRE_UPDATE = "PRE_UPDATE" POST_UPDATE = "POST_UPDATE" PRE_FIRST_UPDATE = "PRE_FIRST_UPDATE" + PRE_EXPERIMENT = "PRE_EXPERIMENT" + POST_EXPERIMENT = "POST_EXPERIMENT" diff --git a/src/ert/run_models/ensemble_experiment.py b/src/ert/run_models/ensemble_experiment.py index 25348477b8f..6c88894f209 100644 --- a/src/ert/run_models/ensemble_experiment.py +++ b/src/ert/run_models/ensemble_experiment.py @@ -14,7 +14,7 @@ from .base_run_model import BaseRunModel, StatusEvents if TYPE_CHECKING: - from ert.config import ErtConfig, QueueConfig + from ert.config import ErtConfig, HookRuntime, QueueConfig logger = logging.getLogger(__name__) @@ -81,23 +81,27 @@ def run_experiment( self.set_env_key("_ERT_EXPERIMENT_ID", str(self.experiment.id)) self.set_env_key("_ERT_ENSEMBLE_ID", str(self.ensemble.id)) + self.set_env_key("_ERT_ITERATION", "0") + self.set_env_key("_IS_FINAL_ITERATION", "False") run_args = create_run_arguments( self.run_paths, np.array(self.active_realizations, dtype=bool), ensemble=self.ensemble, ) + + self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble) sample_prior( self.ensemble, np.where(self.active_realizations)[0], random_seed=self.random_seed, ) - self._evaluate_and_postprocess( run_args, self.ensemble, evaluator_server_config, ) + self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, self.ensemble) @classmethod def name(cls) -> str: diff --git a/src/ert/run_models/ensemble_smoother.py b/src/ert/run_models/ensemble_smoother.py index 547efdfcd3e..a72a08e0c3b 100644 --- a/src/ert/run_models/ensemble_smoother.py +++ b/src/ert/run_models/ensemble_smoother.py @@ -6,7 +6,7 @@ import numpy as np -from ert.config import ErtConfig +from ert.config import ErtConfig, HookRuntime from ert.enkf_main import sample_prior from ert.ensemble_evaluator import EvaluatorServerConfig from ert.storage import Storage @@ -80,7 +80,10 @@ def run_experiment( np.array(self.active_realizations, dtype=bool), ensemble=prior, ) + self.set_env_key("_ERT_ITERATION", "0") + self.set_env_key("_IS_FINAL_ITERATION", "True") + self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble) sample_prior( prior, np.where(self.active_realizations)[0], @@ -105,6 +108,7 @@ def run_experiment( posterior, evaluator_server_config, ) + self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior) @classmethod def name(cls) -> str: diff --git a/src/ert/run_models/iterated_ensemble_smoother.py b/src/ert/run_models/iterated_ensemble_smoother.py index 2a674e062fd..a03f797cbb1 100644 --- a/src/ert/run_models/iterated_ensemble_smoother.py +++ b/src/ert/run_models/iterated_ensemble_smoother.py @@ -144,11 +144,18 @@ def run_experiment( ensemble=prior, ) + self.set_env_key("_ERT_ITERATION", "0") + self.set_env_key( + "_IS_FINAL_ITERATION", + "False", + ) + self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, prior) sample_prior( prior, np.where(self.active_realizations)[0], random_seed=self.random_seed, ) + self._evaluate_and_postprocess( prior_args, prior, @@ -157,6 +164,11 @@ def run_experiment( self.run_workflows(HookRuntime.PRE_FIRST_UPDATE, self._storage, prior) for prior_iter in range(self._total_iterations): + self.set_env_key("_ERT_ITERATION", str(prior_iter + 1)) + self.set_env_key( + "_IS_FINAL_ITERATION", + "True" if (prior_iter == self._total_iterations - 1) else "False", + ) self.send_event( RunModelUpdateBeginEvent(iteration=prior_iter, run_id=prior.id) ) @@ -219,6 +231,8 @@ def run_experiment( ) prior = posterior + self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior) + @classmethod def name(cls) -> str: return "Iterated ensemble smoother" diff --git a/src/ert/run_models/multiple_data_assimilation.py b/src/ert/run_models/multiple_data_assimilation.py index eb394b852b6..f476f547814 100644 --- a/src/ert/run_models/multiple_data_assimilation.py +++ b/src/ert/run_models/multiple_data_assimilation.py @@ -7,7 +7,7 @@ import numpy as np -from ert.config import ErtConfig +from ert.config import ErtConfig, HookRuntime from ert.enkf_main import sample_prior from ert.ensemble_evaluator import EvaluatorServerConfig from ert.storage import Ensemble, Storage @@ -97,6 +97,14 @@ def run_experiment( f"Experiment misconfigured, got starting iteration: {self.start_iteration}," f"restart iteration = {prior.iteration + 1}" ) + + self.set_env_key("_ERT_ITERATION", str(self.start_iteration)) + self.set_env_key( + "_IS_FINAL_ITERATION", + "True" + if (self.start_iteration == self._total_iterations - 1) + else "False", + ) except (KeyError, ValueError) as err: raise ErtRunError( f"Prior ensemble with ID: {id} does not exists" @@ -124,6 +132,8 @@ def run_experiment( np.array(self.active_realizations, dtype=bool), ensemble=prior, ) + + self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble) sample_prior( prior, np.where(self.active_realizations)[0], @@ -155,6 +165,8 @@ def run_experiment( ) prior = posterior + self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior) + @staticmethod def parse_weights(weights: str) -> List[float]: """Parse weights string and scale weights such that their reciprocals sum diff --git a/tests/ert/ui_tests/cli/test_cli.py b/tests/ert/ui_tests/cli/test_cli.py index ed45c984885..41c440a9ed1 100644 --- a/tests/ert/ui_tests/cli/test_cli.py +++ b/tests/ert/ui_tests/cli/test_cli.py @@ -537,6 +537,74 @@ def test_that_stop_on_fail_workflow_jobs_stop_ert( run_cli(TEST_RUN_MODE, "--disable-monitor", "poly.ert") +@pytest.mark.usefixtures("copy_poly_case") +def test_that_post_experiment_hook_works( + monkeypatch, +): + monkeypatch.setattr(_ert.threading, "_can_raise", False) + + # The executable + with open("dump_final_ensemble_id.sh", "w", encoding="utf-8") as f: + f.write( + dedent("""#!/bin/bash + echo $_IS_FINAL_ITERATION> final_ensemble_info.txt + """) + ) + os.chmod("dump_final_ensemble_id.sh", 0o755) + + # The workflow job + with open("DUMP_FINAL_ENSEMBLE_ID", "w", encoding="utf-8") as s: + s.write(""" + INTERNAL False + EXECUTABLE dump_final_ensemble_info.sh + """) + + # The workflow + with open("POST_EXPERIMENT_DUMP.WF", "w", encoding="utf-8") as s: + s.write("""dump_final_ensemble_id""") + + # The executable + with open("dump_first_ensemble_id.sh", "w", encoding="utf-8") as f: + f.write( + dedent("""#!/bin/bash + echo $_ERT_ITERATION > first_ensemble_id.txt + """) + ) + os.chmod("dump_first_ensemble_id.sh", 0o755) + + # The workflow job + with open("DUMP_FIRST_ENSEMBLE_ID", "w", encoding="utf-8") as s: + s.write(""" + INTERNAL False + EXECUTABLE dump_first_ensemble_id.sh + """) + + # The workflow + with open("PRE_EXPERIMENT_DUMP.WF", "w", encoding="utf-8") as s: + s.write("""dump_first_ensemble_id""") + + with open("poly.ert", mode="a", encoding="utf-8") as fh: + fh.write( + dedent( + """ + NUM_REALIZATIONS 2 + + LOAD_WORKFLOW_JOB DUMP_FINAL_ENSEMBLE_ID dump_final_ensemble_id + LOAD_WORKFLOW POST_EXPERIMENT_DUMP.WF POST_EXPERIMENT_DUMP + HOOK_WORKFLOW POST_EXPERIMENT_DUMP POST_EXPERIMENT + + LOAD_WORKFLOW_JOB DUMP_FIRST_ENSEMBLE_ID dump_first_ensemble_id + LOAD_WORKFLOW PRE_EXPERIMENT_DUMP.WF PRE_EXPERIMENT_DUMP + HOOK_WORKFLOW PRE_EXPERIMENT_DUMP PRE_EXPERIMENT + """ + ) + ) + + run_cli(ITERATIVE_ENSEMBLE_SMOOTHER_MODE, "--disable-monitor", "poly.ert") + + # ...2do assert correct contents in files + + @pytest.fixture(name="mock_cli_run") def fixture_mock_cli_run(monkeypatch): end_event = Mock()