From 7aaf9b815db89588f443f3c1950bae2f4ecbefd5 Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Tue, 12 Mar 2024 15:36:26 +0100 Subject: [PATCH] tmp2 --- src/ert/scheduler/driver.py | 1 + src/ert/scheduler/lsf_driver.py | 2 +- src/ert/scheduler/openpbs_driver.py | 2 +- src/ert/scheduler/scheduler.py | 61 ++++++++++---------- tests/unit_tests/scheduler/test_scheduler.py | 2 + 5 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/ert/scheduler/driver.py b/src/ert/scheduler/driver.py index 1207b2db6a0..35f596af895 100644 --- a/src/ert/scheduler/driver.py +++ b/src/ert/scheduler/driver.py @@ -12,6 +12,7 @@ class Driver(ABC): def __init__(self, **kwargs: Dict[str, str]) -> None: self._event_queue: Optional[asyncio.Queue[Event]] = None + self._is_polling = True @property def event_queue(self) -> asyncio.Queue[Event]: diff --git a/src/ert/scheduler/lsf_driver.py b/src/ert/scheduler/lsf_driver.py index 3d4ac51a259..20e26aa5792 100644 --- a/src/ert/scheduler/lsf_driver.py +++ b/src/ert/scheduler/lsf_driver.py @@ -175,7 +175,7 @@ async def kill(self, iens: int) -> None: return async def poll(self) -> None: - while True: + while self._is_polling: if not self._jobs.keys(): await asyncio.sleep(self._poll_period) continue diff --git a/src/ert/scheduler/openpbs_driver.py b/src/ert/scheduler/openpbs_driver.py index bd5a48cce77..93c9906fd4b 100644 --- a/src/ert/scheduler/openpbs_driver.py +++ b/src/ert/scheduler/openpbs_driver.py @@ -262,7 +262,7 @@ async def kill(self, iens: int) -> None: raise RuntimeError(process_message) async def poll(self) -> None: - while True: + while self._is_polling: if not self._jobs: await asyncio.sleep(_POLL_PERIOD) continue diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 316812aed0b..72274b83ea5 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -90,6 +90,7 @@ def __init__( self.completed_jobs: asyncio.Queue[int] = asyncio.Queue() self._cancelled = False + self._scheduling_tasks_running = True self._max_submit = max_submit self._max_running = max_running @@ -166,7 +167,7 @@ async def _publisher(self) -> None: ping_interval=60, close_timeout=60, ): - while True: + while self._scheduling_tasks_running: event = await self._events.get() await conn.send(event) @@ -211,40 +212,38 @@ async def gather_realization_jobs() -> list[BaseException | None]: *self._tasks.values(), return_exceptions=True ) finally: - for scheduling_task in scheduling_tasks: - scheduling_task.cancel() + self._scheduling_tasks_running = False + self.driver._is_polling = False + job_results: Optional[list[BaseException | None]] = None try: # there are two types of tasks and each type is handled differently: # -`gather_realization_jobs` does not raise; rather, each job task's result # is collected into the returning list. Exceptions from the job tasks are # handled in the `else` branch after the evaluation has stopped. # -If exception occurs, it must necessarily come from `scheduling tasks` - - await asyncio.gather(gather_realization_jobs(), *scheduling_tasks) - + job_results = ( + await asyncio.gather(gather_realization_jobs(), *scheduling_tasks) + )[0] except (asyncio.CancelledError, Exception) as e: - if isinstance(e, asyncio.CancelledError): - for result in self._tasks.values() or []: - try: - await result - except asyncio.CancelledError: - continue - except Exception as e: - logger.error(str(result)) - raise e - else: - for job_task in self._tasks.values(): - job_task.cancel() - # suppress potential error during driver.kill - with suppress(Exception): - await job_task - for scheduling_task in scheduling_tasks: - scheduling_task.cancel() - # Log and re-raise non-cancellation errors. - if not isinstance(e, asyncio.CancelledError): - logger.error(str(e)) - raise e + for job_task in self._tasks.values(): + job_task.cancel() + # suppress potential error during driver.kill + with suppress(Exception): + await job_task + for scheduling_task in scheduling_tasks: + scheduling_task.cancel() + # Log and re-raise non-cancellation errors. + if not isinstance(e, asyncio.CancelledError): + logger.error(str(e)) + raise e + else: + for result in job_results: + if not isinstance(result, asyncio.CancelledError) and isinstance( + result, Exception + ): + logger.error(str(result)) + raise result await self.driver.finish() @@ -255,8 +254,12 @@ async def gather_realization_jobs() -> list[BaseException | None]: return EVTYPE_ENSEMBLE_STOPPED async def _process_event_queue(self) -> None: - while True: - event = await self.driver.event_queue.get() + while self._scheduling_tasks_running: + try: + event = await self.driver.event_queue.get_nowait() + except Exception: + await asyncio.sleep(0.1) + continue job = self._jobs[event.iens] # Any event implies the job has at least started diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 31b3b46fcd8..6c87718a30e 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -223,6 +223,7 @@ async def wait(): assert "Realization 0 stopped due to MAX_RUNTIME=1 seconds" in caplog.text +@pytest.mark.timeout(15) async def test_no_resubmit_on_max_runtime_kill(realization, mock_driver): retries = 0 @@ -244,6 +245,7 @@ async def wait(): assert retries == 1 +@pytest.mark.timeout(15) @pytest.mark.parametrize("max_running", [0, 1, 2, 10]) async def test_max_running(max_running, mock_driver, storage, tmp_path): runs: List[bool] = []