Skip to content

Commit

Permalink
Move callback ok wrapper to callback.py. Change signature of callback…
Browse files Browse the repository at this point in the history
…_done to match wrapper
  • Loading branch information
JHolba committed Aug 28, 2023
1 parent 1da3ab2 commit ab4e910
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 64 deletions.
41 changes: 36 additions & 5 deletions src/ert/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@
import logging
import time
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Iterable, Mapping, Optional, Tuple

from ert.config import ParameterConfig, ResponseConfig, SummaryConfig
from ert.config import EnsembleConfig, ParameterConfig, ResponseConfig, SummaryConfig
from ert.run_arg import RunArg

from .load_status import LoadResult, LoadStatus
from .realization_state import RealizationState

if TYPE_CHECKING:
from ert.storage import EnsembleAccessor

CallbackArgs = Tuple[RunArg, Mapping[str, ResponseConfig]]
Callback = Callable[[RunArg, Mapping[str, ResponseConfig]], LoadResult]
CallbackDone = Callable[
[str, str, int, str, int, Optional[str], Dict[str, ResponseConfig]], LoadResult
]

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from ert.storage import EnsembleAccessor


def _read_parameters(
runpath: str,
iens: int,
Expand Down Expand Up @@ -74,6 +78,33 @@ def _write_responses_to_storage(
return LoadResult(LoadStatus.LOAD_SUCCESSFUL, "")


def forward_model_ok_for_job_queue( # pylint: disable=too-many-arguments
storage_path: str,
ensemble_path: str,
iens: int,
runpath: str,
itr: int,
refcase_file: Optional[str],
response_configs: Dict[str, ResponseConfig],
) -> LoadResult:
from ert.storage import EnsembleAccessor, StorageAccessor

if refcase_file:
refcase = EnsembleConfig.load_refcase(refcase_file)
for key, config in response_configs.items():
if isinstance(config, SummaryConfig):
config.refcase = refcase
response_configs[key] = config
local_storage = StorageAccessor(
storage_path, ignore_migration_check=True, ignore_filelock_dangerous=True
)
ensemble_storage = EnsembleAccessor(local_storage, Path(ensemble_path))
run_arg = RunArg("", ensemble_storage, iens, itr, runpath, "")
return forward_model_ok(
run_arg=run_arg, response_configs=response_configs, update_state_map=False
)


def forward_model_ok(
run_arg: RunArg,
response_configs: Mapping[str, ResponseConfig],
Expand Down
8 changes: 4 additions & 4 deletions src/ert/ensemble_evaluator/_builder/_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

SOURCE_TEMPLATE_STEP = "/step/{step_id}"
if TYPE_CHECKING:
from ert.callbacks import Callback, CallbackArgs
from ert.callbacks import Callback, CallbackArgs, CallbackDone
from ert.run_arg import RunArg


Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__( # pylint: disable=too-many-arguments
source: str,
max_runtime: Optional[int],
callback_arguments: CallbackArgs,
done_callback: Callback,
done_callback: CallbackDone,
exit_callback: Callback,
num_cpu: int,
run_path: Path,
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(self) -> None:
# legacy parts
self._max_runtime: Optional[int] = None
self._callback_arguments: Optional[CallbackArgs] = None
self._done_callback: Optional[Callback] = None
self._done_callback: Optional[CallbackDone] = None
self._exit_callback: Optional[Callback] = None
self._num_cpu: Optional[int] = None
self._run_path: Optional[Path] = None
Expand Down Expand Up @@ -124,7 +124,7 @@ def set_callback_arguments(self, callback_arguments: CallbackArgs) -> "StepBuild
self._callback_arguments = callback_arguments
return self

def set_done_callback(self, done_callback: Callback) -> "StepBuilder":
def set_done_callback(self, done_callback: CallbackDone) -> "StepBuilder":
self._done_callback = done_callback
return self

Expand Down
40 changes: 6 additions & 34 deletions src/ert/job_queue/job_queue_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import random
import time
import traceback
from pathlib import Path
from threading import Lock, Semaphore, Thread
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional

from cwrap import BaseCClass
from ecl.util.util import StringList
Expand All @@ -19,12 +18,10 @@
_refresh_status,
_submit,
)
from ert.callbacks import forward_model_ok
from ert.config import EnsembleConfig, ResponseConfig, SummaryConfig
from ert.load_status import LoadResult, LoadStatus
from ert.callbacks import CallbackDone
from ert.config import SummaryConfig
from ert.load_status import LoadStatus
from ert.realization_state import RealizationState
from ert.run_arg import RunArg
from ert.storage import EnsembleAccessor, StorageAccessor

from . import ResPrototype
from .job_status import JobStatus
Expand All @@ -39,31 +36,6 @@
logger = logging.getLogger(__name__)


def forward_model_ok_wrapper( # pylint: disable=too-many-arguments
storage_path: str,
ensemble_path: str,
iens: int,
runpath: str,
itr: int,
refcase_file: Optional[str],
response_configs: Dict[str, ResponseConfig],
) -> LoadResult:
if refcase_file:
refcase = EnsembleConfig.load_refcase(refcase_file)
for key, config in response_configs.items():
if isinstance(config, SummaryConfig):
config.refcase = refcase
response_configs[key] = config
local_storage = StorageAccessor(
storage_path, ignore_migration_check=True, ignore_filelock_dangerous=True
)
ensemble_storage = EnsembleAccessor(local_storage, Path(ensemble_path))
run_arg = RunArg("", ensemble_storage, iens, itr, runpath, "")
return forward_model_ok(
run_arg=run_arg, response_configs=response_configs, update_state_map=False
)


class _BackoffFunction:
def __init__(
self,
Expand Down Expand Up @@ -125,7 +97,7 @@ def __init__(
num_cpu: int,
status_file: str,
exit_file: str,
done_callback_function: Callback,
done_callback_function: CallbackDone,
exit_callback_function: Callback,
callback_arguments: CallbackArgs,
max_runtime: Optional[int] = None,
Expand Down Expand Up @@ -227,7 +199,7 @@ def run_done_callback(self) -> Optional[LoadStatus]:
itr = run_arg.itr
# TODO fix uncaught exception for example when pickling fails
callback_status, status_msg = mp_pool.apply(
forward_model_ok_wrapper,
self.done_callback_function,
(
storage_path,
ensemble_path,
Expand Down
4 changes: 2 additions & 2 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from . import ResPrototype

if TYPE_CHECKING:
from ert.callbacks import Callback
from ert.callbacks import Callback, CallbackDone
from ert.config import ErtConfig
from ert.ensemble_evaluator import LegacyStep
from ert.run_arg import RunArg
Expand Down Expand Up @@ -484,7 +484,7 @@ def add_job_from_run_arg(
run_arg: "RunArg",
ert_config: "ErtConfig",
max_runtime: Optional[int],
ok_cb: Callback,
ok_cb: CallbackDone,
exit_cb: Callback,
num_cpu: int,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from cloudevents.http import CloudEvent

import _ert_com_protocol
from ert.callbacks import forward_model_exit, forward_model_ok
from ert.callbacks import forward_model_exit, forward_model_ok_for_job_queue
from ert.cli import MODULE_MODE
from ert.config import HookRuntime
from ert.enkf_main import EnKFMain
Expand Down Expand Up @@ -398,7 +398,7 @@ def _build_ensemble(
self.ert().resConfig().ensemble_config.response_configs,
)
).set_done_callback(
forward_model_ok
forward_model_ok_for_job_queue
).set_exit_callback(
forward_model_exit
).set_num_cpu(
Expand Down
4 changes: 2 additions & 2 deletions src/ert/simulator/simulation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from time import sleep
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

from ert.callbacks import forward_model_exit, forward_model_ok
from ert.callbacks import forward_model_exit, forward_model_ok_for_job_queue
from ert.config import HookRuntime
from ert.job_queue import Driver, JobQueue, JobQueueManager, RunStatus
from ert.realization_state import RealizationState
Expand Down Expand Up @@ -49,7 +49,7 @@ def _run_forward_model(
run_arg,
ert.resConfig(),
max_runtime,
forward_model_ok,
forward_model_ok_for_job_queue,
forward_model_exit,
ert.get_num_cpu(),
)
Expand Down
48 changes: 33 additions & 15 deletions tests/unit_tests/job_queue/test_job_queue_manager_torque.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
import stat
from dataclasses import dataclass
from pathlib import Path
from threading import BoundedSemaphore
from typing import Callable, TypedDict
from unittest.mock import MagicMock
from types import SimpleNamespace
from typing import Any, Callable, Optional, TypedDict

import pytest

from ert.config import QueueSystem
from ert.job_queue import Driver, JobQueueNode, JobStatus
from ert.load_status import LoadStatus
from ert.run_arg import RunArg


@pytest.fixture(name="temp_working_directory")
Expand All @@ -33,12 +33,6 @@ def fixture_dummy_config():
)


@dataclass
class RunArg:
iens: int
ensemble_storage = MagicMock()


class JobConfig(TypedDict):
job_script: str
num_cpu: int
Expand All @@ -48,15 +42,32 @@ class JobConfig(TypedDict):
exit_callback: Callable


def dummy_ok_callback(runargs, path):
(Path(path) / "OK").write_text("success", encoding="utf-8")
def dummy_ok_callback(
storage_path: str,
ensemble_path: str,
iens: int,
runpath: str,
itr: int,
refcase_file: Optional[str],
response_configs: Any,
):
(Path(runpath) / "OK").write_text("success", encoding="utf-8")
return (LoadStatus.LOAD_SUCCESSFUL, "")


def dummy_exit_callback(*_args):
Path("ERROR").write_text("failure", encoding="utf-8")


def dummy_ensemble_storage():
ensemble_storage = SimpleNamespace()
ensemble_storage.mount_point = ""
ensemble_storage.storage = SimpleNamespace()
ensemble_storage.storage.path = ""

return ensemble_storage


SIMPLE_SCRIPT = """#!/bin/sh
echo "finished successfully" > STATUS
"""
Expand Down Expand Up @@ -157,10 +168,17 @@ def _build_jobqueuenode(dummy_config: JobConfig, job_id=0):
exit_file="ERROR",
done_callback_function=dummy_config["ok_callback"],
exit_callback_function=dummy_config["exit_callback"],
callback_arguments=[
RunArg(iens=job_id),
Path(dummy_config["run_path"].format(job_id)).resolve(),
],
callback_arguments=(
RunArg(
iens=job_id,
ensemble_storage=dummy_ensemble_storage(),
runpath=dummy_config["run_path"].format(job_id),
itr=0,
job_name="jobjobjob",
run_id="runrunrun",
),
{},
),
)
return (job, runpath)

Expand Down

0 comments on commit ab4e910

Please sign in to comment.