Skip to content

Commit

Permalink
Simplify JobQueue Driver options
Browse files Browse the repository at this point in the history
  • Loading branch information
pinkwah committed Nov 29, 2023
1 parent 070e8d5 commit ae1c110
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 85 deletions.
18 changes: 13 additions & 5 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import threading
import uuid
from functools import partial, partialmethod
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
)

from cloudevents.http.event import CloudEvent

Expand Down Expand Up @@ -183,7 +193,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
# something is long running, the evaluator will know and should send
# commands to the task in order to have it killed/retried.
# See https://github.com/equinor/ert/issues/1229
queue_evaluators = None
queue_evaluators: Optional[Sequence[Callable[[], None]]] = None
if (
self._analysis_config.stop_long_running
and self._analysis_config.minimum_required_realizations > 0
Expand All @@ -206,9 +216,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
# NOTE: This touches files on disk...
self._scheduler.add_dispatch_information_to_jobs_file()

result: str = await self._scheduler.execute(
queue_evaluators # type: ignore
)
result: str = await self._scheduler.execute(queue_evaluators)
print(result)
except Exception as exc:
print(exc)
Expand Down
62 changes: 27 additions & 35 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import re
import shlex
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
)

from ert.config.parsing.queue_system import QueueSystem

Expand All @@ -19,22 +29,9 @@
class Driver(ABC):
def __init__(
self,
options: Optional[List[Tuple[str, str]]] = None,
options: Optional[Sequence[Tuple[str, str]]] = None,
):
self._options: Dict[str, str] = {}

if options:
for key, value in options:
self.set_option(key, value)

def set_option(self, option: str, value: str) -> None:
self._options.update({option: value})

def get_option(self, option_key: str) -> str:
return self._options[option_key]

def has_option(self, option_key: str) -> bool:
return option_key in self._options
self.options: Dict[str, str] = dict(options or [])

@abstractmethod
async def submit(self, realization: "RealizationState") -> None:
Expand Down Expand Up @@ -143,7 +140,7 @@ def __init__(self, queue_options: Optional[List[Tuple[str, str]]]):
self._currently_polling = False

async def run_with_retries(
self, func: Callable[[Any], Awaitable[Any]], error_msg: str = ""
self, func: Callable[[], Awaitable[Any]], error_msg: str = ""
) -> None:
current_attempt = 0
while current_attempt < self._max_attempt:
Expand All @@ -160,12 +157,10 @@ async def run_with_retries(

async def submit(self, realization: "RealizationState") -> None:
submit_cmd = self.build_submit_cmd(
[
"-J",
f"poly_{realization.realization.run_arg.iens}",
str(realization.realization.job_script),
str(realization.realization.run_arg.runpath),
]
"-J",
f"poly_{realization.realization.run_arg.iens}",
str(realization.realization.job_script),
str(realization.realization.run_arg.runpath),
)
await self.run_with_retries(
lambda: self._submit(submit_cmd, realization=realization),
Expand Down Expand Up @@ -194,14 +189,12 @@ async def _submit(
logger.info(f"Submitted job {realization} and got LSF JOBID {lsf_id}")
return True

def build_submit_cmd(self, args: List[str]) -> List[str]:
submit_cmd = [
self.get_option("BSUB_CMD") if self.has_option("BSUB_CMD") else "bsub"
]
if self.has_option("LSF_QUEUE"):
submit_cmd += ["-q", self.get_option("LSF_QUEUE")]
def build_submit_cmd(self, *args: str) -> List[str]:
submit_cmd = [self.options.get("BSUB_CMD", "bsub")]
if (lsf_queue := self.options.get("LSF_QUEUE")) is not None:
submit_cmd += ["-q", lsf_queue]

return submit_cmd + args
return [*submit_cmd, *args]

async def run_shell_command(
self, command_to_run: List[str], command_name: str = ""
Expand Down Expand Up @@ -234,10 +227,9 @@ async def poll_statuses(self) -> None:
return

poll_cmd = [
str(self.get_option("BJOBS_CMD"))
if self.has_option("BJOBS_CMD")
else "bjobs"
] + list(self._realstate_to_lsfid.values())
self.options.get("BJOBS_CMD", "bjobs"),
*self._realstate_to_lsfid.values(),
]
try:
await self.run_with_retries(lambda: self._poll_statuses(poll_cmd))
# suppress runtime error
Expand Down Expand Up @@ -300,7 +292,7 @@ async def kill(self, realization: "RealizationState") -> None:
lsf_job_id = self._realstate_to_lsfid[realization]
logger.debug(f"Attempting to kill {lsf_job_id=}")
kill_cmd = [
self.get_option("BKILL_CMD") if self.has_option("BKILL_CMD") else "bkill",
self.options.get("BKILL_CMD", "bkill"),
lsf_job_id,
]
await self.run_with_retries(
Expand Down
8 changes: 4 additions & 4 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ssl
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Union

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
Expand Down Expand Up @@ -196,8 +196,8 @@ def _add_realization(self, realization: QueueableRealization) -> int:

def max_running(self) -> int:
max_running = 0
if self.driver.has_option("MAX_RUNNING"):
max_running = int(self.driver.get_option("MAX_RUNNING"))
if (value := self.driver.options.get("MAX_RUNNING")) is not None:
max_running = int(value)
if max_running == 0:
return len(self._realizations)
return max_running
Expand Down Expand Up @@ -318,7 +318,7 @@ async def _realization_statechange_publisher(self) -> None:

async def execute(
self,
evaluators: Optional[List[Callable[..., Any]]] = None,
evaluators: Optional[Sequence[Callable[[], None]]] = None,
) -> str:
if evaluators is None:
evaluators = []
Expand Down
6 changes: 3 additions & 3 deletions src/ert/simulator/simulation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from threading import Thread
from time import sleep
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Tuple

import numpy as np

Expand Down Expand Up @@ -61,7 +61,7 @@ def _run_forward_model(
ert.ert_config.preferred_num_cpu,
)

queue_evaluators = None
queue_evaluators: Optional[Sequence[Callable[[], None]]] = None
if (
ert.ert_config.analysis_config.stop_long_running
and ert.ert_config.analysis_config.minimum_required_realizations > 0
Expand All @@ -73,7 +73,7 @@ def _run_forward_model(
)
]

asyncio.run(scheduler.execute(evaluators=queue_evaluators)) # type: ignore
asyncio.run(scheduler.execute(queue_evaluators))

run_context.sim_fs.sync()

Expand Down
8 changes: 3 additions & 5 deletions tests/unit_tests/config/test_ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,9 @@ def test_extensive_config(setup_case):
assert queue_config.queue_system == QueueSystem.LSF
assert snake_oil_structure_config["MAX_SUBMIT"] == queue_config.max_submit
driver = Driver.create_driver(queue_config)
assert snake_oil_structure_config["MAX_RUNNING"] == driver.get_option("MAX_RUNNING")
assert snake_oil_structure_config["LSF_SERVER"] == driver.get_option("LSF_SERVER")
assert snake_oil_structure_config["LSF_RESOURCE"] == driver.get_option(
"LSF_RESOURCE"
)
assert snake_oil_structure_config["MAX_RUNNING"] == driver.options["MAX_RUNNING"]
assert snake_oil_structure_config["LSF_SERVER"] == driver.options["LSF_SERVER"]
assert snake_oil_structure_config["LSF_RESOURCE"] == driver.options["LSF_RESOURCE"]

for job_name in snake_oil_structure_config["INSTALL_JOB"]:
job = ert_config.installed_jobs[job_name]
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/config/test_queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_torque_queue_config_memory_pr_job(memory_with_unit_str):

driver = Driver.create_driver(config.queue_config)

assert driver.get_option("MEMORY_PER_JOB") == memory_with_unit_str
assert driver.options["MEMORY_PER_JOB"] == memory_with_unit_str


@pytest.mark.usefixtures("use_tmpdir", "set_site_config")
Expand Down Expand Up @@ -177,8 +177,8 @@ def test_initializing_empty_config_queue_options_resets_to_default_value(
f.write(f"QUEUE_OPTION {queue_system} MAX_RUNNING\n")
config_object = ErtConfig.from_file(filename)
driver = Driver.create_driver(config_object.queue_config)
assert driver.get_option(queue_system_option) == ""
assert driver.get_option("MAX_RUNNING") == "0"
assert driver.options[queue_system_option] == ""
assert driver.options["MAX_RUNNING"] == "0"
for options in config_object.queue_config.queue_options[queue_system]:
assert isinstance(options, tuple)

Expand Down
31 changes: 2 additions & 29 deletions tests/unit_tests/job_queue/_test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,6 @@
from ert.scheduler import Driver


@pytest.mark.xfail(reason="Needs reimplementation")
def test_set_and_unset_option():
queue_config = QueueConfig(
job_script="script.sh",
queue_system=QueueSystem.LOCAL,
max_submit=2,
queue_options={
QueueSystem.LOCAL: [
("MAX_RUNNING", "50"),
("MAX_RUNNING", ""),
]
},
)
driver = Driver.create_driver(queue_config)
assert driver.get_option("MAX_RUNNING") == "0"
assert driver.set_option("MAX_RUNNING", "42")
assert driver.get_option("MAX_RUNNING") == "42"
driver.set_option("MAX_RUNNING", "")
assert driver.get_option("MAX_RUNNING") == "0"
driver.set_option("MAX_RUNNING", "100")
assert driver.get_option("MAX_RUNNING") == "100"
driver.set_option("MAX_RUNNING", "0")
assert driver.get_option("MAX_RUNNING") == "0"


@pytest.mark.xfail(reason="Needs reimplementation")
def test_get_driver_name():
queue_config = QueueConfig(queue_system=QueueSystem.LOCAL)
Expand Down Expand Up @@ -61,8 +36,6 @@ def test_get_slurm_queue_config():
assert queue_config.queue_system == QueueSystem.SLURM
driver = Driver.create_driver(queue_config)

assert driver.get_option("SBATCH") == "/path/to/sbatch"
assert driver.get_option("SCONTROL") == "scontrol"
driver.set_option("SCONTROL", "")
assert driver.get_option("SCONTROL") == ""
assert driver.options["SBATCH"] == "/path/to/sbatch"
assert driver.options["SCONTROL"] == "scontrol"
assert driver.name == "SLURM"
2 changes: 1 addition & 1 deletion tests/unit_tests/job_queue/_test_job_queue_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
def make_driver(queue_system: QueueSystem):
result = Driver(queue_system)
if queue_system == QueueSystem.TORQUE:
result.set_option("QSTAT_CMD", "qstat")
result.options["QSTAT_CMD"] = "qstat"
return result


Expand Down

0 comments on commit ae1c110

Please sign in to comment.