Skip to content

Commit

Permalink
Add workflow.instance() API for obtaining current workflow instance (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison authored Jan 24, 2025
1 parent 150878f commit 044b1de
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
3 changes: 3 additions & 0 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,9 @@ def workflow_get_update_validator(self, name: Optional[str]) -> Optional[Callabl
def workflow_info(self) -> temporalio.workflow.Info:
return self._outbound.info()

def workflow_instance(self) -> Any:
return self._object

def workflow_is_continue_as_new_suggested(self) -> bool:
return self._continue_as_new_suggested

Expand Down
12 changes: 12 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ def workflow_get_update_validator(
@abstractmethod
def workflow_info(self) -> Info: ...

@abstractmethod
def workflow_instance(self) -> Any: ...

@abstractmethod
def workflow_is_continue_as_new_suggested(self) -> bool: ...

Expand Down Expand Up @@ -818,6 +821,15 @@ def info() -> Info:
return _Runtime.current().workflow_info()


def instance() -> Any:
"""Current workflow's instance.
Returns:
The currently running workflow instance.
"""
return _Runtime.current().workflow_instance()


def memo() -> Mapping[str, Any]:
"""Current workflow's memo values, converted without type hints.
Expand Down
41 changes: 41 additions & 0 deletions tests/worker/test_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,44 @@ def pop_trace(name: str, filter: Optional[Callable[[Any], bool]] = None) -> Any:

# Confirm no unexpected traces
assert not interceptor_traces


class WorkflowInstanceAccessInterceptor(Interceptor):
def workflow_interceptor_class(
self, input: WorkflowInterceptorClassInput
) -> Optional[Type[WorkflowInboundInterceptor]]:
return WorkflowInstanceAccessInboundInterceptor


class WorkflowInstanceAccessInboundInterceptor(WorkflowInboundInterceptor):
async def execute_workflow(self, input: ExecuteWorkflowInput) -> int:
# Return integer difference between ids of workflow instance obtained from workflow run method and
# from workflow.instance(). They should be the same, so the difference should be 0.
from_workflow_instance_api = workflow.instance()
assert from_workflow_instance_api is not None
id_from_workflow_instance_api = id(from_workflow_instance_api)
id_from_workflow_run_method = await super().execute_workflow(input)
return id_from_workflow_run_method - id_from_workflow_instance_api


@workflow.defn
class WorkflowInstanceAccessWorkflow:
@workflow.run
async def run(self) -> int:
return id(self)


async def test_workflow_instance_access_from_interceptor(client: Client):
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client,
task_queue=task_queue,
workflows=[WorkflowInstanceAccessWorkflow],
interceptors=[WorkflowInstanceAccessInterceptor()],
):
difference = await client.execute_workflow(
WorkflowInstanceAccessWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)
assert difference == 0

0 comments on commit 044b1de

Please sign in to comment.