diff --git a/src/asphalt/core/_concurrent.py b/src/asphalt/core/_concurrent.py index d4172704..0e57a2d5 100644 --- a/src/asphalt/core/_concurrent.py +++ b/src/asphalt/core/_concurrent.py @@ -22,6 +22,7 @@ if sys.version_info >= (3, 10): from typing import TypeAlias else: + from exceptiongroup import ExceptionGroup from typing_extensions import TypeAlias T_Retval = TypeVar("T_Retval") @@ -179,10 +180,13 @@ async def wait_all_tasks_finished(self) -> None: async def _run( self, ctx: Context, resource_name: str, *, task_status: TaskStatus[None] ) -> None: - async with create_task_group() as self._task_group: - ctx.add_resource(self, resource_name) - task_status.started() - await self._start() + try: + async with create_task_group() as self._task_group: + ctx.add_resource(self, resource_name) + task_status.started() + await self._start() + except ExceptionGroup as excgrp: + raise excgrp.exceptions[0] async def _start( self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED diff --git a/tests/test_concurrent.py b/tests/test_concurrent.py index d24902db..b9c5a37d 100644 --- a/tests/test_concurrent.py +++ b/tests/test_concurrent.py @@ -108,13 +108,9 @@ async def taskfunc() -> NoReturn: assert len(excinfo.value.exceptions) == 1 assert isinstance(excinfo.value.exceptions[0], ExceptionGroup) - excgrp0 = excinfo.value.exceptions[0] - assert len(excgrp0.exceptions) == 1 - assert isinstance(excgrp0, ExceptionGroup) - excgrp1 = excgrp0.exceptions[0] - assert isinstance(excgrp1, ExceptionGroup) - assert len(excgrp1.exceptions) == 1 - assert str(excgrp1.exceptions[0]) == "foo" + excgrp = excinfo.value.exceptions[0] + assert len(excgrp.exceptions) == 1 + assert str(excgrp.exceptions[0]) == "foo" async def test_start_exception_handled(self) -> None: handled_exception: Exception | None = None