Skip to content

Commit

Permalink
Unify returncode values for different drivers when process is killed by
Browse files Browse the repository at this point in the history
signal

This should fix a bug in azure bleeding.
  • Loading branch information
JHolba committed Mar 21, 2024
1 parent 64fe249 commit 795982a
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 48 deletions.
3 changes: 3 additions & 0 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from ert.scheduler.event import Event

SIGNAL_OFFSET = 128
"""Bash and other shells add an offset of 128 to the signal value when a process exited due to a signal"""


class Driver(ABC):
"""Adapter for the HPC cluster."""
Expand Down
1 change: 0 additions & 1 deletion src/ert/scheduler/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class StartedEvent:
class FinishedEvent:
iens: int
returncode: int
aborted: bool = False


Event = Union[StartedEvent, FinishedEvent]
13 changes: 8 additions & 5 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import MutableMapping, Optional

from ert.scheduler.driver import Driver
from ert.scheduler.driver import SIGNAL_OFFSET, Driver
from ert.scheduler.event import FinishedEvent, StartedEvent

_TERMINATE_TIMEOUT = 10.0
Expand Down Expand Up @@ -59,9 +59,7 @@ async def _run(self, iens: int, executable: str, /, *args: str) -> None:
await self.event_queue.put(FinishedEvent(iens=iens, returncode=returncode))
except asyncio.CancelledError:
returncode = await self._kill(proc)
await self.event_queue.put(
FinishedEvent(iens=iens, returncode=returncode, aborted=True)
)
await self.event_queue.put(FinishedEvent(iens=iens, returncode=returncode))

async def _init(self, iens: int, executable: str, /, *args: str) -> Process:
"""This method exists to allow for mocking it in tests"""
Expand All @@ -82,7 +80,12 @@ async def _kill(self, proc: Process) -> int:
await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT)
except asyncio.TimeoutError:
proc.kill()
return await proc.wait()
ret_val = await proc.wait()
# the returncode of a subprocess will be the negative signal value
# if it terminated due to a signal.
# https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess.returncode
# we return SIGNAL_OFFSET + signal value to be in line with lfs/pbs drivers.
return -ret_val + SIGNAL_OFFSET

async def poll(self) -> None:
"""LocalDriver does not poll"""
10 changes: 4 additions & 6 deletions src/ert/scheduler/lsf_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from ert.scheduler.driver import Driver
from ert.scheduler.driver import SIGNAL_OFFSET, Driver
from ert.scheduler.event import Event, FinishedEvent, StartedEvent

_POLL_PERIOD = 2.0 # seconds
LSF_FAILED_JOB = SIGNAL_OFFSET + 65 # first non signal returncode
"""Return code we use when lsf reports failed jobs"""

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -267,11 +269,7 @@ async def _process_job_update(self, job_id: str, new_state: AnyJob) -> None:
logger.debug(
f"Realization {iens} (LSF-id: {self._iens2jobid[iens]}) failed"
)
event = FinishedEvent(
iens=iens,
returncode=1,
aborted=True,
)
event = FinishedEvent(iens=iens, returncode=LSF_FAILED_JOB)

elif isinstance(new_state, FinishedJobSuccess):
logger.debug(
Expand Down
4 changes: 1 addition & 3 deletions src/ert/scheduler/openpbs_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,12 @@ async def _process_job_update(self, job_id: str, new_state: AnyJob) -> None:
event = StartedEvent(iens=iens)
elif isinstance(new_state, FinishedJob):
assert new_state.returncode is not None
aborted = new_state.returncode >= 256
event = FinishedEvent(
iens=iens,
returncode=new_state.returncode,
aborted=aborted,
)

if aborted:
if new_state.returncode != 0:
logger.debug(
f"Realization {iens} (PBS-id: {self._iens2jobid[iens]}) failed"
)
Expand Down
4 changes: 2 additions & 2 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ert.constant_filenames import CERT_FILE
from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED
from ert.scheduler.driver import Driver
from ert.scheduler.driver import SIGNAL_OFFSET, Driver
from ert.scheduler.event import FinishedEvent
from ert.scheduler.job import Job
from ert.scheduler.job import State as JobState
Expand Down Expand Up @@ -282,7 +282,7 @@ async def _process_event_queue(self) -> None:
job.started.set()

if isinstance(event, FinishedEvent):
if event.aborted:
if event.returncode >= SIGNAL_OFFSET:
job.returncode.cancel()
else:
job.returncode.set_result(event.returncode)
Expand Down
24 changes: 11 additions & 13 deletions tests/integration_tests/scheduler/test_generic_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import pytest

from ert.scheduler.driver import SIGNAL_OFFSET
from ert.scheduler.local_driver import LocalDriver
from ert.scheduler.lsf_driver import LsfDriver
from ert.scheduler.lsf_driver import LSF_FAILED_JOB, LsfDriver
from ert.scheduler.openpbs_driver import OpenPBSDriver
from tests.utils import poll

Expand Down Expand Up @@ -60,14 +61,14 @@ async def test_submit_something_that_fails(driver, tmp_path):

expected_returncode = 42
if isinstance(driver, LsfDriver):
expected_returncode = 1
expected_returncode = LSF_FAILED_JOB

async def finished(iens, returncode, aborted):
async def finished(iens, returncode):
assert iens == 0
assert returncode == expected_returncode

if isinstance(driver, LsfDriver):
assert aborted is True
assert returncode != 0

nonlocal finished_called
finished_called = True
Expand All @@ -81,22 +82,19 @@ async def finished(iens, returncode, aborted):
async def test_kill(driver, tmp_path):
os.chdir(tmp_path)
aborted_called = False

expected_returncodes = [1]
if isinstance(driver, OpenPBSDriver):
expected_returncodes = [128 + signal.SIGTERM, 256 + signal.SIGTERM]

if isinstance(driver, LocalDriver):
expected_returncodes = [-signal.SIGTERM]
expected_returncodes = [
LSF_FAILED_JOB,
SIGNAL_OFFSET + signal.SIGTERM,
256 + signal.SIGTERM,
]

async def started(iens):
nonlocal driver
await driver.kill(iens)

async def finished(iens, returncode, aborted):
async def finished(iens, returncode):
assert iens == 0
assert returncode in expected_returncodes
assert aborted is True

nonlocal aborted_called
aborted_called = True
Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/scheduler/test_lsf_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from ert.scheduler import LsfDriver
from ert.scheduler.lsf_driver import LSF_FAILED_JOB
from tests.utils import poll

from .conftest import mock_bin
Expand Down Expand Up @@ -76,9 +77,9 @@ async def test_submit_to_named_queue(tmp_path, caplog):
"actual_returncode, returncode_that_ert_sees",
[
([0, 0]),
([1, 1]),
([2, 1]),
([255, 1]),
([1, LSF_FAILED_JOB]),
([2, LSF_FAILED_JOB]),
([255, LSF_FAILED_JOB]),
([256, 0]), # return codes are 8 bit.
],
)
Expand All @@ -93,10 +94,9 @@ async def test_lsf_driver_masks_returncode(
os.chdir(tmp_path)
driver = LsfDriver()

async def finished(iens, returncode, aborted):
async def finished(iens, returncode):
assert iens == 0
assert returncode == returncode_that_ert_sees
assert aborted == (returncode_that_ert_sees != 0)

await driver.submit(0, "sh", "-c", f"exit {actual_returncode}")
await poll(driver, {0}, finished=finished)
4 changes: 3 additions & 1 deletion tests/unit_tests/scheduler/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import signal
from typing import Any, Coroutine, Literal

import pytest

from ert.scheduler.driver import SIGNAL_OFFSET
from ert.scheduler.local_driver import LocalDriver


Expand Down Expand Up @@ -43,7 +45,7 @@ async def _kill(self, iens):
await self._mock_kill(iens)
else:
await self._mock_kill()
return -15
return signal.SIGTERM + SIGNAL_OFFSET


@pytest.fixture
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/scheduler/test_local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from ert.scheduler import local_driver
from ert.scheduler.driver import SIGNAL_OFFSET
from ert.scheduler.event import FinishedEvent, StartedEvent
from ert.scheduler.local_driver import LocalDriver

Expand Down Expand Up @@ -43,7 +44,7 @@ async def test_kill():
assert await driver.event_queue.get() == StartedEvent(iens=42)
await driver.kill(42)
assert await driver.event_queue.get() == FinishedEvent(
iens=42, returncode=-signal.SIGTERM, aborted=True
iens=42, returncode=signal.SIGTERM + SIGNAL_OFFSET
)


Expand Down Expand Up @@ -73,7 +74,7 @@ async def test_kill_unresponsive_process(monkeypatch, tmp_path):

await driver.kill(42)
assert await driver.event_queue.get() == FinishedEvent(
iens=42, returncode=-signal.SIGKILL, aborted=True
iens=42, returncode=signal.SIGKILL + SIGNAL_OFFSET
)


Expand All @@ -96,7 +97,7 @@ async def test_that_killing_killed_job_does_not_raise():
assert await driver.event_queue.get() == StartedEvent(iens=23)
await driver.kill(23)
assert await driver.event_queue.get() == FinishedEvent(
iens=23, returncode=-signal.SIGTERM, aborted=True
iens=23, returncode=signal.SIGTERM + SIGNAL_OFFSET
)
# Killing a dead job should not raise an exception
await driver.kill(23)
Expand Down
7 changes: 3 additions & 4 deletions tests/unit_tests/scheduler/test_lsf_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ert.scheduler import LsfDriver
from ert.scheduler.lsf_driver import (
BSUB_FLAKY_SSH,
LSF_FAILED_JOB,
FinishedEvent,
FinishedJobFailure,
FinishedJobSuccess,
Expand Down Expand Up @@ -104,17 +105,15 @@ async def mocked_submit(self, iens, *_args, **_kwargs):
assert isinstance(state, RunningJob)
elif started and finished_success and finished_failure:
assert len(events) <= 2 # The StartedEvent is not required
assert events[-1] == FinishedEvent(
iens=0, returncode=events[-1].returncode, aborted=events[-1].aborted
)
assert events[-1] == FinishedEvent(iens=0, returncode=events[-1].returncode)
assert "1" not in driver._jobs
elif started is True and finished_success and not finished_failure:
assert len(events) <= 2 # The StartedEvent is not required
assert events[-1] == FinishedEvent(iens=0, returncode=0)
assert "1" not in driver._jobs
elif started is True and not finished_success and finished_failure:
assert len(events) <= 2 # The StartedEvent is not required
assert events[-1] == FinishedEvent(iens=0, returncode=1, aborted=True)
assert events[-1] == FinishedEvent(iens=0, returncode=LSF_FAILED_JOB)
assert "1" not in driver._jobs


Expand Down
9 changes: 4 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,10 @@ async def poll(driver: Driver, expected: set[int], *, started=None, finished=Non
started : Callable[[int], None]
Called for each job when it starts. Its associated realisation index is
passed.
finished : Callable[[int, int, bool], None]
finished : Callable[[int, int], None]
Called for each job when it finishes. The first argument is the
associated realisation index, the second is the returncode of the job
process and the third argument is whether the job was explicitly
aborted.
associated realisation index and the second is the returncode of the job
process.
"""
from ert.scheduler.event import FinishedEvent, StartedEvent
Expand All @@ -132,7 +131,7 @@ async def poll(driver: Driver, expected: set[int], *, started=None, finished=Non
await started(event.iens)
elif isinstance(event, FinishedEvent):
if finished is not None:
await finished(event.iens, event.returncode, event.aborted)
await finished(event.iens, event.returncode)
completed.add(event.iens)
if completed == expected:
break
Expand Down

0 comments on commit 795982a

Please sign in to comment.