diff --git a/taskiq/cli/scheduler/run.py b/taskiq/cli/scheduler/run.py index 1e35f01..8001965 100644 --- a/taskiq/cli/scheduler/run.py +++ b/taskiq/cli/scheduler/run.py @@ -32,45 +32,45 @@ def to_tz_aware(time: datetime) -> datetime: return time -async def schedules_updater( +async def get_schedules(source: ScheduleSource) -> List[ScheduledTask]: + """ + Get schedules from source. + + If source raises an exception, it will be + logged and an empty list will be returned. + + :param source: source to get schedules from. + """ + try: + return await source.get_schedules() + except Exception as exc: + logger.warning( + "Cannot update schedules with source: %s", + source, + ) + logger.debug(exc, exc_info=True) + return [] + + +async def get_all_schedules( scheduler: TaskiqScheduler, - current_schedules: Dict[ScheduleSource, List[ScheduledTask]], - event: asyncio.Event, -) -> None: +) -> Dict[ScheduleSource, List[ScheduledTask]]: """ - Periodic update to schedules. + Task to update all schedules. - This task periodically checks for new schedules, - assembles the final list and replaces current - schedule with a new one. + This function updates all schedules + from all sources and returns a dict + with source as a key and list of + scheduled tasks as a value. :param scheduler: current scheduler. - :param current_schedules: list of schedules. - :param event: event when schedules are updated. + :return: dict with source as a key and list of scheduled tasks as a value. """ - while True: - logger.debug("Started schedule update.") - new_schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {} - for source in scheduler.sources: - try: - schedules = await source.get_schedules() - except Exception as exc: - logger.warning( - "Cannot update schedules with source: %s", - source, - ) - logger.debug(exc, exc_info=True) - continue - - new_schedules[source] = scheduler.merge_func( - new_schedules.get(source) or [], - schedules, - ) - - current_schedules.clear() - current_schedules.update(new_schedules) - event.set() - await asyncio.sleep(scheduler.refresh_delay) + logger.debug("Started schedule update.") + schedules = await asyncio.gather( + *[get_schedules(source) for source in scheduler.sources], + ) + return dict(zip(scheduler.sources, schedules)) def get_task_delay(task: ScheduledTask) -> Optional[int]: @@ -141,23 +141,14 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: :param scheduler: current scheduler. """ loop = asyncio.get_event_loop() - schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {} - - current_task = asyncio.current_task() - first_update_event = asyncio.Event() - updater_task = loop.create_task( - schedules_updater( - scheduler, - schedules, - first_update_event, - ), - ) - if current_task is not None: - current_task.add_done_callback(lambda _: updater_task.cancel()) - await first_update_event.wait() running_schedules = set() while True: - for source, task_list in schedules.items(): + # We use this method to correctly sleep for one minute. + next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta( + minutes=1, + ) + scheduled_tasks = await get_all_schedules(scheduler) + for source, task_list in scheduled_tasks.items(): for task in task_list: try: task_delay = get_task_delay(task) @@ -175,11 +166,7 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None: running_schedules.add(send_task) send_task.add_done_callback(running_schedules.discard) - delay = ( - datetime.now().replace(second=1, microsecond=0) - + timedelta(minutes=1) - - datetime.now() - ) + delay = next_minute - datetime.now() await asyncio.sleep(delay.total_seconds()) diff --git a/taskiq/scheduler/scheduler.py b/taskiq/scheduler/scheduler.py index 3087c7a..7de51c2 100644 --- a/taskiq/scheduler/scheduler.py +++ b/taskiq/scheduler/scheduler.py @@ -1,7 +1,6 @@ -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING, List from taskiq.kicker import AsyncKicker -from taskiq.scheduler.merge_functions import only_new from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.utils import maybe_awaitable @@ -17,16 +16,9 @@ def __init__( self, broker: "AsyncBroker", sources: List["ScheduleSource"], - merge_func: Callable[ - [List["ScheduledTask"], List["ScheduledTask"]], - List["ScheduledTask"], - ] = only_new, - refresh_delay: float = 30.0, ) -> None: # pragma: no cover self.broker = broker self.sources = sources - self.refresh_delay = refresh_delay - self.merge_func = merge_func async def startup(self) -> None: # pragma: no cover """ diff --git a/tests/cli/scheduler/test_updater.py b/tests/cli/scheduler/test_updater.py new file mode 100644 index 0000000..c2a7b9e --- /dev/null +++ b/tests/cli/scheduler/test_updater.py @@ -0,0 +1,87 @@ +from datetime import datetime +from typing import List, Union + +import pytest + +from taskiq import InMemoryBroker, ScheduleSource +from taskiq.cli.scheduler.run import get_all_schedules +from taskiq.scheduler.scheduled_task import ScheduledTask +from taskiq.scheduler.scheduler import TaskiqScheduler + + +class DummySource(ScheduleSource): + def __init__(self, schedules: Union[Exception, List[ScheduledTask]]) -> None: + self.schedules = schedules + + async def get_schedules(self) -> List[ScheduledTask]: + """Return test schedules, or raise an exception.""" + if isinstance(self.schedules, Exception): + raise self.schedules + return self.schedules + + +@pytest.mark.anyio +async def test_get_schedules_success() -> None: + """Tests that schedules are returned correctly.""" + schedules1 = [ + ScheduledTask( + task_name="a", + labels={}, + args=[], + kwargs={}, + time=datetime.now(), + ), + ScheduledTask( + task_name="b", + labels={}, + args=[], + kwargs={}, + time=datetime.now(), + ), + ] + schedules2 = [ + ScheduledTask( + task_name="c", + labels={}, + args=[], + kwargs={}, + time=datetime.now(), + ), + ] + sources: List[ScheduleSource] = [ + DummySource(schedules1), + DummySource(schedules2), + ] + + schedules = await get_all_schedules( + TaskiqScheduler(InMemoryBroker(), sources), + ) + assert schedules == { + sources[0]: schedules1, + sources[1]: schedules2, + } + + +@pytest.mark.anyio +async def test_get_schedules_error() -> None: + """Tests that if source returned an error, empty list will be returned.""" + source1 = DummySource( + [ + ScheduledTask( + task_name="a", + labels={}, + args=[], + kwargs={}, + time=datetime.now(), + ), + ], + ) + source2 = DummySource(Exception("test")) + + schedules = await get_all_schedules( + TaskiqScheduler(InMemoryBroker(), [source1, source2]), + ) + assert schedules == { + source1: source1.schedules, + source2: [], + }