Skip to content

Commit

Permalink
feat(taskworker): Add namespace parameter to taskworker client (#81860)
Browse files Browse the repository at this point in the history
Now that taskbroker gRPC supports getting tasks by namespace, this PR
enables the taskworker to use that parameter if provided.


Depends on: getsentry/taskbroker#74
  • Loading branch information
enochtangg authored Dec 10, 2024
1 parent 3997d3f commit e8f21d4
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 23 deletions.
2 changes: 1 addition & 1 deletion requirements-dev-frozen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ sentry-forked-django-stubs==5.1.1.post1
sentry-forked-djangorestframework-stubs==3.15.1.post2
sentry-kafka-schemas==0.1.122
sentry-ophio==1.0.0
sentry-protos==0.1.37
sentry-protos==0.1.39
sentry-redis-tools==0.1.7
sentry-relay==0.9.3
sentry-sdk==2.19.2
Expand Down
2 changes: 1 addition & 1 deletion requirements-frozen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ s3transfer==0.10.0
sentry-arroyo==2.18.2
sentry-kafka-schemas==0.1.122
sentry-ophio==1.0.0
sentry-protos==0.1.37
sentry-protos==0.1.39
sentry-redis-tools==0.1.7
sentry-relay==0.9.3
sentry-sdk==2.19.2
Expand Down
9 changes: 7 additions & 2 deletions src/sentry/runner/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,18 @@ def worker(ignore_unknown_queues: bool, **options: Any) -> None:
@click.option(
"--max-task-count", help="Number of tasks this worker should run before exiting", default=10000
)
@click.option(
"--namespace", help="The dedicated task namespace that taskworker operates on", default=None
)
@log_options()
@configuration
def taskworker(rpc_host: str, max_task_count: int, **options: Any) -> None:
def taskworker(rpc_host: str, max_task_count: int, namespace: str | None, **options: Any) -> None:
from sentry.taskworker.worker import TaskWorker

with managed_bgtasks(role="taskworker"):
worker = TaskWorker(rpc_host=rpc_host, max_task_count=max_task_count, **options)
worker = TaskWorker(
rpc_host=rpc_host, max_task_count=max_task_count, namespace=namespace, **options
)
exitcode = worker.start()
raise SystemExit(exitcode)

Expand Down
14 changes: 8 additions & 6 deletions src/sentry/taskworker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import grpc
from sentry_protos.sentry.v1.taskworker_pb2 import (
FetchNextTask,
GetTaskRequest,
SetTaskStatusRequest,
TaskActivation,
Expand All @@ -24,13 +25,14 @@ def __init__(self, host: str) -> None:
self._channel = grpc.insecure_channel(self._host)
self._stub = ConsumerServiceStub(self._channel)

def get_task(self) -> TaskActivation | None:
def get_task(self, namespace: str | None = None) -> TaskActivation | None:
"""
Fetch a pending task
Fetch a pending task.
Will return None when there are no tasks to fetch
If a namespace is provided, only tasks for that namespace will be fetched.
This will return None if there are no tasks to fetch.
"""
request = GetTaskRequest()
request = GetTaskRequest(namespace=namespace)
try:
response = self._stub.GetTask(request)
except grpc.RpcError as err:
Expand All @@ -42,7 +44,7 @@ def get_task(self) -> TaskActivation | None:
return None

def update_task(
self, task_id: str, status: TaskActivationStatus.ValueType, fetch_next: bool = True
self, task_id: str, status: TaskActivationStatus.ValueType, fetch_next_task: FetchNextTask
) -> TaskActivation | None:
"""
Update the status for a given task activation.
Expand All @@ -52,7 +54,7 @@ def update_task(
request = SetTaskStatusRequest(
id=task_id,
status=status,
fetch_next=fetch_next,
fetch_next_task=fetch_next_task,
)
try:
response = self._stub.SetTaskStatus(request)
Expand Down
12 changes: 10 additions & 2 deletions src/sentry/taskworker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TASK_ACTIVATION_STATUS_COMPLETE,
TASK_ACTIVATION_STATUS_FAILURE,
TASK_ACTIVATION_STATUS_RETRY,
FetchNextTask,
TaskActivation,
)

Expand Down Expand Up @@ -67,12 +68,17 @@ class TaskWorker:
"""

def __init__(
self, rpc_host: str, max_task_count: int | None = None, **options: dict[str, Any]
self,
rpc_host: str,
max_task_count: int | None = None,
namespace: str | None = None,
**options: dict[str, Any],
) -> None:
self.options = options
self._execution_count = 0
self._worker_id = uuid4().hex
self._max_task_count = max_task_count
self._namespace = namespace
self.client = TaskworkerClient(rpc_host)
self._pool: Pool | None = None
self._build_pool()
Expand Down Expand Up @@ -134,7 +140,7 @@ def start(self) -> int:

def fetch_task(self) -> TaskActivation | None:
try:
activation = self.client.get_task()
activation = self.client.get_task(self._namespace)
except grpc.RpcError:
metrics.incr("taskworker.worker.get_task.failed")
logger.info("get_task failed. Retrying in 1 second")
Expand Down Expand Up @@ -177,6 +183,7 @@ def process_task(self, activation: TaskActivation) -> TaskActivation | None:
return self.client.update_task(
task_id=activation.id,
status=TASK_ACTIVATION_STATUS_FAILURE,
fetch_next_task=FetchNextTask(namespace=self._namespace),
)

if task.at_most_once:
Expand Down Expand Up @@ -264,4 +271,5 @@ def process_task(self, activation: TaskActivation) -> TaskActivation | None:
return self.client.update_task(
task_id=activation.id,
status=next_state,
fetch_next_task=FetchNextTask(namespace=self._namespace),
)
64 changes: 61 additions & 3 deletions tests/sentry/taskworker/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from google.protobuf.message import Message
from sentry_protos.sentry.v1.taskworker_pb2 import (
TASK_ACTIVATION_STATUS_RETRY,
FetchNextTask,
GetTaskResponse,
SetTaskStatusResponse,
TaskActivation,
Expand Down Expand Up @@ -97,6 +98,31 @@ def test_get_task_ok():
assert result.namespace == "testing"


def test_get_task_with_namespace():
channel = MockChannel()
channel.add_response(
"/sentry_protos.sentry.v1.ConsumerService/GetTask",
GetTaskResponse(
task=TaskActivation(
id="abc123",
namespace="testing",
taskname="do_thing",
parameters="",
headers={},
processing_deadline_duration=10,
)
),
)
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.get_task(namespace="testing")

assert result
assert result.id
assert result.namespace == "testing"


def test_get_task_not_found():
channel = MockChannel()
channel.add_response(
Expand Down Expand Up @@ -142,11 +168,39 @@ def test_update_task_ok_with_next():
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.update_task("abc123", TASK_ACTIVATION_STATUS_RETRY)
result = client.update_task(
"abc123", TASK_ACTIVATION_STATUS_RETRY, FetchNextTask(namespace=None)
)
assert result
assert result.id == "abc123"


def test_update_task_ok_with_next_namespace():
channel = MockChannel()
channel.add_response(
"/sentry_protos.sentry.v1.ConsumerService/SetTaskStatus",
SetTaskStatusResponse(
task=TaskActivation(
id="abc123",
namespace="testing",
taskname="do_thing",
parameters="",
headers={},
processing_deadline_duration=10,
)
),
)
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.update_task(
"abc123", TASK_ACTIVATION_STATUS_RETRY, FetchNextTask(namespace="testing")
)
assert result
assert result.id == "abc123"
assert result.namespace == "testing"


def test_update_task_ok_no_next():
channel = MockChannel()
channel.add_response(
Expand All @@ -155,7 +209,9 @@ def test_update_task_ok_no_next():
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.update_task("abc123", TASK_ACTIVATION_STATUS_RETRY)
result = client.update_task(
"abc123", TASK_ACTIVATION_STATUS_RETRY, FetchNextTask(namespace=None)
)
assert result is None


Expand All @@ -168,5 +224,7 @@ def test_update_task_not_found():
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051")
result = client.update_task("abc123", TASK_ACTIVATION_STATUS_RETRY)
result = client.update_task(
"abc123", TASK_ACTIVATION_STATUS_RETRY, FetchNextTask(namespace=None)
)
assert result is None
33 changes: 25 additions & 8 deletions tests/sentry/taskworker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TASK_ACTIVATION_STATUS_COMPLETE,
TASK_ACTIVATION_STATUS_FAILURE,
TASK_ACTIVATION_STATUS_RETRY,
FetchNextTask,
TaskActivation,
)

Expand Down Expand Up @@ -110,7 +111,9 @@ def test_process_task_complete(self) -> None:
result = taskworker.process_task(SIMPLE_TASK)

mock_update.assert_called_with(
task_id=SIMPLE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
task_id=SIMPLE_TASK.id,
status=TASK_ACTIVATION_STATUS_COMPLETE,
fetch_next_task=FetchNextTask(namespace=None),
)

assert result
Expand All @@ -123,7 +126,9 @@ def test_process_task_retry(self) -> None:
result = taskworker.process_task(RETRY_TASK)

mock_update.assert_called_with(
task_id=RETRY_TASK.id, status=TASK_ACTIVATION_STATUS_RETRY
task_id=RETRY_TASK.id,
status=TASK_ACTIVATION_STATUS_RETRY,
fetch_next_task=FetchNextTask(namespace=None),
)

assert result
Expand All @@ -136,7 +141,9 @@ def test_process_task_failure(self) -> None:
result = taskworker.process_task(FAIL_TASK)

mock_update.assert_called_with(
task_id=FAIL_TASK.id, status=TASK_ACTIVATION_STATUS_FAILURE
task_id=FAIL_TASK.id,
status=TASK_ACTIVATION_STATUS_FAILURE,
fetch_next_task=FetchNextTask(namespace=None),
)
assert result
assert result.id == SIMPLE_TASK.id
Expand All @@ -148,7 +155,9 @@ def test_process_task_at_most_once(self) -> None:
result = taskworker.process_task(AT_MOST_ONCE_TASK)

mock_update.assert_called_with(
task_id=AT_MOST_ONCE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
task_id=AT_MOST_ONCE_TASK.id,
status=TASK_ACTIVATION_STATUS_COMPLETE,
fetch_next_task=FetchNextTask(namespace=None),
)
assert taskworker.process_task(AT_MOST_ONCE_TASK) is None
assert result
Expand All @@ -169,7 +178,9 @@ def test_start_max_task_count(self) -> None:
assert result == 0
assert mock_client.get_task.called
mock_client.update_task.assert_called_with(
task_id=SIMPLE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
task_id=SIMPLE_TASK.id,
status=TASK_ACTIVATION_STATUS_COMPLETE,
fetch_next_task=FetchNextTask(namespace=None),
)

def test_start_loop(self) -> None:
Expand All @@ -188,10 +199,14 @@ def test_start_loop(self) -> None:
assert mock_client.update_task.call_count == 2

mock_client.update_task.assert_any_call(
task_id=SIMPLE_TASK.id, status=TASK_ACTIVATION_STATUS_COMPLETE
task_id=SIMPLE_TASK.id,
status=TASK_ACTIVATION_STATUS_COMPLETE,
fetch_next_task=FetchNextTask(namespace=None),
)
mock_client.update_task.assert_any_call(
task_id=RETRY_TASK.id, status=TASK_ACTIVATION_STATUS_RETRY
task_id=RETRY_TASK.id,
status=TASK_ACTIVATION_STATUS_RETRY,
fetch_next_task=FetchNextTask(namespace=None),
)

def test_start_keyboard_interrupt(self) -> None:
Expand All @@ -210,5 +225,7 @@ def test_start_unknown_task(self) -> None:
result = taskworker.start()
assert result == 0, "Exit zero, all tasks complete"
mock_client.update_task.assert_any_call(
task_id=UNDEFINED_TASK.id, status=TASK_ACTIVATION_STATUS_FAILURE
task_id=UNDEFINED_TASK.id,
status=TASK_ACTIVATION_STATUS_FAILURE,
fetch_next_task=FetchNextTask(namespace=None),
)

0 comments on commit e8f21d4

Please sign in to comment.