Skip to content

Commit 0fde874

Browse files
committed
Activity worker: don't raise on drain + refactoring
1 parent 13b27be commit 0fde874

File tree

1 file changed

+38
-29
lines changed

1 file changed

+38
-29
lines changed

temporalio/worker/_activity.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,23 @@ def __init__(
138138
self._dynamic_activity = defn
139139

140140
async def run(self) -> None:
141-
# Create a task that fails when we get a failure on the queue
142-
async def raise_from_queue() -> NoReturn:
141+
"""
142+
Continually poll for activity tasks and dispatch to handlers.
143+
"""
144+
145+
async def raise_from_exception_queue() -> NoReturn:
143146
raise await self._fail_worker_exception_queue.get()
144147

145-
exception_task = asyncio.create_task(raise_from_queue())
148+
exception_task = asyncio.create_task(raise_from_exception_queue())
146149

147-
# Continually poll for activity work
148150
while True:
149151
try:
150-
# Poll for a task
151152
poll_task = asyncio.create_task(
152153
self._bridge_worker().poll_activity_task()
153154
)
154155
await asyncio.wait(
155156
[poll_task, exception_task], return_when=asyncio.FIRST_COMPLETED
156-
) # type: ignore
157-
# If exception for failing the worker happened, raise it.
158-
# Otherwise, the poll succeeded.
157+
)
159158
if exception_task.done():
160159
poll_task.cancel()
161160
await exception_task
@@ -167,11 +166,14 @@ async def raise_from_queue() -> NoReturn:
167166
# size of 1000 should be plenty for the heartbeat queue.
168167
activity = _RunningActivity(pending_heartbeats=asyncio.Queue(1000))
169168
activity.task = asyncio.create_task(
170-
self._run_activity(task.task_token, task.start, activity)
169+
self._handle_start_activity_task(
170+
task.task_token, task.start, activity
171+
)
171172
)
172173
self._running_activities[task.task_token] = activity
173174
elif task.HasField("cancel"):
174-
self._cancel(task.task_token, task.cancel)
175+
# TODO(nexus-prerelease): does the task get removed from running_activities?
176+
self._handle_cancel_activity_task(task.task_token, task.cancel)
175177
else:
176178
raise RuntimeError(f"Unrecognized activity task: {task}")
177179
except temporalio.bridge.worker.PollShutdownError:
@@ -201,16 +203,17 @@ async def drain_poll_queue(self) -> None:
201203

202204
# Only call this after run()/drain_poll_queue() have returned. This will not
203205
# raise an exception.
204-
# TODO(dan): based on the comment above it looks like the intention may have been to use
205-
# return_exceptions=True
206206
async def wait_all_completed(self) -> None:
207207
running_tasks = [v.task for v in self._running_activities.values() if v.task]
208208
if running_tasks:
209-
await asyncio.gather(*running_tasks, return_exceptions=False)
209+
await asyncio.gather(*running_tasks, return_exceptions=True)
210210

211-
def _cancel(
211+
def _handle_cancel_activity_task(
212212
self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel
213213
) -> None:
214+
"""
215+
Request cancellation of a running activity task.
216+
"""
214217
activity = self._running_activities.get(task_token)
215218
if not activity:
216219
warnings.warn(f"Cannot find activity to cancel for token {task_token!r}")
@@ -275,12 +278,18 @@ async def _heartbeat_async(
275278
)
276279
activity.cancel(cancelled_due_to_heartbeat_error=err)
277280

278-
async def _run_activity(
281+
async def _handle_start_activity_task(
279282
self,
280283
task_token: bytes,
281284
start: temporalio.bridge.proto.activity_task.Start,
282285
running_activity: _RunningActivity,
283286
) -> None:
287+
"""
288+
Handle a start activity task.
289+
290+
Attempt to execute the user activity function and invoke the data converter on
291+
the result. Handle errors and send the task completion.
292+
"""
284293
logger.debug("Running activity %s (token %s)", start.activity_type, task_token)
285294
# We choose to surround interceptor creation and activity invocation in
286295
# a try block so we can mark the workflow as failed on any error instead
@@ -289,7 +298,9 @@ async def _run_activity(
289298
task_token=task_token
290299
)
291300
try:
292-
await self._execute_activity(start, running_activity, completion)
301+
result = await self._execute_activity(start, running_activity, task_token)
302+
[payload] = await self._data_converter.encode([result])
303+
completion.result.completed.result.CopyFrom(payload)
293304
except BaseException as err:
294305
try:
295306
if isinstance(err, temporalio.activity._CompleteAsyncError):
@@ -318,7 +329,7 @@ async def _run_activity(
318329
and running_activity.cancellation_details.details.paused
319330
):
320331
temporalio.activity.logger.warning(
321-
f"Completing as failure due to unhandled cancel error produced by activity pause",
332+
"Completing as failure due to unhandled cancel error produced by activity pause",
322333
)
323334
await self._data_converter.encode_failure(
324335
temporalio.exceptions.ApplicationError(
@@ -402,8 +413,13 @@ async def _execute_activity(
402413
self,
403414
start: temporalio.bridge.proto.activity_task.Start,
404415
running_activity: _RunningActivity,
405-
completion: temporalio.bridge.proto.ActivityTaskCompletion,
406-
):
416+
task_token: bytes,
417+
) -> Any:
418+
"""
419+
Invoke the user's activity function.
420+
421+
Exceptions are handled by a caller of this function.
422+
"""
407423
# Find activity or fail
408424
activity_def = self._activities.get(start.activity_type, self._dynamic_activity)
409425
if not activity_def:
@@ -523,7 +539,7 @@ async def _execute_activity(
523539
else None,
524540
started_time=_proto_to_datetime(start.started_time),
525541
task_queue=self._task_queue,
526-
task_token=completion.task_token,
542+
task_token=task_token,
527543
workflow_id=start.workflow_execution.workflow_id,
528544
workflow_namespace=start.workflow_namespace,
529545
workflow_run_id=start.workflow_execution.run_id,
@@ -562,16 +578,9 @@ async def _execute_activity(
562578
impl: ActivityInboundInterceptor = _ActivityInboundImpl(self, running_activity)
563579
for interceptor in reversed(list(self._interceptors)):
564580
impl = interceptor.intercept_activity(impl)
565-
# Init
581+
566582
impl.init(_ActivityOutboundImpl(self, running_activity.info))
567-
# Exec
568-
result = await impl.execute_activity(input)
569-
# Convert result even if none. Since Python essentially only
570-
# supports single result types (even if they are tuples), we will do
571-
# the same.
572-
completion.result.completed.result.CopyFrom(
573-
(await self._data_converter.encode([result]))[0]
574-
)
583+
return await impl.execute_activity(input)
575584

576585
def assert_activity_valid(self, activity) -> None:
577586
if self._dynamic_activity:

0 commit comments

Comments
 (0)