Skip to content

Commit a845a6f

Browse files
committed
Nexus prototype
1 parent 92b7758 commit a845a6f

15 files changed

+1450
-11
lines changed

README.md

+67
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ informal introduction to the features and their implementation.
9696
- [Heartbeating and Cancellation](#heartbeating-and-cancellation)
9797
- [Worker Shutdown](#worker-shutdown)
9898
- [Testing](#testing-1)
99+
- [Nexus](#nexus)
100+
- [hello](#hello)
99101
- [Workflow Replay](#workflow-replay)
100102
- [Observability](#observability)
101103
- [Metrics](#metrics)
@@ -1314,6 +1316,71 @@ affect calls activity code might make to functions on the `temporalio.activity`
13141316
* `cancel()` can be invoked to simulate a cancellation of the activity
13151317
* `worker_shutdown()` can be invoked to simulate a worker shutdown during execution of the activity
13161318

1319+
1320+
### Nexus
1321+
1322+
See [docs.temporal.io/nexus](https://docs.temporal.io/nexus).
1323+
1324+
#### Service Interface Definition
1325+
1326+
A Nexus Service interface definition is a set of named operations, where each operation is an
1327+
`(input_type,output_type)` pair:
1328+
1329+
```python
1330+
@nexus_service
1331+
class MyNexusService:
1332+
my_operation: NexusOperation[MyOpInput, MyOpOutput]
1333+
```
1334+
1335+
### Operation implementation
1336+
1337+
```python
1338+
@nexus.service(interface.MyNexusService)
1339+
class MyNexusService:
1340+
1341+
@nexus.sync_operation
1342+
def echo(self, input: EchoInput) -> EchoOutput:
1343+
return EchoOutput(message=input.message)
1344+
```
1345+
1346+
```python
1347+
@nexus.service(interface.MyNexusService)
1348+
class MyNexusService:
1349+
1350+
@temporalio.nexus.workflow_operation
1351+
async def hello(
1352+
self, input: HelloInput
1353+
) -> AsyncWorkflowOperationResult[HelloOutput]:
1354+
return await temporalio.nexus.handler.start_workflow(HelloWorkflow.run, input)
1355+
```
1356+
1357+
1358+
### Request options
1359+
1360+
```python
1361+
@dataclass
1362+
class OperationOptions:
1363+
"""Options passed by the Nexus caller when starting an operation."""
1364+
1365+
# A callback URL is required to deliver the completion of an async operation. This URL should be
1366+
# called by a handler upon completion if the started operation is async.
1367+
callback_url: Optional[str] = None
1368+
1369+
# Optional header fields set by the caller to be attached to the callback request when an
1370+
# asynchronous operation completes.
1371+
callback_header: dict[str, str] = field(default_factory=dict)
1372+
1373+
# Request ID that may be used by the server handler to dedupe a start request.
1374+
# By default a v4 UUID will be generated by the client.
1375+
request_id: Optional[str] = None
1376+
1377+
# Links contain arbitrary caller information. Handlers may use these links as
1378+
# metadata on resources associated with an operation.
1379+
links: list[Link] = field(default_factory=list)
1380+
```
1381+
1382+
1383+
13171384
### Workflow Replay
13181385

13191386
Given a workflow's history, it can be replayed locally to check for things like non-determinism errors. For example,

pyproject.toml

+9-1
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ keywords = [
1111
"workflow",
1212
]
1313
dependencies = [
14+
"nexus-rpc",
1415
"protobuf>=3.20",
16+
"pyright==1.1.394",
1517
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
18+
"temporalio-xray",
1619
"types-protobuf>=3.20",
1720
"typing-extensions>=4.2.0,<5",
1821
]
@@ -40,7 +43,7 @@ dev = [
4043
"psutil>=5.9.3,<6",
4144
"pydocstyle>=6.3.0,<7",
4245
"pydoctor>=24.11.1,<25",
43-
"pyright==1.1.377",
46+
"pyright==1.1.394",
4447
"pytest~=7.4",
4548
"pytest-asyncio>=0.21,<0.22",
4649
"pytest-timeout~=2.2",
@@ -185,6 +188,7 @@ exclude = [
185188

186189
[tool.ruff]
187190
target-version = "py39"
191+
extend-ignore = ["E741"] # Allow single-letter variable names like I, O
188192

189193
[build-system]
190194
requires = ["maturin>=1.0,<2.0"]
@@ -202,3 +206,7 @@ exclude = [
202206
[tool.uv]
203207
# Prevent uv commands from building the package by default
204208
package = false
209+
210+
[tool.uv.sources]
211+
nexus-rpc = { path = "../nexus-sdk-python", editable = true }
212+
temporalio-xray = { path = "../xray/sdks/python", editable = true }

temporalio/bridge/src/worker.rs

+27-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use temporal_sdk_core_api::worker::{
2020
};
2121
use temporal_sdk_core_api::Worker;
2222
use temporal_sdk_core_protos::coresdk::workflow_completion::WorkflowActivationCompletion;
23-
use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion};
23+
use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion, nexus::NexusTaskCompletion};
2424
use temporal_sdk_core_protos::temporal::api::history::v1::History;
2525
use tokio::sync::mpsc::{channel, Sender};
2626
use tokio_stream::wrappers::ReceiverStream;
@@ -475,6 +475,19 @@ impl WorkerRef {
475475
})
476476
}
477477

478+
fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
479+
let worker = self.worker.as_ref().unwrap().clone();
480+
self.runtime.future_into_py(py, async move {
481+
let bytes = match worker.poll_nexus_task().await {
482+
Ok(task) => task.encode_to_vec(),
483+
Err(PollError::ShutDown) => return Err(PollShutdownError::new_err(())),
484+
Err(err) => return Err(PyRuntimeError::new_err(format!("Poll failure: {}", err))),
485+
};
486+
let bytes: &[u8] = &bytes;
487+
Ok(Python::with_gil(|py| bytes.into_py(py)))
488+
})
489+
}
490+
478491
fn complete_workflow_activation<'p>(
479492
&self,
480493
py: Python<'p>,
@@ -505,6 +518,19 @@ impl WorkerRef {
505518
})
506519
}
507520

521+
fn complete_nexus_task<'p>(&self, py: Python<'p>, proto: &PyBytes) -> PyResult<&'p PyAny> {
522+
let worker = self.worker.as_ref().unwrap().clone();
523+
let completion = NexusTaskCompletion::decode(proto.as_bytes())
524+
.map_err(|err| PyValueError::new_err(format!("Invalid proto: {}", err)))?;
525+
self.runtime.future_into_py(py, async move {
526+
worker
527+
.complete_nexus_task(completion)
528+
.await
529+
.context("Completion failure")
530+
.map_err(Into::into)
531+
})
532+
}
533+
508534
fn record_activity_heartbeat(&self, proto: &PyBytes) -> PyResult<()> {
509535
enter_sync!(self.runtime);
510536
let heartbeat = ActivityHeartbeat::decode(proto.as_bytes())

temporalio/bridge/worker.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import temporalio.bridge.client
2727
import temporalio.bridge.proto
2828
import temporalio.bridge.proto.activity_task
29+
import temporalio.bridge.proto.nexus
2930
import temporalio.bridge.proto.workflow_activation
3031
import temporalio.bridge.proto.workflow_completion
3132
import temporalio.bridge.runtime
@@ -35,7 +36,7 @@
3536
from temporalio.bridge.temporal_sdk_bridge import (
3637
CustomSlotSupplier as BridgeCustomSlotSupplier,
3738
)
38-
from temporalio.bridge.temporal_sdk_bridge import PollShutdownError
39+
from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore
3940

4041

4142
@dataclass
@@ -156,6 +157,14 @@ async def poll_activity_task(
156157
await self._ref.poll_activity_task()
157158
)
158159

160+
async def poll_nexus_task(
161+
self,
162+
) -> temporalio.bridge.proto.nexus.NexusTask:
163+
"""Poll for a nexus task."""
164+
return temporalio.bridge.proto.nexus.NexusTask.FromString(
165+
await self._ref.poll_nexus_task()
166+
)
167+
159168
async def complete_workflow_activation(
160169
self,
161170
comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
@@ -169,6 +178,12 @@ async def complete_activity_task(
169178
"""Complete an activity task."""
170179
await self._ref.complete_activity_task(comp.SerializeToString())
171180

181+
async def complete_nexus_task(
182+
self, comp: temporalio.bridge.proto.nexus.NexusTaskCompletion
183+
) -> None:
184+
"""Complete a nexus task."""
185+
await self._ref.complete_nexus_task(comp.SerializeToString())
186+
172187
def record_activity_heartbeat(
173188
self, comp: temporalio.bridge.proto.ActivityHeartbeat
174189
) -> None:

temporalio/client.py

+20
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ async def start_workflow(
316316
start_delay: Optional[timedelta] = None,
317317
start_signal: Optional[str] = None,
318318
start_signal_args: Sequence[Any] = [],
319+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
319320
rpc_metadata: Mapping[str, str] = {},
320321
rpc_timeout: Optional[timedelta] = None,
321322
request_eager_start: bool = False,
@@ -350,6 +351,7 @@ async def start_workflow(
350351
start_delay: Optional[timedelta] = None,
351352
start_signal: Optional[str] = None,
352353
start_signal_args: Sequence[Any] = [],
354+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
353355
rpc_metadata: Mapping[str, str] = {},
354356
rpc_timeout: Optional[timedelta] = None,
355357
request_eager_start: bool = False,
@@ -386,6 +388,7 @@ async def start_workflow(
386388
start_delay: Optional[timedelta] = None,
387389
start_signal: Optional[str] = None,
388390
start_signal_args: Sequence[Any] = [],
391+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
389392
rpc_metadata: Mapping[str, str] = {},
390393
rpc_timeout: Optional[timedelta] = None,
391394
request_eager_start: bool = False,
@@ -422,6 +425,7 @@ async def start_workflow(
422425
start_delay: Optional[timedelta] = None,
423426
start_signal: Optional[str] = None,
424427
start_signal_args: Sequence[Any] = [],
428+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
425429
rpc_metadata: Mapping[str, str] = {},
426430
rpc_timeout: Optional[timedelta] = None,
427431
request_eager_start: bool = False,
@@ -456,6 +460,7 @@ async def start_workflow(
456460
start_delay: Optional[timedelta] = None,
457461
start_signal: Optional[str] = None,
458462
start_signal_args: Sequence[Any] = [],
463+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
459464
rpc_metadata: Mapping[str, str] = {},
460465
rpc_timeout: Optional[timedelta] = None,
461466
request_eager_start: bool = False,
@@ -500,6 +505,8 @@ async def start_workflow(
500505
instead of traditional workflow start.
501506
start_signal_args: Arguments for start_signal if start_signal
502507
present.
508+
completion_callbacks: Callbacks to be called by the server when the workflow reaches a
509+
terminal state.
503510
rpc_metadata: Headers used on the RPC call. Keys here override
504511
client-level RPC metadata keys.
505512
rpc_timeout: Optional RPC deadline to set for the RPC call.
@@ -544,6 +551,7 @@ async def start_workflow(
544551
static_details=static_details,
545552
start_signal=start_signal,
546553
start_signal_args=start_signal_args,
554+
completion_callbacks=completion_callbacks,
547555
ret_type=result_type or result_type_from_type_hint,
548556
rpc_metadata=rpc_metadata,
549557
rpc_timeout=rpc_timeout,
@@ -579,6 +587,7 @@ async def execute_workflow(
579587
start_delay: Optional[timedelta] = None,
580588
start_signal: Optional[str] = None,
581589
start_signal_args: Sequence[Any] = [],
590+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
582591
rpc_metadata: Mapping[str, str] = {},
583592
rpc_timeout: Optional[timedelta] = None,
584593
request_eager_start: bool = False,
@@ -613,6 +622,7 @@ async def execute_workflow(
613622
start_delay: Optional[timedelta] = None,
614623
start_signal: Optional[str] = None,
615624
start_signal_args: Sequence[Any] = [],
625+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
616626
rpc_metadata: Mapping[str, str] = {},
617627
rpc_timeout: Optional[timedelta] = None,
618628
request_eager_start: bool = False,
@@ -649,6 +659,7 @@ async def execute_workflow(
649659
start_delay: Optional[timedelta] = None,
650660
start_signal: Optional[str] = None,
651661
start_signal_args: Sequence[Any] = [],
662+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
652663
rpc_metadata: Mapping[str, str] = {},
653664
rpc_timeout: Optional[timedelta] = None,
654665
request_eager_start: bool = False,
@@ -685,6 +696,7 @@ async def execute_workflow(
685696
start_delay: Optional[timedelta] = None,
686697
start_signal: Optional[str] = None,
687698
start_signal_args: Sequence[Any] = [],
699+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
688700
rpc_metadata: Mapping[str, str] = {},
689701
rpc_timeout: Optional[timedelta] = None,
690702
request_eager_start: bool = False,
@@ -719,6 +731,7 @@ async def execute_workflow(
719731
start_delay: Optional[timedelta] = None,
720732
start_signal: Optional[str] = None,
721733
start_signal_args: Sequence[Any] = [],
734+
completion_callbacks: Sequence[temporalio.common.CompletionCallback] = [],
722735
rpc_metadata: Mapping[str, str] = {},
723736
rpc_timeout: Optional[timedelta] = None,
724737
request_eager_start: bool = False,
@@ -753,6 +766,7 @@ async def execute_workflow(
753766
start_delay=start_delay,
754767
start_signal=start_signal,
755768
start_signal_args=start_signal_args,
769+
completion_callbacks=completion_callbacks,
756770
rpc_metadata=rpc_metadata,
757771
rpc_timeout=rpc_timeout,
758772
request_eager_start=request_eager_start,
@@ -5148,6 +5162,7 @@ class StartWorkflowInput:
51485162
headers: Mapping[str, temporalio.api.common.v1.Payload]
51495163
start_signal: Optional[str]
51505164
start_signal_args: Sequence[Any]
5165+
completion_callbacks: Sequence[temporalio.common.CompletionCallback]
51515166
static_summary: Optional[str]
51525167
static_details: Optional[str]
51535168
# Type may be absent
@@ -5770,6 +5785,11 @@ async def _build_start_workflow_execution_request(
57705785
req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest()
57715786
req.request_eager_execution = input.request_eager_start
57725787
await self._populate_start_workflow_execution_request(req, input)
5788+
for callback in input.completion_callbacks:
5789+
c = temporalio.api.common.v1.Callback()
5790+
c.nexus.url = callback.url
5791+
c.nexus.header.update(callback.header)
5792+
req.completion_callbacks.append(c)
57735793
return req
57745794

57755795
async def _build_signal_with_start_workflow_execution_request(

temporalio/common.py

+11
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,17 @@ def __setstate__(self, state: object) -> None:
195195
)
196196

197197

198+
@dataclass(frozen=True)
199+
class CompletionCallback:
200+
"""Callback to attach to various events in the system, e.g. workflow run completion."""
201+
202+
url: str
203+
"""Callback URL."""
204+
205+
header: Mapping[str, str]
206+
"""Header to attach to callback request."""
207+
208+
198209
# We choose to make this a list instead of an sequence so we can catch if people
199210
# are not sending lists each time but maybe accidentally sending a string (which
200211
# is a sequence)

temporalio/nexus/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)