Skip to content

Commit

Permalink
Raise exception from sub-task group
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 19, 2024
1 parent b2d16d3 commit b02717f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
14 changes: 10 additions & 4 deletions src/asphalt/core/_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
else:
from typing_extensions import TypeAlias

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

T_Retval = TypeVar("T_Retval")
TeardownAction: TypeAlias = Union[Callable[[], Any], Literal["cancel"], None]
ExceptionHandler: TypeAlias = Callable[[Exception], bool]
Expand Down Expand Up @@ -179,10 +182,13 @@ async def wait_all_tasks_finished(self) -> None:
async def _run(
self, ctx: Context, resource_name: str, *, task_status: TaskStatus[None]
) -> None:
async with create_task_group() as self._task_group:
ctx.add_resource(self, resource_name)
task_status.started()
await self._start()
try:
async with create_task_group() as self._task_group:
ctx.add_resource(self, resource_name)
task_status.started()
await self._start()
except ExceptionGroup as excgrp:
raise excgrp.exceptions[0]

async def _start(
self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED
Expand Down
10 changes: 3 additions & 7 deletions tests/test_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,9 @@ async def taskfunc() -> NoReturn:

assert len(excinfo.value.exceptions) == 1
assert isinstance(excinfo.value.exceptions[0], ExceptionGroup)
excgrp0 = excinfo.value.exceptions[0]
assert len(excgrp0.exceptions) == 1
assert isinstance(excgrp0, ExceptionGroup)
excgrp1 = excgrp0.exceptions[0]
assert isinstance(excgrp1, ExceptionGroup)
assert len(excgrp1.exceptions) == 1
assert str(excgrp1.exceptions[0]) == "foo"
excgrp = excinfo.value.exceptions[0]
assert len(excgrp.exceptions) == 1
assert str(excgrp.exceptions[0]) == "foo"

async def test_start_exception_handled(self) -> None:
handled_exception: Exception | None = None
Expand Down

0 comments on commit b02717f

Please sign in to comment.