From ae1c1104546c2e20302e5b135613e735537ed443 Mon Sep 17 00:00:00 2001 From: Zohar Malamant Date: Mon, 27 Nov 2023 14:49:47 +0100 Subject: [PATCH] Simplify JobQueue Driver options --- .../ensemble_evaluator/_builder/_legacy.py | 18 ++++-- src/ert/scheduler/driver.py | 62 ++++++++----------- src/ert/scheduler/scheduler.py | 8 +-- src/ert/simulator/simulation_context.py | 6 +- tests/unit_tests/config/test_ert_config.py | 8 +-- tests/unit_tests/config/test_queue_config.py | 6 +- tests/unit_tests/job_queue/_test_driver.py | 31 +--------- .../job_queue/_test_job_queue_node.py | 2 +- 8 files changed, 56 insertions(+), 85 deletions(-) diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index fb0170a8bb8..590e903e867 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -5,7 +5,17 @@ import threading import uuid from functools import partial, partialmethod -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, +) from cloudevents.http.event import CloudEvent @@ -183,7 +193,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches # something is long running, the evaluator will know and should send # commands to the task in order to have it killed/retried. # See https://github.com/equinor/ert/issues/1229 - queue_evaluators = None + queue_evaluators: Optional[Sequence[Callable[[], None]]] = None if ( self._analysis_config.stop_long_running and self._analysis_config.minimum_required_realizations > 0 @@ -206,9 +216,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches # NOTE: This touches files on disk... self._scheduler.add_dispatch_information_to_jobs_file() - result: str = await self._scheduler.execute( - queue_evaluators # type: ignore - ) + result: str = await self._scheduler.execute(queue_evaluators) print(result) except Exception as exc: print(exc) diff --git a/src/ert/scheduler/driver.py b/src/ert/scheduler/driver.py index 7df8f98e220..b94728b322f 100644 --- a/src/ert/scheduler/driver.py +++ b/src/ert/scheduler/driver.py @@ -4,7 +4,17 @@ import re import shlex from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, +) from ert.config.parsing.queue_system import QueueSystem @@ -19,22 +29,9 @@ class Driver(ABC): def __init__( self, - options: Optional[List[Tuple[str, str]]] = None, + options: Optional[Sequence[Tuple[str, str]]] = None, ): - self._options: Dict[str, str] = {} - - if options: - for key, value in options: - self.set_option(key, value) - - def set_option(self, option: str, value: str) -> None: - self._options.update({option: value}) - - def get_option(self, option_key: str) -> str: - return self._options[option_key] - - def has_option(self, option_key: str) -> bool: - return option_key in self._options + self.options: Dict[str, str] = dict(options or []) @abstractmethod async def submit(self, realization: "RealizationState") -> None: @@ -143,7 +140,7 @@ def __init__(self, queue_options: Optional[List[Tuple[str, str]]]): self._currently_polling = False async def run_with_retries( - self, func: Callable[[Any], Awaitable[Any]], error_msg: str = "" + self, func: Callable[[], Awaitable[Any]], error_msg: str = "" ) -> None: current_attempt = 0 while current_attempt < self._max_attempt: @@ -160,12 +157,10 @@ async def run_with_retries( async def submit(self, realization: "RealizationState") -> None: submit_cmd = self.build_submit_cmd( - [ - "-J", - f"poly_{realization.realization.run_arg.iens}", - str(realization.realization.job_script), - str(realization.realization.run_arg.runpath), - ] + "-J", + f"poly_{realization.realization.run_arg.iens}", + str(realization.realization.job_script), + str(realization.realization.run_arg.runpath), ) await self.run_with_retries( lambda: self._submit(submit_cmd, realization=realization), @@ -194,14 +189,12 @@ async def _submit( logger.info(f"Submitted job {realization} and got LSF JOBID {lsf_id}") return True - def build_submit_cmd(self, args: List[str]) -> List[str]: - submit_cmd = [ - self.get_option("BSUB_CMD") if self.has_option("BSUB_CMD") else "bsub" - ] - if self.has_option("LSF_QUEUE"): - submit_cmd += ["-q", self.get_option("LSF_QUEUE")] + def build_submit_cmd(self, *args: str) -> List[str]: + submit_cmd = [self.options.get("BSUB_CMD", "bsub")] + if (lsf_queue := self.options.get("LSF_QUEUE")) is not None: + submit_cmd += ["-q", lsf_queue] - return submit_cmd + args + return [*submit_cmd, *args] async def run_shell_command( self, command_to_run: List[str], command_name: str = "" @@ -234,10 +227,9 @@ async def poll_statuses(self) -> None: return poll_cmd = [ - str(self.get_option("BJOBS_CMD")) - if self.has_option("BJOBS_CMD") - else "bjobs" - ] + list(self._realstate_to_lsfid.values()) + self.options.get("BJOBS_CMD", "bjobs"), + *self._realstate_to_lsfid.values(), + ] try: await self.run_with_retries(lambda: self._poll_statuses(poll_cmd)) # suppress runtime error @@ -300,7 +292,7 @@ async def kill(self, realization: "RealizationState") -> None: lsf_job_id = self._realstate_to_lsfid[realization] logger.debug(f"Attempting to kill {lsf_job_id=}") kill_cmd = [ - self.get_option("BKILL_CMD") if self.has_option("BKILL_CMD") else "bkill", + self.options.get("BKILL_CMD", "bkill"), lsf_job_id, ] await self.run_with_retries( diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 50c77952a17..59a19e5fc9c 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -11,7 +11,7 @@ import ssl from collections import deque from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Union from cloudevents.conversion import to_json from cloudevents.http import CloudEvent @@ -196,8 +196,8 @@ def _add_realization(self, realization: QueueableRealization) -> int: def max_running(self) -> int: max_running = 0 - if self.driver.has_option("MAX_RUNNING"): - max_running = int(self.driver.get_option("MAX_RUNNING")) + if (value := self.driver.options.get("MAX_RUNNING")) is not None: + max_running = int(value) if max_running == 0: return len(self._realizations) return max_running @@ -318,7 +318,7 @@ async def _realization_statechange_publisher(self) -> None: async def execute( self, - evaluators: Optional[List[Callable[..., Any]]] = None, + evaluators: Optional[Sequence[Callable[[], None]]] = None, ) -> str: if evaluators is None: evaluators = [] diff --git a/src/ert/simulator/simulation_context.py b/src/ert/simulator/simulation_context.py index 8d2d71c583e..0dc67a4c123 100644 --- a/src/ert/simulator/simulation_context.py +++ b/src/ert/simulator/simulation_context.py @@ -4,7 +4,7 @@ from functools import partial from threading import Thread from time import sleep -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Tuple import numpy as np @@ -61,7 +61,7 @@ def _run_forward_model( ert.ert_config.preferred_num_cpu, ) - queue_evaluators = None + queue_evaluators: Optional[Sequence[Callable[[], None]]] = None if ( ert.ert_config.analysis_config.stop_long_running and ert.ert_config.analysis_config.minimum_required_realizations > 0 @@ -73,7 +73,7 @@ def _run_forward_model( ) ] - asyncio.run(scheduler.execute(evaluators=queue_evaluators)) # type: ignore + asyncio.run(scheduler.execute(queue_evaluators)) run_context.sim_fs.sync() diff --git a/tests/unit_tests/config/test_ert_config.py b/tests/unit_tests/config/test_ert_config.py index ebbdf9bb5e5..a997ec30fcc 100644 --- a/tests/unit_tests/config/test_ert_config.py +++ b/tests/unit_tests/config/test_ert_config.py @@ -188,11 +188,9 @@ def test_extensive_config(setup_case): assert queue_config.queue_system == QueueSystem.LSF assert snake_oil_structure_config["MAX_SUBMIT"] == queue_config.max_submit driver = Driver.create_driver(queue_config) - assert snake_oil_structure_config["MAX_RUNNING"] == driver.get_option("MAX_RUNNING") - assert snake_oil_structure_config["LSF_SERVER"] == driver.get_option("LSF_SERVER") - assert snake_oil_structure_config["LSF_RESOURCE"] == driver.get_option( - "LSF_RESOURCE" - ) + assert snake_oil_structure_config["MAX_RUNNING"] == driver.options["MAX_RUNNING"] + assert snake_oil_structure_config["LSF_SERVER"] == driver.options["LSF_SERVER"] + assert snake_oil_structure_config["LSF_RESOURCE"] == driver.options["LSF_RESOURCE"] for job_name in snake_oil_structure_config["INSTALL_JOB"]: job = ert_config.installed_jobs[job_name] diff --git a/tests/unit_tests/config/test_queue_config.py b/tests/unit_tests/config/test_queue_config.py index a349361783e..f2830042e09 100644 --- a/tests/unit_tests/config/test_queue_config.py +++ b/tests/unit_tests/config/test_queue_config.py @@ -92,7 +92,7 @@ def test_torque_queue_config_memory_pr_job(memory_with_unit_str): driver = Driver.create_driver(config.queue_config) - assert driver.get_option("MEMORY_PER_JOB") == memory_with_unit_str + assert driver.options["MEMORY_PER_JOB"] == memory_with_unit_str @pytest.mark.usefixtures("use_tmpdir", "set_site_config") @@ -177,8 +177,8 @@ def test_initializing_empty_config_queue_options_resets_to_default_value( f.write(f"QUEUE_OPTION {queue_system} MAX_RUNNING\n") config_object = ErtConfig.from_file(filename) driver = Driver.create_driver(config_object.queue_config) - assert driver.get_option(queue_system_option) == "" - assert driver.get_option("MAX_RUNNING") == "0" + assert driver.options[queue_system_option] == "" + assert driver.options["MAX_RUNNING"] == "0" for options in config_object.queue_config.queue_options[queue_system]: assert isinstance(options, tuple) diff --git a/tests/unit_tests/job_queue/_test_driver.py b/tests/unit_tests/job_queue/_test_driver.py index a62e9439003..b0cd9ebd2db 100644 --- a/tests/unit_tests/job_queue/_test_driver.py +++ b/tests/unit_tests/job_queue/_test_driver.py @@ -6,31 +6,6 @@ from ert.scheduler import Driver -@pytest.mark.xfail(reason="Needs reimplementation") -def test_set_and_unset_option(): - queue_config = QueueConfig( - job_script="script.sh", - queue_system=QueueSystem.LOCAL, - max_submit=2, - queue_options={ - QueueSystem.LOCAL: [ - ("MAX_RUNNING", "50"), - ("MAX_RUNNING", ""), - ] - }, - ) - driver = Driver.create_driver(queue_config) - assert driver.get_option("MAX_RUNNING") == "0" - assert driver.set_option("MAX_RUNNING", "42") - assert driver.get_option("MAX_RUNNING") == "42" - driver.set_option("MAX_RUNNING", "") - assert driver.get_option("MAX_RUNNING") == "0" - driver.set_option("MAX_RUNNING", "100") - assert driver.get_option("MAX_RUNNING") == "100" - driver.set_option("MAX_RUNNING", "0") - assert driver.get_option("MAX_RUNNING") == "0" - - @pytest.mark.xfail(reason="Needs reimplementation") def test_get_driver_name(): queue_config = QueueConfig(queue_system=QueueSystem.LOCAL) @@ -61,8 +36,6 @@ def test_get_slurm_queue_config(): assert queue_config.queue_system == QueueSystem.SLURM driver = Driver.create_driver(queue_config) - assert driver.get_option("SBATCH") == "/path/to/sbatch" - assert driver.get_option("SCONTROL") == "scontrol" - driver.set_option("SCONTROL", "") - assert driver.get_option("SCONTROL") == "" + assert driver.options["SBATCH"] == "/path/to/sbatch" + assert driver.options["SCONTROL"] == "scontrol" assert driver.name == "SLURM" diff --git a/tests/unit_tests/job_queue/_test_job_queue_node.py b/tests/unit_tests/job_queue/_test_job_queue_node.py index 435e167d5d6..c8357a7e7d0 100644 --- a/tests/unit_tests/job_queue/_test_job_queue_node.py +++ b/tests/unit_tests/job_queue/_test_job_queue_node.py @@ -29,7 +29,7 @@ def make_driver(queue_system: QueueSystem): result = Driver(queue_system) if queue_system == QueueSystem.TORQUE: - result.set_option("QSTAT_CMD", "qstat") + result.options["QSTAT_CMD"] = "qstat" return result