10
10
from typing import (
11
11
Any ,
12
12
Callable ,
13
+ NoReturn ,
13
14
Optional ,
14
15
Sequence ,
15
16
Type ,
@@ -51,9 +52,8 @@ def __init__(
51
52
metric_meter : temporalio .common .MetricMeter ,
52
53
executor : Optional [concurrent .futures .ThreadPoolExecutor ],
53
54
) -> None :
54
- # TODO(nexus-prerelease): make it possible to query task queue of bridge worker
55
- # instead of passing unused task_queue into _NexusWorker,
56
- # _ActivityWorker, etc?
55
+ # TODO: make it possible to query task queue of bridge worker instead of passing
56
+ # unused task_queue into _NexusWorker, _ActivityWorker, etc?
57
57
self ._bridge_worker = bridge_worker
58
58
self ._client = client
59
59
self ._task_queue = task_queue
@@ -74,50 +74,69 @@ def __init__(
74
74
# TODO(nexus-prerelease): metric_meter
75
75
self ._metric_meter = metric_meter
76
76
self ._running_tasks : dict [bytes , asyncio .Task [Any ]] = {}
77
+ self ._fail_worker_exception_queue : asyncio .Queue [Exception ] = asyncio .Queue ()
77
78
78
79
async def run (self ) -> None :
80
+ """
81
+ Continually poll for Nexus tasks and dispatch to handlers.
82
+ """
83
+
84
+ async def raise_from_exception_queue () -> NoReturn :
85
+ raise await self ._fail_worker_exception_queue .get ()
86
+
87
+ exception_task = asyncio .create_task (raise_from_exception_queue ())
88
+
79
89
while True :
80
90
try :
81
91
poll_task = asyncio .create_task (self ._bridge_worker ().poll_nexus_task ())
82
- except Exception as err :
83
- raise RuntimeError ("Nexus worker failed" ) from err
84
-
85
- task = await poll_task
92
+ await asyncio .wait (
93
+ [poll_task , exception_task ], return_when = asyncio .FIRST_COMPLETED
94
+ )
95
+ if exception_task .done ():
96
+ poll_task .cancel ()
97
+ await exception_task
98
+ task = await poll_task
86
99
87
- if task .HasField ("task" ):
88
- task = task .task
89
- if task .request .HasField ("start_operation" ):
90
- self ._running_tasks [task .task_token ] = asyncio .create_task (
91
- self ._run_nexus_operation (
92
- task .task_token ,
93
- task .request .start_operation ,
94
- dict (task .request .header ),
100
+ if task .HasField ("task" ):
101
+ task = task .task
102
+ if task .request .HasField ("start_operation" ):
103
+ self ._running_tasks [task .task_token ] = asyncio .create_task (
104
+ self ._handle_start_operation_task (
105
+ task .task_token ,
106
+ task .request .start_operation ,
107
+ dict (task .request .header ),
108
+ )
95
109
)
96
- )
97
- elif task .request .HasField ("cancel_operation" ):
98
- # TODO(nexus-prerelease): report errors occurring during execution of user
99
- # cancellation method
100
- asyncio .create_task (
101
- self ._handle_cancel_operation (
102
- task .request .cancel_operation , task .task_token
110
+ elif task .request .HasField ("cancel_operation" ):
111
+ # TODO(nexus-prerelease): do we need to track cancel operation
112
+ # tasks as we do start operation tasks?
113
+ asyncio .create_task (
114
+ self ._handle_cancel_operation_task (
115
+ task .request .cancel_operation , task .task_token
116
+ )
117
+ )
118
+ else :
119
+ raise NotImplementedError (
120
+ f"Invalid Nexus task request: { task .request } "
121
+ )
122
+ elif task .HasField ("cancel_task" ):
123
+ task = task .cancel_task
124
+ if _task := self ._running_tasks .get (task .task_token ):
125
+ # TODO(nexus-prerelease): when do we remove the entry from _running_operations?
126
+ _task .cancel ()
127
+ else :
128
+ temporalio .nexus .logger .warning (
129
+ f"Received cancel_task but no running operation exists for "
130
+ f"task token: { task .task_token } "
103
131
)
104
- )
105
- else :
106
- raise NotImplementedError (
107
- f"Invalid Nexus task request: { task .request } "
108
- )
109
- elif task .HasField ("cancel_task" ):
110
- task = task .cancel_task
111
- if _task := self ._running_tasks .get (task .task_token ):
112
- # TODO(nexus-prerelease): when do we remove the entry from _running_operations?
113
- _task .cancel ()
114
132
else :
115
- temporalio .nexus .logger .warning (
116
- f"Received cancel_task but no running operation exists for "
117
- f"task token: { task .task_token } "
118
- )
119
- else :
120
- raise NotImplementedError (f"Invalid Nexus task: { task } " )
133
+ raise NotImplementedError (f"Invalid Nexus task: { task } " )
134
+
135
+ # TODO(nexus-prerelease): handle poller shutdown
136
+ # except temporalio.bridge.worker.PollShutdownError
137
+
138
+ except Exception as err :
139
+ raise RuntimeError ("Nexus worker failed" ) from err
121
140
122
141
# Only call this if run() raised an error
123
142
async def drain_poll_queue (self ) -> None :
@@ -133,15 +152,23 @@ async def drain_poll_queue(self) -> None:
133
152
except temporalio .bridge .worker .PollShutdownError :
134
153
return
135
154
155
+ # Only call this after run()/drain_poll_queue() have returned. This will not
156
+ # raise an exception.
136
157
async def wait_all_completed (self ) -> None :
137
- await asyncio .gather (* self ._running_tasks .values (), return_exceptions = False )
158
+ await asyncio .gather (* self ._running_tasks .values (), return_exceptions = True )
138
159
139
160
# TODO(nexus-prerelease): stack trace pruning. See sdk-typescript NexusHandler.execute
140
161
# "Any call up to this function and including this one will be trimmed out of stack traces.""
141
162
142
- async def _handle_cancel_operation (
163
+ async def _handle_cancel_operation_task (
143
164
self , request : temporalio .api .nexus .v1 .CancelOperationRequest , task_token : bytes
144
165
) -> None :
166
+ """
167
+ Handle a cancel operation task.
168
+
169
+ Attempt to execute the user cancel_operation method. Handle errors and send the
170
+ task completion.
171
+ """
145
172
ctx = temporalio .nexus .CancelOperationContext (
146
173
service = request .service ,
147
174
operation = request .operation ,
@@ -177,128 +204,48 @@ async def _handle_cancel_operation(
177
204
except Exception :
178
205
temporalio .nexus .logger .exception ("Failed to send Nexus task completion" )
179
206
180
- async def _run_nexus_operation (
207
+ async def _handle_start_operation_task (
181
208
self ,
182
209
task_token : bytes ,
183
210
start_request : temporalio .api .nexus .v1 .StartOperationRequest ,
184
- header : dict [str , str ],
211
+ headers : dict [str , str ],
185
212
) -> None :
186
- async def run () -> temporalio .bridge .proto .nexus .NexusTaskCompletion :
187
- try :
188
- ctx = temporalio .nexus .StartOperationContext (
189
- service = start_request .service ,
190
- operation = start_request .operation ,
191
- headers = header ,
192
- request_id = start_request .request_id ,
193
- callback_url = start_request .callback ,
194
- inbound_links = [
195
- nexusrpc .handler .Link (url = link .url , type = link .type )
196
- for link in start_request .links
197
- ],
198
- callback_headers = dict (start_request .callback_header ),
199
- _client = self ._client ,
200
- _task_queue = self ._task_queue ,
201
- )
202
- temporalio .nexus .current_context .set (
203
- temporalio .nexus .Context (operation_context = ctx )
204
- )
205
- input = nexusrpc .handler .LazyValue (
206
- serializer = _DummyPayloadSerializer (
207
- data_converter = self ._data_converter ,
208
- payload = start_request .payload ,
209
- ),
210
- headers = {},
211
- stream = None ,
212
- )
213
- try :
214
- result = await self ._handler .start_operation (ctx , input )
215
- except (
216
- nexusrpc .handler .UnknownServiceError ,
217
- nexusrpc .handler .UnknownOperationError ,
218
- ) as err :
219
- # TODO(nexus-prerelease): error message
220
- raise nexusrpc .handler .HandlerError (
221
- "No matching operation handler" ,
222
- type = nexusrpc .handler .HandlerErrorType .NOT_FOUND ,
223
- cause = err ,
224
- retryable = False ,
225
- ) from err
226
-
227
- except nexusrpc .handler .OperationError as err :
228
- return temporalio .bridge .proto .nexus .NexusTaskCompletion (
229
- task_token = task_token ,
230
- completed = temporalio .api .nexus .v1 .Response (
231
- start_operation = temporalio .api .nexus .v1 .StartOperationResponse (
232
- operation_error = await self ._operation_error_to_proto (err ),
233
- ),
213
+ """
214
+ Handle a start operation task.
215
+
216
+ Attempt to execute the user start_operation method and invoke the data converter
217
+ on the result. Handle errors and send the task completion.
218
+ """
219
+
220
+ try :
221
+ start_response = await self ._start_operation (start_request , headers )
222
+ except BaseException as err :
223
+ handler_err = _exception_to_handler_error (err )
224
+ completion = temporalio .bridge .proto .nexus .NexusTaskCompletion (
225
+ task_token = task_token ,
226
+ error = temporalio .api .nexus .v1 .HandlerError (
227
+ error_type = handler_err .type .value ,
228
+ failure = await self ._exception_to_failure_proto (
229
+ handler_err .__cause__
234
230
),
235
- )
236
- except BaseException as err :
237
- handler_err = _exception_to_handler_error (err )
238
- return temporalio .bridge .proto .nexus .NexusTaskCompletion (
239
- task_token = task_token ,
240
- error = temporalio .api .nexus .v1 .HandlerError (
241
- error_type = handler_err .type .value ,
242
- failure = await self ._exception_to_failure_proto (
243
- handler_err .__cause__
244
- ),
245
- retry_behavior = (
246
- temporalio .api .enums .v1 .NexusHandlerErrorRetryBehavior .NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE
247
- if handler_err .retryable
248
- else temporalio .api .enums .v1 .NexusHandlerErrorRetryBehavior .NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE
249
- ),
231
+ retry_behavior = (
232
+ temporalio .api .enums .v1 .NexusHandlerErrorRetryBehavior .NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE
233
+ if handler_err .retryable
234
+ else temporalio .api .enums .v1 .NexusHandlerErrorRetryBehavior .NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE
250
235
),
251
- )
252
- else :
253
- if isinstance (result , nexusrpc .handler .StartOperationResultAsync ):
254
- op_resp = temporalio .api .nexus .v1 .StartOperationResponse (
255
- async_success = temporalio .api .nexus .v1 .StartOperationResponse .Async (
256
- operation_token = result .token ,
257
- links = [
258
- temporalio .api .nexus .v1 .Link (url = l .url , type = l .type )
259
- for l in ctx .outbound_links
260
- ],
261
- )
262
- )
263
- elif isinstance (result , nexusrpc .handler .StartOperationResultSync ):
264
- # TODO(nexus-prerelease): error handling here; what error type should it be?
265
- [payload ] = await self ._data_converter .encode ([result .value ])
266
- op_resp = temporalio .api .nexus .v1 .StartOperationResponse (
267
- sync_success = temporalio .api .nexus .v1 .StartOperationResponse .Sync (
268
- payload = payload
269
- )
270
- )
271
- else :
272
- # TODO(nexus-prerelease): what should the error response be when the user has failed to wrap their return type?
273
- # TODO(nexus-prerelease): unify this failure completion with the path above
274
- err = TypeError (
275
- "Operation start method must return either nexusrpc.handler.StartOperationResultSync "
276
- "or nexusrpc.handler.StartOperationResultAsync"
277
- )
278
- handler_err = _exception_to_handler_error (err )
279
- return temporalio .bridge .proto .nexus .NexusTaskCompletion (
280
- task_token = task_token ,
281
- error = temporalio .api .nexus .v1 .HandlerError (
282
- error_type = handler_err .type .value ,
283
- failure = await self ._exception_to_failure_proto (
284
- handler_err .__cause__
285
- ),
286
- retry_behavior = (
287
- temporalio .api .enums .v1 .NexusHandlerErrorRetryBehavior .NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE
288
- if handler_err .retryable
289
- else temporalio .api .enums .v1 .NexusHandlerErrorRetryBehavior .NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE
290
- ),
291
- ),
292
- )
293
-
294
- return temporalio .bridge .proto .nexus .NexusTaskCompletion (
295
- task_token = task_token ,
296
- completed = temporalio .api .nexus .v1 .Response (start_operation = op_resp ),
297
- )
236
+ ),
237
+ )
238
+ else :
239
+ completion = temporalio .bridge .proto .nexus .NexusTaskCompletion (
240
+ task_token = task_token ,
241
+ completed = temporalio .api .nexus .v1 .Response (
242
+ start_operation = start_response
243
+ ),
244
+ )
298
245
299
246
try :
300
- completion = await run ()
301
247
await self ._bridge_worker ().complete_nexus_task (completion )
248
+ # TODO(nexus-prerelease): handle BrokenExecutor by failing the worker
302
249
except Exception :
303
250
temporalio .nexus .logger .exception ("Failed to send Nexus task completion" )
304
251
finally :
@@ -309,6 +256,86 @@ async def run() -> temporalio.bridge.proto.nexus.NexusTaskCompletion:
309
256
"Failed to remove completed Nexus operation"
310
257
)
311
258
259
+ async def _start_operation (
260
+ self ,
261
+ start_request : temporalio .api .nexus .v1 .StartOperationRequest ,
262
+ headers : dict [str , str ],
263
+ ) -> temporalio .api .nexus .v1 .StartOperationResponse :
264
+ """
265
+ Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.
266
+
267
+ OperationError is handled by this function, since it results in a StartOperationResponse.
268
+
269
+ All other exceptions are handled by a caller of this function.
270
+ """
271
+ ctx = temporalio .nexus .StartOperationContext (
272
+ service = start_request .service ,
273
+ operation = start_request .operation ,
274
+ headers = headers ,
275
+ request_id = start_request .request_id ,
276
+ callback_url = start_request .callback ,
277
+ inbound_links = [
278
+ nexusrpc .handler .Link (url = link .url , type = link .type )
279
+ for link in start_request .links
280
+ ],
281
+ callback_headers = dict (start_request .callback_header ),
282
+ _client = self ._client ,
283
+ _task_queue = self ._task_queue ,
284
+ )
285
+ temporalio .nexus .current_context .set (
286
+ temporalio .nexus .Context (operation_context = ctx )
287
+ )
288
+ input = nexusrpc .handler .LazyValue (
289
+ serializer = _DummyPayloadSerializer (
290
+ data_converter = self ._data_converter ,
291
+ payload = start_request .payload ,
292
+ ),
293
+ headers = {},
294
+ stream = None ,
295
+ )
296
+ try :
297
+ result = await self ._handler .start_operation (ctx , input )
298
+ if isinstance (result , nexusrpc .handler .StartOperationResultAsync ):
299
+ return temporalio .api .nexus .v1 .StartOperationResponse (
300
+ async_success = temporalio .api .nexus .v1 .StartOperationResponse .Async (
301
+ operation_token = result .token ,
302
+ links = [
303
+ temporalio .api .nexus .v1 .Link (url = l .url , type = l .type )
304
+ for l in ctx .outbound_links
305
+ ],
306
+ )
307
+ )
308
+ elif isinstance (result , nexusrpc .handler .StartOperationResultSync ):
309
+ [payload ] = await self ._data_converter .encode ([result .value ])
310
+ return temporalio .api .nexus .v1 .StartOperationResponse (
311
+ sync_success = temporalio .api .nexus .v1 .StartOperationResponse .Sync (
312
+ payload = payload
313
+ )
314
+ )
315
+ else :
316
+ raise _exception_to_handler_error (
317
+ TypeError (
318
+ "Operation start method must return either nexusrpc.handler.StartOperationResultSync "
319
+ "or nexusrpc.handler.StartOperationResultAsync"
320
+ )
321
+ )
322
+
323
+ except (
324
+ nexusrpc .handler .UnknownServiceError ,
325
+ nexusrpc .handler .UnknownOperationError ,
326
+ ) as err :
327
+ # TODO(nexus-prerelease): error message
328
+ raise nexusrpc .handler .HandlerError (
329
+ "No matching operation handler" ,
330
+ type = nexusrpc .handler .HandlerErrorType .NOT_FOUND ,
331
+ cause = err ,
332
+ retryable = False ,
333
+ ) from err
334
+ except nexusrpc .handler .OperationError as err :
335
+ return temporalio .api .nexus .v1 .StartOperationResponse (
336
+ operation_error = await self ._operation_error_to_proto (err ),
337
+ )
338
+
312
339
async def _exception_to_failure_proto (
313
340
self ,
314
341
err : BaseException ,
0 commit comments