Skip to content

Commit

Permalink
Add tests of asyncio.Lock and asyncio.Semaphore usage
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison committed Jun 29, 2024
1 parent 7ac4445 commit a4d39c9
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 2 deletions.
26 changes: 24 additions & 2 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from contextlib import closing
from datetime import timedelta
from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar

from temporalio.api.common.v1 import WorkflowExecution
from temporalio.api.enums.v1 import IndexedValueType
Expand All @@ -14,11 +14,12 @@
)
from temporalio.api.update.v1 import UpdateRef
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
from temporalio.client import BuildIdOpAddNewDefault, Client
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
from temporalio.common import SearchAttributeKey
from temporalio.service import RPCError, RPCStatusCode
from temporalio.worker import Worker, WorkflowRunner
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
from temporalio.workflow import UpdateMethodMultiParam


def new_worker(
Expand Down Expand Up @@ -128,3 +129,24 @@ async def workflow_update_exists(
if err.status != RPCStatusCode.NOT_FOUND:
raise
return False


# TODO: type update return value
async def admitted_update_task(
client: Client,
handle: WorkflowHandle,
update_method: UpdateMethodMultiParam,
id: str,
**kwargs,
) -> asyncio.Task:
"""
Return an asyncio.Task for an update after waiting for it to be admitted.
"""
update_task = asyncio.create_task(
handle.execute_update(update_method, id=id, **kwargs)
)
await assert_eq_eventually(
True,
lambda: workflow_update_exists(client, handle.id, id),
)
return update_task
298 changes: 298 additions & 0 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
WorkflowRunner,
)
from tests.helpers import (
admitted_update_task,
assert_eq_eventually,
ensure_search_attributes_present,
find_free_port,
Expand Down Expand Up @@ -5510,3 +5511,300 @@ def _unfinished_handler_warning_cls(self) -> Type:
"update": workflow.UnfinishedUpdateHandlersWarning,
"signal": workflow.UnfinishedSignalHandlersWarning,
}[self.handler_type]


# The following Lock and Semaphore tests test that asyncio concurrency primitives work as expected
# in workflow code. There is nothing Temporal-specific about the way that asyncio.Lock and
# asyncio.Semaphore are used here.


@activity.defn
async def noop_activity_for_lock_or_semaphore_tests() -> None:
return None


@dataclass
class LockOrSemaphoreWorkflowConcurrencySummary:
ever_in_critical_section: int
peak_in_critical_section: int


@dataclass
class UseLockOrSemaphoreWorkflowParameters:
n_coroutines: int = 0
semaphore_initial_value: Optional[int] = None
sleep: Optional[float] = None
timeout: Optional[float] = None


@workflow.defn
class CoroutinesUseLockWorkflow:
def __init__(self) -> None:
self.params: UseLockOrSemaphoreWorkflowParameters
self.lock_or_semaphore: Union[asyncio.Lock, asyncio.Semaphore]
self._currently_in_critical_section: set[str] = set()
self._ever_in_critical_section: set[str] = set()
self._peak_in_critical_section = 0

def init(self, params: UseLockOrSemaphoreWorkflowParameters):
self.params = params
if self.params.semaphore_initial_value is not None:
self.lock_or_semaphore = asyncio.Semaphore(
self.params.semaphore_initial_value
)
else:
self.lock_or_semaphore = asyncio.Lock()

@workflow.run
async def run(
self,
params: UseLockOrSemaphoreWorkflowParameters,
) -> LockOrSemaphoreWorkflowConcurrencySummary:
# TODO: Use workflow init method when it exists.
self.init(params)
await asyncio.gather(
*(self.coroutine(f"{i}") for i in range(self.params.n_coroutines))
)
assert not any(self._currently_in_critical_section)
return LockOrSemaphoreWorkflowConcurrencySummary(
len(self._ever_in_critical_section),
self._peak_in_critical_section,
)

async def coroutine(self, id: str):
if self.params.timeout:
try:
await asyncio.wait_for(
self.lock_or_semaphore.acquire(), self.params.timeout
)
except asyncio.TimeoutError:
return
else:
await self.lock_or_semaphore.acquire()
self._enters_critical_section(id)
try:
if self.params.sleep:
await asyncio.sleep(self.params.sleep)
else:
await workflow.execute_activity(
noop_activity_for_lock_or_semaphore_tests,
schedule_to_close_timeout=timedelta(seconds=30),
)
finally:
self.lock_or_semaphore.release()
self._exits_critical_section(id)

def _enters_critical_section(self, id: str) -> None:
self._currently_in_critical_section.add(id)
self._ever_in_critical_section.add(id)
self._peak_in_critical_section = max(
self._peak_in_critical_section,
len(self._currently_in_critical_section),
)

def _exits_critical_section(self, id: str) -> None:
self._currently_in_critical_section.remove(id)


@workflow.defn
class HandlerCoroutinesUseLockWorkflow(CoroutinesUseLockWorkflow):
def __init__(self) -> None:
super().__init__()
self.workflow_may_exit = False

@workflow.run
async def run(
self,
) -> LockOrSemaphoreWorkflowConcurrencySummary:
await workflow.wait_condition(lambda: self.workflow_may_exit)
return LockOrSemaphoreWorkflowConcurrencySummary(
len(self._ever_in_critical_section),
self._peak_in_critical_section,
)

@workflow.update
async def my_update(self, params: UseLockOrSemaphoreWorkflowParameters):
# TODO: Use workflow init method when it exists.
if not hasattr(self, "params"):
self.init(params)
assert (update_info := workflow.current_update_info())
await self.coroutine(update_info.id)

@workflow.signal
async def finish(self):
self.workflow_may_exit = True


async def _do_workflow_coroutines_lock_or_semaphore_test(
client: Client,
params: UseLockOrSemaphoreWorkflowParameters,
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
):
async with new_worker(
client,
CoroutinesUseLockWorkflow,
activities=[noop_activity_for_lock_or_semaphore_tests],
) as worker:
summary = await client.execute_workflow(
CoroutinesUseLockWorkflow.run,
arg=params,
id=str(uuid.uuid4()),
task_queue=worker.task_queue,
)
assert summary == expectation


async def _do_update_handler_lock_or_semaphore_test(
client: Client,
env: WorkflowEnvironment,
params: UseLockOrSemaphoreWorkflowParameters,
n_updates: int,
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
):
if env.supports_time_skipping:
pytest.skip(
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
)

task_queue = "tq"
handle = await client.start_workflow(
HandlerCoroutinesUseLockWorkflow.run,
id=f"wf-{str(uuid.uuid4())}",
task_queue=task_queue,
)
# Create updates in Admitted state, before the worker starts polling.
admitted_updates = [
await admitted_update_task(
client,
handle,
HandlerCoroutinesUseLockWorkflow.my_update,
arg=params,
id=f"update-{i}",
)
for i in range(n_updates)
]
async with new_worker(
client,
HandlerCoroutinesUseLockWorkflow,
activities=[noop_activity_for_lock_or_semaphore_tests],
task_queue=task_queue,
):
for update_task in admitted_updates:
await update_task
await handle.signal(HandlerCoroutinesUseLockWorkflow.finish)
summary = await handle.result()
assert summary == expectation


async def test_workflow_coroutines_can_use_lock(client: Client):
await _do_workflow_coroutines_lock_or_semaphore_test(
client,
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5),
# The lock limits concurrency to 1
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=5, peak_in_critical_section=1
),
)


async def test_update_handler_can_use_lock_to_serialize_handler_executions(
client: Client, env: WorkflowEnvironment
):
await _do_update_handler_lock_or_semaphore_test(
client,
env,
UseLockOrSemaphoreWorkflowParameters(),
n_updates=5,
# The lock limits concurrency to 1
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=5, peak_in_critical_section=1
),
)


async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Client):
await _do_workflow_coroutines_lock_or_semaphore_test(
client,
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, sleep=0.5, timeout=0.1),
# Second and subsequent coroutines fail to acquire the lock due to the timeout.
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=1, peak_in_critical_section=1
),
)


async def test_update_handler_lock_acquisition_respects_timeout(
client: Client, env: WorkflowEnvironment
):
await _do_update_handler_lock_or_semaphore_test(
client,
env,
# Second and subsequent handler executions fail to acquire the lock due to the timeout.
UseLockOrSemaphoreWorkflowParameters(sleep=0.5, timeout=0.1),
n_updates=5,
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=1, peak_in_critical_section=1
),
)


async def test_workflow_coroutines_can_use_semaphore(client: Client):
await _do_workflow_coroutines_lock_or_semaphore_test(
client,
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, semaphore_initial_value=3),
# The semaphore limits concurrency to 3
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=5, peak_in_critical_section=3
),
)


async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency(
client: Client, env: WorkflowEnvironment
):
await _do_update_handler_lock_or_semaphore_test(
client,
env,
# The semaphore limits concurrency to 3
UseLockOrSemaphoreWorkflowParameters(semaphore_initial_value=3),
n_updates=5,
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=5, peak_in_critical_section=3
),
)


async def test_workflow_coroutine_semaphore_acquisition_respects_timeout(
client: Client,
):
await _do_workflow_coroutines_lock_or_semaphore_test(
client,
UseLockOrSemaphoreWorkflowParameters(
n_coroutines=5, semaphore_initial_value=3, sleep=0.5, timeout=0.1
),
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
# slot fail.
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=3, peak_in_critical_section=3
),
)


async def test_update_handler_semaphore_acquisition_respects_timeout(
client: Client, env: WorkflowEnvironment
):
await _do_update_handler_lock_or_semaphore_test(
client,
env,
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
# slot fail.
UseLockOrSemaphoreWorkflowParameters(
semaphore_initial_value=3,
sleep=0.5,
timeout=0.1,
),
n_updates=5,
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=3, peak_in_critical_section=3
),
)

0 comments on commit a4d39c9

Please sign in to comment.