Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(locked_tasks): Add timeout to auto unlock tasks #16

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TaskHandler(typing.Generic[P]):
"priority",
"transaction_on_commit",
"task_identifier",
"timeout",
)

unwrapped: typing.Callable[P, None]
Expand All @@ -38,11 +39,13 @@ def __init__(
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = None,
) -> None:
self.unwrapped = f
self.queue_size = queue_size
self.priority = priority
self.transaction_on_commit = transaction_on_commit
self.timeout = timeout

task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
Expand Down Expand Up @@ -87,6 +90,7 @@ def delay(
scheduled_for=delay_until or timezone.now(),
priority=self.priority,
queue_size=self.queue_size,
timeout=self.timeout,
args=args,
kwargs=kwargs,
)
Expand Down Expand Up @@ -124,6 +128,7 @@ def register_task_handler( # noqa: C901
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = timedelta(seconds=60),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The longest run duration for a task in production is 00:00:30.459713

) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]:
"""
Turn a function into an asynchronous task.
Expand All @@ -150,6 +155,7 @@ def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]:
queue_size=queue_size,
priority=priority,
transaction_on_commit=transaction_on_commit,
timeout=timeout,
)

return wrapper
Expand All @@ -161,6 +167,7 @@ def register_recurring_task(
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
first_run_time: time | None = None,
timeout: timedelta | None = timedelta(minutes=30),
Copy link
Member Author

@gagantrivedi gagantrivedi Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The longest run duration for a recurring task in production is 00:11:49.801789

) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
if not os.environ.get("RUN_BY_PROCESSOR"):
# Do not register recurring tasks if not invoked by task processor
Expand All @@ -182,6 +189,7 @@ def decorator(f: typing.Callable[..., None]) -> RecurringTask:
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
"timeout": timeout,
},
)
return task
Expand Down
60 changes: 60 additions & 0 deletions task_processor/migrations/0012_add_locked_at_and_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Generated by Django 3.2.23 on 2025-01-06 04:51

from task_processor.migrations.helpers import PostgresOnlyRunSQL
import datetime
from django.db import migrations, models
import os


class Migration(migrations.Migration):

dependencies = [
("task_processor", "0011_add_priority_to_get_tasks_to_process"),
]

operations = [
migrations.AddField(
model_name="recurringtask",
name="locked_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="task",
name="locked_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="recurringtask",
name="timeout",
field=models.DurationField(default=datetime.timedelta(minutes=30)),
),
migrations.AddField(
model_name="task",
name="timeout",
field=models.DurationField(default=datetime.timedelta(minutes=1)),
),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0012_get_recurringtasks_to_process.sql",
),
reverse_sql=os.path.join(
os.path.dirname(__file__),
"sql",
"0008_get_recurringtasks_to_process.sql",
),
),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0012_get_tasks_to_process.sql",
),
reverse_sql=os.path.join(
os.path.dirname(__file__),
"sql",
"0011_get_tasks_to_process.sql",
),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
CREATE OR REPLACE FUNCTION get_recurringtasks_to_process(num_tasks integer)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only lines 10, 20, and 25 were updated here.

RETURNS SETOF task_processor_recurringtask AS $$
DECLARE
row_to_return task_processor_recurringtask;
BEGIN
-- Select the tasks that needs to be processed
FOR row_to_return IN
SELECT *
FROM task_processor_recurringtask
WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout)
ORDER BY id
LIMIT num_tasks
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
FOR UPDATE SKIP LOCKED
LOOP
-- Lock every selected task(by updating `is_locked` to true)
UPDATE task_processor_recurringtask
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
-- transaction is complete (but the tasks are still being executed by the current worker)
SET is_locked = TRUE, locked_at = NOW()
WHERE id = row_to_return.id;
-- If we don't explicitly update the columns here, the client will receive a row
-- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`.
row_to_return.is_locked := TRUE;
row_to_return.locked_at := NOW();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking, but we should set NOW() to some variable for exact accuracy here, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, NOW() returns the same value on every call in a transaction

RETURN NEXT row_to_return;
END LOOP;

RETURN;
END;
$$ LANGUAGE plpgsql

31 changes: 31 additions & 0 deletions task_processor/migrations/sql/0012_get_tasks_to_process.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
CREATE OR REPLACE FUNCTION get_tasks_to_process(num_tasks integer)
RETURNS SETOF task_processor_task AS $$
DECLARE
row_to_return task_processor_task;
BEGIN
-- Select the tasks that needs to be processed
FOR row_to_return IN
SELECT *
FROM task_processor_task
WHERE num_failures < 3 AND scheduled_for < NOW() AND completed = FALSE AND (is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout))
ORDER BY priority ASC, scheduled_for ASC, created_at ASC
LIMIT num_tasks
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
FOR UPDATE SKIP LOCKED
LOOP
-- Lock every selected task(by updating `is_locked` to true)
UPDATE task_processor_task
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
-- transaction is complete (but the tasks are still being executed by the current worker)
SET is_locked = TRUE, locked_at = NOW()
WHERE id = row_to_return.id;
-- If we don't explicitly update the columns here, the client will receive a row
-- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`.
row_to_return.is_locked := TRUE;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also do row_to_return.locked_at = NOW() (or maybe NOW() should be stored in a variable for exact accuracy)?

RETURN NEXT row_to_return;
END LOOP;

RETURN;
END;
$$ LANGUAGE plpgsql

11 changes: 10 additions & 1 deletion task_processor/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
import uuid
from datetime import datetime
from datetime import datetime, timedelta

import simplejson as json
from django.core.serializers.json import DjangoJSONEncoder
Expand Down Expand Up @@ -30,6 +30,8 @@ class AbstractBaseTask(models.Model):
serialized_kwargs = models.TextField(blank=True, null=True)
is_locked = models.BooleanField(default=False)

locked_at = models.DateTimeField(blank=True, null=True)

class Meta:
abstract = True

Expand Down Expand Up @@ -61,6 +63,7 @@ def mark_success(self):

def unlock(self):
self.is_locked = False
self.locked_at = None

def run(self):
return self.callable(*self.args, **self.kwargs)
Expand All @@ -80,6 +83,8 @@ def callable(self) -> typing.Callable:
class Task(AbstractBaseTask):
scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now)

timeout = models.DurationField(default=timedelta(minutes=1))

# denormalise failures and completion so that we can use select_for_update
num_failures = models.IntegerField(default=0)
completed = models.BooleanField(default=False)
Expand Down Expand Up @@ -109,6 +114,7 @@ def create(
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
timeout: timedelta | None = timedelta(seconds=60),
) -> "Task":
if queue_size and cls._is_queue_full(task_identifier, queue_size):
raise TaskQueueFullError(
Expand All @@ -121,6 +127,7 @@ def create(
priority=priority,
serialized_args=cls.serialize_data(args or tuple()),
serialized_kwargs=cls.serialize_data(kwargs or dict()),
timeout=timeout,
)

@classmethod
Expand All @@ -147,6 +154,8 @@ class RecurringTask(AbstractBaseTask):
run_every = models.DurationField()
first_run_time = models.TimeField(blank=True, null=True)

timeout = models.DurationField(default=timedelta(minutes=30))

objects = RecurringTaskManager()

class Meta:
Expand Down
23 changes: 17 additions & 6 deletions task_processor/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback
import typing
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta

from django.utils import timezone
Expand Down Expand Up @@ -36,7 +37,8 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]:

if executed_tasks:
Task.objects.bulk_update(
executed_tasks, fields=["completed", "num_failures", "is_locked"]
executed_tasks,
fields=["completed", "num_failures", "is_locked", "locked_at"],
)

if task_runs:
Expand Down Expand Up @@ -78,7 +80,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:

# update all tasks that were not deleted
to_update = [task for task in tasks if task.id]
RecurringTask.objects.bulk_update(to_update, fields=["is_locked"])
RecurringTask.objects.bulk_update(to_update, fields=["is_locked", "locked_at"])

if task_runs:
RecurringTaskRun.objects.bulk_create(task_runs)
Expand All @@ -93,16 +95,25 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas
task_run = task.task_runs.model(started_at=timezone.now(), task=task)

try:
task.run()
task_run.result = TaskResult.SUCCESS
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(task.run)
timeout = task.timeout.total_seconds() if task.timeout else None
future.result(timeout=timeout) # Wait for completion or timeout

task_run.result = TaskResult.SUCCESS
task_run.finished_at = timezone.now()
task.mark_success()

except Exception as e:
# For errors that don't include a default message (e.g., TimeoutError),
matthewelwell marked this conversation as resolved.
Show resolved Hide resolved
# fall back to using repr.
err_msg = str(e) or repr(e)

logger.error(
"Failed to execute task '%s'. Exception was: %s",
"Failed to execute task '%s', with id %d. Exception: %s",
task.task_identifier,
str(e),
task.id,
err_msg,
exc_info=True,
)
logger.debug("args: %s", str(task.args))
Expand Down
Loading
Loading