diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index d8966a0d..728f1ed8 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -4,7 +4,7 @@ import uuid from contextlib import closing from datetime import timedelta -from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar +from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar from temporalio.api.common.v1 import WorkflowExecution from temporalio.api.enums.v1 import IndexedValueType @@ -14,11 +14,12 @@ ) from temporalio.api.update.v1 import UpdateRef from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest -from temporalio.client import BuildIdOpAddNewDefault, Client +from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle from temporalio.common import SearchAttributeKey from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import Worker, WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +from temporalio.workflow import UpdateMethodMultiParam def new_worker( @@ -128,3 +129,24 @@ async def workflow_update_exists( if err.status != RPCStatusCode.NOT_FOUND: raise return False + + +# TODO: type update return value +async def admitted_update_task( + client: Client, + handle: WorkflowHandle, + update_method: UpdateMethodMultiParam, + id: str, + **kwargs, +) -> asyncio.Task: + """ + Return an asyncio.Task for an update after waiting for it to be admitted. + """ + update_task = asyncio.create_task( + handle.execute_update(update_method, id=id, **kwargs) + ) + await assert_eq_eventually( + True, + lambda: workflow_update_exists(client, handle.id, id), + ) + return update_task diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index fcd957ae..4b4cbb01 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -105,6 +105,7 @@ WorkflowRunner, ) from tests.helpers import ( + admitted_update_task, assert_eq_eventually, ensure_search_attributes_present, find_free_port, @@ -5505,3 +5506,300 @@ def _unfinished_handler_warning_cls(self) -> Type: "update": workflow.UnfinishedUpdateHandlersWarning, "signal": workflow.UnfinishedSignalHandlersWarning, }[self.handler_type] + + +# The following Lock and Semaphore tests test that asyncio concurrency primitives work as expected +# in workflow code. There is nothing Temporal-specific about the way that asyncio.Lock and +# asyncio.Semaphore are used here. + + +@activity.defn +async def noop_activity_for_lock_or_semaphore_tests() -> None: + return None + + +@dataclass +class LockOrSemaphoreWorkflowConcurrencySummary: + ever_in_critical_section: int + peak_in_critical_section: int + + +@dataclass +class UseLockOrSemaphoreWorkflowParameters: + n_coroutines: int = 0 + semaphore_initial_value: Optional[int] = None + sleep: Optional[float] = None + timeout: Optional[float] = None + + +@workflow.defn +class CoroutinesUseLockWorkflow: + def __init__(self) -> None: + self.params: UseLockOrSemaphoreWorkflowParameters + self.lock_or_semaphore: Union[asyncio.Lock, asyncio.Semaphore] + self._currently_in_critical_section: set[str] = set() + self._ever_in_critical_section: set[str] = set() + self._peak_in_critical_section = 0 + + def init(self, params: UseLockOrSemaphoreWorkflowParameters): + self.params = params + if self.params.semaphore_initial_value is not None: + self.lock_or_semaphore = asyncio.Semaphore( + self.params.semaphore_initial_value + ) + else: + self.lock_or_semaphore = asyncio.Lock() + + @workflow.run + async def run( + self, + params: UseLockOrSemaphoreWorkflowParameters, + ) -> LockOrSemaphoreWorkflowConcurrencySummary: + # TODO: Use workflow init method when it exists. + self.init(params) + await asyncio.gather( + *(self.coroutine(f"{i}") for i in range(self.params.n_coroutines)) + ) + assert not any(self._currently_in_critical_section) + return LockOrSemaphoreWorkflowConcurrencySummary( + len(self._ever_in_critical_section), + self._peak_in_critical_section, + ) + + async def coroutine(self, id: str): + if self.params.timeout: + try: + await asyncio.wait_for( + self.lock_or_semaphore.acquire(), self.params.timeout + ) + except asyncio.TimeoutError: + return + else: + await self.lock_or_semaphore.acquire() + self._enters_critical_section(id) + try: + if self.params.sleep: + await asyncio.sleep(self.params.sleep) + else: + await workflow.execute_activity( + noop_activity_for_lock_or_semaphore_tests, + schedule_to_close_timeout=timedelta(seconds=30), + ) + finally: + self.lock_or_semaphore.release() + self._exits_critical_section(id) + + def _enters_critical_section(self, id: str) -> None: + self._currently_in_critical_section.add(id) + self._ever_in_critical_section.add(id) + self._peak_in_critical_section = max( + self._peak_in_critical_section, + len(self._currently_in_critical_section), + ) + + def _exits_critical_section(self, id: str) -> None: + self._currently_in_critical_section.remove(id) + + +@workflow.defn +class HandlerCoroutinesUseLockWorkflow(CoroutinesUseLockWorkflow): + def __init__(self) -> None: + super().__init__() + self.workflow_may_exit = False + + @workflow.run + async def run( + self, + ) -> LockOrSemaphoreWorkflowConcurrencySummary: + await workflow.wait_condition(lambda: self.workflow_may_exit) + return LockOrSemaphoreWorkflowConcurrencySummary( + len(self._ever_in_critical_section), + self._peak_in_critical_section, + ) + + @workflow.update + async def my_update(self, params: UseLockOrSemaphoreWorkflowParameters): + # TODO: Use workflow init method when it exists. + if not hasattr(self, "params"): + self.init(params) + assert (update_info := workflow.current_update_info()) + await self.coroutine(update_info.id) + + @workflow.signal + async def finish(self): + self.workflow_may_exit = True + + +async def _do_workflow_coroutines_lock_or_semaphore_test( + client: Client, + params: UseLockOrSemaphoreWorkflowParameters, + expectation: LockOrSemaphoreWorkflowConcurrencySummary, +): + async with new_worker( + client, + CoroutinesUseLockWorkflow, + activities=[noop_activity_for_lock_or_semaphore_tests], + ) as worker: + summary = await client.execute_workflow( + CoroutinesUseLockWorkflow.run, + arg=params, + id=str(uuid.uuid4()), + task_queue=worker.task_queue, + ) + assert summary == expectation + + +async def _do_update_handler_lock_or_semaphore_test( + client: Client, + env: WorkflowEnvironment, + params: UseLockOrSemaphoreWorkflowParameters, + n_updates: int, + expectation: LockOrSemaphoreWorkflowConcurrencySummary, +): + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/1903" + ) + + task_queue = "tq" + handle = await client.start_workflow( + HandlerCoroutinesUseLockWorkflow.run, + id=f"wf-{str(uuid.uuid4())}", + task_queue=task_queue, + ) + # Create updates in Admitted state, before the worker starts polling. + admitted_updates = [ + await admitted_update_task( + client, + handle, + HandlerCoroutinesUseLockWorkflow.my_update, + arg=params, + id=f"update-{i}", + ) + for i in range(n_updates) + ] + async with new_worker( + client, + HandlerCoroutinesUseLockWorkflow, + activities=[noop_activity_for_lock_or_semaphore_tests], + task_queue=task_queue, + ): + for update_task in admitted_updates: + await update_task + await handle.signal(HandlerCoroutinesUseLockWorkflow.finish) + summary = await handle.result() + assert summary == expectation + + +async def test_workflow_coroutines_can_use_lock(client: Client): + await _do_workflow_coroutines_lock_or_semaphore_test( + client, + UseLockOrSemaphoreWorkflowParameters(n_coroutines=5), + # The lock limits concurrency to 1 + expectation=LockOrSemaphoreWorkflowConcurrencySummary( + ever_in_critical_section=5, peak_in_critical_section=1 + ), + ) + + +async def test_update_handler_can_use_lock_to_serialize_handler_executions( + client: Client, env: WorkflowEnvironment +): + await _do_update_handler_lock_or_semaphore_test( + client, + env, + UseLockOrSemaphoreWorkflowParameters(), + n_updates=5, + # The lock limits concurrency to 1 + expectation=LockOrSemaphoreWorkflowConcurrencySummary( + ever_in_critical_section=5, peak_in_critical_section=1 + ), + ) + + +async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Client): + await _do_workflow_coroutines_lock_or_semaphore_test( + client, + UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, sleep=0.5, timeout=0.1), + # Second and subsequent coroutines fail to acquire the lock due to the timeout. + expectation=LockOrSemaphoreWorkflowConcurrencySummary( + ever_in_critical_section=1, peak_in_critical_section=1 + ), + ) + + +async def test_update_handler_lock_acquisition_respects_timeout( + client: Client, env: WorkflowEnvironment +): + await _do_update_handler_lock_or_semaphore_test( + client, + env, + # Second and subsequent handler executions fail to acquire the lock due to the timeout. + UseLockOrSemaphoreWorkflowParameters(sleep=0.5, timeout=0.1), + n_updates=5, + expectation=LockOrSemaphoreWorkflowConcurrencySummary( + ever_in_critical_section=1, peak_in_critical_section=1 + ), + ) + + +async def test_workflow_coroutines_can_use_semaphore(client: Client): + await _do_workflow_coroutines_lock_or_semaphore_test( + client, + UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, semaphore_initial_value=3), + # The semaphore limits concurrency to 3 + expectation=LockOrSemaphoreWorkflowConcurrencySummary( + ever_in_critical_section=5, peak_in_critical_section=3 + ), + ) + + +async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency( + client: Client, env: WorkflowEnvironment +): + await _do_update_handler_lock_or_semaphore_test( + client, + env, + # The semaphore limits concurrency to 3 + UseLockOrSemaphoreWorkflowParameters(semaphore_initial_value=3), + n_updates=5, + expectation=LockOrSemaphoreWorkflowConcurrencySummary( + ever_in_critical_section=5, peak_in_critical_section=3 + ), + ) + + +async def test_workflow_coroutine_semaphore_acquisition_respects_timeout( + client: Client, +): + await _do_workflow_coroutines_lock_or_semaphore_test( + client, + UseLockOrSemaphoreWorkflowParameters( + n_coroutines=5, semaphore_initial_value=3, sleep=0.5, timeout=0.1 + ), + # Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore + # slot fail. + expectation=LockOrSemaphoreWorkflowConcurrencySummary( + ever_in_critical_section=3, peak_in_critical_section=3 + ), + ) + + +async def test_update_handler_semaphore_acquisition_respects_timeout( + client: Client, env: WorkflowEnvironment +): + await _do_update_handler_lock_or_semaphore_test( + client, + env, + # Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore + # slot fail. + UseLockOrSemaphoreWorkflowParameters( + semaphore_initial_value=3, + sleep=0.5, + timeout=0.1, + ), + n_updates=5, + expectation=LockOrSemaphoreWorkflowConcurrencySummary( + ever_in_critical_section=3, peak_in_critical_section=3 + ), + )