Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(locked_tasks): Add timeout to auto unlock tasks #16

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TaskHandler(typing.Generic[P]):
"priority",
"transaction_on_commit",
"task_identifier",
"timeout",
)

unwrapped: typing.Callable[P, None]
Expand All @@ -38,11 +39,13 @@ def __init__(
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = None,
) -> None:
self.unwrapped = f
self.queue_size = queue_size
self.priority = priority
self.transaction_on_commit = transaction_on_commit
self.timeout = timeout

task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
Expand Down Expand Up @@ -87,6 +90,7 @@ def delay(
scheduled_for=delay_until or timezone.now(),
priority=self.priority,
queue_size=self.queue_size,
timeout=self.timeout,
args=args,
kwargs=kwargs,
)
Expand Down Expand Up @@ -124,6 +128,7 @@ def register_task_handler( # noqa: C901
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = timedelta(seconds=60),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The longest run duration for a task in production is 00:00:30.459713

) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]:
"""
Turn a function into an asynchronous task.
Expand All @@ -150,6 +155,7 @@ def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]:
queue_size=queue_size,
priority=priority,
transaction_on_commit=transaction_on_commit,
timeout=timeout,
)

return wrapper
Expand All @@ -161,6 +167,7 @@ def register_recurring_task(
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
first_run_time: time | None = None,
timeout: timedelta | None = timedelta(minutes=30),
Copy link
Member Author

@gagantrivedi gagantrivedi Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The longest run duration for a recurring task in production is 00:11:49.801789

) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
if not os.environ.get("RUN_BY_PROCESSOR"):
# Do not register recurring tasks if not invoked by task processor
Expand All @@ -182,6 +189,7 @@ def decorator(f: typing.Callable[..., None]) -> RecurringTask:
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
"timeout": timeout,
},
)
return task
Expand Down
39 changes: 39 additions & 0 deletions task_processor/migrations/0012_add_locked_at_and_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Generated by Django 3.2.23 on 2025-01-06 04:51

from task_processor.migrations.helpers import PostgresOnlyRunSQL
import datetime
from django.db import migrations, models
import os


class Migration(migrations.Migration):

dependencies = [
("task_processor", "0011_add_priority_to_get_tasks_to_process"),
]

operations = [
migrations.AddField(
model_name="recurringtask",
name="locked_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="recurringtask",
name="timeout",
field=models.DurationField(default=datetime.timedelta(minutes=30)),
),
migrations.AddField(
model_name="task",
name="timeout",
field=models.DurationField(blank=True, null=True),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default not set here to avoid locking the table.

),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0012_get_recurringtasks_to_process.sql",
),
reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process",
matthewelwell marked this conversation as resolved.
Show resolved Hide resolved
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
CREATE OR REPLACE FUNCTION get_recurringtasks_to_process(num_tasks integer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only lines 10, 20, and 25 were updated here.

RETURNS SETOF task_processor_recurringtask AS $$
DECLARE
row_to_return task_processor_recurringtask;
BEGIN
-- Select the tasks that needs to be processed
FOR row_to_return IN
SELECT *
FROM task_processor_recurringtask
WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout)
matthewelwell marked this conversation as resolved.
Show resolved Hide resolved
ORDER BY id
LIMIT num_tasks
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
FOR UPDATE SKIP LOCKED
LOOP
-- Lock every selected task(by updating `is_locked` to true)
UPDATE task_processor_recurringtask
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
-- transaction is complete (but the tasks are still being executed by the current worker)
SET is_locked = TRUE, locked_at = NOW()
WHERE id = row_to_return.id;
-- If we don't explicitly update the columns here, the client will receive a row
-- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`.
row_to_return.is_locked := TRUE;
row_to_return.locked_at := NOW();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking, but we should set NOW() to some variable for exact accuracy here, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, NOW() returns the same value on every call in a transaction

RETURN NEXT row_to_return;
END LOOP;

RETURN;
END;
$$ LANGUAGE plpgsql

10 changes: 9 additions & 1 deletion task_processor/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
import uuid
from datetime import datetime
from datetime import datetime, timedelta

import simplejson as json
from django.core.serializers.json import DjangoJSONEncoder
Expand Down Expand Up @@ -61,6 +61,7 @@ def mark_success(self):

def unlock(self):
self.is_locked = False
self.locked_at = None

def run(self):
return self.callable(*self.args, **self.kwargs)
Expand All @@ -80,6 +81,8 @@ def callable(self) -> typing.Callable:
class Task(AbstractBaseTask):
scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now)

timeout = models.DurationField(null=True, blank=True)

# denormalise failures and completion so that we can use select_for_update
num_failures = models.IntegerField(default=0)
completed = models.BooleanField(default=False)
Expand Down Expand Up @@ -109,6 +112,7 @@ def create(
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
timeout: timedelta | None = None,
) -> "Task":
if queue_size and cls._is_queue_full(task_identifier, queue_size):
raise TaskQueueFullError(
Expand All @@ -121,6 +125,7 @@ def create(
priority=priority,
serialized_args=cls.serialize_data(args or tuple()),
serialized_kwargs=cls.serialize_data(kwargs or dict()),
timeout=timeout,
)

@classmethod
Expand All @@ -146,6 +151,9 @@ def mark_success(self):
class RecurringTask(AbstractBaseTask):
run_every = models.DurationField()
first_run_time = models.TimeField(blank=True, null=True)
locked_at = models.DateTimeField(blank=True, null=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need this on the regular Task model because we can use scheduled_for ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially planned to add the unlock feature only for recurring tasks, but on second thought, I have now added it to regular tasks as well.


timeout = models.DurationField(default=timedelta(minutes=30))

objects = RecurringTaskManager()

Expand Down
20 changes: 15 additions & 5 deletions task_processor/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback
import typing
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta

from django.utils import timezone
Expand Down Expand Up @@ -78,7 +79,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:

# update all tasks that were not deleted
to_update = [task for task in tasks if task.id]
RecurringTask.objects.bulk_update(to_update, fields=["is_locked"])
RecurringTask.objects.bulk_update(to_update, fields=["is_locked", "locked_at"])

if task_runs:
RecurringTaskRun.objects.bulk_create(task_runs)
Expand All @@ -93,16 +94,25 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas
task_run = task.task_runs.model(started_at=timezone.now(), task=task)

try:
task.run()
task_run.result = TaskResult.SUCCESS
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(task.run)
timeout = task.timeout.total_seconds() if task.timeout else None
future.result(timeout=timeout) # Wait for completion or timeout

task_run.result = TaskResult.SUCCESS
task_run.finished_at = timezone.now()
task.mark_success()

except Exception as e:
# For errors that don't include a default message (e.g., TimeoutError),
matthewelwell marked this conversation as resolved.
Show resolved Hide resolved
# fall back to using repr.
err_msg = str(e) or repr(e)

logger.error(
"Failed to execute task '%s'. Exception was: %s",
"Failed to execute task '%s', with id %d. Exception: %s",
task.task_identifier,
str(e),
task.id,
err_msg,
exc_info=True,
)
logger.debug("args: %s", str(task.args))
Expand Down
Loading
Loading