Skip to content

Commit 8213e5e

Browse files
committed
Activity worker: don't raise on drain + refactoring
1 parent 13b27be commit 8213e5e

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

temporalio/worker/_activity.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,21 @@ def __init__(
138138
self._dynamic_activity = defn
139139

140140
async def run(self) -> None:
141-
# Create a task that fails when we get a failure on the queue
142-
async def raise_from_queue() -> NoReturn:
141+
"""Continually poll for activity tasks and dispatch to handlers."""
142+
143+
async def raise_from_exception_queue() -> NoReturn:
143144
raise await self._fail_worker_exception_queue.get()
144145

145-
exception_task = asyncio.create_task(raise_from_queue())
146+
exception_task = asyncio.create_task(raise_from_exception_queue())
146147

147-
# Continually poll for activity work
148148
while True:
149149
try:
150-
# Poll for a task
151150
poll_task = asyncio.create_task(
152151
self._bridge_worker().poll_activity_task()
153152
)
154153
await asyncio.wait(
155154
[poll_task, exception_task], return_when=asyncio.FIRST_COMPLETED
156-
) # type: ignore
157-
# If exception for failing the worker happened, raise it.
158-
# Otherwise, the poll succeeded.
155+
)
159156
if exception_task.done():
160157
poll_task.cancel()
161158
await exception_task
@@ -167,11 +164,14 @@ async def raise_from_queue() -> NoReturn:
167164
# size of 1000 should be plenty for the heartbeat queue.
168165
activity = _RunningActivity(pending_heartbeats=asyncio.Queue(1000))
169166
activity.task = asyncio.create_task(
170-
self._run_activity(task.task_token, task.start, activity)
167+
self._handle_start_activity_task(
168+
task.task_token, task.start, activity
169+
)
171170
)
172171
self._running_activities[task.task_token] = activity
173172
elif task.HasField("cancel"):
174-
self._cancel(task.task_token, task.cancel)
173+
# TODO(nexus-prerelease): does the task get removed from running_activities?
174+
self._handle_cancel_activity_task(task.task_token, task.cancel)
175175
else:
176176
raise RuntimeError(f"Unrecognized activity task: {task}")
177177
except temporalio.bridge.worker.PollShutdownError:
@@ -201,16 +201,15 @@ async def drain_poll_queue(self) -> None:
201201

202202
# Only call this after run()/drain_poll_queue() have returned. This will not
203203
# raise an exception.
204-
# TODO(dan): based on the comment above it looks like the intention may have been to use
205-
# return_exceptions=True
206204
async def wait_all_completed(self) -> None:
207205
running_tasks = [v.task for v in self._running_activities.values() if v.task]
208206
if running_tasks:
209-
await asyncio.gather(*running_tasks, return_exceptions=False)
207+
await asyncio.gather(*running_tasks, return_exceptions=True)
210208

211-
def _cancel(
209+
def _handle_cancel_activity_task(
212210
self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel
213211
) -> None:
212+
"""Request cancellation of a running activity task."""
214213
activity = self._running_activities.get(task_token)
215214
if not activity:
216215
warnings.warn(f"Cannot find activity to cancel for token {task_token!r}")
@@ -275,12 +274,17 @@ async def _heartbeat_async(
275274
)
276275
activity.cancel(cancelled_due_to_heartbeat_error=err)
277276

278-
async def _run_activity(
277+
async def _handle_start_activity_task(
279278
self,
280279
task_token: bytes,
281280
start: temporalio.bridge.proto.activity_task.Start,
282281
running_activity: _RunningActivity,
283282
) -> None:
283+
"""Handle a start activity task.
284+
285+
Attempt to execute the user activity function and invoke the data converter on
286+
the result. Handle errors and send the task completion.
287+
"""
284288
logger.debug("Running activity %s (token %s)", start.activity_type, task_token)
285289
# We choose to surround interceptor creation and activity invocation in
286290
# a try block so we can mark the workflow as failed on any error instead
@@ -289,7 +293,9 @@ async def _run_activity(
289293
task_token=task_token
290294
)
291295
try:
292-
await self._execute_activity(start, running_activity, completion)
296+
result = await self._execute_activity(start, running_activity, task_token)
297+
[payload] = await self._data_converter.encode([result])
298+
completion.result.completed.result.CopyFrom(payload)
293299
except BaseException as err:
294300
try:
295301
if isinstance(err, temporalio.activity._CompleteAsyncError):
@@ -318,7 +324,7 @@ async def _run_activity(
318324
and running_activity.cancellation_details.details.paused
319325
):
320326
temporalio.activity.logger.warning(
321-
f"Completing as failure due to unhandled cancel error produced by activity pause",
327+
"Completing as failure due to unhandled cancel error produced by activity pause",
322328
)
323329
await self._data_converter.encode_failure(
324330
temporalio.exceptions.ApplicationError(
@@ -402,8 +408,12 @@ async def _execute_activity(
402408
self,
403409
start: temporalio.bridge.proto.activity_task.Start,
404410
running_activity: _RunningActivity,
405-
completion: temporalio.bridge.proto.ActivityTaskCompletion,
406-
):
411+
task_token: bytes,
412+
) -> Any:
413+
"""Invoke the user's activity function.
414+
415+
Exceptions are handled by a caller of this function.
416+
"""
407417
# Find activity or fail
408418
activity_def = self._activities.get(start.activity_type, self._dynamic_activity)
409419
if not activity_def:
@@ -523,7 +533,7 @@ async def _execute_activity(
523533
else None,
524534
started_time=_proto_to_datetime(start.started_time),
525535
task_queue=self._task_queue,
526-
task_token=completion.task_token,
536+
task_token=task_token,
527537
workflow_id=start.workflow_execution.workflow_id,
528538
workflow_namespace=start.workflow_namespace,
529539
workflow_run_id=start.workflow_execution.run_id,
@@ -562,16 +572,9 @@ async def _execute_activity(
562572
impl: ActivityInboundInterceptor = _ActivityInboundImpl(self, running_activity)
563573
for interceptor in reversed(list(self._interceptors)):
564574
impl = interceptor.intercept_activity(impl)
565-
# Init
575+
566576
impl.init(_ActivityOutboundImpl(self, running_activity.info))
567-
# Exec
568-
result = await impl.execute_activity(input)
569-
# Convert result even if none. Since Python essentially only
570-
# supports single result types (even if they are tuples), we will do
571-
# the same.
572-
completion.result.completed.result.CopyFrom(
573-
(await self._data_converter.encode([result]))[0]
574-
)
577+
return await impl.execute_activity(input)
575578

576579
def assert_activity_valid(self, activity) -> None:
577580
if self._dynamic_activity:

0 commit comments

Comments
 (0)