Skip to content

Commit

Permalink
Exclude comm handshake from connect timeout (#7698)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Grainger <[email protected]>
  • Loading branch information
fjetter and graingert authored Jul 28, 2023
1 parent 9eb6728 commit 2751741
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 81 deletions.
30 changes: 4 additions & 26 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import weakref
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, ClassVar

import dask
Expand Down Expand Up @@ -264,20 +263,8 @@ async def on_connection(
) -> None:
local_info = {**comm.handshake_info(), **(handshake_overrides or {})}

timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, default="seconds")
try:
# Timeout is to ensure that we'll terminate connections eventually.
# Connector side will employ smaller timeouts and we should only
# reach this if the comm is dead anyhow.
await wait_for(comm.write(local_info), timeout=timeout)
handshake = await wait_for(comm.read(), timeout=timeout)
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
except Exception as e:
with suppress(Exception):
await comm.close()
raise CommClosedError(f"Comm {comm!r} closed.") from e
await comm.write(local_info)
handshake = await comm.read()

comm.remote_info = handshake
comm.remote_info["address"] = comm.peer_address
Expand Down Expand Up @@ -386,17 +373,8 @@ def time_left():
**comm.handshake_info(),
**(handshake_overrides or {}),
}
try:
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
handshake = await wait_for(comm.read(), time_left())
await wait_for(comm.write(local_info), time_left())
except Exception as exc:
with suppress(Exception):
await comm.close()
raise OSError(
f"Timed out during handshake while connecting to {addr} after {timeout} s"
) from exc
await comm.write(local_info)
handshake = await comm.read()

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
Expand Down
9 changes: 4 additions & 5 deletions distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ class UnreliableBackend(tcp.TCPBackend):
listener.stop()


@pytest.mark.slow
@gen_test()
async def test_handshake_slow_comm(tcp, monkeypatch):
class SlowComm(tcp.TCP):
Expand Down Expand Up @@ -999,11 +1000,9 @@ def get_connector(self):

import dask

with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}):
with pytest.raises(
IOError, match="Timed out during handshake while connecting to"
):
await connect(listener.contact_address)
# The connect itself is fast. Only the handshake is slow
with dask.config.set({"distributed.comm.timeouts.connect": "500ms"}):
await connect(listener.contact_address)
finally:
listener.stop()

Expand Down
16 changes: 0 additions & 16 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6011,22 +6011,6 @@ async def test_client_timeout_2():
assert stop - start < 1


@gen_test()
async def test_client_active_bad_port():
import tornado.httpserver
import tornado.web

application = tornado.web.Application([(r"/", tornado.web.RequestHandler)])
http_server = tornado.httpserver.HTTPServer(application)
http_server.listen(8080)
with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}):
c = Client("127.0.0.1:8080", asynchronous=True)
with pytest.raises((TimeoutError, IOError)):
async with c:
pass
http_server.stop()


@pytest.mark.parametrize("direct", [True, False])
@gen_cluster(client=True, client_kwargs={"serializers": ["dask", "msgpack"]})
async def test_turn_off_pickle(c, s, a, b, direct):
Expand Down
49 changes: 27 additions & 22 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,30 +1063,35 @@ async def kill(self, *, timeout, reason=None):
@pytest.mark.slow
@gen_cluster(client=True, Worker=SlowKillNanny, nthreads=[("", 1)] * 2)
async def test_restart_nanny_timeout_exceeded(c, s, a, b):
f = c.submit(div, 1, 0)
fr = c.submit(inc, 1, resources={"FOO": 1})
await wait(f)
assert s.erred_tasks
assert s.computations
assert s.unrunnable
assert s.tasks

with pytest.raises(
TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s"
):
await c.restart(timeout="1s")
assert a.kill_called.is_set()
assert b.kill_called.is_set()
try:
f = c.submit(div, 1, 0)
fr = c.submit(inc, 1, resources={"FOO": 1})
await wait(f)
assert s.erred_tasks
assert s.computations
assert s.unrunnable
assert s.tasks

assert not s.workers
assert not s.erred_tasks
assert not s.computations
assert not s.unrunnable
assert not s.tasks
with pytest.raises(
TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s"
):
await c.restart(timeout="1s")
assert a.kill_called.is_set()
assert b.kill_called.is_set()

assert not s.workers
assert not s.erred_tasks
assert not s.computations
assert not s.unrunnable
assert not s.tasks

assert not c.futures
assert f.status == "cancelled"
assert fr.status == "cancelled"
finally:
a.kill_proceed.set()
b.kill_proceed.set()

assert not c.futures
assert f.status == "cancelled"
assert fr.status == "cancelled"


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
Expand Down
13 changes: 2 additions & 11 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,16 +599,9 @@ async def test_dump_cluster_state_unresponsive_local_worker(s, a, b, tmp_path):


@pytest.mark.slow
@gen_cluster(
client=True,
Worker=Nanny,
config={"distributed.comm.timeouts.connect": "600ms"},
)
@gen_cluster(client=True, Worker=Nanny)
async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path):
clog_fut = asyncio.create_task(
c.run(lambda dask_scheduler: dask_scheduler.stop(), workers=[a.worker_address])
)
await asyncio.sleep(0.2)
await c.run(lambda dask_worker: dask_worker.stop(), workers=[a.worker_address])

await dump_cluster_state(s, [a, b], str(tmp_path), "dump")
with open(f"{tmp_path}/dump.yaml") as fh:
Expand All @@ -620,8 +613,6 @@ async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path):
"OSError('Timed out trying to connect to"
)

clog_fut.cancel()


# Note: WINDOWS constant doesn't work with `mypy --platform win32`
if sys.platform == "win32":
Expand Down
3 changes: 2 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,8 @@ async def close( # type: ignore
for pc in self.periodic_callbacks.values():
pc.stop()

self.stop()

# Cancel async instructions
await BaseWorker.close(self, timeout=timeout)

Expand Down Expand Up @@ -1638,7 +1640,6 @@ def _close(executor, wait):
executor=executor, wait=executor_wait
) # Just run it directly

self.stop()
await self.rpc.close()

self.status = Status.closed
Expand Down

0 comments on commit 2751741

Please sign in to comment.