diff --git a/edb/server/protocol/request_queue.py b/edb/server/protocol/request_queue.py new file mode 100644 index 000000000000..c09f2b6f1f59 --- /dev/null +++ b/edb/server/protocol/request_queue.py @@ -0,0 +1,404 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations +from dataclasses import dataclass, field +from typing import ( + Awaitable, + Callable, + Collection, + Final, + Generic, + Iterable, + Literal, + Optional, + Tuple, + TypeVar, +) + +import abc +import asyncio +import copy +import random + +_T = TypeVar('_T') + + +@dataclass +class Context: + """Overall parameters which apply to all requests.""" + + # The maximum number of times to retry tasks + max_retry_count: int = 4 + + # Information about the service's rate limits + request_limits: Optional[Limits] = None + + # Whether to jitter the delay time if a retry error is produced + jitter: bool = True + + # Initial and lower bound for guessing the delay + guess_delay_min: Final[float] = 0.001 + + # The upper bound for delays, both known and guessed + delay_max: Final[float] = 60.0 + + +@dataclass +class ExecutionReport: + """Information about the tasks after they are complete""" + + unknown_error_count: int = 0 + known_error_messages: list[str] = field(default_factory=list) + remaining_retries: int = 0 + + updated_limits: Optional[Limits] = None + + +@dataclass +class Limits: + """Information about a service's rate limits.""" + + # Total limit of a resource per minute for a service. + # A True value represents a unlimited total + total: Optional[int | Literal[True]] = None + + # Remaining resources before the limit is hit. + remaining: Optional[int] = None + + # A guess about the delay in seconds needed between requests. + # To be used if no other data is available. + guess_delay: Optional[float] = None + + def update(self, latest: Limits) -> Limits: + """Update based on the latest information.""" + + # The total will change rarely. Always take the latest value if + # it exists. + if latest.total is not None: + self.total = latest.total + + # The remaining amount can fluctuate quite a bit, take the smallest + # value available. + if self.remaining is None: + self.remaining = latest.remaining + elif latest.remaining is not None: + self.remaining = min(self.remaining, latest.remaining) + + if self.total is True and self.remaining: + # If there is a remaining value, the total is not actually + # unlimited. + self.total = None + + # Always use the latest guess value if it exists. + if latest.guess_delay is not None: + self.guess_delay = latest.guess_delay + + return self + + +class Task(abc.ABC, Generic[_T]): + """Represents an async request""" + + params: Params[_T] + _inner: asyncio.Task + + def __init__(self, params: Params[_T]): + self.params = params + self._inner = asyncio.create_task(self.run()) + + @abc.abstractmethod + async def run(self) -> Optional[Result[_T]]: + """Run the task and return a result.""" + raise NotImplementedError + + async def wait_result(self) -> None: + """Wait for the request to complete.""" + await self._inner + + def get_result(self) -> Optional[Result[_T]]: + """Get the result of the request.""" + task_result = self._inner.result() + return task_result + + +class Params(abc.ABC, Generic[_T]): + """The parameters of an async request task. + + These are used to generate tasks. They may be used to generate multiple + tasks if the task fails and is re-tried. + """ + + @abc.abstractmethod + def cost(self) -> int: + """Expected cost to execute the task.""" + raise NotImplementedError + + @abc.abstractmethod + def create_task(self) -> Task[_T]: + """Create a task using the parameters.""" + raise NotImplementedError + + +@dataclass(frozen=True) +class Result(abc.ABC, Generic[_T]): + """The result of an async request. + + Some tasks may include updated request limit information in their + response. + """ + + data: _T | Error + + # Some services can return their request limits + request_limits: Optional[Limits] = None + + def finalize(self) -> None: + """An optional finalize task to be run sequentially.""" + pass + + +@dataclass(frozen=True) +class Error: + """Represents an error from an async request.""" + + message: str + + # If there was an error, it may be possible to retry the request + # Eg. 429 too many requests + retry: bool + + +async def execute_requests( + params: list[Params[_T]], + *, + ctx: Context, +) -> ExecutionReport: + report = ExecutionReport() + + # Set up request limits + if ctx.request_limits is None: + # If no other information is available, for the first attempt assume + # there is no limit. + request_limits = Limits(total=True) + + else: + request_limits = copy.copy(ctx.request_limits) + + # If any tasks fail and can be retried, retry them up to a maximum number + # of times. + retry_count: int = 0 + active_task_indexes: set[int] = set(range(len(params))) + + while active_task_indexes and retry_count < ctx.max_retry_count: + retry_task_indexes: set[int] = set() + + # Run tasks + + execution_strategy = _choose_execution_strategy( + params, active_task_indexes, request_limits, + ) + + results: dict[int, Result[_T]] = await execution_strategy( + params, active_task_indexes, request_limits, ctx, + ) + + # Check task results + request_limits.remaining = None + + for task_index in active_task_indexes: + if task_index not in results: + report.unknown_error_count += 1 + continue + + task_result = results[task_index] + + if isinstance(task_result.data, Error): + if task_result.data.retry: + # task can be retried + retry_task_indexes.add(task_index) + + else: + # error with message + report.known_error_messages.append(task_result.data.message) + + task_result.finalize() + + if task_result.request_limits is not None: + request_limits.update(task_result.request_limits) + + retry_count += 1 + active_task_indexes = retry_task_indexes + + # Note how many retries were left + report.remaining_retries += len(retry_task_indexes) + + # If there is a guess rate, decrease it. If it is reused in the future, + # this allows the guess to gradually decrease over time, approaching the + # actual rate limit + if request_limits.guess_delay is not None: + request_limits.guess_delay = max( + 0.95 * request_limits.guess_delay, + ctx.guess_delay_min, + ) + + return report + + +def _choose_execution_strategy( + params: list[Params[_T]], + indexes: Collection[int], + limits: Limits, +) -> Callable[ + [list[Params[_T]], Iterable[int], Limits, Context], + Awaitable[dict[int, Result[_T]]], +]: + # Choose a strategy based on the rate limit information available. + # + # Note: Regardless of the strategy used, it is always possible to fail + # a request from rate limits as the provider may be accessed by multiple + # users. + + cost = sum( + params[index].cost() + for index in indexes + ) + + if ( + limits.remaining is not None + and cost <= limits.remaining + ) or ( + limits.total is True + ): + return _execute_all + + elif limits.total is not None: + return _execute_known_limit + + else: + return _execute_guess_limit + + +async def _execute_all( + params: list[Params[_T]], + indexes: Iterable[int], + limits: Limits, + ctx: Context, +) -> dict[int, Result[_T]]: + # Send all requests at once. + # We are confident that all requests can be handled right away. + + tasks: dict[int, Task[_T]] = {} + + for task_index in indexes: + tasks[task_index] = params[task_index].create_task() + + results: dict[int, Result[_T]] = {} + + for task_index, task in tasks.items(): + await task.wait_result() + + task_result = task.get_result() + if task_result is not None: + results[task_index] = task_result + + return results + + +async def _execute_known_limit( + params: list[Params[_T]], + indexes: Iterable[int], + limits: Limits, + ctx: Context, +) -> dict[int, Result[_T]]: + # Send requests one at a time at a rate corresponding to the limit. + + assert limits.total is not None + + results, _ = await _execute_with_limit( + params, + indexes, + 60.0 / limits.total * 1.1, # 10% buffer just in case + ctx, + ) + + return results + + +async def _execute_guess_limit( + params: list[Params[_T]], + indexes: Iterable[int], + limits: Limits, + ctx: Context, +) -> dict[int, Result[_T]]: + # Otherwise, send requests one at a time, but try to guess the rate + # limit by reducing it if a request fails due to too many requests. + + results, base_delay = await _execute_with_limit( + params, + indexes, + ( + ctx.guess_delay_min + if limits.guess_delay is None else + limits.guess_delay + ), + ctx, + ) + + limits.guess_delay = base_delay + + return results + + +async def _execute_with_limit( + params: list[Params[_T]], + indexes: Iterable[int], + base_delay: float, + ctx: Context, +) -> Tuple[dict[int, Result[_T]], float]: + + results: dict[int, Result[_T]] = {} + + for task_index in indexes: + await asyncio.sleep(min( + base_delay * params[task_index].cost(), + ctx.delay_max, + )) + + task = params[task_index].create_task() + await task.wait_result() + + task_result = task.get_result() + if task_result is None: + # No valid result was produced + continue + + results[task_index] = task_result + + # If the task failed but allows retry, increase the delay before + # the next attempt. + if ( + isinstance(task_result.data, Error) + and task_result.data.retry + ): + base_delay = min( + base_delay * (1 + random.random() if ctx.jitter else 2), + ctx.delay_max, + ) + + return results, base_delay diff --git a/tests/test_server_request_queue.py b/tests/test_server_request_queue.py new file mode 100644 index 000000000000..c5e9a8c86f06 --- /dev/null +++ b/tests/test_server_request_queue.py @@ -0,0 +1,726 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Optional + +import asyncio +import unittest + +from edb.server.protocol import request_queue as rq + + +try: + import async_solipsism +except ImportError: + async_solipsism = None # type: ignore + + +# Simulate tasks which returns a TestData containing an int. +# +# TestResult keeps track of when it has returned "created" and whether it has +# been finalized. +# +# TestParams and TestTask are defined so that the params already have the +# desired TestResults, and TestTask fetches and returns it. +# +# When testing retry behaviour, TestTask will return each provided TestResult +# in sequence. + +@dataclass(frozen=True) +class TestData: + value: int + + +class TestResult(rq.Result[TestData]): + + # The time this result was "produced" + time: float = -1 + + # when finalized, log the data value and time + finalize_target: Optional[dict[int, float]] = None + + def __init__( + self, + finalize_target: Optional[dict[int, float]] = None, + **kwargs + ): + super().__init__(**kwargs) + + self.finalize_target = finalize_target + + def finalize(self) -> None: + if self.finalize_target is not None and isinstance(self.data, TestData): + self.finalize_target[self.data.value] = self.time + + +@dataclass +class TestParams(rq.Params[TestData]): + + # Cost multiplier used to factor the rate delay + _cost: int + + # The desired results + _results: list[TestResult] = field(default_factory=list) + + # The index of the current retry + _try_index: int = -1 + + def cost(self) -> int: + return self._cost + + def create_task(self) -> TestTask: + self._try_index += 1 + return TestTask(self) + + +class TestTask(rq.Task[TestData]): + + def __init__(self, params: TestParams): + super().__init__(params=params) + + async def run(self) -> Optional[TestResult]: + assert isinstance(self.params, TestParams) + if self.params._try_index < len(self.params._results): + result = self.params._results[self.params._try_index] + result.time = asyncio.get_running_loop().time() + + return result + + else: + return None + + +def with_fake_event_loop(f): + # async_solpsism creates an event loop with, among other things, + # a totally fake clock which starts at 0. + def new(*args, **kwargs): + loop = async_solipsism.EventLoop() + try: + loop.run_until_complete(f(*args, **kwargs)) + finally: + loop.close() + + return new + + +@unittest.skipIf(async_solipsism is None, 'async_solipsism is missing') +class TestRequests(unittest.TestCase): + + def test_limits_update_01(self): + # Check total takes the "latest" value + self.assertEqual( + rq.Limits(total=None).update(rq.Limits(total=None)), + rq.Limits(total=None), + ) + + self.assertEqual( + rq.Limits(total=None).update(rq.Limits(total=10)), + rq.Limits(total=10), + ) + + self.assertEqual( + rq.Limits(total=None).update(rq.Limits(total=True)), + rq.Limits(total=True), + ) + + self.assertEqual( + rq.Limits(total=10).update(rq.Limits(total=None)), + rq.Limits(total=10), + ) + + self.assertEqual( + rq.Limits(total=10).update(rq.Limits(total=20)), + rq.Limits(total=20), + ) + + self.assertEqual( + rq.Limits(total=10).update(rq.Limits(total=True)), + rq.Limits(total=True), + ) + + self.assertEqual( + rq.Limits(total=True).update(rq.Limits(total=None)), + rq.Limits(total=True), + ) + + self.assertEqual( + rq.Limits(total=True).update(rq.Limits(total=True)), + rq.Limits(total=True), + ) + + self.assertEqual( + rq.Limits(total=True).update(rq.Limits(total=10)), + rq.Limits(total=10), + ) + + # Check remaining takes the smallest available value + self.assertEqual( + rq.Limits(remaining=None).update(rq.Limits(remaining=None)), + rq.Limits(remaining=None), + ) + + self.assertEqual( + rq.Limits(remaining=None).update(rq.Limits(remaining=10)), + rq.Limits(remaining=10), + ) + + self.assertEqual( + rq.Limits(remaining=10).update(rq.Limits(remaining=None)), + rq.Limits(remaining=10), + ) + + self.assertEqual( + rq.Limits(remaining=10).update(rq.Limits(remaining=20)), + rq.Limits(remaining=10), + ) + + self.assertEqual( + rq.Limits(remaining=20).update(rq.Limits(remaining=10)), + rq.Limits(remaining=10), + ) + + # Check that a remaining value resets an unlimited total + self.assertEqual( + rq.Limits(total=True, remaining=10).update(rq.Limits()), + rq.Limits(remaining=10), + ) + + self.assertEqual( + rq.Limits(total=True).update(rq.Limits(remaining=10)), + rq.Limits(remaining=10), + ) + + self.assertEqual( + rq.Limits(total=True, remaining=20).update(rq.Limits(remaining=10)), + rq.Limits(remaining=10), + ) + # Check total takes the "latest" value + + @with_fake_event_loop + async def test_execute_requests_01(self): + # All tasks return a valid result + finalize_target: dict[int, float] = {} + + report = await rq.execute_requests( + params=[ + TestParams( + _cost=1, + _results=[ + TestResult( + data=TestData(1), + finalize_target=finalize_target, + ) + ] + ), + TestParams( + _cost=2, + _results=[ + TestResult( + data=TestData(2), + finalize_target=finalize_target, + ) + ] + ), + TestParams( + _cost=3, + _results=[ + TestResult( + data=TestData(3), + finalize_target=finalize_target, + ) + ] + ), + TestParams( + _cost=4, + _results=[ + TestResult( + data=TestData(4), + finalize_target=finalize_target, + ) + ] + ), + ], + ctx=rq.Context(jitter=False, request_limits=rq.Limits(total=True)), + ) + + self.assertEqual( + {1: 0, 2: 0, 3: 0, 4: 0}, + finalize_target + ) + + self.assertEqual(0, report.unknown_error_count) + self.assertEqual([], report.known_error_messages) + self.assertEqual(0, report.remaining_retries) + + @with_fake_event_loop + async def test_execute_requests_02(self): + # A mix of successes and failures + finalize_target: dict[int, float] = {} + + report = await rq.execute_requests( + params=[ + TestParams( + _cost=1, + _results=[ + TestResult( + data=TestData(1), + finalize_target=finalize_target, + ) + ] + ), + TestParams( + _cost=2, + _results=[ + # successful retry + TestResult(data=rq.Error('B', True)), + TestResult( + data=TestData(2), + finalize_target=finalize_target, + ), + ] + ), + TestParams( + _cost=3 + ), + TestParams( + _cost=4, _results=[TestResult(data=rq.Error('D', False))] + ), + TestParams( + _cost=2, + _results=[ + # unsuccessful retry + TestResult(data=rq.Error('E', True)), + TestResult(data=rq.Error('E', True)), + TestResult(data=rq.Error('E', True)), + TestResult(data=rq.Error('E', True)), + TestResult(data=rq.Error('E', True)), + ] + ), + ], + ctx=rq.Context(jitter=False, request_limits=rq.Limits(total=True)), + ) + + self.assertEqual( + {1: 0, 2: 0}, + finalize_target + ) + + self.assertEqual(1, report.unknown_error_count) + self.assertEqual(['D'], report.known_error_messages) + self.assertEqual(1, report.remaining_retries) + + def test_choose_execution_strategy_01(self): + params = [ + TestParams(_cost=1), + TestParams(_cost=2), + TestParams(_cost=3), + TestParams(_cost=4), + ] + + indexes = [0, 1, 2, 3] + + # If there is no rate limit, use _execute_all. + self.assertEqual( + rq._execute_all, + rq._choose_execution_strategy( + params, indexes, rq.Limits(total=True), + ), + ) + + # If there are enough remaining requests to cover the task indexes, + # use _execute_all. + self.assertEqual( + rq._execute_all, + rq._choose_execution_strategy( + params, indexes, rq.Limits(remaining=20), + ), + ) + + self.assertEqual( + rq._execute_all, + rq._choose_execution_strategy( + params, indexes, rq.Limits(remaining=10), + ), + ) + + # If there are not enough remaining requests, and the rate limit is + # known, use _execute_known_limit. + self.assertEqual( + rq._execute_known_limit, + rq._choose_execution_strategy( + params, indexes, rq.Limits(remaining=8, total=8), + ), + ) + + self.assertEqual( + rq._execute_known_limit, + rq._choose_execution_strategy( + params, indexes, rq.Limits(total=10), + ), + ) + + # Otherwise, use _execute_guess_limit + self.assertEqual( + rq._execute_guess_limit, + rq._choose_execution_strategy( + params, indexes, rq.Limits(remaining=8, guess_delay=10), + ), + ) + + self.assertEqual( + rq._execute_guess_limit, + rq._choose_execution_strategy( + params, indexes, rq.Limits(guess_delay=10), + ), + ) + + self.assertEqual( + rq._execute_guess_limit, + rq._choose_execution_strategy( + params, indexes, rq.Limits(), + ), + ) + + @with_fake_event_loop + async def test_execute_all_01(self): + # All tasks return a valid result + results = await rq._execute_all( + params=[ + TestParams(_cost=1, _results=[TestResult(data=TestData(1))]), + TestParams(_cost=2, _results=[TestResult(data=TestData(2))]), + TestParams(_cost=3, _results=[TestResult(data=TestData(3))]), + TestParams(_cost=4, _results=[TestResult(data=TestData(4))]), + ], + indexes=[0, 1, 2, 3], + limits=rq.Limits(total=True), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 0: TestResult(data=TestData(1)), + 1: TestResult(data=TestData(2)), + 2: TestResult(data=TestData(3)), + 3: TestResult(data=TestData(4)), + }, + results, + ) + + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [0, 0, 0, 0]) + + @with_fake_event_loop + async def test_execute_all_02(self): + # A mix of successes and failures + results = await rq._execute_all( + params=[ + TestParams( + _cost=1, _results=[TestResult(data=TestData(1))] + ), + TestParams( + _cost=2, _results=[TestResult(data=rq.Error('B', True))] + ), + TestParams( + _cost=3 + ), + TestParams( + _cost=4, _results=[TestResult(data=rq.Error('D', False))] + ), + ], + indexes=[0, 1, 2, 3], + limits=rq.Limits(total=True), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 0: TestResult(data=TestData(1)), + 1: TestResult(data=rq.Error('B', True)), + 3: TestResult(data=rq.Error('D', False)), + }, + results, + ) + + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [0, 0, 0]) + + @with_fake_event_loop + async def test_execute_all_03(self): + # Run only some tasks + results = await rq._execute_all( + params=[ + TestParams(_cost=1, _results=[TestResult(data=TestData(1))]), + TestParams(_cost=2, _results=[TestResult(data=TestData(2))]), + TestParams(_cost=3, _results=[TestResult(data=TestData(3))]), + TestParams(_cost=4, _results=[TestResult(data=TestData(4))]), + ], + indexes=[1, 3], + limits=rq.Limits(total=True), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 1: TestResult(data=TestData(2)), + 3: TestResult(data=TestData(4)), + }, + results, + ) + + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [0, 0]) + + @with_fake_event_loop + async def test_execute_known_limit_01(self): + # All tasks return a valid result + results = await rq._execute_known_limit( + params=[ + TestParams(_cost=1, _results=[TestResult(data=TestData(1))]), + TestParams(_cost=2, _results=[TestResult(data=TestData(2))]), + TestParams(_cost=3, _results=[TestResult(data=TestData(3))]), + TestParams(_cost=4, _results=[TestResult(data=TestData(4))]), + ], + indexes=[0, 1, 2, 3], + limits=rq.Limits(total=6), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 0: TestResult(data=TestData(1)), + 1: TestResult(data=TestData(2)), + 2: TestResult(data=TestData(3)), + 3: TestResult(data=TestData(4)), + }, + results, + ) + + # The ideal delay is 60s / 6 = 10s + # With a 1.1 factor, the base delay is 11s + # + # The cumulative cost factor is [1, 3, 6, 10]. + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [11, 33, 66, 110]) + + @with_fake_event_loop + async def test_execute_known_limit_02(self): + # A mix of successes and failures + results = await rq._execute_known_limit( + params=[ + TestParams( + _cost=1, _results=[TestResult(data=TestData(1))] + ), + TestParams( + _cost=2, _results=[TestResult(data=rq.Error('B', True))] + ), + TestParams( + _cost=3 + ), + TestParams( + _cost=4, _results=[TestResult(data=rq.Error('D', False))] + ), + ], + indexes=[0, 1, 2, 3], + limits=rq.Limits(total=6), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 0: TestResult(data=TestData(1)), + 1: TestResult(data=rq.Error('B', True)), + 3: TestResult(data=rq.Error('D', False)), + }, + results, + ) + + # The ideal delay is 60s / 6 = 10s + # With a 1.1 factor, the base delay is 11s + # + # The cumulative cost factor is [1, 3, 9, 17]. + # + # The cost increment increases after index 1 because of a retry, + # which causes the base delay to increase. + # + # The final value is lower than the expected 11*17=187 because the + # delay is always capped at the delay max. + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [11, 33, 153]) + + @with_fake_event_loop + async def test_execute_known_limit_03(self): + # Run only some tasks + results = await rq._execute_known_limit( + params=[ + TestParams(_cost=1, _results=[TestResult(data=TestData(1))]), + TestParams(_cost=2, _results=[TestResult(data=TestData(2))]), + TestParams(_cost=3, _results=[TestResult(data=TestData(3))]), + TestParams(_cost=4, _results=[TestResult(data=TestData(4))]), + ], + indexes=[1, 3], + limits=rq.Limits(total=6), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 1: TestResult(data=TestData(2)), + 3: TestResult(data=TestData(4)), + }, + results, + ) + + # The ideal delay is 60s / 6 = 10s + # With a 1.1 factor, the base delay is 11s + # + # The cumulative cost factor is [2, 6]. + # + # Skipped indexes don't cause a delay + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [22, 66]) + + @with_fake_event_loop + async def test_execute_guess_limit_01(self): + # All tasks return a valid result + results = await rq._execute_guess_limit( + params=[ + TestParams(_cost=1, _results=[TestResult(data=TestData(1))]), + TestParams(_cost=2, _results=[TestResult(data=TestData(2))]), + TestParams(_cost=3, _results=[TestResult(data=TestData(3))]), + TestParams(_cost=4, _results=[TestResult(data=TestData(4))]), + ], + indexes=[0, 1, 2, 3], + limits=rq.Limits(guess_delay=10), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 0: TestResult(data=TestData(1)), + 1: TestResult(data=TestData(2)), + 2: TestResult(data=TestData(3)), + 3: TestResult(data=TestData(4)), + }, + results, + ) + + # The cumulative cost factor is [1, 3, 6, 10]. + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [10, 30, 60, 100]) + + @with_fake_event_loop + async def test_execute_guess_limit_02(self): + # A mix of successes and failures + results = await rq._execute_guess_limit( + params=[ + TestParams( + _cost=1, _results=[TestResult(data=TestData(1))] + ), + TestParams( + _cost=2, _results=[TestResult(data=rq.Error('B', True))] + ), + TestParams( + _cost=3 + ), + TestParams( + _cost=4, _results=[TestResult(data=rq.Error('D', False))] + ), + ], + indexes=[0, 1, 2, 3], + limits=rq.Limits(guess_delay=10), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 0: TestResult(data=TestData(1)), + 1: TestResult(data=rq.Error('B', True)), + 3: TestResult(data=rq.Error('D', False)), + }, + results, + ) + + # The cumulative cost factor is [1, 3, 9, 17]. + # + # The cost increment increases after index 1 because of a retry, + # which causes the base delay to increase. + # + # The final value is lower than the expected 10*17=170 because the + # delay is always capped at the delay max. + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [10, 30, 150]) + + @with_fake_event_loop + async def test_execute_guess_limit_03(self): + # Run only some tasks + results = await rq._execute_guess_limit( + params=[ + TestParams(_cost=1, _results=[TestResult(data=TestData(1))]), + TestParams(_cost=2, _results=[TestResult(data=TestData(2))]), + TestParams(_cost=3, _results=[TestResult(data=TestData(3))]), + TestParams(_cost=4, _results=[TestResult(data=TestData(4))]), + ], + indexes=[1, 3], + limits=rq.Limits(guess_delay=10), + ctx=rq.Context(jitter=False), + ) + + self.assertEqual( + { + 1: TestResult(data=TestData(2)), + 3: TestResult(data=TestData(4)), + }, + results, + ) + + # The cumulative cost factor is [2, 6]. + # + # Skipped indexes don't cause a delay + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [20, 60]) + + @with_fake_event_loop + async def test_execute_guess_limit_04(self): + # Use the minimum guess delay if no other guess was provided + results = await rq._execute_guess_limit( + params=[ + TestParams(_cost=1, _results=[TestResult(data=TestData(1))]), + TestParams(_cost=2, _results=[TestResult(data=TestData(2))]), + TestParams(_cost=3, _results=[TestResult(data=TestData(3))]), + TestParams(_cost=4, _results=[TestResult(data=TestData(4))]), + ], + indexes=[0, 1, 2, 3], + limits=rq.Limits(), + ctx=rq.Context(jitter=False, guess_delay_min=10), + ) + + self.assertEqual( + { + 0: TestResult(data=TestData(1)), + 1: TestResult(data=TestData(2)), + 2: TestResult(data=TestData(3)), + 3: TestResult(data=TestData(4)), + }, + results, + ) + + # The cumulative cost factor is [1, 3, 6, 10]. + times = sorted(r.time for r in results.values()) + self.assertEqual(times, [10, 30, 60, 100])