diff --git a/src/_ert/threading.py b/src/_ert/threading.py index 3893deb7f5c..b5c0b98c90e 100644 --- a/src/_ert/threading.py +++ b/src/_ert/threading.py @@ -6,11 +6,15 @@ import threading from threading import Thread as _Thread from types import FrameType -from typing import Optional +from typing import Any, Callable, Iterable, Optional logger = logging.getLogger(__name__) +_current_exception: Optional[ErtThreadError] = None +_can_raise = False + + class ErtThreadError(Exception): def __init__(self, exception: BaseException, thread: _Thread) -> None: super().__init__(repr(exception)) @@ -26,29 +30,28 @@ def __str__(self) -> str: class ErtThread(_Thread): + def __init__( + self, + target: Callable[..., Any], + name: str | None = None, + args: Iterable[Any] = (), + *, + daemon: bool | None = None, + should_raise: bool = True, + ) -> None: + super().__init__(target=target, name=name, args=args, daemon=daemon) + self._should_raise = should_raise + def run(self) -> None: try: super().run() except BaseException as exc: logger.error(str(exc), exc_info=exc) - - # Re-raising this exception on main thread can have unknown - # repercussions in production. Potentially, an unnecessary thread - # was dying due to an exception and we didn't care, but with this - # change this would bring down all of Ert. We take the conservative - # approach and make re-raising optional, and enable it only for - # the test suite. - if os.environ.get("_ERT_THREAD_RAISE", ""): + if _can_raise and self._should_raise: _raise_on_main_thread(exc) -_current_exception: Optional[BaseException] = None - - def _raise_on_main_thread(exception: BaseException) -> None: - if threading.main_thread() is threading.current_thread(): - raise exception - global _current_exception # noqa: PLW0603 _current_exception = ErtThreadError(exception, threading.current_thread()) @@ -68,4 +71,10 @@ def _handler(signum: int, frametype: FrameType | None) -> None: raise current_exception -signal.signal(signal.SIGUSR1, _handler) +def set_signal_handler() -> None: + global _can_raise # noqa: PLW0603 + if _can_raise: + return + + signal.signal(signal.SIGUSR1, _handler) + _can_raise = True diff --git a/src/ert/__main__.py b/src/ert/__main__.py index e0443e25ded..927bb296d35 100755 --- a/src/ert/__main__.py +++ b/src/ert/__main__.py @@ -14,6 +14,7 @@ from resdata import set_abort_handler import ert.shared +from _ert.threading import set_signal_handler from ert.cli import ( ENSEMBLE_EXPERIMENT_MODE, ENSEMBLE_SMOOTHER_MODE, @@ -537,6 +538,9 @@ def main() -> None: warnings.filterwarnings("ignore", category=DeprecationWarning) locale.setlocale(locale.LC_NUMERIC, "C") + # Have ErtThread re-raise uncaught exceptions on main thread + set_signal_handler() + args = ert_parser(None, sys.argv[1:]) log_dir = os.path.abspath(args.logdir) diff --git a/src/ert/gui/tools/plugins/plugin_runner.py b/src/ert/gui/tools/plugins/plugin_runner.py index 078ea76b33c..b5ae705ea1c 100644 --- a/src/ert/gui/tools/plugins/plugin_runner.py +++ b/src/ert/gui/tools/plugins/plugin_runner.py @@ -36,16 +36,22 @@ def run(self): run_function = partial(self.__runWorkflowJob, plugin, arguments) - workflow_job_thread = ErtThread(name="ert_gui_workflow_job_thread") - workflow_job_thread.daemon = True - workflow_job_thread.run = run_function + workflow_job_thread = ErtThread( + name="ert_gui_workflow_job_thread", + target=run_function, + daemon=True, + should_raise=False, + ) workflow_job_thread.start() poll_function = partial(self.__pollRunner, dialog) - self.poll_thread = ErtThread(name="ert_gui_workflow_job_poll_thread") - self.poll_thread.daemon = True - self.poll_thread.run = poll_function + self.poll_thread = ErtThread( + name="ert_gui_workflow_job_poll_thread", + target=poll_function, + daemon=True, + should_raise=False, + ) self.poll_thread.start() dialog.show() diff --git a/src/ert/gui/tools/workflows/run_workflow_widget.py b/src/ert/gui/tools/workflows/run_workflow_widget.py index aedacc5448b..76bb588ed1d 100644 --- a/src/ert/gui/tools/workflows/run_workflow_widget.py +++ b/src/ert/gui/tools/workflows/run_workflow_widget.py @@ -115,9 +115,12 @@ def startWorkflow(self): ) self._running_workflow_dialog.closeButtonPressed.connect(self.cancelWorkflow) - workflow_thread = ErtThread(name="ert_gui_workflow_thread") - workflow_thread.daemon = True - workflow_thread.run = self.runWorkflow + workflow_thread = ErtThread( + name="ert_gui_workflow_thread", + target=self.runWorkflow, + daemon=True, + should_raise=False, + ) workflow = self.ert.ert_config.workflows[self.getCurrentWorkflowName()] self._workflow_runner = WorkflowRunner( diff --git a/tests/conftest.py b/tests/conftest.py index ac5223874f6..e4f83281242 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,8 @@ from typing import TYPE_CHECKING, cast from unittest.mock import MagicMock +from _ert.threading import set_signal_handler + if sys.version_info >= (3, 9): from importlib.resources import files else: @@ -50,7 +52,7 @@ def log_check(): @pytest.fixture(scope="session", autouse=True) def _reraise_thread_exceptions_on_main_thread(): """Allow `ert.shared.threading.ErtThread` to re-raise exceptions on main thread""" - os.environ["_ERT_THREAD_RAISE"] = "1" + set_signal_handler() @pytest.fixture