Skip to content

Commit

Permalink
Treat exception in long lived tasks by cancelling the jobs correctly
Browse files Browse the repository at this point in the history
We gather collectively results of realization tasks and scheduling_tasks
(long live scheduling tasks). The main point is that the exceptions from realization tasks are treated differently than
the exceptions from scheduling tasks. Exception in scheduling tasks requires immidiate handling.

This includes unit tests when OpenPBS driver hanging / fails and
scheduler related exception tests.

Additionally, this commit removes
1) async_utils.background_tasks and test_async_utils.py
2) start.Event as this feature is not really required

Co-authored-by: Jonathan Karlsen <[email protected]>
  • Loading branch information
xjules and jonathan-eq committed Mar 9, 2024
1 parent 2f02032 commit 4a2bf1b
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 102 deletions.
36 changes: 1 addition & 35 deletions src/ert/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,14 @@

import asyncio
import logging
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncGenerator,
Coroutine,
Generator,
MutableSequence,
TypeVar,
Union,
)
from typing import Any, Coroutine, Generator, TypeVar, Union

logger = logging.getLogger(__name__)

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)


@asynccontextmanager
async def background_tasks() -> AsyncGenerator[Any, Any]:
"""Context manager for long-living tasks that cancel when exiting the
context
"""

tasks: MutableSequence[asyncio.Task[Any]] = []

def add(coro: Coroutine[Any, Any, Any]) -> None:
tasks.append(asyncio.create_task(coro))

try:
yield add
finally:
for t in tasks:
t.cancel()
for exc in await asyncio.gather(*tasks, return_exceptions=True):
if isinstance(exc, asyncio.CancelledError):
continue
if isinstance(exc, BaseException):
logger.error(str(exc), exc_info=exc)
tasks.clear()


def new_event_loop() -> asyncio.AbstractEventLoop:
loop = asyncio.new_event_loop()
loop.set_task_factory(_create_task)
Expand Down
7 changes: 2 additions & 5 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import time
import uuid
from contextlib import suppress
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional
Expand Down Expand Up @@ -117,19 +116,17 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)
with suppress(asyncio.CancelledError):
await self.returncode
self.returncode.cancel()
await self._send(State.ABORTED)
finally:
if timeout_task and not timeout_task.done():
timeout_task.cancel()
sem.release()

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2
self, sem: asyncio.BoundedSemaphore, max_submit: int = 2
) -> None:
self._requested_max_submit = max_submit
await start.wait()

for attempt in range(max_submit):
await self._submit_and_run_once(sem)
Expand Down
83 changes: 59 additions & 24 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ssl
import time
from collections import defaultdict
from contextlib import suppress
from dataclasses import asdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, MutableMapping, Optional, Sequence
Expand All @@ -15,7 +16,6 @@
from websockets import Headers
from websockets.client import connect

from ert.async_utils import background_tasks
from ert.constant_filenames import CERT_FILE
from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED
from ert.scheduler.driver import Driver
Expand Down Expand Up @@ -181,32 +181,67 @@ async def execute(
# We need to store the loop due to when calling
# cancel jobs from another thread
self._loop = asyncio.get_running_loop()
async with background_tasks() as cancel_when_execute_is_done:
cancel_when_execute_is_done(self._publisher())
cancel_when_execute_is_done(self._process_event_queue())
cancel_when_execute_is_done(self.driver.poll())
if min_required_realizations > 0:
cancel_when_execute_is_done(
scheduling_tasks = [
asyncio.create_task(self._publisher()),
asyncio.create_task(self._process_event_queue()),
asyncio.create_task(self.driver.poll()),
]

if min_required_realizations > 0:
scheduling_tasks.append(
asyncio.create_task(
self._stop_long_running_jobs(min_required_realizations)
)
cancel_when_execute_is_done(self._update_avg_job_runtime())

start = asyncio.Event()
sem = asyncio.BoundedSemaphore(self._max_running or len(self._jobs))
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(
job(start, sem, self._max_submit)
)

start.set()
results = await asyncio.gather(
*self._tasks.values(), return_exceptions=True
)
for result in results:
if isinstance(result, Exception):
logger.error(result)

await self.driver.finish()
scheduling_tasks.append(asyncio.create_task(self._update_avg_job_runtime()))

sem = asyncio.BoundedSemaphore(self._max_running or len(self._jobs))
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(job(sem, self._max_submit))

async def gather_realization_jobs() -> list[BaseException | None]:
"""This makes sure that all the tasks are completed, where afterwards
we cancel scheduling_tasks.It returns list of task exceptions or None.
"""
try:
return await asyncio.gather(
*self._tasks.values(), return_exceptions=True
)
finally:
for scheduling_task in scheduling_tasks:
scheduling_task.cancel()

job_results: Optional[list[BaseException | None]] = None
try:
# there are two types of tasks and each type is handled differently:
# -`gather_realization_jobs`` does not raise; rather, each job task's result
# is collected into the returning list. Exceptions from the job tasks are
# handled in the `else` branch after the evaluation has stopped.
# -If exception occurs, it must necessarily come from `scheduling tasks``
job_results = (
await asyncio.gather(gather_realization_jobs(), *scheduling_tasks)
)[0]
except (asyncio.CancelledError, Exception) as e:
for job_task in self._tasks.values():
job_task.cancel()
# suppress potential error during driver.kill
with suppress(Exception):
await job_task
for scheduling_task in scheduling_tasks:
scheduling_task.cancel()
# Log and re-raise non-cancellation errors.
if not isinstance(e, asyncio.CancelledError):
logger.error(str(e))
raise e
else:
for result in job_results or []:
if not isinstance(result, asyncio.CancelledError) and isinstance(
result, Exception
):
logger.error(str(result))
raise result

await self.driver.finish()

if self._cancelled:
logger.debug("scheduler cancelled, stopping jobs...")
Expand Down
49 changes: 49 additions & 0 deletions tests/integration_tests/scheduler/test_openpbs_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import asyncio
from functools import partial

import pytest

from ert.cli import ENSEMBLE_EXPERIMENT_MODE
from ert.cli.main import ErtCliError
from ert.scheduler.openpbs_driver import OpenPBSDriver
from tests.integration_tests.run_cli import run_cli

from .conftest import mock_bin
Expand All @@ -26,3 +31,47 @@ def test_openpbs_driver_with_poly_example():
"--enable-scheduler",
"poly.ert",
)


async def mock_failure(message, *args, **kwargs):
raise RuntimeError(message)


@pytest.mark.timeout(30)
@pytest.mark.integration_test
@pytest.mark.usefixtures("copy_poly_case")
def test_openpbs_driver_with_poly_example_failing_submit_fails_ert_and_propagates_exception_to_user(
monkeypatch, caplog
):
monkeypatch.setattr(
OpenPBSDriver, "submit", partial(mock_failure, "Submit job failed")
)
with open("poly.ert", mode="a+", encoding="utf-8") as f:
f.write("QUEUE_SYSTEM TORQUE\nNUM_REALIZATIONS 2")
with pytest.raises(ErtCliError):
run_cli(
ENSEMBLE_EXPERIMENT_MODE,
"--enable-scheduler",
"poly.ert",
)
assert "RuntimeError: Submit job failed" in caplog.text


@pytest.mark.timeout(30)
@pytest.mark.integration_test
@pytest.mark.usefixtures("copy_poly_case")
def test_openpbs_driver_with_poly_example_failing_poll_fails_ert_and_propagates_exception_to_user(
monkeypatch, caplog
):
monkeypatch.setattr(
OpenPBSDriver, "poll", partial(mock_failure, "Status polling failed")
)
with open("poly.ert", mode="a+", encoding="utf-8") as f:
f.write("QUEUE_SYSTEM TORQUE\nNUM_REALIZATIONS 2")
with pytest.raises(ErtCliError):
run_cli(
ENSEMBLE_EXPERIMENT_MODE,
"--enable-scheduler",
"poly.ert",
)
assert "RuntimeError: Status polling failed" in caplog.text
40 changes: 40 additions & 0 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import shutil
import time
from functools import partial
from pathlib import Path
from typing import List

Expand Down Expand Up @@ -468,3 +469,42 @@ async def wait():
for start, next_start in zip(run_start_times[:-1], run_start_times[1:])
]
assert min(deltas) >= submit_sleep * 0.8


async def mock_failure(message, *args, **kwargs):
raise RuntimeError(message)


@pytest.mark.timeout(5)
async def test_that_driver_poll_exceptions_are_propagated(mock_driver, realization):
driver = mock_driver()
driver.poll = partial(mock_failure, "Status polling failed")

sch = scheduler.Scheduler(driver, [realization])

with pytest.raises(RuntimeError, match="Status polling failed"):
await sch.execute()


@pytest.mark.timeout(5)
async def test_that_publisher_exceptions_are_propagated(mock_driver, realization):
driver = mock_driver()

sch = scheduler.Scheduler(driver, [realization])
sch._publisher = partial(mock_failure, "Publisher failed")

with pytest.raises(RuntimeError, match="Publisher failed"):
await sch.execute()


@pytest.mark.timeout(5)
async def test_that_process_event_queue_exceptions_are_propagated(
mock_driver, realization
):
driver = mock_driver()

sch = scheduler.Scheduler(driver, [realization])
sch._process_event_queue = partial(mock_failure, "Processing event queue failed")

with pytest.raises(RuntimeError, match="Processing event queue failed"):
await sch.execute()
38 changes: 0 additions & 38 deletions tests/unit_tests/test_async_utils.py

This file was deleted.

0 comments on commit 4a2bf1b

Please sign in to comment.