Skip to content

Commit 7a121bb

Browse files
committed
gemini
1 parent 6f1eefc commit 7a121bb

File tree

3 files changed

+50
-24
lines changed

3 files changed

+50
-24
lines changed

temporalio/worker/_interceptor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import asyncio
65
import concurrent.futures
76
from dataclasses import dataclass
87
from datetime import timedelta
@@ -425,6 +424,6 @@ def start_local_activity(
425424

426425
async def start_nexus_operation(
427426
self, input: StartNexusOperationInput
428-
) -> asyncio.Task:
427+
) -> temporalio.workflow.NexusOperationHandle[Any]:
429428
"""Called for every :py:func:`temporalio.workflow.start_nexus_operation` call."""
430429
return await self.next.start_nexus_operation(input)

temporalio/worker/_workflow_instance.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,7 @@ async def workflow_start_nexus_operation(
14331433
input: Any,
14341434
schedule_to_close_timeout: Optional[timedelta] = None,
14351435
headers: Optional[Mapping[str, str]] = None,
1436-
) -> asyncio.Task:
1436+
) -> temporalio.workflow.NexusOperationHandle[Any]:
14371437
return await self._outbound.start_nexus_operation(
14381438
StartNexusOperationInput(
14391439
endpoint=endpoint,
@@ -2860,13 +2860,17 @@ async def cancel(self) -> None:
28602860
await self._instance._cancel_external_workflow(command)
28612861

28622862

2863-
class _NexusOperationHandle:
2863+
I = TypeVar("I")
2864+
O = TypeVar("O")
2865+
2866+
2867+
class _NexusOperationHandle(temporalio.workflow.NexusOperationHandle[O]):
28642868
def __init__(
28652869
self,
28662870
instance: _WorkflowInstanceImpl,
28672871
seq: int,
28682872
input: StartNexusOperationInput,
2869-
fn: Coroutine[Any, Any, Any],
2873+
fn: Coroutine[Any, Any, O],
28702874
):
28712875
self._instance = instance
28722876
self._seq = seq
@@ -2876,6 +2880,9 @@ def __init__(
28762880
self._result_fut: asyncio.Future[Any] = instance.create_future()
28772881
self._operation_id: Optional[str] = None
28782882

2883+
async def result(self) -> O:
2884+
return await self._result_fut
2885+
28792886
def _resolve_start_success(self, operation_id: str) -> None:
28802887
span = xray.get_current_span()
28812888
span.add_event("_resolve_start_success", {"operation_id": operation_id})

temporalio/workflow.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
import temporalio.converter
5757
import temporalio.exceptions
5858
import temporalio.workflow
59-
from temporalio import workflow
6059

6160
from .types import (
6261
AnyType,
@@ -751,15 +750,15 @@ def workflow_start_local_activity(
751750
) -> ActivityHandle[Any]: ...
752751

753752
@abstractmethod
754-
def workflow_start_nexus_operation(
753+
async def workflow_start_nexus_operation(
755754
self,
756755
endpoint: str,
757756
service: str,
758757
operation: str,
759758
input: Any,
760759
schedule_to_close_timeout: Optional[timedelta] = None,
761760
headers: Optional[Mapping[str, str]] = None,
762-
) -> asyncio.Task: ...
761+
) -> NexusOperationHandle[Any]: ...
763762

764763
@abstractmethod
765764
def workflow_time_ns(self) -> int: ...
@@ -4250,21 +4249,43 @@ async def execute_child_workflow(
42504249
O = TypeVar("O")
42514250

42524251

4252+
class NexusOperationHandle(Generic[O]):
4253+
async def result(self) -> O:
4254+
raise NotImplementedError
4255+
4256+
42534257
async def start_nexus_operation(
42544258
endpoint: str,
42554259
service: str,
42564260
operation: str,
42574261
input: Any,
4262+
*,
42584263
schedule_to_close_timeout: Optional[timedelta] = None,
42594264
headers: Optional[Mapping[str, str]] = None,
4260-
) -> asyncio.Task:
4265+
) -> NexusOperationHandle[Any]:
4266+
"""Start a Nexus operation and return its handle.
4267+
4268+
Args:
4269+
endpoint: The Nexus endpoint.
4270+
service: The Nexus service.
4271+
operation: The Nexus operation.
4272+
input: The Nexus operation input.
4273+
schedule_to_close_timeout: Timeout for the entire operation attempt.
4274+
headers: Headers to send with the Nexus HTTP request.
4275+
4276+
Returns:
4277+
A handle to the Nexus operation. The result can be obtained as
4278+
```python
4279+
await handle.result()
4280+
```
4281+
"""
42614282
return await _Runtime.current().workflow_start_nexus_operation(
4262-
endpoint,
4263-
service,
4264-
operation,
4265-
input,
4266-
schedule_to_close_timeout,
4267-
headers,
4283+
endpoint=endpoint,
4284+
service=service,
4285+
operation=operation,
4286+
input=input,
4287+
schedule_to_close_timeout=schedule_to_close_timeout,
4288+
headers=headers,
42684289
)
42694290

42704291

@@ -5029,7 +5050,7 @@ async def method(
50295050
schedule_to_close_timeout: Optional[timedelta] = None,
50305051
headers: Optional[Mapping[str, str]] = None,
50315052
):
5032-
return await workflow.start_nexus_operation(
5053+
return await temporalio.workflow.start_nexus_operation(
50335054
endpoint=endpoint,
50345055
service=interface.__name__,
50355056
operation=name,
@@ -5042,15 +5063,15 @@ async def method(
50425063

50435064
methods[name] = method
50445065

5045-
class StartOperation:
5066+
class _ServiceClient:
50465067
async def start_operation(
50475068
self,
50485069
operation: Callable[[Any, I], Awaitable[O]],
50495070
input: I,
50505071
schedule_to_close_timeout: Optional[timedelta] = None,
50515072
headers: Optional[Mapping[str, str]] = None,
5052-
) -> asyncio.Task[O]:
5053-
return await workflow.start_nexus_operation(
5073+
) -> NexusOperationHandle[O]:
5074+
return await temporalio.workflow.start_nexus_operation(
50545075
endpoint=endpoint,
50555076
service=interface.__name__,
50565077
operation=operation.__name__,
@@ -5061,7 +5082,7 @@ async def start_operation(
50615082
headers=headers or {},
50625083
)
50635084

5064-
cls = type(f"{interface.__name__}Client", (StartOperation,), methods)
5085+
cls = type(f"{interface.__name__}Client", (_ServiceClient,), methods)
50655086
return cls() # type: ignore
50665087

50675088

@@ -5082,8 +5103,8 @@ async def start_operation(
50825103
input: I,
50835104
schedule_to_close_timeout: Optional[timedelta] = None,
50845105
headers: Optional[Mapping[str, str]] = None,
5085-
) -> asyncio.Task[O]:
5086-
return await workflow.start_nexus_operation(
5106+
) -> NexusOperationHandle[O]:
5107+
return await temporalio.workflow.start_nexus_operation(
50875108
endpoint=self._endpoint,
50885109
service=self._service_name,
50895110
operation=operation.name,
@@ -5104,5 +5125,4 @@ async def execute_operation(
51045125
handle = await self.start_operation(
51055126
operation, input, schedule_to_close_timeout, headers
51065127
)
5107-
# TODO(dan): handle.result()
5108-
return await handle
5128+
return await handle.result()

0 commit comments

Comments
 (0)