Skip to content

Commit

Permalink
Simplified scheduler. (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Oct 17, 2023
1 parent 76d4846 commit b223097
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 62 deletions.
93 changes: 40 additions & 53 deletions taskiq/cli/scheduler/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,45 +32,45 @@ def to_tz_aware(time: datetime) -> datetime:
return time


async def schedules_updater(
async def get_schedules(source: ScheduleSource) -> List[ScheduledTask]:
"""
Get schedules from source.
If source raises an exception, it will be
logged and an empty list will be returned.
:param source: source to get schedules from.
"""
try:
return await source.get_schedules()
except Exception as exc:
logger.warning(
"Cannot update schedules with source: %s",
source,
)
logger.debug(exc, exc_info=True)
return []


async def get_all_schedules(
scheduler: TaskiqScheduler,
current_schedules: Dict[ScheduleSource, List[ScheduledTask]],
event: asyncio.Event,
) -> None:
) -> Dict[ScheduleSource, List[ScheduledTask]]:
"""
Periodic update to schedules.
Task to update all schedules.
This task periodically checks for new schedules,
assembles the final list and replaces current
schedule with a new one.
This function updates all schedules
from all sources and returns a dict
with source as a key and list of
scheduled tasks as a value.
:param scheduler: current scheduler.
:param current_schedules: list of schedules.
:param event: event when schedules are updated.
:return: dict with source as a key and list of scheduled tasks as a value.
"""
while True:
logger.debug("Started schedule update.")
new_schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}
for source in scheduler.sources:
try:
schedules = await source.get_schedules()
except Exception as exc:
logger.warning(
"Cannot update schedules with source: %s",
source,
)
logger.debug(exc, exc_info=True)
continue

new_schedules[source] = scheduler.merge_func(
new_schedules.get(source) or [],
schedules,
)

current_schedules.clear()
current_schedules.update(new_schedules)
event.set()
await asyncio.sleep(scheduler.refresh_delay)
logger.debug("Started schedule update.")
schedules = await asyncio.gather(
*[get_schedules(source) for source in scheduler.sources],
)
return dict(zip(scheduler.sources, schedules))


def get_task_delay(task: ScheduledTask) -> Optional[int]:
Expand Down Expand Up @@ -141,23 +141,14 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
:param scheduler: current scheduler.
"""
loop = asyncio.get_event_loop()
schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}

current_task = asyncio.current_task()
first_update_event = asyncio.Event()
updater_task = loop.create_task(
schedules_updater(
scheduler,
schedules,
first_update_event,
),
)
if current_task is not None:
current_task.add_done_callback(lambda _: updater_task.cancel())
await first_update_event.wait()
running_schedules = set()
while True:
for source, task_list in schedules.items():
# We use this method to correctly sleep for one minute.
next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta(
minutes=1,
)
scheduled_tasks = await get_all_schedules(scheduler)
for source, task_list in scheduled_tasks.items():
for task in task_list:
try:
task_delay = get_task_delay(task)
Expand All @@ -175,11 +166,7 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
running_schedules.add(send_task)
send_task.add_done_callback(running_schedules.discard)

delay = (
datetime.now().replace(second=1, microsecond=0)
+ timedelta(minutes=1)
- datetime.now()
)
delay = next_minute - datetime.now()
await asyncio.sleep(delay.total_seconds())


Expand Down
10 changes: 1 addition & 9 deletions taskiq/scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Callable, List
from typing import TYPE_CHECKING, List

from taskiq.kicker import AsyncKicker
from taskiq.scheduler.merge_functions import only_new
from taskiq.scheduler.scheduled_task import ScheduledTask
from taskiq.utils import maybe_awaitable

Expand All @@ -17,16 +16,9 @@ def __init__(
self,
broker: "AsyncBroker",
sources: List["ScheduleSource"],
merge_func: Callable[
[List["ScheduledTask"], List["ScheduledTask"]],
List["ScheduledTask"],
] = only_new,
refresh_delay: float = 30.0,
) -> None: # pragma: no cover
self.broker = broker
self.sources = sources
self.refresh_delay = refresh_delay
self.merge_func = merge_func

async def startup(self) -> None: # pragma: no cover
"""
Expand Down
87 changes: 87 additions & 0 deletions tests/cli/scheduler/test_updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from datetime import datetime
from typing import List, Union

import pytest

from taskiq import InMemoryBroker, ScheduleSource
from taskiq.cli.scheduler.run import get_all_schedules
from taskiq.scheduler.scheduled_task import ScheduledTask
from taskiq.scheduler.scheduler import TaskiqScheduler


class DummySource(ScheduleSource):
def __init__(self, schedules: Union[Exception, List[ScheduledTask]]) -> None:
self.schedules = schedules

async def get_schedules(self) -> List[ScheduledTask]:
"""Return test schedules, or raise an exception."""
if isinstance(self.schedules, Exception):
raise self.schedules
return self.schedules


@pytest.mark.anyio
async def test_get_schedules_success() -> None:
"""Tests that schedules are returned correctly."""
schedules1 = [
ScheduledTask(
task_name="a",
labels={},
args=[],
kwargs={},
time=datetime.now(),
),
ScheduledTask(
task_name="b",
labels={},
args=[],
kwargs={},
time=datetime.now(),
),
]
schedules2 = [
ScheduledTask(
task_name="c",
labels={},
args=[],
kwargs={},
time=datetime.now(),
),
]
sources: List[ScheduleSource] = [
DummySource(schedules1),
DummySource(schedules2),
]

schedules = await get_all_schedules(
TaskiqScheduler(InMemoryBroker(), sources),
)
assert schedules == {
sources[0]: schedules1,
sources[1]: schedules2,
}


@pytest.mark.anyio
async def test_get_schedules_error() -> None:
"""Tests that if source returned an error, empty list will be returned."""
source1 = DummySource(
[
ScheduledTask(
task_name="a",
labels={},
args=[],
kwargs={},
time=datetime.now(),
),
],
)
source2 = DummySource(Exception("test"))

schedules = await get_all_schedules(
TaskiqScheduler(InMemoryBroker(), [source1, source2]),
)
assert schedules == {
source1: source1.schedules,
source2: [],
}

0 comments on commit b223097

Please sign in to comment.