@@ -138,24 +138,21 @@ def __init__(
138
138
self ._dynamic_activity = defn
139
139
140
140
async def run (self ) -> None :
141
- # Create a task that fails when we get a failure on the queue
142
- async def raise_from_queue () -> NoReturn :
141
+ """Continually poll for activity tasks and dispatch to handlers."""
142
+
143
+ async def raise_from_exception_queue () -> NoReturn :
143
144
raise await self ._fail_worker_exception_queue .get ()
144
145
145
- exception_task = asyncio .create_task (raise_from_queue ())
146
+ exception_task = asyncio .create_task (raise_from_exception_queue ())
146
147
147
- # Continually poll for activity work
148
148
while True :
149
149
try :
150
- # Poll for a task
151
150
poll_task = asyncio .create_task (
152
151
self ._bridge_worker ().poll_activity_task ()
153
152
)
154
153
await asyncio .wait (
155
154
[poll_task , exception_task ], return_when = asyncio .FIRST_COMPLETED
156
- ) # type: ignore
157
- # If exception for failing the worker happened, raise it.
158
- # Otherwise, the poll succeeded.
155
+ )
159
156
if exception_task .done ():
160
157
poll_task .cancel ()
161
158
await exception_task
@@ -167,11 +164,14 @@ async def raise_from_queue() -> NoReturn:
167
164
# size of 1000 should be plenty for the heartbeat queue.
168
165
activity = _RunningActivity (pending_heartbeats = asyncio .Queue (1000 ))
169
166
activity .task = asyncio .create_task (
170
- self ._run_activity (task .task_token , task .start , activity )
167
+ self ._handle_start_activity_task (
168
+ task .task_token , task .start , activity
169
+ )
171
170
)
172
171
self ._running_activities [task .task_token ] = activity
173
172
elif task .HasField ("cancel" ):
174
- self ._cancel (task .task_token , task .cancel )
173
+ # TODO(nexus-prerelease): does the task get removed from running_activities?
174
+ self ._handle_cancel_activity_task (task .task_token , task .cancel )
175
175
else :
176
176
raise RuntimeError (f"Unrecognized activity task: { task } " )
177
177
except temporalio .bridge .worker .PollShutdownError :
@@ -201,16 +201,15 @@ async def drain_poll_queue(self) -> None:
201
201
202
202
# Only call this after run()/drain_poll_queue() have returned. This will not
203
203
# raise an exception.
204
- # TODO(dan): based on the comment above it looks like the intention may have been to use
205
- # return_exceptions=True
206
204
async def wait_all_completed (self ) -> None :
207
205
running_tasks = [v .task for v in self ._running_activities .values () if v .task ]
208
206
if running_tasks :
209
- await asyncio .gather (* running_tasks , return_exceptions = False )
207
+ await asyncio .gather (* running_tasks , return_exceptions = True )
210
208
211
- def _cancel (
209
+ def _handle_cancel_activity_task (
212
210
self , task_token : bytes , cancel : temporalio .bridge .proto .activity_task .Cancel
213
211
) -> None :
212
+ """Request cancellation of a running activity task."""
214
213
activity = self ._running_activities .get (task_token )
215
214
if not activity :
216
215
warnings .warn (f"Cannot find activity to cancel for token { task_token !r} " )
@@ -275,12 +274,17 @@ async def _heartbeat_async(
275
274
)
276
275
activity .cancel (cancelled_due_to_heartbeat_error = err )
277
276
278
- async def _run_activity (
277
+ async def _handle_start_activity_task (
279
278
self ,
280
279
task_token : bytes ,
281
280
start : temporalio .bridge .proto .activity_task .Start ,
282
281
running_activity : _RunningActivity ,
283
282
) -> None :
283
+ """Handle a start activity task.
284
+
285
+ Attempt to execute the user activity function and invoke the data converter on
286
+ the result. Handle errors and send the task completion.
287
+ """
284
288
logger .debug ("Running activity %s (token %s)" , start .activity_type , task_token )
285
289
# We choose to surround interceptor creation and activity invocation in
286
290
# a try block so we can mark the workflow as failed on any error instead
@@ -289,7 +293,9 @@ async def _run_activity(
289
293
task_token = task_token
290
294
)
291
295
try :
292
- await self ._execute_activity (start , running_activity , completion )
296
+ result = await self ._execute_activity (start , running_activity , task_token )
297
+ [payload ] = await self ._data_converter .encode ([result ])
298
+ completion .result .completed .result .CopyFrom (payload )
293
299
except BaseException as err :
294
300
try :
295
301
if isinstance (err , temporalio .activity ._CompleteAsyncError ):
@@ -318,7 +324,7 @@ async def _run_activity(
318
324
and running_activity .cancellation_details .details .paused
319
325
):
320
326
temporalio .activity .logger .warning (
321
- f "Completing as failure due to unhandled cancel error produced by activity pause" ,
327
+ "Completing as failure due to unhandled cancel error produced by activity pause" ,
322
328
)
323
329
await self ._data_converter .encode_failure (
324
330
temporalio .exceptions .ApplicationError (
@@ -402,8 +408,12 @@ async def _execute_activity(
402
408
self ,
403
409
start : temporalio .bridge .proto .activity_task .Start ,
404
410
running_activity : _RunningActivity ,
405
- completion : temporalio .bridge .proto .ActivityTaskCompletion ,
406
- ):
411
+ task_token : bytes ,
412
+ ) -> Any :
413
+ """Invoke the user's activity function.
414
+
415
+ Exceptions are handled by a caller of this function.
416
+ """
407
417
# Find activity or fail
408
418
activity_def = self ._activities .get (start .activity_type , self ._dynamic_activity )
409
419
if not activity_def :
@@ -523,7 +533,7 @@ async def _execute_activity(
523
533
else None ,
524
534
started_time = _proto_to_datetime (start .started_time ),
525
535
task_queue = self ._task_queue ,
526
- task_token = completion . task_token ,
536
+ task_token = task_token ,
527
537
workflow_id = start .workflow_execution .workflow_id ,
528
538
workflow_namespace = start .workflow_namespace ,
529
539
workflow_run_id = start .workflow_execution .run_id ,
@@ -562,16 +572,9 @@ async def _execute_activity(
562
572
impl : ActivityInboundInterceptor = _ActivityInboundImpl (self , running_activity )
563
573
for interceptor in reversed (list (self ._interceptors )):
564
574
impl = interceptor .intercept_activity (impl )
565
- # Init
575
+
566
576
impl .init (_ActivityOutboundImpl (self , running_activity .info ))
567
- # Exec
568
- result = await impl .execute_activity (input )
569
- # Convert result even if none. Since Python essentially only
570
- # supports single result types (even if they are tuples), we will do
571
- # the same.
572
- completion .result .completed .result .CopyFrom (
573
- (await self ._data_converter .encode ([result ]))[0 ]
574
- )
577
+ return await impl .execute_activity (input )
575
578
576
579
def assert_activity_valid (self , activity ) -> None :
577
580
if self ._dynamic_activity :
0 commit comments