Skip to content

Commit b223097

Browse files
authored
Simplified scheduler. (#221)
1 parent 76d4846 commit b223097

File tree

3 files changed

+128
-62
lines changed

3 files changed

+128
-62
lines changed

taskiq/cli/scheduler/run.py

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -32,45 +32,45 @@ def to_tz_aware(time: datetime) -> datetime:
3232
return time
3333

3434

35-
async def schedules_updater(
35+
async def get_schedules(source: ScheduleSource) -> List[ScheduledTask]:
36+
"""
37+
Get schedules from source.
38+
39+
If source raises an exception, it will be
40+
logged and an empty list will be returned.
41+
42+
:param source: source to get schedules from.
43+
"""
44+
try:
45+
return await source.get_schedules()
46+
except Exception as exc:
47+
logger.warning(
48+
"Cannot update schedules with source: %s",
49+
source,
50+
)
51+
logger.debug(exc, exc_info=True)
52+
return []
53+
54+
55+
async def get_all_schedules(
3656
scheduler: TaskiqScheduler,
37-
current_schedules: Dict[ScheduleSource, List[ScheduledTask]],
38-
event: asyncio.Event,
39-
) -> None:
57+
) -> Dict[ScheduleSource, List[ScheduledTask]]:
4058
"""
41-
Periodic update to schedules.
59+
Task to update all schedules.
4260
43-
This task periodically checks for new schedules,
44-
assembles the final list and replaces current
45-
schedule with a new one.
61+
This function updates all schedules
62+
from all sources and returns a dict
63+
with source as a key and list of
64+
scheduled tasks as a value.
4665
4766
:param scheduler: current scheduler.
48-
:param current_schedules: list of schedules.
49-
:param event: event when schedules are updated.
67+
:return: dict with source as a key and list of scheduled tasks as a value.
5068
"""
51-
while True:
52-
logger.debug("Started schedule update.")
53-
new_schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}
54-
for source in scheduler.sources:
55-
try:
56-
schedules = await source.get_schedules()
57-
except Exception as exc:
58-
logger.warning(
59-
"Cannot update schedules with source: %s",
60-
source,
61-
)
62-
logger.debug(exc, exc_info=True)
63-
continue
64-
65-
new_schedules[source] = scheduler.merge_func(
66-
new_schedules.get(source) or [],
67-
schedules,
68-
)
69-
70-
current_schedules.clear()
71-
current_schedules.update(new_schedules)
72-
event.set()
73-
await asyncio.sleep(scheduler.refresh_delay)
69+
logger.debug("Started schedule update.")
70+
schedules = await asyncio.gather(
71+
*[get_schedules(source) for source in scheduler.sources],
72+
)
73+
return dict(zip(scheduler.sources, schedules))
7474

7575

7676
def get_task_delay(task: ScheduledTask) -> Optional[int]:
@@ -141,23 +141,14 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
141141
:param scheduler: current scheduler.
142142
"""
143143
loop = asyncio.get_event_loop()
144-
schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}
145-
146-
current_task = asyncio.current_task()
147-
first_update_event = asyncio.Event()
148-
updater_task = loop.create_task(
149-
schedules_updater(
150-
scheduler,
151-
schedules,
152-
first_update_event,
153-
),
154-
)
155-
if current_task is not None:
156-
current_task.add_done_callback(lambda _: updater_task.cancel())
157-
await first_update_event.wait()
158144
running_schedules = set()
159145
while True:
160-
for source, task_list in schedules.items():
146+
# We use this method to correctly sleep for one minute.
147+
next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta(
148+
minutes=1,
149+
)
150+
scheduled_tasks = await get_all_schedules(scheduler)
151+
for source, task_list in scheduled_tasks.items():
161152
for task in task_list:
162153
try:
163154
task_delay = get_task_delay(task)
@@ -175,11 +166,7 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
175166
running_schedules.add(send_task)
176167
send_task.add_done_callback(running_schedules.discard)
177168

178-
delay = (
179-
datetime.now().replace(second=1, microsecond=0)
180-
+ timedelta(minutes=1)
181-
- datetime.now()
182-
)
169+
delay = next_minute - datetime.now()
183170
await asyncio.sleep(delay.total_seconds())
184171

185172

taskiq/scheduler/scheduler.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import TYPE_CHECKING, Callable, List
1+
from typing import TYPE_CHECKING, List
22

33
from taskiq.kicker import AsyncKicker
4-
from taskiq.scheduler.merge_functions import only_new
54
from taskiq.scheduler.scheduled_task import ScheduledTask
65
from taskiq.utils import maybe_awaitable
76

@@ -17,16 +16,9 @@ def __init__(
1716
self,
1817
broker: "AsyncBroker",
1918
sources: List["ScheduleSource"],
20-
merge_func: Callable[
21-
[List["ScheduledTask"], List["ScheduledTask"]],
22-
List["ScheduledTask"],
23-
] = only_new,
24-
refresh_delay: float = 30.0,
2519
) -> None: # pragma: no cover
2620
self.broker = broker
2721
self.sources = sources
28-
self.refresh_delay = refresh_delay
29-
self.merge_func = merge_func
3022

3123
async def startup(self) -> None: # pragma: no cover
3224
"""

tests/cli/scheduler/test_updater.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from datetime import datetime
2+
from typing import List, Union
3+
4+
import pytest
5+
6+
from taskiq import InMemoryBroker, ScheduleSource
7+
from taskiq.cli.scheduler.run import get_all_schedules
8+
from taskiq.scheduler.scheduled_task import ScheduledTask
9+
from taskiq.scheduler.scheduler import TaskiqScheduler
10+
11+
12+
class DummySource(ScheduleSource):
13+
def __init__(self, schedules: Union[Exception, List[ScheduledTask]]) -> None:
14+
self.schedules = schedules
15+
16+
async def get_schedules(self) -> List[ScheduledTask]:
17+
"""Return test schedules, or raise an exception."""
18+
if isinstance(self.schedules, Exception):
19+
raise self.schedules
20+
return self.schedules
21+
22+
23+
@pytest.mark.anyio
24+
async def test_get_schedules_success() -> None:
25+
"""Tests that schedules are returned correctly."""
26+
schedules1 = [
27+
ScheduledTask(
28+
task_name="a",
29+
labels={},
30+
args=[],
31+
kwargs={},
32+
time=datetime.now(),
33+
),
34+
ScheduledTask(
35+
task_name="b",
36+
labels={},
37+
args=[],
38+
kwargs={},
39+
time=datetime.now(),
40+
),
41+
]
42+
schedules2 = [
43+
ScheduledTask(
44+
task_name="c",
45+
labels={},
46+
args=[],
47+
kwargs={},
48+
time=datetime.now(),
49+
),
50+
]
51+
sources: List[ScheduleSource] = [
52+
DummySource(schedules1),
53+
DummySource(schedules2),
54+
]
55+
56+
schedules = await get_all_schedules(
57+
TaskiqScheduler(InMemoryBroker(), sources),
58+
)
59+
assert schedules == {
60+
sources[0]: schedules1,
61+
sources[1]: schedules2,
62+
}
63+
64+
65+
@pytest.mark.anyio
66+
async def test_get_schedules_error() -> None:
67+
"""Tests that if source returned an error, empty list will be returned."""
68+
source1 = DummySource(
69+
[
70+
ScheduledTask(
71+
task_name="a",
72+
labels={},
73+
args=[],
74+
kwargs={},
75+
time=datetime.now(),
76+
),
77+
],
78+
)
79+
source2 = DummySource(Exception("test"))
80+
81+
schedules = await get_all_schedules(
82+
TaskiqScheduler(InMemoryBroker(), [source1, source2]),
83+
)
84+
assert schedules == {
85+
source1: source1.schedules,
86+
source2: [],
87+
}

0 commit comments

Comments
 (0)