Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add post/pre experiment simulation hooks #8993

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ert/config/parsing/hook_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ class HookRuntime(StrEnum):
PRE_UPDATE = "PRE_UPDATE"
POST_UPDATE = "POST_UPDATE"
PRE_FIRST_UPDATE = "PRE_FIRST_UPDATE"
PRE_EXPERIMENT = "PRE_EXPERIMENT"
POST_EXPERIMENT = "POST_EXPERIMENT"
8 changes: 6 additions & 2 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .base_run_model import BaseRunModel, StatusEvents

if TYPE_CHECKING:
from ert.config import ErtConfig, QueueConfig
from ert.config import ErtConfig, HookRuntime, QueueConfig


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,23 +81,27 @@ def run_experiment(

self.set_env_key("_ERT_EXPERIMENT_ID", str(self.experiment.id))
self.set_env_key("_ERT_ENSEMBLE_ID", str(self.ensemble.id))
self.set_env_key("_ERT_ITERATION", "0")
self.set_env_key("_IS_FINAL_ITERATION", "False")

run_args = create_run_arguments(
self.run_paths,
np.array(self.active_realizations, dtype=bool),
ensemble=self.ensemble,
)

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)
sample_prior(
self.ensemble,
np.where(self.active_realizations)[0],
random_seed=self.random_seed,
)

self._evaluate_and_postprocess(
run_args,
self.ensemble,
evaluator_server_config,
)
self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, self.ensemble)

@classmethod
def name(cls) -> str:
Expand Down
6 changes: 5 additions & 1 deletion src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from ert.config import ErtConfig
from ert.config import ErtConfig, HookRuntime
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Storage
Expand Down Expand Up @@ -80,7 +80,10 @@
np.array(self.active_realizations, dtype=bool),
ensemble=prior,
)
self.set_env_key("_ERT_ITERATION", "0")
self.set_env_key("_IS_FINAL_ITERATION", "True")

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)

Check failure on line 86 in src/ert/run_models/ensemble_smoother.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

"EnsembleSmoother" has no attribute "ensemble"
sample_prior(
prior,
np.where(self.active_realizations)[0],
Expand All @@ -105,6 +108,7 @@
posterior,
evaluator_server_config,
)
self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior)

@classmethod
def name(cls) -> str:
Expand Down
14 changes: 14 additions & 0 deletions src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,18 @@ def run_experiment(
ensemble=prior,
)

self.set_env_key("_ERT_ITERATION", "0")
self.set_env_key(
"_IS_FINAL_ITERATION",
"False",
)
self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, prior)
sample_prior(
prior,
np.where(self.active_realizations)[0],
random_seed=self.random_seed,
)

self._evaluate_and_postprocess(
prior_args,
prior,
Expand All @@ -157,6 +164,11 @@ def run_experiment(

self.run_workflows(HookRuntime.PRE_FIRST_UPDATE, self._storage, prior)
for prior_iter in range(self._total_iterations):
self.set_env_key("_ERT_ITERATION", str(prior_iter + 1))
self.set_env_key(
"_IS_FINAL_ITERATION",
"True" if (prior_iter == self._total_iterations - 1) else "False",
)
self.send_event(
RunModelUpdateBeginEvent(iteration=prior_iter, run_id=prior.id)
)
Expand Down Expand Up @@ -219,6 +231,8 @@ def run_experiment(
)
prior = posterior

self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior)

@classmethod
def name(cls) -> str:
return "Iterated ensemble smoother"
Expand Down
14 changes: 13 additions & 1 deletion src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from ert.config import ErtConfig
from ert.config import ErtConfig, HookRuntime
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Storage
Expand Down Expand Up @@ -97,6 +97,14 @@
f"Experiment misconfigured, got starting iteration: {self.start_iteration},"
f"restart iteration = {prior.iteration + 1}"
)

self.set_env_key("_ERT_ITERATION", str(self.start_iteration))
self.set_env_key(
"_IS_FINAL_ITERATION",
"True"
if (self.start_iteration == self._total_iterations - 1)
else "False",
)
except (KeyError, ValueError) as err:
raise ErtRunError(
f"Prior ensemble with ID: {id} does not exists"
Expand Down Expand Up @@ -124,6 +132,8 @@
np.array(self.active_realizations, dtype=bool),
ensemble=prior,
)

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)

Check failure on line 136 in src/ert/run_models/multiple_data_assimilation.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

"MultipleDataAssimilation" has no attribute "ensemble"
sample_prior(
prior,
np.where(self.active_realizations)[0],
Expand Down Expand Up @@ -155,6 +165,8 @@
)
prior = posterior

self.run_workflows(HookRuntime.POST_EXPERIMENT, self._storage, prior)

@staticmethod
def parse_weights(weights: str) -> List[float]:
"""Parse weights string and scale weights such that their reciprocals sum
Expand Down
68 changes: 68 additions & 0 deletions tests/ert/ui_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,74 @@ def test_that_stop_on_fail_workflow_jobs_stop_ert(
run_cli(TEST_RUN_MODE, "--disable-monitor", "poly.ert")


@pytest.mark.usefixtures("copy_poly_case")
def test_that_post_experiment_hook_works(
monkeypatch,
):
monkeypatch.setattr(_ert.threading, "_can_raise", False)

# The executable
with open("dump_final_ensemble_id.sh", "w", encoding="utf-8") as f:
f.write(
dedent("""#!/bin/bash
echo $_IS_FINAL_ITERATION> final_ensemble_info.txt
""")
)
os.chmod("dump_final_ensemble_id.sh", 0o755)

# The workflow job
with open("DUMP_FINAL_ENSEMBLE_ID", "w", encoding="utf-8") as s:
s.write("""
INTERNAL False
EXECUTABLE dump_final_ensemble_info.sh
""")

# The workflow
with open("POST_EXPERIMENT_DUMP.WF", "w", encoding="utf-8") as s:
s.write("""dump_final_ensemble_id""")

# The executable
with open("dump_first_ensemble_id.sh", "w", encoding="utf-8") as f:
f.write(
dedent("""#!/bin/bash
echo $_ERT_ITERATION > first_ensemble_id.txt
""")
)
os.chmod("dump_first_ensemble_id.sh", 0o755)

# The workflow job
with open("DUMP_FIRST_ENSEMBLE_ID", "w", encoding="utf-8") as s:
s.write("""
INTERNAL False
EXECUTABLE dump_first_ensemble_id.sh
""")

# The workflow
with open("PRE_EXPERIMENT_DUMP.WF", "w", encoding="utf-8") as s:
s.write("""dump_first_ensemble_id""")

with open("poly.ert", mode="a", encoding="utf-8") as fh:
fh.write(
dedent(
"""
NUM_REALIZATIONS 2

LOAD_WORKFLOW_JOB DUMP_FINAL_ENSEMBLE_ID dump_final_ensemble_id
LOAD_WORKFLOW POST_EXPERIMENT_DUMP.WF POST_EXPERIMENT_DUMP
HOOK_WORKFLOW POST_EXPERIMENT_DUMP POST_EXPERIMENT

LOAD_WORKFLOW_JOB DUMP_FIRST_ENSEMBLE_ID dump_first_ensemble_id
LOAD_WORKFLOW PRE_EXPERIMENT_DUMP.WF PRE_EXPERIMENT_DUMP
HOOK_WORKFLOW PRE_EXPERIMENT_DUMP PRE_EXPERIMENT
"""
)
)

run_cli(ITERATIVE_ENSEMBLE_SMOOTHER_MODE, "--disable-monitor", "poly.ert")

# ...2do assert correct contents in files


@pytest.fixture(name="mock_cli_run")
def fixture_mock_cli_run(monkeypatch):
end_event = Mock()
Expand Down
Loading