Skip to content

Commit

Permalink
Removed source from the TaskiqSchedule. (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Oct 15, 2023
1 parent 43af7a9 commit 02c2817
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 48 deletions.
4 changes: 0 additions & 4 deletions docs/examples/extending/schedule_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ async def get_schedules(self) -> List["ScheduledTask"]:
args=[],
kwargs={},
cron="* * * * *",
#
# We need point on self source for calling pre_send / post_send when
# task is ready to be enqueued.
source=self,
),
]

Expand Down
5 changes: 4 additions & 1 deletion taskiq/abc/schedule_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def shutdown(self) -> None: # noqa: B027
async def get_schedules(self) -> List["ScheduledTask"]:
"""Get list of taskiq schedules."""

async def add_schedule(self, schedule: "ScheduledTask") -> None: # noqa: B027
async def add_schedule(self, schedule: "ScheduledTask") -> None:
"""
Add a new schedule.
Expand All @@ -33,6 +33,9 @@ async def add_schedule(self, schedule: "ScheduledTask") -> None: # noqa: B027
:param schedule: schedule to add.
"""
raise NotImplementedError(
f"The source {self.__class__.__name__} does not support adding schedules.",
)

def pre_send( # noqa: B027
self,
Expand Down
47 changes: 27 additions & 20 deletions taskiq/cli/scheduler/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import sys
from datetime import datetime, timedelta
from logging import basicConfig, getLevelName, getLogger
from typing import List, Optional
from typing import Dict, List, Optional

import pytz
from pycron import is_now

from taskiq.abc.schedule_source import ScheduleSource
from taskiq.cli.scheduler.args import SchedulerArgs
from taskiq.cli.utils import import_object, import_tasks
from taskiq.scheduler.scheduler import ScheduledTask, TaskiqScheduler
Expand All @@ -32,7 +33,7 @@ def to_tz_aware(time: datetime) -> datetime:

async def schedules_updater(
scheduler: TaskiqScheduler,
current_schedules: List[ScheduledTask],
current_schedules: Dict[ScheduleSource, List[ScheduledTask]],
event: asyncio.Event,
) -> None:
"""
Expand All @@ -48,7 +49,7 @@ async def schedules_updater(
"""
while True:
logger.debug("Started schedule update.")
new_schedules: "List[ScheduledTask]" = []
new_schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}
for source in scheduler.sources:
try:
schedules = await source.get_schedules()
Expand All @@ -60,10 +61,13 @@ async def schedules_updater(
logger.debug(exc, exc_info=True)
continue

new_schedules = scheduler.merge_func(new_schedules, schedules)
new_schedules[source] = scheduler.merge_func(
new_schedules.get(source) or [],
schedules,
)

current_schedules.clear()
current_schedules.extend(new_schedules)
current_schedules.update(new_schedules)
event.set()
await asyncio.sleep(scheduler.refresh_delay)

Expand Down Expand Up @@ -100,6 +104,7 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]:

async def delayed_send(
scheduler: TaskiqScheduler,
source: ScheduleSource,
task: ScheduledTask,
delay: int,
) -> None:
Expand All @@ -115,13 +120,14 @@ async def delayed_send(
the delay and send the task after some delay.
:param scheduler: current scheduler.
:param source: source of the task.
:param task: task to send.
:param delay: how long to wait.
"""
if delay > 0:
await asyncio.sleep(delay)
logger.info("Sending task %s.", task.task_name)
await scheduler.on_ready(task)
await scheduler.on_ready(source, task)


async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
Expand All @@ -134,33 +140,34 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
:param scheduler: current scheduler.
"""
loop = asyncio.get_event_loop()
tasks: "List[ScheduledTask]" = []
schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}

current_task = asyncio.current_task()
first_update_event = asyncio.Event()
updater_task = loop.create_task(
schedules_updater(
scheduler,
tasks,
schedules,
first_update_event,
),
)
if current_task is not None:
current_task.add_done_callback(lambda _: updater_task.cancel())
await first_update_event.wait()
while True:
for task in tasks:
try:
task_delay = get_task_delay(task)
except ValueError:
logger.warning(
"Cannot parse cron: %s for task: %s",
task.cron,
task.task_name,
)
continue
if task_delay is not None:
loop.create_task(delayed_send(scheduler, task, task_delay))
for source, task_list in schedules.items():
for task in task_list:
try:
task_delay = get_task_delay(task)
except ValueError:
logger.warning(
"Cannot parse cron: %s for task: %s",
task.cron,
task.task_name,
)
continue
if task_delay is not None:
loop.create_task(delayed_send(scheduler, source, task, task_delay))

delay = (
datetime.now().replace(second=1, microsecond=0)
Expand Down
1 change: 0 additions & 1 deletion taskiq/schedule_sources/label_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ async def get_schedules(self) -> List["ScheduledTask"]:
cron=schedule.get("cron"),
time=schedule.get("time"),
cron_offset=schedule.get("cron_offset"),
source=self,
),
)
return schedules
Expand Down
17 changes: 16 additions & 1 deletion taskiq/scheduler/merge_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import TYPE_CHECKING, List

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -34,8 +35,22 @@ def only_unique(
:param new_tasks: newly discovered tasks.
:return: list of unique schedules.
"""
result = old_tasks
result = copy.copy(old_tasks)
for task in new_tasks:
if task not in result:
result.append(task)
return result


def only_new(
_old_tasks: List["ScheduledTask"],
new_tasks: List["ScheduledTask"],
) -> List["ScheduledTask"]:
"""
This function preserves only new schedules.
:param old_tasks: previously discovered tasks.
:param new_tasks: newly discovered schedules.
:return: list of new schedules.
"""
return new_tasks
11 changes: 5 additions & 6 deletions taskiq/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from taskiq.abc.broker import AsyncBroker
from taskiq.kicker import AsyncKicker
from taskiq.scheduler.merge_functions import preserve_all
from taskiq.scheduler.merge_functions import only_new
from taskiq.utils import maybe_awaitable

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -19,7 +19,6 @@ class ScheduledTask:
labels: Dict[str, Any]
args: List[Any]
kwargs: Dict[str, Any]
source: "ScheduleSource" # Backward point to source which this task belongs to
cron: Optional[str] = field(default=None)
cron_offset: Optional[Union[str, timedelta]] = field(default=None)
time: Optional[datetime] = field(default=None)
Expand All @@ -44,7 +43,7 @@ def __init__(
merge_func: Callable[
[List["ScheduledTask"], List["ScheduledTask"]],
List["ScheduledTask"],
] = preserve_all,
] = only_new,
refresh_delay: float = 30.0,
) -> None: # pragma: no cover
self.broker = broker
Expand All @@ -61,19 +60,19 @@ async def startup(self) -> None: # pragma: no cover
"""
await self.broker.startup()

async def on_ready(self, task: ScheduledTask) -> None:
async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None:
"""
This method is called when task is ready to be enqueued.
It's triggered on proper time depending on `task.cron` or `task.time` attribute.
:param task: task to send
"""
await maybe_awaitable(task.source.pre_send(task))
await maybe_awaitable(source.pre_send(task))
await AsyncKicker(task.task_name, self.broker, task.labels).kiq(
*task.args,
**task.kwargs,
)
await maybe_awaitable(task.source.post_send(task))
await maybe_awaitable(source.post_send(task))

async def shutdown(self) -> None:
"""Shutdown the scheduler process."""
Expand Down
11 changes: 0 additions & 11 deletions tests/cli/scheduler/test_task_delays.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from tzlocal import get_localzone

from taskiq.cli.scheduler.run import get_task_delay
from taskiq.schedule_sources.label_based import LabelScheduleSource
from taskiq.scheduler.scheduler import ScheduledTask

DUMMY_SOURCE = LabelScheduleSource(broker=None) # type: ignore


def test_should_run_success() -> None:
hour = datetime.datetime.utcnow().hour
Expand All @@ -19,7 +16,6 @@ def test_should_run_success() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
cron=f"* {hour} * * *",
),
)
Expand All @@ -35,7 +31,6 @@ def test_should_run_cron_str_offset() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
cron=f"* {hour} * * *",
cron_offset=str(zone),
),
Expand All @@ -52,7 +47,6 @@ def test_should_run_cron_td_offset() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
cron=f"* {hour} * * *",
cron_offset=datetime.timedelta(hours=offset),
),
Expand All @@ -68,7 +62,6 @@ def test_time_utc_without_zone() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time - datetime.timedelta(seconds=1),
),
)
Expand All @@ -83,7 +76,6 @@ def test_time_utc_with_zone() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time - datetime.timedelta(seconds=1),
),
)
Expand All @@ -99,7 +91,6 @@ def test_time_utc_with_local_zone() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time - datetime.timedelta(seconds=1),
),
)
Expand All @@ -114,7 +105,6 @@ def test_time_localtime_without_zone() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time - datetime.timedelta(seconds=1),
),
)
Expand All @@ -130,7 +120,6 @@ def test_time_delay() -> None:
labels={},
args=[],
kwargs={},
source=DUMMY_SOURCE,
time=time,
),
)
Expand Down
1 change: 0 additions & 1 deletion tests/schedule_sources/test_label_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def task() -> None:
labels={"schedule": schedule_label},
args=[],
kwargs={},
source=source,
),
]

Expand Down
4 changes: 1 addition & 3 deletions tests/scheduler/test_label_based_sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None:
def task() -> None:
pass

source = LabelScheduleSource(broker)
schedules = await source.get_schedules()
schedules = await LabelScheduleSource(broker).get_schedules()
assert schedules == [
ScheduledTask(
cron=schedule_label[0].get("cron"),
Expand All @@ -40,7 +39,6 @@ def task() -> None:
labels={"schedule": schedule_label},
args=[],
kwargs={},
source=source,
),
]

Expand Down

0 comments on commit 02c2817

Please sign in to comment.