Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed source from the TaskiqSchedule. #218

Merged
merged 4 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/examples/extending/schedule_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ async def get_schedules(self) -> List["ScheduledTask"]:
args=[],
kwargs={},
cron="* * * * *",
#
# We need point on self source for calling pre_send / post_send when
# task is ready to be enqueued.
source=self,
),
]

Expand Down
5 changes: 4 additions & 1 deletion taskiq/abc/schedule_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def shutdown(self) -> None: # noqa: B027
async def get_schedules(self) -> List["ScheduledTask"]:
"""Get list of taskiq schedules."""

async def add_schedule(self, schedule: "ScheduledTask") -> None: # noqa: B027
async def add_schedule(self, schedule: "ScheduledTask") -> None:
"""
Add a new schedule.

Expand All @@ -33,6 +33,9 @@ async def add_schedule(self, schedule: "ScheduledTask") -> None: # noqa: B027

:param schedule: schedule to add.
"""
raise NotImplementedError(
f"The source {self.__class__.__name__} does not support adding schedules.",
)

def pre_send( # noqa: B027
self,
Expand Down
47 changes: 27 additions & 20 deletions taskiq/cli/scheduler/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import sys
from datetime import datetime, timedelta
from logging import basicConfig, getLevelName, getLogger
from typing import List, Optional
from typing import Dict, List, Optional

import pytz
from pycron import is_now

from taskiq.abc.schedule_source import ScheduleSource
from taskiq.cli.scheduler.args import SchedulerArgs
from taskiq.cli.utils import import_object, import_tasks
from taskiq.scheduler.scheduler import ScheduledTask, TaskiqScheduler
Expand All @@ -32,7 +33,7 @@ def to_tz_aware(time: datetime) -> datetime:

async def schedules_updater(
scheduler: TaskiqScheduler,
current_schedules: List[ScheduledTask],
current_schedules: Dict[ScheduleSource, List[ScheduledTask]],
event: asyncio.Event,
) -> None:
"""
Expand All @@ -48,7 +49,7 @@ async def schedules_updater(
"""
while True:
logger.debug("Started schedule update.")
new_schedules: "List[ScheduledTask]" = []
new_schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}
for source in scheduler.sources:
try:
schedules = await source.get_schedules()
Expand All @@ -60,10 +61,13 @@ async def schedules_updater(
logger.debug(exc, exc_info=True)
continue

new_schedules = scheduler.merge_func(new_schedules, schedules)
new_schedules[source] = scheduler.merge_func(
new_schedules.get(source) or [],
schedules,
)

current_schedules.clear()
current_schedules.extend(new_schedules)
current_schedules.update(new_schedules)
event.set()
await asyncio.sleep(scheduler.refresh_delay)

Expand Down Expand Up @@ -100,6 +104,7 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]:

async def delayed_send(
scheduler: TaskiqScheduler,
source: ScheduleSource,
task: ScheduledTask,
delay: int,
) -> None:
Expand All @@ -115,13 +120,14 @@ async def delayed_send(
the delay and send the task after some delay.

:param scheduler: current scheduler.
:param source: source of the task.
:param task: task to send.
:param delay: how long to wait.
"""
if delay > 0:
await asyncio.sleep(delay)
logger.info("Sending task %s.", task.task_name)
await scheduler.on_ready(task)
await scheduler.on_ready(source, task)


async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
Expand All @@ -134,33 +140,34 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
:param scheduler: current scheduler.
"""
loop = asyncio.get_event_loop()
tasks: "List[ScheduledTask]" = []
schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}

current_task = asyncio.current_task()
first_update_event = asyncio.Event()
updater_task = loop.create_task(
schedules_updater(
scheduler,
tasks,
schedules,
first_update_event,
),
)
if current_task is not None:
current_task.add_done_callback(lambda _: updater_task.cancel())
await first_update_event.wait()
while True:
for task in tasks:
try:
task_delay = get_task_delay(task)
except ValueError:
logger.warning(
"Cannot parse cron: %s for task: %s",
task.cron,
task.task_name,
)
continue
if task_delay is not None:
loop.create_task(delayed_send(scheduler, task, task_delay))
for source, task_list in schedules.items():
for task in task_list:
try:
task_delay = get_task_delay(task)
except ValueError:
logger.warning(
"Cannot parse cron: %s for task: %s",
task.cron,
task.task_name,
)
continue
if task_delay is not None:
loop.create_task(delayed_send(scheduler, source, task, task_delay))

delay = (
datetime.now().replace(second=1, microsecond=0)
Expand Down
1 change: 0 additions & 1 deletion taskiq/schedule_sources/label_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ async def get_schedules(self) -> List["ScheduledTask"]:
cron=schedule.get("cron"),
time=schedule.get("time"),
cron_offset=schedule.get("cron_offset"),
source=self,
),
)
return schedules
Expand Down
17 changes: 16 additions & 1 deletion taskiq/scheduler/merge_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import TYPE_CHECKING, List

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -34,8 +35,22 @@ def only_unique(
:param new_tasks: newly discovered tasks.
:return: list of unique schedules.
"""
result = old_tasks
result = copy.copy(old_tasks)
for task in new_tasks:
if task not in result:
result.append(task)
return result


def only_new(
_old_tasks: List["ScheduledTask"],
new_tasks: List["ScheduledTask"],
) -> List["ScheduledTask"]:
"""
This function preserves only new schedules.

:param old_tasks: previously discovered tasks.
:param new_tasks: newly discovered schedules.
:return: list of new schedules.
"""
return new_tasks
11 changes: 5 additions & 6 deletions taskiq/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from taskiq.abc.broker import AsyncBroker
from taskiq.kicker import AsyncKicker
from taskiq.scheduler.merge_functions import preserve_all
from taskiq.scheduler.merge_functions import only_new
from taskiq.utils import maybe_awaitable

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -19,7 +19,6 @@ class ScheduledTask:
labels: Dict[str, Any]
args: List[Any]
kwargs: Dict[str, Any]
source: "ScheduleSource" # Backward point to source which this task belongs to
cron: Optional[str] = field(default=None)
cron_offset: Optional[Union[str, timedelta]] = field(default=None)
time: Optional[datetime] = field(default=None)
Expand All @@ -44,7 +43,7 @@ def __init__(
merge_func: Callable[
[List["ScheduledTask"], List["ScheduledTask"]],
List["ScheduledTask"],
] = preserve_all,
] = only_new,
refresh_delay: float = 30.0,
) -> None: # pragma: no cover
self.broker = broker
Expand All @@ -61,19 +60,19 @@ async def startup(self) -> None: # pragma: no cover
"""
await self.broker.startup()

async def on_ready(self, task: ScheduledTask) -> None:
async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None:
"""
This method is called when task is ready to be enqueued.

It's triggered on proper time depending on `task.cron` or `task.time` attribute.
:param task: task to send
"""
await maybe_awaitable(task.source.pre_send(task))
await maybe_awaitable(source.pre_send(task))
await AsyncKicker(task.task_name, self.broker, task.labels).kiq(
*task.args,
**task.kwargs,
)
await maybe_awaitable(task.source.post_send(task))
await maybe_awaitable(source.post_send(task))

async def shutdown(self) -> None:
"""Shutdown the scheduler process."""
Expand Down
11 changes: 0 additions & 11 deletions tests/cli/scheduler/test_task_delays.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from tzlocal import get_localzone

from taskiq.cli.scheduler.run import get_task_delay
from taskiq.schedule_sources.label_based import LabelScheduleSource
from taskiq.scheduler.scheduler import ScheduledTask

DUMMY_SOURCE = LabelScheduleSource(broker=None) # type: ignore


def test_should_run_success() -> None:
hour = datetime.datetime.utcnow().hour
Expand All @@ -19,7 +16,6 @@ def test_should_run_success() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
cron=f"* {hour} * * *",
),
)
Expand All @@ -35,7 +31,6 @@ def test_should_run_cron_str_offset() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
cron=f"* {hour} * * *",
cron_offset=str(zone),
),
Expand All @@ -52,7 +47,6 @@ def test_should_run_cron_td_offset() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
cron=f"* {hour} * * *",
cron_offset=datetime.timedelta(hours=offset),
),
Expand All @@ -68,7 +62,6 @@ def test_time_utc_without_zone() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time - datetime.timedelta(seconds=1),
),
)
Expand All @@ -83,7 +76,6 @@ def test_time_utc_with_zone() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time - datetime.timedelta(seconds=1),
),
)
Expand All @@ -99,7 +91,6 @@ def test_time_utc_with_local_zone() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time - datetime.timedelta(seconds=1),
),
)
Expand All @@ -114,7 +105,6 @@ def test_time_localtime_without_zone() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time - datetime.timedelta(seconds=1),
),
)
Expand All @@ -130,7 +120,6 @@ def test_time_delay() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time,
),
)
Expand Down
1 change: 0 additions & 1 deletion tests/schedule_sources/test_label_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def task() -> None:
labels={"schedule": schedule_label},
args=[],
kwargs={},
source=source,
),
]

Expand Down
4 changes: 1 addition & 3 deletions tests/scheduler/test_label_based_sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None:
def task() -> None:
pass

source = LabelScheduleSource(broker)
schedules = await source.get_schedules()
schedules = await LabelScheduleSource(broker).get_schedules()
assert schedules == [
ScheduledTask(
cron=schedule_label[0].get("cron"),
Expand All @@ -40,7 +39,6 @@ def task() -> None:
labels={"schedule": schedule_label},
args=[],
kwargs={},
source=source,
),
]

Expand Down
Loading