diff --git a/src/ert/async_utils.py b/src/ert/async_utils.py index a5657fd585b..a4bd79b8c50 100644 --- a/src/ert/async_utils.py +++ b/src/ert/async_utils.py @@ -2,16 +2,7 @@ import asyncio import logging -from contextlib import asynccontextmanager -from typing import ( - Any, - AsyncGenerator, - Coroutine, - Generator, - MutableSequence, - TypeVar, - Union, -) +from typing import Any, Coroutine, Generator, TypeVar, Union logger = logging.getLogger(__name__) @@ -19,31 +10,6 @@ _T_co = TypeVar("_T_co", covariant=True) -@asynccontextmanager -async def background_tasks() -> AsyncGenerator[Any, Any]: - """Context manager for long-living tasks that cancel when exiting the - context - - """ - - tasks: MutableSequence[asyncio.Task[Any]] = [] - - def add(coro: Coroutine[Any, Any, Any]) -> None: - tasks.append(asyncio.create_task(coro)) - - try: - yield add - finally: - for t in tasks: - t.cancel() - for exc in await asyncio.gather(*tasks, return_exceptions=True): - if isinstance(exc, asyncio.CancelledError): - continue - if isinstance(exc, BaseException): - logger.error(str(exc), exc_info=exc) - tasks.clear() - - def new_event_loop() -> asyncio.AbstractEventLoop: loop = asyncio.new_event_loop() loop.set_task_factory(_create_task) diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 780dd7b6f55..85868ac84ab 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -4,7 +4,6 @@ import logging import time import uuid -from contextlib import suppress from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, List, Optional @@ -117,8 +116,7 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: except asyncio.CancelledError: await self._send(State.ABORTING) await self.driver.kill(self.iens) - with suppress(asyncio.CancelledError): - await self.returncode + self.returncode.cancel() await self._send(State.ABORTED) finally: if timeout_task and not timeout_task.done(): @@ -126,10 +124,9 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: sem.release() async def __call__( - self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2 + self, sem: asyncio.BoundedSemaphore, max_submit: int = 2 ) -> None: self._requested_max_submit = max_submit - await start.wait() for attempt in range(max_submit): await self._submit_and_run_once(sem) diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 9460eeb1c45..0efe9ba64ec 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -7,6 +7,7 @@ import ssl import time from collections import defaultdict +from contextlib import suppress from dataclasses import asdict from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, MutableMapping, Optional, Sequence @@ -15,7 +16,6 @@ from websockets import Headers from websockets.client import connect -from ert.async_utils import background_tasks from ert.constant_filenames import CERT_FILE from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED from ert.scheduler.driver import Driver @@ -181,32 +181,67 @@ async def execute( # We need to store the loop due to when calling # cancel jobs from another thread self._loop = asyncio.get_running_loop() - async with background_tasks() as cancel_when_execute_is_done: - cancel_when_execute_is_done(self._publisher()) - cancel_when_execute_is_done(self._process_event_queue()) - cancel_when_execute_is_done(self.driver.poll()) - if min_required_realizations > 0: - cancel_when_execute_is_done( + scheduling_tasks = [ + asyncio.create_task(self._publisher()), + asyncio.create_task(self._process_event_queue()), + asyncio.create_task(self.driver.poll()), + ] + + if min_required_realizations > 0: + scheduling_tasks.append( + asyncio.create_task( self._stop_long_running_jobs(min_required_realizations) ) - cancel_when_execute_is_done(self._update_avg_job_runtime()) - - start = asyncio.Event() - sem = asyncio.BoundedSemaphore(self._max_running or len(self._jobs)) - for iens, job in self._jobs.items(): - self._tasks[iens] = asyncio.create_task( - job(start, sem, self._max_submit) - ) - - start.set() - results = await asyncio.gather( - *self._tasks.values(), return_exceptions=True ) - for result in results: - if isinstance(result, Exception): - logger.error(result) - - await self.driver.finish() + scheduling_tasks.append(asyncio.create_task(self._update_avg_job_runtime())) + + sem = asyncio.BoundedSemaphore(self._max_running or len(self._jobs)) + for iens, job in self._jobs.items(): + self._tasks[iens] = asyncio.create_task(job(sem, self._max_submit)) + + async def gather_realization_jobs() -> list[BaseException | None]: + """This makes sure that all the tasks are completed, where afterwards + we cancel scheduling_tasks.It returns list of task exceptions or None. + """ + try: + return await asyncio.gather( + *self._tasks.values(), return_exceptions=True + ) + finally: + for scheduling_task in scheduling_tasks: + scheduling_task.cancel() + + job_results: Optional[list[BaseException | None]] = None + try: + # there are two types of tasks and each type is handled differently: + # -`gather_realization_jobs`` does not raise; rather, each job task's result + # is collected into the returning list. Exceptions from the job tasks are + # handled in the `else` branch after the evaluation has stopped. + # -If exception occurs, it must necessarily come from `scheduling tasks`` + job_results = ( + await asyncio.gather(gather_realization_jobs(), *scheduling_tasks) + )[0] + except (asyncio.CancelledError, Exception) as e: + for job_task in self._tasks.values(): + job_task.cancel() + # suppress potential error during driver.kill + with suppress(Exception): + await job_task + for scheduling_task in scheduling_tasks: + scheduling_task.cancel() + # Log and re-raise non-cancellation errors. + if not isinstance(e, asyncio.CancelledError): + logger.error(str(e)) + raise e + else: + for result in job_results or []: + if not isinstance(result, asyncio.CancelledError) and isinstance( + result, Exception + ): + logger.error(str(result)) + raise result + + await self.driver.finish() if self._cancelled: logger.debug("scheduler cancelled, stopping jobs...") diff --git a/tests/integration_tests/scheduler/test_openpbs_driver.py b/tests/integration_tests/scheduler/test_openpbs_driver.py index 3759f98d00d..b3cb8c0a9c6 100644 --- a/tests/integration_tests/scheduler/test_openpbs_driver.py +++ b/tests/integration_tests/scheduler/test_openpbs_driver.py @@ -1,6 +1,11 @@ +import asyncio +from functools import partial + import pytest from ert.cli import ENSEMBLE_EXPERIMENT_MODE +from ert.cli.main import ErtCliError +from ert.scheduler.openpbs_driver import OpenPBSDriver from tests.integration_tests.run_cli import run_cli from .conftest import mock_bin @@ -26,3 +31,47 @@ def test_openpbs_driver_with_poly_example(): "--enable-scheduler", "poly.ert", ) + + +async def mock_failure(message, *args, **kwargs): + raise RuntimeError(message) + + +@pytest.mark.timeout(30) +@pytest.mark.integration_test +@pytest.mark.usefixtures("copy_poly_case") +def test_openpbs_driver_with_poly_example_failing_submit_fails_ert_and_propagates_exception_to_user( + monkeypatch, caplog +): + monkeypatch.setattr( + OpenPBSDriver, "submit", partial(mock_failure, "Submit job failed") + ) + with open("poly.ert", mode="a+", encoding="utf-8") as f: + f.write("QUEUE_SYSTEM TORQUE\nNUM_REALIZATIONS 2") + with pytest.raises(ErtCliError): + run_cli( + ENSEMBLE_EXPERIMENT_MODE, + "--enable-scheduler", + "poly.ert", + ) + assert "RuntimeError: Submit job failed" in caplog.text + + +@pytest.mark.timeout(30) +@pytest.mark.integration_test +@pytest.mark.usefixtures("copy_poly_case") +def test_openpbs_driver_with_poly_example_failing_poll_fails_ert_and_propagates_exception_to_user( + monkeypatch, caplog +): + monkeypatch.setattr( + OpenPBSDriver, "poll", partial(mock_failure, "Status polling failed") + ) + with open("poly.ert", mode="a+", encoding="utf-8") as f: + f.write("QUEUE_SYSTEM TORQUE\nNUM_REALIZATIONS 2") + with pytest.raises(ErtCliError): + run_cli( + ENSEMBLE_EXPERIMENT_MODE, + "--enable-scheduler", + "poly.ert", + ) + assert "RuntimeError: Status polling failed" in caplog.text diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 407968e13fd..531054ac9fe 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -3,6 +3,7 @@ import random import shutil import time +from functools import partial from pathlib import Path from typing import List @@ -468,3 +469,42 @@ async def wait(): for start, next_start in zip(run_start_times[:-1], run_start_times[1:]) ] assert min(deltas) >= submit_sleep * 0.8 + + +async def mock_failure(message, *args, **kwargs): + raise RuntimeError(message) + + +@pytest.mark.timeout(5) +async def test_that_driver_poll_exceptions_are_propagated(mock_driver, realization): + driver = mock_driver() + driver.poll = partial(mock_failure, "Status polling failed") + + sch = scheduler.Scheduler(driver, [realization]) + + with pytest.raises(RuntimeError, match="Status polling failed"): + await sch.execute() + + +@pytest.mark.timeout(5) +async def test_that_publisher_exceptions_are_propagated(mock_driver, realization): + driver = mock_driver() + + sch = scheduler.Scheduler(driver, [realization]) + sch._publisher = partial(mock_failure, "Publisher failed") + + with pytest.raises(RuntimeError, match="Publisher failed"): + await sch.execute() + + +@pytest.mark.timeout(5) +async def test_that_process_event_queue_exceptions_are_propagated( + mock_driver, realization +): + driver = mock_driver() + + sch = scheduler.Scheduler(driver, [realization]) + sch._process_event_queue = partial(mock_failure, "Processing event queue failed") + + with pytest.raises(RuntimeError, match="Processing event queue failed"): + await sch.execute() diff --git a/tests/unit_tests/test_async_utils.py b/tests/unit_tests/test_async_utils.py deleted file mode 100644 index 6c977c3f962..00000000000 --- a/tests/unit_tests/test_async_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -import asyncio - -import pytest - -from ert.async_utils import background_tasks - - -@pytest.mark.timeout(1) -async def test_background_tasks(caplog): - current_task_future = asyncio.Future() - - async def task(): - current_task_future.set_result(asyncio.current_task()) - await asyncio.sleep(100) - - async with background_tasks() as bt: - bt(task()) - current_task = await current_task_future - assert not current_task.done() - - assert current_task.done() - assert caplog.records == [] - - -@pytest.mark.timeout(1) -async def test_background_tasks_with_exception(caplog): - started = asyncio.Event() - - async def task(): - started.set() - raise ValueError("Uh-oh!") - - async with background_tasks() as bt: - bt(task()) - await started.wait() - - assert len(caplog.records) == 1 - assert caplog.records[0].message == "Uh-oh!"