Skip to content

Commit bce0fd0

Browse files
committed
Refactor Nexus worker
1 parent d30c686 commit bce0fd0

File tree

1 file changed

+180
-153
lines changed

1 file changed

+180
-153
lines changed

temporalio/worker/_nexus.py

Lines changed: 180 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import (
1111
Any,
1212
Callable,
13+
NoReturn,
1314
Optional,
1415
Sequence,
1516
Type,
@@ -51,9 +52,8 @@ def __init__(
5152
metric_meter: temporalio.common.MetricMeter,
5253
executor: Optional[concurrent.futures.ThreadPoolExecutor],
5354
) -> None:
54-
# TODO(nexus-prerelease): make it possible to query task queue of bridge worker
55-
# instead of passing unused task_queue into _NexusWorker,
56-
# _ActivityWorker, etc?
55+
# TODO: make it possible to query task queue of bridge worker instead of passing
56+
# unused task_queue into _NexusWorker, _ActivityWorker, etc?
5757
self._bridge_worker = bridge_worker
5858
self._client = client
5959
self._task_queue = task_queue
@@ -74,50 +74,69 @@ def __init__(
7474
# TODO(nexus-prerelease): metric_meter
7575
self._metric_meter = metric_meter
7676
self._running_tasks: dict[bytes, asyncio.Task[Any]] = {}
77+
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()
7778

7879
async def run(self) -> None:
80+
"""
81+
Continually poll for Nexus tasks and dispatch to handlers.
82+
"""
83+
84+
async def raise_from_exception_queue() -> NoReturn:
85+
raise await self._fail_worker_exception_queue.get()
86+
87+
exception_task = asyncio.create_task(raise_from_exception_queue())
88+
7989
while True:
8090
try:
8191
poll_task = asyncio.create_task(self._bridge_worker().poll_nexus_task())
82-
except Exception as err:
83-
raise RuntimeError("Nexus worker failed") from err
84-
85-
task = await poll_task
92+
await asyncio.wait(
93+
[poll_task, exception_task], return_when=asyncio.FIRST_COMPLETED
94+
)
95+
if exception_task.done():
96+
poll_task.cancel()
97+
await exception_task
98+
task = await poll_task
8699

87-
if task.HasField("task"):
88-
task = task.task
89-
if task.request.HasField("start_operation"):
90-
self._running_tasks[task.task_token] = asyncio.create_task(
91-
self._run_nexus_operation(
92-
task.task_token,
93-
task.request.start_operation,
94-
dict(task.request.header),
100+
if task.HasField("task"):
101+
task = task.task
102+
if task.request.HasField("start_operation"):
103+
self._running_tasks[task.task_token] = asyncio.create_task(
104+
self._handle_start_operation_task(
105+
task.task_token,
106+
task.request.start_operation,
107+
dict(task.request.header),
108+
)
95109
)
96-
)
97-
elif task.request.HasField("cancel_operation"):
98-
# TODO(nexus-prerelease): report errors occurring during execution of user
99-
# cancellation method
100-
asyncio.create_task(
101-
self._handle_cancel_operation(
102-
task.request.cancel_operation, task.task_token
110+
elif task.request.HasField("cancel_operation"):
111+
# TODO(nexus-prerelease): do we need to track cancel operation
112+
# tasks as we do start operation tasks?
113+
asyncio.create_task(
114+
self._handle_cancel_operation_task(
115+
task.request.cancel_operation, task.task_token
116+
)
117+
)
118+
else:
119+
raise NotImplementedError(
120+
f"Invalid Nexus task request: {task.request}"
121+
)
122+
elif task.HasField("cancel_task"):
123+
task = task.cancel_task
124+
if _task := self._running_tasks.get(task.task_token):
125+
# TODO(nexus-prerelease): when do we remove the entry from _running_operations?
126+
_task.cancel()
127+
else:
128+
temporalio.nexus.logger.warning(
129+
f"Received cancel_task but no running operation exists for "
130+
f"task token: {task.task_token}"
103131
)
104-
)
105-
else:
106-
raise NotImplementedError(
107-
f"Invalid Nexus task request: {task.request}"
108-
)
109-
elif task.HasField("cancel_task"):
110-
task = task.cancel_task
111-
if _task := self._running_tasks.get(task.task_token):
112-
# TODO(nexus-prerelease): when do we remove the entry from _running_operations?
113-
_task.cancel()
114132
else:
115-
temporalio.nexus.logger.warning(
116-
f"Received cancel_task but no running operation exists for "
117-
f"task token: {task.task_token}"
118-
)
119-
else:
120-
raise NotImplementedError(f"Invalid Nexus task: {task}")
133+
raise NotImplementedError(f"Invalid Nexus task: {task}")
134+
135+
# TODO(nexus-prerelease): handle poller shutdown
136+
# except temporalio.bridge.worker.PollShutdownError
137+
138+
except Exception as err:
139+
raise RuntimeError("Nexus worker failed") from err
121140

122141
# Only call this if run() raised an error
123142
async def drain_poll_queue(self) -> None:
@@ -133,15 +152,23 @@ async def drain_poll_queue(self) -> None:
133152
except temporalio.bridge.worker.PollShutdownError:
134153
return
135154

155+
# Only call this after run()/drain_poll_queue() have returned. This will not
156+
# raise an exception.
136157
async def wait_all_completed(self) -> None:
137-
await asyncio.gather(*self._running_tasks.values(), return_exceptions=False)
158+
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
138159

139160
# TODO(nexus-prerelease): stack trace pruning. See sdk-typescript NexusHandler.execute
140161
# "Any call up to this function and including this one will be trimmed out of stack traces.""
141162

142-
async def _handle_cancel_operation(
163+
async def _handle_cancel_operation_task(
143164
self, request: temporalio.api.nexus.v1.CancelOperationRequest, task_token: bytes
144165
) -> None:
166+
"""
167+
Handle a cancel operation task.
168+
169+
Attempt to execute the user cancel_operation method. Handle errors and send the
170+
task completion.
171+
"""
145172
ctx = temporalio.nexus.CancelOperationContext(
146173
service=request.service,
147174
operation=request.operation,
@@ -177,128 +204,48 @@ async def _handle_cancel_operation(
177204
except Exception:
178205
temporalio.nexus.logger.exception("Failed to send Nexus task completion")
179206

180-
async def _run_nexus_operation(
207+
async def _handle_start_operation_task(
181208
self,
182209
task_token: bytes,
183210
start_request: temporalio.api.nexus.v1.StartOperationRequest,
184-
header: dict[str, str],
211+
headers: dict[str, str],
185212
) -> None:
186-
async def run() -> temporalio.bridge.proto.nexus.NexusTaskCompletion:
187-
try:
188-
ctx = temporalio.nexus.StartOperationContext(
189-
service=start_request.service,
190-
operation=start_request.operation,
191-
headers=header,
192-
request_id=start_request.request_id,
193-
callback_url=start_request.callback,
194-
inbound_links=[
195-
nexusrpc.handler.Link(url=link.url, type=link.type)
196-
for link in start_request.links
197-
],
198-
callback_headers=dict(start_request.callback_header),
199-
_client=self._client,
200-
_task_queue=self._task_queue,
201-
)
202-
temporalio.nexus.current_context.set(
203-
temporalio.nexus.Context(operation_context=ctx)
204-
)
205-
input = nexusrpc.handler.LazyValue(
206-
serializer=_DummyPayloadSerializer(
207-
data_converter=self._data_converter,
208-
payload=start_request.payload,
209-
),
210-
headers={},
211-
stream=None,
212-
)
213-
try:
214-
result = await self._handler.start_operation(ctx, input)
215-
except (
216-
nexusrpc.handler.UnknownServiceError,
217-
nexusrpc.handler.UnknownOperationError,
218-
) as err:
219-
# TODO(nexus-prerelease): error message
220-
raise nexusrpc.handler.HandlerError(
221-
"No matching operation handler",
222-
type=nexusrpc.handler.HandlerErrorType.NOT_FOUND,
223-
cause=err,
224-
retryable=False,
225-
) from err
226-
227-
except nexusrpc.handler.OperationError as err:
228-
return temporalio.bridge.proto.nexus.NexusTaskCompletion(
229-
task_token=task_token,
230-
completed=temporalio.api.nexus.v1.Response(
231-
start_operation=temporalio.api.nexus.v1.StartOperationResponse(
232-
operation_error=await self._operation_error_to_proto(err),
233-
),
213+
"""
214+
Handle a start operation task.
215+
216+
Attempt to execute the user start_operation method and invoke the data converter
217+
on the result. Handle errors and send the task completion.
218+
"""
219+
220+
try:
221+
start_response = await self._start_operation(start_request, headers)
222+
except BaseException as err:
223+
handler_err = _exception_to_handler_error(err)
224+
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
225+
task_token=task_token,
226+
error=temporalio.api.nexus.v1.HandlerError(
227+
error_type=handler_err.type.value,
228+
failure=await self._exception_to_failure_proto(
229+
handler_err.__cause__
234230
),
235-
)
236-
except BaseException as err:
237-
handler_err = _exception_to_handler_error(err)
238-
return temporalio.bridge.proto.nexus.NexusTaskCompletion(
239-
task_token=task_token,
240-
error=temporalio.api.nexus.v1.HandlerError(
241-
error_type=handler_err.type.value,
242-
failure=await self._exception_to_failure_proto(
243-
handler_err.__cause__
244-
),
245-
retry_behavior=(
246-
temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE
247-
if handler_err.retryable
248-
else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE
249-
),
231+
retry_behavior=(
232+
temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE
233+
if handler_err.retryable
234+
else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE
250235
),
251-
)
252-
else:
253-
if isinstance(result, nexusrpc.handler.StartOperationResultAsync):
254-
op_resp = temporalio.api.nexus.v1.StartOperationResponse(
255-
async_success=temporalio.api.nexus.v1.StartOperationResponse.Async(
256-
operation_token=result.token,
257-
links=[
258-
temporalio.api.nexus.v1.Link(url=l.url, type=l.type)
259-
for l in ctx.outbound_links
260-
],
261-
)
262-
)
263-
elif isinstance(result, nexusrpc.handler.StartOperationResultSync):
264-
# TODO(nexus-prerelease): error handling here; what error type should it be?
265-
[payload] = await self._data_converter.encode([result.value])
266-
op_resp = temporalio.api.nexus.v1.StartOperationResponse(
267-
sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync(
268-
payload=payload
269-
)
270-
)
271-
else:
272-
# TODO(nexus-prerelease): what should the error response be when the user has failed to wrap their return type?
273-
# TODO(nexus-prerelease): unify this failure completion with the path above
274-
err = TypeError(
275-
"Operation start method must return either nexusrpc.handler.StartOperationResultSync "
276-
"or nexusrpc.handler.StartOperationResultAsync"
277-
)
278-
handler_err = _exception_to_handler_error(err)
279-
return temporalio.bridge.proto.nexus.NexusTaskCompletion(
280-
task_token=task_token,
281-
error=temporalio.api.nexus.v1.HandlerError(
282-
error_type=handler_err.type.value,
283-
failure=await self._exception_to_failure_proto(
284-
handler_err.__cause__
285-
),
286-
retry_behavior=(
287-
temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE
288-
if handler_err.retryable
289-
else temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE
290-
),
291-
),
292-
)
293-
294-
return temporalio.bridge.proto.nexus.NexusTaskCompletion(
295-
task_token=task_token,
296-
completed=temporalio.api.nexus.v1.Response(start_operation=op_resp),
297-
)
236+
),
237+
)
238+
else:
239+
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
240+
task_token=task_token,
241+
completed=temporalio.api.nexus.v1.Response(
242+
start_operation=start_response
243+
),
244+
)
298245

299246
try:
300-
completion = await run()
301247
await self._bridge_worker().complete_nexus_task(completion)
248+
# TODO(nexus-prerelease): handle BrokenExecutor by failing the worker
302249
except Exception:
303250
temporalio.nexus.logger.exception("Failed to send Nexus task completion")
304251
finally:
@@ -309,6 +256,86 @@ async def run() -> temporalio.bridge.proto.nexus.NexusTaskCompletion:
309256
"Failed to remove completed Nexus operation"
310257
)
311258

259+
async def _start_operation(
260+
self,
261+
start_request: temporalio.api.nexus.v1.StartOperationRequest,
262+
headers: dict[str, str],
263+
) -> temporalio.api.nexus.v1.StartOperationResponse:
264+
"""
265+
Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.
266+
267+
OperationError is handled by this function, since it results in a StartOperationResponse.
268+
269+
All other exceptions are handled by a caller of this function.
270+
"""
271+
ctx = temporalio.nexus.StartOperationContext(
272+
service=start_request.service,
273+
operation=start_request.operation,
274+
headers=headers,
275+
request_id=start_request.request_id,
276+
callback_url=start_request.callback,
277+
inbound_links=[
278+
nexusrpc.handler.Link(url=link.url, type=link.type)
279+
for link in start_request.links
280+
],
281+
callback_headers=dict(start_request.callback_header),
282+
_client=self._client,
283+
_task_queue=self._task_queue,
284+
)
285+
temporalio.nexus.current_context.set(
286+
temporalio.nexus.Context(operation_context=ctx)
287+
)
288+
input = nexusrpc.handler.LazyValue(
289+
serializer=_DummyPayloadSerializer(
290+
data_converter=self._data_converter,
291+
payload=start_request.payload,
292+
),
293+
headers={},
294+
stream=None,
295+
)
296+
try:
297+
result = await self._handler.start_operation(ctx, input)
298+
if isinstance(result, nexusrpc.handler.StartOperationResultAsync):
299+
return temporalio.api.nexus.v1.StartOperationResponse(
300+
async_success=temporalio.api.nexus.v1.StartOperationResponse.Async(
301+
operation_token=result.token,
302+
links=[
303+
temporalio.api.nexus.v1.Link(url=l.url, type=l.type)
304+
for l in ctx.outbound_links
305+
],
306+
)
307+
)
308+
elif isinstance(result, nexusrpc.handler.StartOperationResultSync):
309+
[payload] = await self._data_converter.encode([result.value])
310+
return temporalio.api.nexus.v1.StartOperationResponse(
311+
sync_success=temporalio.api.nexus.v1.StartOperationResponse.Sync(
312+
payload=payload
313+
)
314+
)
315+
else:
316+
raise _exception_to_handler_error(
317+
TypeError(
318+
"Operation start method must return either nexusrpc.handler.StartOperationResultSync "
319+
"or nexusrpc.handler.StartOperationResultAsync"
320+
)
321+
)
322+
323+
except (
324+
nexusrpc.handler.UnknownServiceError,
325+
nexusrpc.handler.UnknownOperationError,
326+
) as err:
327+
# TODO(nexus-prerelease): error message
328+
raise nexusrpc.handler.HandlerError(
329+
"No matching operation handler",
330+
type=nexusrpc.handler.HandlerErrorType.NOT_FOUND,
331+
cause=err,
332+
retryable=False,
333+
) from err
334+
except nexusrpc.handler.OperationError as err:
335+
return temporalio.api.nexus.v1.StartOperationResponse(
336+
operation_error=await self._operation_error_to_proto(err),
337+
)
338+
312339
async def _exception_to_failure_proto(
313340
self,
314341
err: BaseException,

0 commit comments

Comments
 (0)