Skip to content

Commit

Permalink
fix: prevent pending tasks from racing teardown
Browse files Browse the repository at this point in the history
  • Loading branch information
gmega committed Jan 27, 2025
1 parent 90dda4f commit 63b4c51
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 31 deletions.
2 changes: 2 additions & 0 deletions benchmarks/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def cmd_run_experiment(experiments: Dict[str, ExperimentBuilder[Experiment]], ar
logger.info(DECLogEntry.adapt_instance(experiment))
experiment.build().run()

print(f"Experiment {args.experiment} completed successfully.")


def cmd_describe_experiment(args):
if not args.type:
Expand Down
26 changes: 23 additions & 3 deletions benchmarks/core/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from concurrent import futures
from concurrent.futures.thread import ThreadPoolExecutor
from queue import Queue
from typing import Iterable, Iterator, List
from typing import Iterable, Iterator, List, cast

from typing_extensions import TypeVar

Expand Down Expand Up @@ -50,8 +51,27 @@ def _consume(task: Iterable[T]) -> None:
yield item

# This will cause any exceptions thrown in tasks to be re-raised.
for future in task_futures:
future.result()
ensure_successful(task_futures)

finally:
executor.shutdown(wait=True)


def ensure_successful(futs: Iterable[futures.Future[T]]) -> List[T]:
future_list = list(futs)
futures.wait(future_list, return_when=futures.ALL_COMPLETED)

# We treat cancelled futures as if they were successful.
exceptions = [
fut.exception()
for fut in future_list
if not fut.cancelled() and fut.exception() is not None
]

if exceptions:
raise ExceptionGroup(
"One or more computations failed to complete successfully",
cast(List[Exception], exceptions),
)

return [cast(T, fut.result()) for fut in future_list]
52 changes: 37 additions & 15 deletions benchmarks/core/experiments/static_experiment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
from multiprocessing.pool import ThreadPool
from concurrent.futures.thread import ThreadPoolExecutor

from time import sleep
from typing import Sequence, Optional

from typing_extensions import Generic, List, Tuple

from benchmarks.core.concurrency import ensure_successful
from benchmarks.core.experiments.experiments import ExperimentWithLifecycle
from benchmarks.core.network import (
TInitialMetadata,
Expand Down Expand Up @@ -36,8 +38,8 @@ def __init__(
self.file_size = file_size
self.seed = seed

self._pool = ThreadPool(
processes=len(network) - len(seeders)
self._executor = ThreadPoolExecutor(
max_workers=len(network) - len(seeders)
if concurrency is None
else concurrency
)
Expand Down Expand Up @@ -71,20 +73,34 @@ def _leech(leecher):
_log_request(leecher, "leech", str(self.meta), RequestEventType.end)
return download

downloads = list(self._pool.imap_unordered(_leech, leechers))

logger.info("Now waiting for downloads to complete")

def _await_for_download(element: Tuple[int, DownloadHandle]) -> int:
downloads = ensure_successful(
[self._executor.submit(_leech, leecher) for leecher in leechers]
)

def _await_for_download(
element: Tuple[int, DownloadHandle],
) -> Tuple[int, DownloadHandle]:
index, download = element
if not download.await_for_completion():
raise Exception(
f"Download ({index}, {str(download)}) did not complete in time."
)
return index

for i in self._pool.imap_unordered(_await_for_download, enumerate(downloads)):
logger.info("Download %d / %d completed", i + 1, len(downloads))
logger.info(
"Download %d / %d completed (node: %s)",
index + 1,
len(downloads),
download.node.name,
)
return element

ensure_successful(
[
self._executor.submit(_await_for_download, (i, download))
for i, download in enumerate(downloads)
]
)

# FIXME this is a hack to ensure that nodes get a chance to log their data before we
# run the teardown hook and remove the torrents.
Expand All @@ -96,15 +112,21 @@ def _remove(element: Tuple[int, Node[TNetworkHandle, TInitialMetadata]]):
index, node = element
assert self._cid is not None # to please mypy
node.remove(self._cid)
return index
logger.info("Node %d (%s) removed file", index + 1, node.name)
return element

try:
for i in self._pool.imap_unordered(_remove, enumerate(self.nodes)):
logger.info("Node %d removed file", i + 1)
# Since teardown might be called as the result of an exception, it's expected
# that not all removes will succeed, so we don't check their result.
ensure_successful(
[
self._executor.submit(_remove, (i, node))
for i, node in enumerate(self.nodes)
]
)
finally:
logger.info("Shut down thread pool.")
self._pool.close()
self._pool.join()
self._executor.shutdown(wait=True)
logger.info("Done.")

def _split_nodes(
Expand Down
77 changes: 69 additions & 8 deletions benchmarks/core/experiments/tests/test_static_experiment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from dataclasses import dataclass
from io import StringIO
from typing import Optional, List
Expand All @@ -19,26 +20,41 @@ def __str__(self):


class MockNode(Node[MockGenData, str]):
def __init__(self, name="mock_node") -> None:
def __init__(
self,
name="mock_node",
download_lag: float = 0,
should_fail_download: bool = False,
) -> None:
self._name = name
self.seeding: Optional[MockGenData] = None
self.leeching: Optional[MockGenData] = None
self.download_was_awaited = False

self.cleanup_was_called = False
self.download_lag = download_lag
self.download_completed = False
self.download_failed = False

self.should_fail_download = should_fail_download

@property
def name(self) -> str:
return self._name

def genseed(self, size: int, seed: int, meta: str) -> MockGenData:
self.seeding = MockGenData(size=size, seed=seed, name=meta)
self.download_completed = True
return self.seeding

def leech(self, handle: MockGenData):
self.leeching = handle
return MockDownloadHandle(self)
return MockDownloadHandle(self, self.download_lag, self.should_fail_download)

def remove(self, handle: MockGenData):
assert (
self.download_completed or self.download_failed
), "Removing download before completion"

if self.leeching is not None:
assert self.leeching == handle
elif self.seeding is not None:
Expand All @@ -49,19 +65,40 @@ def remove(self, handle: MockGenData):
)

self.remove_was_called = True
return True


class MockDownloadHandle(DownloadHandle):
def __init__(self, parent: MockNode) -> None:
def __init__(
self, parent: MockNode, lag: float = 0, should_fail: bool = False
) -> None:
self.parent = parent
self.lag = lag
self.should_fail = should_fail

@property
def node(self):
return self.parent

def await_for_completion(self, timeout: float = 0) -> bool:
self.parent.download_was_awaited = True
if self.should_fail:
self.parent.download_failed = True
raise Exception("Oooops, I failed!")
time.sleep(self.lag)
self.parent.download_completed = True
return True


def mock_network(n: int) -> List[MockNode]:
return [MockNode(f"node-{i}") for i in range(n)]
def mock_network(
n: int, fail: Optional[List[int]] = None, download_lag: float = 0.0
) -> List[MockNode]:
fail_list = fail or []
return [
MockNode(
f"node-{i}", should_fail_download=i in fail_list, download_lag=download_lag
)
for i in range(n)
]


def test_should_generate_correct_data_and_seed():
Expand Down Expand Up @@ -104,7 +141,7 @@ def test_should_download_at_remaining_nodes():
if node.leeching is not None:
assert node.leeching == gendata
assert node.seeding is None
assert node.download_was_awaited
assert node.download_completed
actual_leechers.add(index)

assert actual_leechers == set(range(13)) - set(seeders)
Expand Down Expand Up @@ -199,3 +236,27 @@ def test_should_delete_file_from_nodes_at_the_end_of_the_experiment():

assert network[0].remove_was_called
assert network[1].remove_was_called


def test_should_not_have_pending_download_operations_running_at_teardown():
network = mock_network(n=3, fail=[1], download_lag=1)
seeders = [0]

experiment = StaticDisseminationExperiment(
seeders=seeders,
network=network,
meta="dataset-1",
file_size=1000,
seed=12,
)

try:
experiment.run()
except* Exception as e:
assert len(e.exceptions) == 1
assert str(e.exceptions[0]) == "Oooops, I failed!"

# Downloads should have been marked as completed even
# though we had one exception.
assert network[0].download_completed
assert network[2].download_completed
6 changes: 6 additions & 0 deletions benchmarks/core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
class DownloadHandle(ABC):
"""A :class:`DownloadHandle` is a reference to an ongoing download operation."""

@property
@abstractmethod
def node(self) -> "Node":
"""The node that initiated the download."""
pass

@abstractmethod
def await_for_completion(self, timeout: float = 0) -> bool:
"""Blocks the current thread until either the download completes or a timeout expires.
Expand Down
43 changes: 41 additions & 2 deletions benchmarks/core/tests/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Semaphore
from typing import Iterable

import pytest

from benchmarks.core.concurrency import pflatmap
from benchmarks.core.concurrency import pflatmap, ensure_successful


@pytest.fixture
def executor():
executor = None
try:
executor = ThreadPoolExecutor(max_workers=3)
yield executor
finally:
if executor is not None:
executor.shutdown(wait=True)


def test_should_run_iterators_in_separate_threads():
Expand Down Expand Up @@ -44,7 +56,34 @@ def faulty_task():
for val in it:
actual_vals.add(val)
assert False, "ValueError was not raised"
except ValueError:
except* ValueError:
pass

assert actual_vals == reference_vals


def test_should_return_results_when_no_failures_occur(executor):
def reliable_task(i: int) -> int:
return i

assert set(
ensure_successful(executor.submit(reliable_task, i) for i in range(10))
) == set(range(10))


def test_should_raise_exception_when_one_task_fails(executor):
def reliable_task(i: int) -> int:
return i

def faulty_task(i: int):
raise ValueError("I'm very faulty")

try:
ensure_successful(
executor.submit(reliable_task if i % 2 == 0 else faulty_task, i)
for i in range(10)
)
except* ValueError as e:
assert len(e.exceptions) == 5
for exception in e.exceptions:
assert str(exception) == "I'm very faulty"
6 changes: 5 additions & 1 deletion benchmarks/deluge/deluge_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,13 @@ def __getattr__(self, item):

class DelugeDownloadHandle(DownloadHandle):
def __init__(self, torrent: Torrent, node: DelugeNode) -> None:
self.node = node
self._node = node
self.torrent = torrent

@property
def node(self) -> DelugeNode:
return self._node

def await_for_completion(self, timeout: float = 0) -> bool:
name = self.torrent.name

Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@ types-pyyaml = "^6.0.12.20240917"
types-requests = "^2.32.0.20241016"
httpx = "^0.28.1"


[tool.poetry.group.dev.dependencies]
pre-commit = "^4.0.1"


[tool.poetry.group.agent.dependencies]
uvicorn = "^0.34.0"

Expand All @@ -42,6 +40,9 @@ markers = [
[tool.mypy]
ignore_missing_imports = true

[tool.ruff]
target-version = "py312"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down

0 comments on commit 63b4c51

Please sign in to comment.