Skip to content

Commit

Permalink
Fix test nanny timeout (#8847)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Sep 3, 2024
1 parent 5639216 commit d728052
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 29 deletions.
47 changes: 27 additions & 20 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ async def _():
await self.instantiate()

try:
await wait_for(_(), timeout)
await wait_for(asyncio.shield(_()), timeout)
except asyncio.TimeoutError:
logger.error(
f"Restart timed out after {timeout}s; returning before finished"
Expand Down Expand Up @@ -745,26 +745,30 @@ async def start(self) -> Status:
os.environ.update(self.pre_spawn_env)

try:
await self.process.start()
except OSError:
logger.exception("Nanny failed to start process", exc_info=True)
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
try:
msg = await self._wait_until_connected(uid)
except Exception:
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
raise
try:
await self.process.start()
except OSError:
# This can only happen if the actual process creation failed, e.g.
# multiprocessing.Process.start failed. This is not tested!
logger.exception("Nanny failed to start process", exc_info=True)
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
try:
msg = await self._wait_until_connected(uid)
except Exception:
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
raise
finally:
self.running.set()
if not msg:
return self.status
self.worker_address = msg["address"]
self.worker_dir = msg["dir"]
assert self.worker_address
self.status = Status.running
self.running.set()

return self.status

Expand Down Expand Up @@ -799,6 +803,7 @@ def mark_stopped(self):
msg = self._death_message(self.process.pid, r)
logger.info(msg)
self.status = Status.stopped
self.running.clear()
self.stopped.set()
# Release resources
self.process.close()
Expand Down Expand Up @@ -830,22 +835,24 @@ async def kill(
"""
deadline = time() + timeout

if self.status == Status.stopped:
return
if self.status == Status.stopping:
await self.stopped.wait()
return
# If the process is not properly up it will not watch the closing queue
# and we may end up leaking this process
# Therefore wait for it to be properly started before killing it
if self.status == Status.starting:
await self.running.wait()

assert self.status in (
Status.stopping,
Status.stopped,
Status.running,
Status.failed, # process failed to start, but hasn't been joined yet
Status.closing_gracefully,
), self.status
if self.status == Status.stopped:
return
if self.status == Status.stopping:
await self.stopped.wait()
return
self.status = Status.stopping
logger.info("Nanny asking worker to close. Reason: %s", reason)

Expand Down
68 changes: 59 additions & 9 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,25 +208,47 @@ async def test_scheduler_file():
s.stop()


@pytest.mark.xfail(
os.environ.get("MINDEPS") == "true",
reason="Timeout errors with mindeps environment",
)
@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)])
async def test_nanny_timeout(c, s, a):
@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart(c, s, a):
x = await c.scatter(123)
assert await c.submit(lambda: 1) == 1

await a.restart()

while x.status != "cancelled":
await asyncio.sleep(0.1)

assert await c.submit(lambda: 1) == 1


@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart_timeout(c, s, a):
x = await c.scatter(123)
with captured_logger(
logging.getLogger("distributed.nanny"), level=logging.ERROR
) as logger:
await a.restart(timeout=0.1)
await a.restart(timeout=0)

out = logger.getvalue()
assert "timed out" in out.lower()

start = time()
while x.status != "cancelled":
await asyncio.sleep(0.1)
assert time() < start + 7

assert await c.submit(lambda: 1) == 1


@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart_timeout_stress(c, s, a):
x = await c.scatter(123)
restarts = [a.restart(timeout=random.random()) for _ in range(100)]
await asyncio.gather(*restarts)

while x.status != "cancelled":
await asyncio.sleep(0.1)

assert await c.submit(lambda: 1) == 1
assert len(s.workers) == 1


@gen_cluster(
Expand Down Expand Up @@ -582,6 +604,34 @@ async def test_worker_start_exception(s):
assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue()


@gen_cluster(nthreads=[])
async def test_worker_start_exception_while_killing(s):
nanny = Nanny(s.address, worker_class=BrokenWorker)

async def try_to_kill_nanny():
while not nanny.process or nanny.process.status != Status.starting:
await asyncio.sleep(0)
await nanny.kill()

kill_task = asyncio.create_task(try_to_kill_nanny())
with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs:
with raises_with_cause(
RuntimeError,
"Nanny failed to start",
RuntimeError,
"BrokenWorker failed to start",
):
async with nanny:
pass
await kill_task
assert nanny.status == Status.failed
# ^ NOTE: `Nanny.close` sets it to `closed`, then `Server.start._close_on_failure` sets it to `failed`
assert nanny.process is None
assert "Restarting worker" not in logs.getvalue()
# Avoid excessive spewing. (It's also printed once extra within the subprocess, which is okay.)
assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue()


@gen_cluster(nthreads=[])
async def test_failure_during_worker_initialization(s):
with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs:
Expand Down

0 comments on commit d728052

Please sign in to comment.