From 6e07161fb0e512dcd4703a4330e31f486c684e60 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Mon, 6 Jan 2025 10:24:52 +0530 Subject: [PATCH] fix(17/recurring-task-lock): Add timeout to auto unlock task --- task_processor/decorators.py | 8 ++ .../0012_add_locked_at_and_timeout.py | 39 ++++++ .../0012_get_recurringtasks_to_process.sql | 32 +++++ task_processor/models.py | 10 +- task_processor/processor.py | 20 ++- .../test_unit_task_processor_processor.py | 130 +++++++++++++++++- 6 files changed, 228 insertions(+), 11 deletions(-) create mode 100644 task_processor/migrations/0012_add_locked_at_and_timeout.py create mode 100644 task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql diff --git a/task_processor/decorators.py b/task_processor/decorators.py index f5c306b..27958c7 100644 --- a/task_processor/decorators.py +++ b/task_processor/decorators.py @@ -26,6 +26,7 @@ class TaskHandler(typing.Generic[P]): "priority", "transaction_on_commit", "task_identifier", + "timeout", ) unwrapped: typing.Callable[P, None] @@ -38,11 +39,13 @@ def __init__( queue_size: int | None = None, priority: TaskPriority = TaskPriority.NORMAL, transaction_on_commit: bool = True, + timeout: timedelta | None = None, ) -> None: self.unwrapped = f self.queue_size = queue_size self.priority = priority self.transaction_on_commit = transaction_on_commit + self.timeout = timeout task_name = task_name or f.__name__ task_module = getmodule(f).__name__.rsplit(".")[-1] @@ -87,6 +90,7 @@ def delay( scheduled_for=delay_until or timezone.now(), priority=self.priority, queue_size=self.queue_size, + timeout=self.timeout, args=args, kwargs=kwargs, ) @@ -124,6 +128,7 @@ def register_task_handler( # noqa: C901 queue_size: int | None = None, priority: TaskPriority = TaskPriority.NORMAL, transaction_on_commit: bool = True, + timeout: timedelta | None = timedelta(seconds=60), ) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]: """ Turn a function into an asynchronous task. @@ -150,6 +155,7 @@ def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]: queue_size=queue_size, priority=priority, transaction_on_commit=transaction_on_commit, + timeout=timeout, ) return wrapper @@ -161,6 +167,7 @@ def register_recurring_task( args: tuple[typing.Any] = (), kwargs: dict[str, typing.Any] | None = None, first_run_time: time | None = None, + timeout: timedelta | None = timedelta(minutes=30), ) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]: if not os.environ.get("RUN_BY_PROCESSOR"): # Do not register recurring tasks if not invoked by task processor @@ -182,6 +189,7 @@ def decorator(f: typing.Callable[..., None]) -> RecurringTask: "serialized_kwargs": RecurringTask.serialize_data(kwargs or {}), "run_every": run_every, "first_run_time": first_run_time, + "timeout": timeout, }, ) return task diff --git a/task_processor/migrations/0012_add_locked_at_and_timeout.py b/task_processor/migrations/0012_add_locked_at_and_timeout.py new file mode 100644 index 0000000..a3f65a3 --- /dev/null +++ b/task_processor/migrations/0012_add_locked_at_and_timeout.py @@ -0,0 +1,39 @@ +# Generated by Django 3.2.23 on 2025-01-06 04:51 + +from task_processor.migrations.helpers import PostgresOnlyRunSQL +import datetime +from django.db import migrations, models +import os + + +class Migration(migrations.Migration): + + dependencies = [ + ("task_processor", "0011_add_priority_to_get_tasks_to_process"), + ] + + operations = [ + migrations.AddField( + model_name="recurringtask", + name="locked_at", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name="recurringtask", + name="timeout", + field=models.DurationField(default=datetime.timedelta(minutes=30)), + ), + migrations.AddField( + model_name="task", + name="timeout", + field=models.DurationField(blank=True, null=True), + ), + PostgresOnlyRunSQL.from_sql_file( + os.path.join( + os.path.dirname(__file__), + "sql", + "0012_get_recurringtasks_to_process.sql", + ), + reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process", + ), + ] diff --git a/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql b/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql new file mode 100644 index 0000000..d8483f2 --- /dev/null +++ b/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql @@ -0,0 +1,32 @@ +CREATE OR REPLACE FUNCTION get_recurringtasks_to_process(num_tasks integer) +RETURNS SETOF task_processor_recurringtask AS $$ +DECLARE + row_to_return task_processor_recurringtask; +BEGIN + -- Select the tasks that needs to be processed + FOR row_to_return IN + SELECT * + FROM task_processor_recurringtask + WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout) + ORDER BY id + LIMIT num_tasks + -- Select for update to ensure that no other workers can select these tasks while in this transaction block + FOR UPDATE SKIP LOCKED + LOOP + -- Lock every selected task(by updating `is_locked` to true) + UPDATE task_processor_recurringtask + -- Lock this row by setting is_locked True, so that no other workers can select these tasks after this + -- transaction is complete (but the tasks are still being executed by the current worker) + SET is_locked = TRUE, locked_at = NOW() + WHERE id = row_to_return.id; + -- If we don't explicitly update the columns here, the client will receive a row + -- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`. + row_to_return.is_locked := TRUE; + row_to_return.locked_at := NOW(); + RETURN NEXT row_to_return; + END LOOP; + + RETURN; +END; +$$ LANGUAGE plpgsql + diff --git a/task_processor/models.py b/task_processor/models.py index 9093b22..c6e6248 100644 --- a/task_processor/models.py +++ b/task_processor/models.py @@ -1,6 +1,6 @@ import typing import uuid -from datetime import datetime +from datetime import datetime, timedelta import simplejson as json from django.core.serializers.json import DjangoJSONEncoder @@ -61,6 +61,7 @@ def mark_success(self): def unlock(self): self.is_locked = False + self.locked_at = None def run(self): return self.callable(*self.args, **self.kwargs) @@ -80,6 +81,8 @@ def callable(self) -> typing.Callable: class Task(AbstractBaseTask): scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now) + timeout = models.DurationField(null=True, blank=True) + # denormalise failures and completion so that we can use select_for_update num_failures = models.IntegerField(default=0) completed = models.BooleanField(default=False) @@ -109,6 +112,7 @@ def create( *, args: typing.Tuple[typing.Any] = None, kwargs: typing.Dict[str, typing.Any] = None, + timeout: timedelta | None = None, ) -> "Task": if queue_size and cls._is_queue_full(task_identifier, queue_size): raise TaskQueueFullError( @@ -121,6 +125,7 @@ def create( priority=priority, serialized_args=cls.serialize_data(args or tuple()), serialized_kwargs=cls.serialize_data(kwargs or dict()), + timeout=timeout, ) @classmethod @@ -146,6 +151,9 @@ def mark_success(self): class RecurringTask(AbstractBaseTask): run_every = models.DurationField() first_run_time = models.TimeField(blank=True, null=True) + locked_at = models.DateTimeField(blank=True, null=True) + + timeout = models.DurationField(default=timedelta(minutes=30)) objects = RecurringTaskManager() diff --git a/task_processor/processor.py b/task_processor/processor.py index 93d5436..e1281ae 100644 --- a/task_processor/processor.py +++ b/task_processor/processor.py @@ -1,6 +1,7 @@ import logging import traceback import typing +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from django.utils import timezone @@ -78,7 +79,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]: # update all tasks that were not deleted to_update = [task for task in tasks if task.id] - RecurringTask.objects.bulk_update(to_update, fields=["is_locked"]) + RecurringTask.objects.bulk_update(to_update, fields=["is_locked", "locked_at"]) if task_runs: RecurringTaskRun.objects.bulk_create(task_runs) @@ -93,16 +94,25 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas task_run = task.task_runs.model(started_at=timezone.now(), task=task) try: - task.run() - task_run.result = TaskResult.SUCCESS + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(task.run) + timeout = task.timeout.total_seconds() if task.timeout else None + future.result(timeout=timeout) # Wait for completion or timeout + task_run.result = TaskResult.SUCCESS task_run.finished_at = timezone.now() task.mark_success() + except Exception as e: + # For errors that don't include a default message (e.g., TimeoutError), + # fall back to using repr. + err_msg = str(e) or repr(e) + logger.error( - "Failed to execute task '%s'. Exception was: %s", + "Failed to execute task '%s', with id %d. Exception: %s", task.task_identifier, - str(e), + task.id, + err_msg, exc_info=True, ) logger.debug("args: %s", str(task.args)) diff --git a/tests/unit/task_processor/test_unit_task_processor_processor.py b/tests/unit/task_processor/test_unit_task_processor_processor.py index 04826a3..aab4ae6 100644 --- a/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -1,5 +1,6 @@ import logging import time +import typing import uuid from datetime import timedelta from threading import Thread @@ -8,6 +9,7 @@ from django.core.cache import cache from django.utils import timezone from freezegun import freeze_time +from pytest import MonkeyPatch from task_processor.decorators import ( register_recurring_task, @@ -28,6 +30,11 @@ ) from task_processor.task_registry import registered_tasks +if typing.TYPE_CHECKING: + # This import breaks private-package-test workflow in core + from tests.unit.task_processor.conftest import GetTaskProcessorCaplog + + DEFAULT_CACHE_KEY = "foo" DEFAULT_CACHE_VALUE = "bar" @@ -63,6 +70,83 @@ def test_run_task_runs_task_and_creates_task_run_object_when_success(db): assert task.completed +def test_run_task_kills_task_after_timeout( + db: None, + get_task_processor_caplog: "GetTaskProcessorCaplog", +) -> None: + # Given + caplog = get_task_processor_caplog(logging.ERROR) + task = Task.create( + _sleep.task_identifier, + scheduled_for=timezone.now(), + args=(1,), + timeout=timedelta(microseconds=1), + ) + task.save() + + # When + task_runs = run_tasks() + + # Then + assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1 + task_run = task_runs[0] + assert task_run.result == TaskResult.FAILURE + assert task_run.started_at + assert task_run.finished_at is None + assert "TimeoutError" in task_run.error_details + + task.refresh_from_db() + + assert task.completed is False + assert task.num_failures == 1 + assert task.is_locked is False + + assert len(caplog.records) == 1 + assert caplog.records[0].message == ( + f"Failed to execute task '{task.task_identifier}', with id {task.id}. Exception: TimeoutError()" + ) + + +def test_run_recurring_task_kills_task_after_timeout( + db: None, + monkeypatch: MonkeyPatch, + get_task_processor_caplog: "GetTaskProcessorCaplog", +) -> None: + # Given + caplog = get_task_processor_caplog(logging.ERROR) + monkeypatch.setenv("RUN_BY_PROCESSOR", "True") + + @register_recurring_task( + run_every=timedelta(seconds=1), timeout=timedelta(microseconds=1) + ) + def _dummy_recurring_task(): + time.sleep(1) + + task = RecurringTask.objects.get( + task_identifier=_dummy_recurring_task.task_identifier + ) + # When + task_runs = run_recurring_tasks() + + # Then + assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1 + task_run = task_runs[0] + assert task_run.result == TaskResult.FAILURE + assert task_run.started_at + assert task_run.finished_at is None + assert "TimeoutError" in task_run.error_details + + task.refresh_from_db() + + assert task.locked_at is None + assert task.is_locked is False + + assert len(caplog.records) == 1 + assert caplog.records[0].message == ( + f"Failed to execute task '{task.task_identifier}', with id {task.id}. Exception: TimeoutError()" + ) + + def test_run_recurring_tasks_runs_task_and_creates_recurring_task_run_object_when_success( db, monkeypatch, @@ -91,6 +175,43 @@ def _dummy_recurring_task(): assert task_run.error_details is None +def test_run_recurring_tasks_runs_locked_task_after_tiemout( + db: None, + monkeypatch: MonkeyPatch, +) -> None: + # Given + monkeypatch.setenv("RUN_BY_PROCESSOR", "True") + + @register_recurring_task(run_every=timedelta(hours=1)) + def _dummy_recurring_task(): + cache.set(DEFAULT_CACHE_KEY, DEFAULT_CACHE_VALUE) + + task = RecurringTask.objects.get( + task_identifier=_dummy_recurring_task.task_identifier + ) + task.is_locked = True + task.locked_at = timezone.now() - timedelta(hours=1) + task.save() + + # When + task_runs = run_recurring_tasks() + + # Then + assert cache.get(DEFAULT_CACHE_KEY) + + assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1 + task_run = task_runs[0] + assert task_run.result == TaskResult.SUCCESS + assert task_run.started_at + assert task_run.finished_at + assert task_run.error_details is None + + # And the task is no longer locked + task.refresh_from_db() + assert task.is_locked is False + assert task.locked_at is None + + @pytest.mark.django_db(transaction=True) def test_run_recurring_tasks_multiple_runs(db, run_by_processor): # Given @@ -211,12 +332,11 @@ def _a_task(): def test_run_task_runs_task_and_creates_task_run_object_when_failure( - db: None, caplog: pytest.LogCaptureFixture + db: None, + get_task_processor_caplog: "GetTaskProcessorCaplog", ) -> None: # Given - task_processor_logger = logging.getLogger("task_processor") - task_processor_logger.propagate = True - task_processor_logger.level = logging.DEBUG + caplog = get_task_processor_caplog(logging.DEBUG) msg = "Error!" task = Task.create( @@ -243,7 +363,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure( log_record = caplog.records[0] assert log_record.levelname == "ERROR" assert log_record.message == ( - f"Failed to execute task '{task.task_identifier}'. Exception was: {msg}" + f"Failed to execute task '{task.task_identifier}', with id {task.id}. Exception: {msg}" ) debug_log_args, debug_log_kwargs = caplog.records[1:]