Skip to content

Commit

Permalink
Make ErtThread signal and re-raising opt-in
Browse files Browse the repository at this point in the history
This commit does the following:
1. Adds a function `set_signal_handler` that must be called before
   re-raising works
2. Adds a kwarg `should_raise` that must be True for re-raising to
   work
3. Workflow jobs don't reraise, only log
  • Loading branch information
pinkwah committed Mar 6, 2024
1 parent 06da87b commit 55210ff
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 26 deletions.
41 changes: 25 additions & 16 deletions src/_ert/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
import threading
from threading import Thread as _Thread
from types import FrameType
from typing import Optional
from typing import Any, Callable, Iterable, Optional

logger = logging.getLogger(__name__)


_current_exception: Optional[ErtThreadError] = None
_can_raise = False


class ErtThreadError(Exception):
def __init__(self, exception: BaseException, thread: _Thread) -> None:
super().__init__(repr(exception))
Expand All @@ -26,29 +30,28 @@ def __str__(self) -> str:


class ErtThread(_Thread):
def __init__(
self,
target: Callable[..., Any],
name: str | None = None,
args: Iterable[Any] = (),
*,
daemon: bool | None = None,
should_raise: bool = True,
) -> None:
super().__init__(target=target, name=name, args=args, daemon=daemon)
self._should_raise = should_raise

def run(self) -> None:
try:
super().run()
except BaseException as exc:
logger.error(str(exc), exc_info=exc)

# Re-raising this exception on main thread can have unknown
# repercussions in production. Potentially, an unnecessary thread
# was dying due to an exception and we didn't care, but with this
# change this would bring down all of Ert. We take the conservative
# approach and make re-raising optional, and enable it only for
# the test suite.
if os.environ.get("_ERT_THREAD_RAISE", ""):
if _can_raise and self._should_raise:
_raise_on_main_thread(exc)


_current_exception: Optional[BaseException] = None


def _raise_on_main_thread(exception: BaseException) -> None:
if threading.main_thread() is threading.current_thread():
raise exception

global _current_exception # noqa: PLW0603
_current_exception = ErtThreadError(exception, threading.current_thread())

Expand All @@ -68,4 +71,10 @@ def _handler(signum: int, frametype: FrameType | None) -> None:
raise current_exception


signal.signal(signal.SIGUSR1, _handler)
def set_signal_handler() -> None:
global _can_raise # noqa: PLW0603
if _can_raise:
return

signal.signal(signal.SIGUSR1, _handler)
_can_raise = True
4 changes: 4 additions & 0 deletions src/ert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from resdata import set_abort_handler

import ert.shared
from _ert.threading import set_signal_handler
from ert.cli import (
ENSEMBLE_EXPERIMENT_MODE,
ENSEMBLE_SMOOTHER_MODE,
Expand Down Expand Up @@ -537,6 +538,9 @@ def main() -> None:
warnings.filterwarnings("ignore", category=DeprecationWarning)
locale.setlocale(locale.LC_NUMERIC, "C")

# Have ErtThread re-raise uncaught exceptions on main thread
set_signal_handler()

args = ert_parser(None, sys.argv[1:])

log_dir = os.path.abspath(args.logdir)
Expand Down
18 changes: 12 additions & 6 deletions src/ert/gui/tools/plugins/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,22 @@ def run(self):

run_function = partial(self.__runWorkflowJob, plugin, arguments)

workflow_job_thread = ErtThread(name="ert_gui_workflow_job_thread")
workflow_job_thread.daemon = True
workflow_job_thread.run = run_function
workflow_job_thread = ErtThread(
name="ert_gui_workflow_job_thread",
target=run_function,
daemon=True,
should_raise=False,
)
workflow_job_thread.start()

poll_function = partial(self.__pollRunner, dialog)

self.poll_thread = ErtThread(name="ert_gui_workflow_job_poll_thread")
self.poll_thread.daemon = True
self.poll_thread.run = poll_function
self.poll_thread = ErtThread(
name="ert_gui_workflow_job_poll_thread",
target=poll_function,
daemon=True,
should_raise=False,
)
self.poll_thread.start()

dialog.show()
Expand Down
9 changes: 6 additions & 3 deletions src/ert/gui/tools/workflows/run_workflow_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,12 @@ def startWorkflow(self):
)
self._running_workflow_dialog.closeButtonPressed.connect(self.cancelWorkflow)

workflow_thread = ErtThread(name="ert_gui_workflow_thread")
workflow_thread.daemon = True
workflow_thread.run = self.runWorkflow
workflow_thread = ErtThread(
name="ert_gui_workflow_thread",
target=self.runWorkflow,
daemon=True,
should_raise=False,
)

workflow = self.ert.ert_config.workflows[self.getCurrentWorkflowName()]
self._workflow_runner = WorkflowRunner(
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import TYPE_CHECKING, cast
from unittest.mock import MagicMock

from _ert.threading import set_signal_handler

if sys.version_info >= (3, 9):
from importlib.resources import files
else:
Expand Down Expand Up @@ -50,7 +52,7 @@ def log_check():
@pytest.fixture(scope="session", autouse=True)
def _reraise_thread_exceptions_on_main_thread():
"""Allow `ert.shared.threading.ErtThread` to re-raise exceptions on main thread"""
os.environ["_ERT_THREAD_RAISE"] = "1"
set_signal_handler()


@pytest.fixture
Expand Down

0 comments on commit 55210ff

Please sign in to comment.