Skip to content

Commit

Permalink
feat: add StreamActivatedJobs
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Oct 8, 2024
1 parent eaa793a commit 51df5ce
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 14 deletions.
13 changes: 13 additions & 0 deletions pyzeebe/errors/job_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ def __init__(self, task_type: str, worker: str, timeout: int, max_jobs_to_activa
super().__init__(msg)


class StreamActivateJobsRequestInvalidError(PyZeebeError):
def __init__(self, task_type: str, worker: str, timeout: int):
msg = "Failed to activate jobs. Reasons:"
if task_type == "" or task_type is None:
msg = msg + "task_type is empty, "
if worker == "" or task_type is None:
msg = msg + "worker is empty, "
if timeout < 1:
msg = msg + "job timeout is smaller than 0ms, "

super().__init__(msg)


class JobAlreadyDeactivatedError(PyZeebeError):
def __init__(self, job_key: int) -> None:
super().__init__(f"Job {job_key} was already stopped (Completed/Failed/Error)")
Expand Down
30 changes: 30 additions & 0 deletions pyzeebe/grpc_internals/zeebe_job_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
ActivateJobsRequest,
CompleteJobRequest,
FailJobRequest,
StreamActivatedJobsRequest,
ThrowErrorRequest,
)

from pyzeebe.errors import (
ActivateJobsRequestInvalidError,
JobAlreadyDeactivatedError,
JobNotFoundError,
StreamActivateJobsRequestInvalidError,
)
from pyzeebe.grpc_internals.grpc_utils import is_error_status
from pyzeebe.grpc_internals.zeebe_adapter_base import ZeebeAdapterBase
Expand Down Expand Up @@ -63,6 +65,34 @@ async def activate_jobs(
raise ActivateJobsRequestInvalidError(task_type, worker, timeout, max_jobs_to_activate) from grpc_error
await self._handle_grpc_error(grpc_error)

async def stream_activate_jobs(
self,
task_type: str,
worker: str,
timeout: int,
variables_to_fetch: Iterable[str],
request_timeout: int,
tenant_ids: Optional[Iterable[str]] = None,
) -> AsyncGenerator[Job, None]:
try:
async for raw_job in self._gateway_stub.StreamActivatedJobs(
StreamActivatedJobsRequest(
type=task_type,
worker=worker,
timeout=timeout,
fetchVariable=variables_to_fetch,
tenantIds=tenant_ids or [],
),
timeout=request_timeout,
):
job = self._create_job_from_raw_job(raw_job)
logger.debug("Got job: %s from zeebe", job)
yield job
except grpc.aio.AioRpcError as grpc_error:
if is_error_status(grpc_error, grpc.StatusCode.INVALID_ARGUMENT):
raise StreamActivateJobsRequestInvalidError(task_type, worker, timeout) from grpc_error
await self._handle_grpc_error(grpc_error)

def _create_job_from_raw_job(self, response: ActivatedJob) -> Job:
return Job(
key=response.key,
Expand Down
91 changes: 80 additions & 11 deletions pyzeebe/worker/job_poller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import asyncio
import logging
from typing import List, Optional
from typing import Optional, final

from pyzeebe.errors import (
ActivateJobsRequestInvalidError,
Expand All @@ -17,17 +18,17 @@
logger = logging.getLogger(__name__)


class JobPoller:
class JobPollerABC(abc.ABC):
def __init__(
self,
zeebe_adapter: ZeebeJobAdapter,
task: Task,
queue: "asyncio.Queue[Job]",
queue: asyncio.Queue[Job],
worker_name: str,
request_timeout: int,
task_state: TaskState,
poll_retry_delay: int,
tenant_ids: Optional[List[str]],
tenant_ids: Optional[list[str]],
) -> None:
self.zeebe_adapter = zeebe_adapter
self.task = task
Expand All @@ -43,6 +44,9 @@ async def poll(self) -> None:
while self.should_poll():
await self.activate_max_jobs()

@abc.abstractmethod
async def poll_once(self) -> None: ...

async def activate_max_jobs(self) -> None:
if self.calculate_max_jobs_to_activate() > 0:
await self.poll_once()
Expand All @@ -54,6 +58,20 @@ async def activate_max_jobs(self) -> None:
)
await asyncio.sleep(self.poll_retry_delay)

def should_poll(self) -> bool:
return not self.stop_event.is_set() and (self.zeebe_adapter.connected or self.zeebe_adapter.retrying_connection)

def calculate_max_jobs_to_activate(self) -> int:
worker_max_jobs = self.task.config.max_running_jobs - self.task_state.count_active()
return min(worker_max_jobs, self.task.config.max_jobs_to_activate)

async def stop(self) -> None:
self.stop_event.set()
await self.queue.join()


@final
class JobPoller(JobPollerABC):
async def poll_once(self) -> None:
try:
jobs = self.zeebe_adapter.activate_jobs(
Expand Down Expand Up @@ -83,13 +101,64 @@ async def poll_once(self) -> None:
)
await asyncio.sleep(5)

def should_poll(self) -> bool:
return not self.stop_event.is_set() and (self.zeebe_adapter.connected or self.zeebe_adapter.retrying_connection)

def calculate_max_jobs_to_activate(self) -> int:
worker_max_jobs = self.task.config.max_running_jobs - self.task_state.count_active()
return min(worker_max_jobs, self.task.config.max_jobs_to_activate)
@final
class JobStreamer(JobPollerABC):
def __init__(
self,
zeebe_adapter: ZeebeJobAdapter,
task: Task,
queue: asyncio.Queue[Job],
worker_name: str,
request_timeout: int,
task_state: TaskState,
poll_retry_delay: int,
tenant_ids: Optional[list[str]],
) -> None:
super().__init__(
zeebe_adapter,
task,
queue,
worker_name,
request_timeout,
task_state,
poll_retry_delay,
tenant_ids,
)
self._create_stream()

def _create_stream(self) -> None:
self._stream = self.zeebe_adapter.stream_activate_jobs(
task_type=self.task.type,
worker=self.worker_name,
timeout=self.task.config.timeout_ms,
variables_to_fetch=self.task.config.variables_to_fetch or [],
request_timeout=self.request_timeout,
tenant_ids=self.tenant_ids,
)

async def poll_once(self) -> None:
try:
job = await self._stream.__anext__()
self.task_state.add(job)
await self.queue.put(job)
except StopAsyncIteration:
self._create_stream()
except ActivateJobsRequestInvalidError:
logger.warning("Activate job requests was invalid for task %s", self.task.type)
raise
except (
ZeebeBackPressureError,
ZeebeGatewayUnavailableError,
ZeebeInternalError,
ZeebeDeadlineExceeded,
) as error:
logger.warning(
"Failed to activate jobs from the gateway. Exception: %s. Retrying in 5 seconds...",
repr(error),
)
await asyncio.sleep(5)

async def stop(self) -> None:
self.stop_event.set()
await self.queue.join()
await self._stream.aclose()
await super().stop()
24 changes: 22 additions & 2 deletions pyzeebe/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyzeebe.task import task_builder
from pyzeebe.task.exception_handler import ExceptionHandler
from pyzeebe.worker.job_executor import JobExecutor
from pyzeebe.worker.job_poller import JobPoller
from pyzeebe.worker.job_poller import JobPoller, JobStreamer
from pyzeebe.worker.task_router import ZeebeTaskRouter
from pyzeebe.worker.task_state import TaskState

Expand All @@ -34,6 +34,7 @@ def __init__(
poll_retry_delay: int = 5,
tenant_ids: Optional[List[str]] = None,
exception_handler: Optional[ExceptionHandler] = None,
stream_enabled: bool = False,
):
"""
Args:
Expand All @@ -47,6 +48,7 @@ def __init__(
watcher_max_errors_factor (int): Number of consecutive errors for a task watcher will accept before raising MaxConsecutiveTaskThreadError
poll_retry_delay (int): The number of seconds to wait before attempting to poll again when reaching max amount of running jobs
tenant_ids (List[str]): A list of tenant IDs for which to activate jobs. New in Zeebe 8.3.
stream_enabled (bool): Enables the job worker to stream jobs. It will still poll for older jobs, but streaming is favored. New in Zeebe 8.4.
"""
super().__init__(before, after, exception_handler)
self.zeebe_adapter = ZeebeAdapter(grpc_channel, max_connection_retries)
Expand All @@ -57,11 +59,13 @@ def __init__(
self.poll_retry_delay = poll_retry_delay
self.tenant_ids = tenant_ids
self._job_pollers: List[JobPoller] = []
self._job_streamers: List[JobStreamer] = []
self._job_executors: List[JobExecutor] = []
self._stop_event = anyio.Event()
self._stream_enabled = stream_enabled

def _init_tasks(self) -> None:
self._job_executors, self._job_pollers = [], []
self._job_executors, self._job_pollers, self._job_streamers = [], [], []

for task in self.tasks:
jobs_queue: "asyncio.Queue[Job]" = asyncio.Queue()
Expand All @@ -82,6 +86,19 @@ def _init_tasks(self) -> None:
self._job_pollers.append(poller)
self._job_executors.append(executor)

if self._stream_enabled:
streamer = JobStreamer(
self.zeebe_adapter,
task,
jobs_queue,
self.name,
self.request_timeout,
task_state,
self.poll_retry_delay,
self.tenant_ids,
)
self._job_streamers.append(streamer)

async def work(self) -> None:
"""
Start the worker. The worker will poll zeebe for jobs of each task in a different thread.
Expand All @@ -100,6 +117,9 @@ async def work(self) -> None:
for poller in self._job_pollers:
tg.start_soon(poller.poll)

for streamer in self._job_streamers:
tg.start_soon(streamer.poll)

for executor in self._job_executors:
tg.start_soon(executor.execute)

Expand Down
31 changes: 31 additions & 0 deletions tests/unit/utils/gateway_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,37 @@ def ActivateJobs(self, request, context):
)
yield ActivateJobsResponse(jobs=jobs)

def StreamActivatedJobs(self, request, context):
if not request.type:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
return ActivatedJob()

if request.timeout <= 0:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
return ActivatedJob()

if not request.worker:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
return ActivatedJob()

for active_job in self.active_jobs.values():
if active_job.type == request.type:
yield ActivatedJob(
key=active_job.key,
type=active_job.type,
processInstanceKey=active_job.process_instance_key,
bpmnProcessId=active_job.bpmn_process_id,
processDefinitionVersion=active_job.process_definition_version,
processDefinitionKey=active_job.process_definition_key,
elementId=active_job.element_id,
elementInstanceKey=active_job.element_instance_key,
customHeaders=json.dumps(active_job.custom_headers),
worker=active_job.worker,
retries=active_job.retries,
deadline=active_job.deadline,
variables=json.dumps(active_job.variables),
)

def CompleteJob(self, request, context):
if request.jobKey in self.active_jobs.keys():
active_job = self.active_jobs.get(request.jobKey)
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/worker/job_poller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyzeebe.grpc_internals.zeebe_adapter import ZeebeAdapter
from pyzeebe.job.job import Job
from pyzeebe.task.task import Task
from pyzeebe.worker.job_poller import JobPoller
from pyzeebe.worker.job_poller import JobPoller, JobStreamer
from pyzeebe.worker.task_state import TaskState
from tests.unit.utils.gateway_mock import GatewayMock
from tests.unit.utils.random_utils import random_job
Expand All @@ -17,6 +17,13 @@ def job_poller(zeebe_adapter: ZeebeAdapter, task: Task, queue: asyncio.Queue, ta
return JobPoller(zeebe_adapter, task, queue, "test_worker", 100, task_state, 0, None)


@pytest.fixture
def job_stream_poller(
zeebe_adapter: ZeebeAdapter, task: Task, queue: asyncio.Queue, task_state: TaskState
) -> JobStreamer:
return JobStreamer(zeebe_adapter, task, queue, "test_worker", 100, task_state, 0, [])


@pytest.mark.asyncio
class TestPollOnce:
async def test_one_job_is_polled(
Expand Down Expand Up @@ -44,6 +51,33 @@ async def test_job_is_added_to_task_state(
assert job_poller.task_state.count_active() == 1


@pytest.mark.asyncio
class TestStreamPollOnce:
async def test_one_job_is_polled(
self, job_stream_poller: JobStreamer, queue: asyncio.Queue, job_from_task: Job, grpc_servicer: GatewayMock
):
grpc_servicer.active_jobs[job_from_task.key] = job_from_task

await job_stream_poller.poll_once()

job: Job = queue.get_nowait()
assert job.key == job_from_task.key

async def test_no_job_is_polled(self, job_stream_poller: JobStreamer, queue: asyncio.Queue):
await job_stream_poller.poll_once()

assert queue.empty()

async def test_job_is_added_to_task_state(
self, job_stream_poller: JobStreamer, job_from_task: Job, grpc_servicer: GatewayMock
):
grpc_servicer.active_jobs[job_from_task.key] = job_from_task

await job_stream_poller.poll_once()

assert job_stream_poller.task_state.count_active() == 1


class TestShouldPoll:
def test_should_poll_returns_expected_result_when_disconnected(self, job_poller: JobPoller):
job_poller.zeebe_adapter.connected = False
Expand Down

0 comments on commit 51df5ce

Please sign in to comment.