@@ -138,24 +138,23 @@ def __init__(
138
138
self ._dynamic_activity = defn
139
139
140
140
async def run (self ) -> None :
141
- # Create a task that fails when we get a failure on the queue
142
- async def raise_from_queue () -> NoReturn :
141
+ """
142
+ Continually poll for activity tasks and dispatch to handlers.
143
+ """
144
+
145
+ async def raise_from_exception_queue () -> NoReturn :
143
146
raise await self ._fail_worker_exception_queue .get ()
144
147
145
- exception_task = asyncio .create_task (raise_from_queue ())
148
+ exception_task = asyncio .create_task (raise_from_exception_queue ())
146
149
147
- # Continually poll for activity work
148
150
while True :
149
151
try :
150
- # Poll for a task
151
152
poll_task = asyncio .create_task (
152
153
self ._bridge_worker ().poll_activity_task ()
153
154
)
154
155
await asyncio .wait (
155
156
[poll_task , exception_task ], return_when = asyncio .FIRST_COMPLETED
156
- ) # type: ignore
157
- # If exception for failing the worker happened, raise it.
158
- # Otherwise, the poll succeeded.
157
+ )
159
158
if exception_task .done ():
160
159
poll_task .cancel ()
161
160
await exception_task
@@ -167,11 +166,14 @@ async def raise_from_queue() -> NoReturn:
167
166
# size of 1000 should be plenty for the heartbeat queue.
168
167
activity = _RunningActivity (pending_heartbeats = asyncio .Queue (1000 ))
169
168
activity .task = asyncio .create_task (
170
- self ._run_activity (task .task_token , task .start , activity )
169
+ self ._handle_start_activity_task (
170
+ task .task_token , task .start , activity
171
+ )
171
172
)
172
173
self ._running_activities [task .task_token ] = activity
173
174
elif task .HasField ("cancel" ):
174
- self ._cancel (task .task_token , task .cancel )
175
+ # TODO(nexus-prerelease): does the task get removed from running_activities?
176
+ self ._handle_cancel_activity_task (task .task_token , task .cancel )
175
177
else :
176
178
raise RuntimeError (f"Unrecognized activity task: { task } " )
177
179
except temporalio .bridge .worker .PollShutdownError :
@@ -201,16 +203,17 @@ async def drain_poll_queue(self) -> None:
201
203
202
204
# Only call this after run()/drain_poll_queue() have returned. This will not
203
205
# raise an exception.
204
- # TODO(dan): based on the comment above it looks like the intention may have been to use
205
- # return_exceptions=True
206
206
async def wait_all_completed (self ) -> None :
207
207
running_tasks = [v .task for v in self ._running_activities .values () if v .task ]
208
208
if running_tasks :
209
- await asyncio .gather (* running_tasks , return_exceptions = False )
209
+ await asyncio .gather (* running_tasks , return_exceptions = True )
210
210
211
- def _cancel (
211
+ def _handle_cancel_activity_task (
212
212
self , task_token : bytes , cancel : temporalio .bridge .proto .activity_task .Cancel
213
213
) -> None :
214
+ """
215
+ Request cancellation of a running activity task.
216
+ """
214
217
activity = self ._running_activities .get (task_token )
215
218
if not activity :
216
219
warnings .warn (f"Cannot find activity to cancel for token { task_token !r} " )
@@ -275,12 +278,18 @@ async def _heartbeat_async(
275
278
)
276
279
activity .cancel (cancelled_due_to_heartbeat_error = err )
277
280
278
- async def _run_activity (
281
+ async def _handle_start_activity_task (
279
282
self ,
280
283
task_token : bytes ,
281
284
start : temporalio .bridge .proto .activity_task .Start ,
282
285
running_activity : _RunningActivity ,
283
286
) -> None :
287
+ """
288
+ Handle a start activity task.
289
+
290
+ Attempt to execute the user activity function and invoke the data converter on
291
+ the result. Handle errors and send the task completion.
292
+ """
284
293
logger .debug ("Running activity %s (token %s)" , start .activity_type , task_token )
285
294
# We choose to surround interceptor creation and activity invocation in
286
295
# a try block so we can mark the workflow as failed on any error instead
@@ -289,7 +298,9 @@ async def _run_activity(
289
298
task_token = task_token
290
299
)
291
300
try :
292
- await self ._execute_activity (start , running_activity , completion )
301
+ result = await self ._execute_activity (start , running_activity , task_token )
302
+ [payload ] = await self ._data_converter .encode ([result ])
303
+ completion .result .completed .result .CopyFrom (payload )
293
304
except BaseException as err :
294
305
try :
295
306
if isinstance (err , temporalio .activity ._CompleteAsyncError ):
@@ -318,7 +329,7 @@ async def _run_activity(
318
329
and running_activity .cancellation_details .details .paused
319
330
):
320
331
temporalio .activity .logger .warning (
321
- f "Completing as failure due to unhandled cancel error produced by activity pause" ,
332
+ "Completing as failure due to unhandled cancel error produced by activity pause" ,
322
333
)
323
334
await self ._data_converter .encode_failure (
324
335
temporalio .exceptions .ApplicationError (
@@ -402,8 +413,13 @@ async def _execute_activity(
402
413
self ,
403
414
start : temporalio .bridge .proto .activity_task .Start ,
404
415
running_activity : _RunningActivity ,
405
- completion : temporalio .bridge .proto .ActivityTaskCompletion ,
406
- ):
416
+ task_token : bytes ,
417
+ ) -> Any :
418
+ """
419
+ Invoke the user's activity function.
420
+
421
+ Exceptions are handled by a caller of this function.
422
+ """
407
423
# Find activity or fail
408
424
activity_def = self ._activities .get (start .activity_type , self ._dynamic_activity )
409
425
if not activity_def :
@@ -523,7 +539,7 @@ async def _execute_activity(
523
539
else None ,
524
540
started_time = _proto_to_datetime (start .started_time ),
525
541
task_queue = self ._task_queue ,
526
- task_token = completion . task_token ,
542
+ task_token = task_token ,
527
543
workflow_id = start .workflow_execution .workflow_id ,
528
544
workflow_namespace = start .workflow_namespace ,
529
545
workflow_run_id = start .workflow_execution .run_id ,
@@ -562,16 +578,9 @@ async def _execute_activity(
562
578
impl : ActivityInboundInterceptor = _ActivityInboundImpl (self , running_activity )
563
579
for interceptor in reversed (list (self ._interceptors )):
564
580
impl = interceptor .intercept_activity (impl )
565
- # Init
581
+
566
582
impl .init (_ActivityOutboundImpl (self , running_activity .info ))
567
- # Exec
568
- result = await impl .execute_activity (input )
569
- # Convert result even if none. Since Python essentially only
570
- # supports single result types (even if they are tuples), we will do
571
- # the same.
572
- completion .result .completed .result .CopyFrom (
573
- (await self ._data_converter .encode ([result ]))[0 ]
574
- )
583
+ return await impl .execute_activity (input )
575
584
576
585
def assert_activity_valid (self , activity ) -> None :
577
586
if self ._dynamic_activity :
0 commit comments