diff --git a/benchmarks/cli.py b/benchmarks/cli.py index 8de0889..8b27435 100644 --- a/benchmarks/cli.py +++ b/benchmarks/cli.py @@ -60,6 +60,8 @@ def cmd_run_experiment(experiments: Dict[str, ExperimentBuilder[Experiment]], ar logger.info(DECLogEntry.adapt_instance(experiment)) experiment.build().run() + print(f"Experiment {args.experiment} completed successfully.") + def cmd_describe_experiment(args): if not args.type: diff --git a/benchmarks/core/concurrency.py b/benchmarks/core/concurrency.py index a7a5170..e1d53ce 100644 --- a/benchmarks/core/concurrency.py +++ b/benchmarks/core/concurrency.py @@ -1,6 +1,7 @@ +from concurrent import futures from concurrent.futures.thread import ThreadPoolExecutor from queue import Queue -from typing import Iterable, Iterator, List +from typing import Iterable, Iterator, List, cast from typing_extensions import TypeVar @@ -50,8 +51,27 @@ def _consume(task: Iterable[T]) -> None: yield item # This will cause any exceptions thrown in tasks to be re-raised. - for future in task_futures: - future.result() + ensure_successful(task_futures) finally: executor.shutdown(wait=True) + + +def ensure_successful(futs: Iterable[futures.Future[T]]) -> List[T]: + future_list = list(futs) + futures.wait(future_list, return_when=futures.ALL_COMPLETED) + + # We treat cancelled futures as if they were successful. + exceptions = [ + fut.exception() + for fut in future_list + if not fut.cancelled() and fut.exception() is not None + ] + + if exceptions: + raise ExceptionGroup( + "One or more computations failed to complete successfully", + cast(List[Exception], exceptions), + ) + + return [cast(T, fut.result()) for fut in future_list] diff --git a/benchmarks/core/experiments/static_experiment.py b/benchmarks/core/experiments/static_experiment.py index a7773f2..6cde1f9 100644 --- a/benchmarks/core/experiments/static_experiment.py +++ b/benchmarks/core/experiments/static_experiment.py @@ -1,10 +1,12 @@ import logging -from multiprocessing.pool import ThreadPool +from concurrent.futures.thread import ThreadPoolExecutor + from time import sleep from typing import Sequence, Optional from typing_extensions import Generic, List, Tuple +from benchmarks.core.concurrency import ensure_successful from benchmarks.core.experiments.experiments import ExperimentWithLifecycle from benchmarks.core.network import ( TInitialMetadata, @@ -36,8 +38,8 @@ def __init__( self.file_size = file_size self.seed = seed - self._pool = ThreadPool( - processes=len(network) - len(seeders) + self._executor = ThreadPoolExecutor( + max_workers=len(network) - len(seeders) if concurrency is None else concurrency ) @@ -71,20 +73,34 @@ def _leech(leecher): _log_request(leecher, "leech", str(self.meta), RequestEventType.end) return download - downloads = list(self._pool.imap_unordered(_leech, leechers)) - logger.info("Now waiting for downloads to complete") - def _await_for_download(element: Tuple[int, DownloadHandle]) -> int: + downloads = ensure_successful( + [self._executor.submit(_leech, leecher) for leecher in leechers] + ) + + def _await_for_download( + element: Tuple[int, DownloadHandle], + ) -> Tuple[int, DownloadHandle]: index, download = element if not download.await_for_completion(): raise Exception( f"Download ({index}, {str(download)}) did not complete in time." ) - return index - - for i in self._pool.imap_unordered(_await_for_download, enumerate(downloads)): - logger.info("Download %d / %d completed", i + 1, len(downloads)) + logger.info( + "Download %d / %d completed (node: %s)", + index + 1, + len(downloads), + download.node.name, + ) + return element + + ensure_successful( + [ + self._executor.submit(_await_for_download, (i, download)) + for i, download in enumerate(downloads) + ] + ) # FIXME this is a hack to ensure that nodes get a chance to log their data before we # run the teardown hook and remove the torrents. @@ -96,15 +112,21 @@ def _remove(element: Tuple[int, Node[TNetworkHandle, TInitialMetadata]]): index, node = element assert self._cid is not None # to please mypy node.remove(self._cid) - return index + logger.info("Node %d (%s) removed file", index + 1, node.name) + return element try: - for i in self._pool.imap_unordered(_remove, enumerate(self.nodes)): - logger.info("Node %d removed file", i + 1) + # Since teardown might be called as the result of an exception, it's expected + # that not all removes will succeed, so we don't check their result. + ensure_successful( + [ + self._executor.submit(_remove, (i, node)) + for i, node in enumerate(self.nodes) + ] + ) finally: logger.info("Shut down thread pool.") - self._pool.close() - self._pool.join() + self._executor.shutdown(wait=True) logger.info("Done.") def _split_nodes( diff --git a/benchmarks/core/experiments/tests/test_static_experiment.py b/benchmarks/core/experiments/tests/test_static_experiment.py index 16397b2..0945aea 100644 --- a/benchmarks/core/experiments/tests/test_static_experiment.py +++ b/benchmarks/core/experiments/tests/test_static_experiment.py @@ -1,3 +1,4 @@ +import time from dataclasses import dataclass from io import StringIO from typing import Optional, List @@ -19,12 +20,22 @@ def __str__(self): class MockNode(Node[MockGenData, str]): - def __init__(self, name="mock_node") -> None: + def __init__( + self, + name="mock_node", + download_lag: float = 0, + should_fail_download: bool = False, + ) -> None: self._name = name self.seeding: Optional[MockGenData] = None self.leeching: Optional[MockGenData] = None - self.download_was_awaited = False + self.cleanup_was_called = False + self.download_lag = download_lag + self.download_completed = False + self.download_failed = False + + self.should_fail_download = should_fail_download @property def name(self) -> str: @@ -32,13 +43,18 @@ def name(self) -> str: def genseed(self, size: int, seed: int, meta: str) -> MockGenData: self.seeding = MockGenData(size=size, seed=seed, name=meta) + self.download_completed = True return self.seeding def leech(self, handle: MockGenData): self.leeching = handle - return MockDownloadHandle(self) + return MockDownloadHandle(self, self.download_lag, self.should_fail_download) def remove(self, handle: MockGenData): + assert ( + self.download_completed or self.download_failed + ), "Removing download before completion" + if self.leeching is not None: assert self.leeching == handle elif self.seeding is not None: @@ -49,19 +65,40 @@ def remove(self, handle: MockGenData): ) self.remove_was_called = True + return True class MockDownloadHandle(DownloadHandle): - def __init__(self, parent: MockNode) -> None: + def __init__( + self, parent: MockNode, lag: float = 0, should_fail: bool = False + ) -> None: self.parent = parent + self.lag = lag + self.should_fail = should_fail + + @property + def node(self): + return self.parent def await_for_completion(self, timeout: float = 0) -> bool: - self.parent.download_was_awaited = True + if self.should_fail: + self.parent.download_failed = True + raise Exception("Oooops, I failed!") + time.sleep(self.lag) + self.parent.download_completed = True return True -def mock_network(n: int) -> List[MockNode]: - return [MockNode(f"node-{i}") for i in range(n)] +def mock_network( + n: int, fail: Optional[List[int]] = None, download_lag: float = 0.0 +) -> List[MockNode]: + fail_list = fail or [] + return [ + MockNode( + f"node-{i}", should_fail_download=i in fail_list, download_lag=download_lag + ) + for i in range(n) + ] def test_should_generate_correct_data_and_seed(): @@ -104,7 +141,7 @@ def test_should_download_at_remaining_nodes(): if node.leeching is not None: assert node.leeching == gendata assert node.seeding is None - assert node.download_was_awaited + assert node.download_completed actual_leechers.add(index) assert actual_leechers == set(range(13)) - set(seeders) @@ -199,3 +236,27 @@ def test_should_delete_file_from_nodes_at_the_end_of_the_experiment(): assert network[0].remove_was_called assert network[1].remove_was_called + + +def test_should_not_have_pending_download_operations_running_at_teardown(): + network = mock_network(n=3, fail=[1], download_lag=1) + seeders = [0] + + experiment = StaticDisseminationExperiment( + seeders=seeders, + network=network, + meta="dataset-1", + file_size=1000, + seed=12, + ) + + try: + experiment.run() + except* Exception as e: + assert len(e.exceptions) == 1 + assert str(e.exceptions[0]) == "Oooops, I failed!" + + # Downloads should have been marked as completed even + # though we had one exception. + assert network[0].download_completed + assert network[2].download_completed diff --git a/benchmarks/core/network.py b/benchmarks/core/network.py index 76c9fc9..b660979 100644 --- a/benchmarks/core/network.py +++ b/benchmarks/core/network.py @@ -9,6 +9,12 @@ class DownloadHandle(ABC): """A :class:`DownloadHandle` is a reference to an ongoing download operation.""" + @property + @abstractmethod + def node(self) -> "Node": + """The node that initiated the download.""" + pass + @abstractmethod def await_for_completion(self, timeout: float = 0) -> bool: """Blocks the current thread until either the download completes or a timeout expires. diff --git a/benchmarks/core/tests/test_concurrency.py b/benchmarks/core/tests/test_concurrency.py index e330133..0f6aa92 100644 --- a/benchmarks/core/tests/test_concurrency.py +++ b/benchmarks/core/tests/test_concurrency.py @@ -1,9 +1,21 @@ +from concurrent.futures.thread import ThreadPoolExecutor from threading import Semaphore from typing import Iterable import pytest -from benchmarks.core.concurrency import pflatmap +from benchmarks.core.concurrency import pflatmap, ensure_successful + + +@pytest.fixture +def executor(): + executor = None + try: + executor = ThreadPoolExecutor(max_workers=3) + yield executor + finally: + if executor is not None: + executor.shutdown(wait=True) def test_should_run_iterators_in_separate_threads(): @@ -44,7 +56,34 @@ def faulty_task(): for val in it: actual_vals.add(val) assert False, "ValueError was not raised" - except ValueError: + except* ValueError: pass assert actual_vals == reference_vals + + +def test_should_return_results_when_no_failures_occur(executor): + def reliable_task(i: int) -> int: + return i + + assert set( + ensure_successful(executor.submit(reliable_task, i) for i in range(10)) + ) == set(range(10)) + + +def test_should_raise_exception_when_one_task_fails(executor): + def reliable_task(i: int) -> int: + return i + + def faulty_task(i: int): + raise ValueError("I'm very faulty") + + try: + ensure_successful( + executor.submit(reliable_task if i % 2 == 0 else faulty_task, i) + for i in range(10) + ) + except* ValueError as e: + assert len(e.exceptions) == 5 + for exception in e.exceptions: + assert str(exception) == "I'm very faulty" diff --git a/benchmarks/deluge/deluge_node.py b/benchmarks/deluge/deluge_node.py index f6edfa6..a04a71c 100644 --- a/benchmarks/deluge/deluge_node.py +++ b/benchmarks/deluge/deluge_node.py @@ -207,9 +207,13 @@ def __getattr__(self, item): class DelugeDownloadHandle(DownloadHandle): def __init__(self, torrent: Torrent, node: DelugeNode) -> None: - self.node = node + self._node = node self.torrent = torrent + @property + def node(self) -> DelugeNode: + return self._node + def await_for_completion(self, timeout: float = 0) -> bool: name = self.torrent.name diff --git a/pyproject.toml b/pyproject.toml index d050445..2d241e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,11 +26,9 @@ types-pyyaml = "^6.0.12.20240917" types-requests = "^2.32.0.20241016" httpx = "^0.28.1" - [tool.poetry.group.dev.dependencies] pre-commit = "^4.0.1" - [tool.poetry.group.agent.dependencies] uvicorn = "^0.34.0" @@ -42,6 +40,9 @@ markers = [ [tool.mypy] ignore_missing_imports = true +[tool.ruff] +target-version = "py312" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"