Skip to content

Commit

Permalink
Ensure ConnectionPool is closed even if network stack swallows Cancel…
Browse files Browse the repository at this point in the history
…ledErrors (#8928)
  • Loading branch information
fjetter authored Nov 8, 2024
1 parent c38c509 commit ff8b2f4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
14 changes: 12 additions & 2 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,7 @@ def __init__(
)
self._pending_count = 0
self._connecting_count = 0
self._connecting_close_timeout = 5
self.status = Status.init

def _validate(self) -> None:
Expand Down Expand Up @@ -1537,7 +1538,9 @@ def callback(task: asyncio.Task[Comm]) -> None:
try:
return connect_attempt.result()
except asyncio.CancelledError:
raise CommClosedError(reason)
if reason:
raise CommClosedError(reason)
raise

def reuse(self, addr: str, comm: Comm) -> None:
"""
Expand Down Expand Up @@ -1615,8 +1618,15 @@ async def close(self) -> None:
for _ in comms:
self.semaphore.release()

start = time()
while self._connecting:
await asyncio.sleep(0.005)
if time() - start > self._connecting_close_timeout:
logger.warning(
"Pending connections refuse to cancel. %d connections pending. Closing anyway.",
len(self._connecting),
)
break
await asyncio.sleep(0.01)


def coerce_to_address(o):
Expand Down
1 change: 1 addition & 0 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ async def close( # type:ignore[override]
self.status = Status.closed
await super().close()
self.__exit_stack.__exit__(None, None, None)
logger.info("Nanny at %r closed.", self.address_safe)
return "OK"

async def _log_event(self, topic, msg):
Expand Down
43 changes: 42 additions & 1 deletion distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dask

from distributed.batched import BatchedSend
from distributed.comm.core import CommClosedError
from distributed.comm.core import CommClosedError, FatalCommClosedError
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPListener
from distributed.core import (
Expand Down Expand Up @@ -707,6 +707,47 @@ async def connect_to_server():
assert all(t.cancelled() for t in tasks)


@gen_test()
async def test_connection_pool_catch_all_cancellederrors(monkeypatch):
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPConnector

in_connect = asyncio.Event()
block_connect = asyncio.Event()

class BlockedConnector(TCPConnector):
async def connect(self, address, deserialize, **connection_args):
# This is extremely artificial and assumes that something further
# down in the stack would block a cancellation. We want to make sure
# that our ConnectionPool closes regardless of this.
in_connect.set()
try:
await block_connect.wait()
except asyncio.CancelledError:
await asyncio.sleep(30)
raise
raise FatalCommClosedError()

class BlockedConnectBackend(TCPBackend):
_connector_class = BlockedConnector

monkeypatch.setitem(backends, "tcp", BlockedConnectBackend())

async with Server({}) as server:
await server.listen("tcp://")
pool = await ConnectionPool(limit=2)
pool._connecting_close_timeout = 0

t = asyncio.create_task(pool.connect(server.address))

await in_connect.wait()
while not pool._connecting_count:
await asyncio.sleep(0.1)
with captured_logger("distributed.core") as sio:
await pool.close()
assert "Pending connections refuse to cancel" in sio.getvalue()


@gen_test()
async def test_remove_cancels_connect_attempts():
loop = asyncio.get_running_loop()
Expand Down

0 comments on commit ff8b2f4

Please sign in to comment.