Skip to content

Commit

Permalink
Add background_tasks for canceling tasks
Browse files Browse the repository at this point in the history
This contextmanager makes it simpler to ensure that infinitely-running
tasks get canceled when appropriate, even when an exception occurs.
  • Loading branch information
pinkwah committed Dec 14, 2023
1 parent 07818cf commit 65c62c6
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 23 deletions.
38 changes: 37 additions & 1 deletion src/ert/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,50 @@
from __future__ import annotations

import asyncio
import logging
import sys
from contextlib import asynccontextmanager
from traceback import print_exception
from typing import Any, Coroutine, Generator, TypeVar, Union
from typing import (
Any,
AsyncGenerator,
Coroutine,
Generator,
MutableSequence,
TypeVar,
Union,
)

logger = logging.getLogger(__name__)

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)


@asynccontextmanager
async def background_tasks() -> AsyncGenerator[Any, Any]:
"""Context manager for long-living tasks that cancel when exiting the
context
"""

tasks: MutableSequence[asyncio.Task[Any]] = []

def add(coro: Coroutine[Any, Any, Any]) -> None:
tasks.append(asyncio.create_task(coro))

try:
yield add
finally:
for t in tasks:
t.cancel()
for exc in await asyncio.gather(*tasks, return_exceptions=True):
if isinstance(exc, asyncio.CancelledError):
continue
logger.error(str(exc), exc_info=exc)
tasks.clear()


def new_event_loop() -> asyncio.AbstractEventLoop:
loop = asyncio.new_event_loop()
loop.set_task_factory(_create_task)
Expand Down
10 changes: 3 additions & 7 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ async def kill(self, iens: int) -> None:
iens: Realization number.
"""

def create_poll_task(self) -> Optional[asyncio.Task[None]]:
"""Create a `asyncio.Task` for polling the cluster.
Returns:
`asyncio.Task`, or None if polling is not applicable (eg. for LocalDriver)
"""
return None
@abstractmethod
async def poll(self) -> None:
"""Poll for new job events"""
3 changes: 3 additions & 0 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ async def _wait_until_finish(
except asyncio.CancelledError:
proc.terminate()
await self.event_queue.put((iens, JobEvent.ABORTED))

async def poll(self) -> None:
"""LocalDriver does not poll"""
27 changes: 12 additions & 15 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from websockets import Headers
from websockets.client import connect

from ert.async_utils import background_tasks
from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED
from ert.scheduler.driver import Driver, JobEvent
from ert.scheduler.job import Job
Expand Down Expand Up @@ -124,23 +125,19 @@ async def execute(
if queue_evaluators is not None:
logger.warning(f"Ignoring queue_evaluators: {queue_evaluators}")

publisher_task = asyncio.create_task(self._publisher())
poller_task = self.driver.create_poll_task()
event_queue_task = asyncio.create_task(self._process_event_queue())
async with background_tasks() as cancel_when_execute_is_done:
cancel_when_execute_is_done(self._publisher())
cancel_when_execute_is_done(self._process_event_queue())
cancel_when_execute_is_done(self.driver.poll())

start = asyncio.Event()
sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(job(start, sem))
start = asyncio.Event()
sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(job(start, sem))

start.set()
for task in self._tasks.values():
await task

publisher_task.cancel()
event_queue_task.cancel()
if poller_task:
poller_task.cancel()
start.set()
for task in self._tasks.values():
await task

if self._cancelled:
return EVTYPE_ENSEMBLE_CANCELLED
Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/test_async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio

import pytest

from ert.async_utils import background_tasks


@pytest.mark.timeout(1)
async def test_background_tasks(caplog):
current_task_future = asyncio.Future()

async def task():
current_task_future.set_result(asyncio.current_task())
await asyncio.sleep(100)

async with background_tasks() as bt:
bt(task())
current_task = await current_task_future
assert not current_task.done()

assert current_task.done()
assert caplog.records == []


@pytest.mark.timeout(1)
async def test_background_tasks_with_exception(caplog):
started = asyncio.Event()

async def task():
started.set()
raise ValueError("Uh-oh!")

async with background_tasks() as bt:
bt(task())
await started.wait()

assert len(caplog.records) == 1
assert caplog.records[0].message == "Uh-oh!"

0 comments on commit 65c62c6

Please sign in to comment.