Skip to content

Commit

Permalink
fix(17/recurring-task-lock): Add timeout to auto unlock task
Browse files Browse the repository at this point in the history
  • Loading branch information
gagantrivedi committed Jan 6, 2025
1 parent f92adfd commit 6e07161
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 11 deletions.
8 changes: 8 additions & 0 deletions task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TaskHandler(typing.Generic[P]):
"priority",
"transaction_on_commit",
"task_identifier",
"timeout",
)

unwrapped: typing.Callable[P, None]
Expand All @@ -38,11 +39,13 @@ def __init__(
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = None,
) -> None:
self.unwrapped = f
self.queue_size = queue_size
self.priority = priority
self.transaction_on_commit = transaction_on_commit
self.timeout = timeout

task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
Expand Down Expand Up @@ -87,6 +90,7 @@ def delay(
scheduled_for=delay_until or timezone.now(),
priority=self.priority,
queue_size=self.queue_size,
timeout=self.timeout,
args=args,
kwargs=kwargs,
)
Expand Down Expand Up @@ -124,6 +128,7 @@ def register_task_handler( # noqa: C901
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = timedelta(seconds=60),
) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]:
"""
Turn a function into an asynchronous task.
Expand All @@ -150,6 +155,7 @@ def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]:
queue_size=queue_size,
priority=priority,
transaction_on_commit=transaction_on_commit,
timeout=timeout,
)

return wrapper
Expand All @@ -161,6 +167,7 @@ def register_recurring_task(
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
first_run_time: time | None = None,
timeout: timedelta | None = timedelta(minutes=30),
) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
if not os.environ.get("RUN_BY_PROCESSOR"):
# Do not register recurring tasks if not invoked by task processor
Expand All @@ -182,6 +189,7 @@ def decorator(f: typing.Callable[..., None]) -> RecurringTask:
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
"timeout": timeout,
},
)
return task
Expand Down
39 changes: 39 additions & 0 deletions task_processor/migrations/0012_add_locked_at_and_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Generated by Django 3.2.23 on 2025-01-06 04:51

from task_processor.migrations.helpers import PostgresOnlyRunSQL
import datetime
from django.db import migrations, models
import os


class Migration(migrations.Migration):

dependencies = [
("task_processor", "0011_add_priority_to_get_tasks_to_process"),
]

operations = [
migrations.AddField(
model_name="recurringtask",
name="locked_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="recurringtask",
name="timeout",
field=models.DurationField(default=datetime.timedelta(minutes=30)),
),
migrations.AddField(
model_name="task",
name="timeout",
field=models.DurationField(blank=True, null=True),
),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0012_get_recurringtasks_to_process.sql",
),
reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process",
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
CREATE OR REPLACE FUNCTION get_recurringtasks_to_process(num_tasks integer)
RETURNS SETOF task_processor_recurringtask AS $$
DECLARE
row_to_return task_processor_recurringtask;
BEGIN
-- Select the tasks that needs to be processed
FOR row_to_return IN
SELECT *
FROM task_processor_recurringtask
WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout)
ORDER BY id
LIMIT num_tasks
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
FOR UPDATE SKIP LOCKED
LOOP
-- Lock every selected task(by updating `is_locked` to true)
UPDATE task_processor_recurringtask
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
-- transaction is complete (but the tasks are still being executed by the current worker)
SET is_locked = TRUE, locked_at = NOW()
WHERE id = row_to_return.id;
-- If we don't explicitly update the columns here, the client will receive a row
-- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`.
row_to_return.is_locked := TRUE;
row_to_return.locked_at := NOW();
RETURN NEXT row_to_return;
END LOOP;

RETURN;
END;
$$ LANGUAGE plpgsql

10 changes: 9 additions & 1 deletion task_processor/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
import uuid
from datetime import datetime
from datetime import datetime, timedelta

import simplejson as json
from django.core.serializers.json import DjangoJSONEncoder
Expand Down Expand Up @@ -61,6 +61,7 @@ def mark_success(self):

def unlock(self):
self.is_locked = False
self.locked_at = None

def run(self):
return self.callable(*self.args, **self.kwargs)
Expand All @@ -80,6 +81,8 @@ def callable(self) -> typing.Callable:
class Task(AbstractBaseTask):
scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now)

timeout = models.DurationField(null=True, blank=True)

# denormalise failures and completion so that we can use select_for_update
num_failures = models.IntegerField(default=0)
completed = models.BooleanField(default=False)
Expand Down Expand Up @@ -109,6 +112,7 @@ def create(
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
timeout: timedelta | None = None,
) -> "Task":
if queue_size and cls._is_queue_full(task_identifier, queue_size):
raise TaskQueueFullError(
Expand All @@ -121,6 +125,7 @@ def create(
priority=priority,
serialized_args=cls.serialize_data(args or tuple()),
serialized_kwargs=cls.serialize_data(kwargs or dict()),
timeout=timeout,
)

@classmethod
Expand All @@ -146,6 +151,9 @@ def mark_success(self):
class RecurringTask(AbstractBaseTask):
run_every = models.DurationField()
first_run_time = models.TimeField(blank=True, null=True)
locked_at = models.DateTimeField(blank=True, null=True)

timeout = models.DurationField(default=timedelta(minutes=30))

objects = RecurringTaskManager()

Expand Down
20 changes: 15 additions & 5 deletions task_processor/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback
import typing
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta

from django.utils import timezone
Expand Down Expand Up @@ -78,7 +79,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:

# update all tasks that were not deleted
to_update = [task for task in tasks if task.id]
RecurringTask.objects.bulk_update(to_update, fields=["is_locked"])
RecurringTask.objects.bulk_update(to_update, fields=["is_locked", "locked_at"])

if task_runs:
RecurringTaskRun.objects.bulk_create(task_runs)
Expand All @@ -93,16 +94,25 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas
task_run = task.task_runs.model(started_at=timezone.now(), task=task)

try:
task.run()
task_run.result = TaskResult.SUCCESS
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(task.run)
timeout = task.timeout.total_seconds() if task.timeout else None
future.result(timeout=timeout) # Wait for completion or timeout

task_run.result = TaskResult.SUCCESS
task_run.finished_at = timezone.now()
task.mark_success()

except Exception as e:
# For errors that don't include a default message (e.g., TimeoutError),
# fall back to using repr.
err_msg = str(e) or repr(e)

logger.error(
"Failed to execute task '%s'. Exception was: %s",
"Failed to execute task '%s', with id %d. Exception: %s",
task.task_identifier,
str(e),
task.id,
err_msg,
exc_info=True,
)
logger.debug("args: %s", str(task.args))
Expand Down
Loading

0 comments on commit 6e07161

Please sign in to comment.