Skip to content

Commit

Permalink
Fix decide_worker picking a closing worker (#8032)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Aug 3, 2023
1 parent 6867c5b commit 84e1984
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 22 deletions.
12 changes: 5 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,7 +2207,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
restrictions.
Out of eligible workers holding dependencies of ``ts``, selects the worker
where, considering worker backlong and data-transfer costs, the task is
where, considering worker backlog and data-transfer costs, the task is
estimated to start running the soonest.
Returns
Expand All @@ -2222,9 +2222,6 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:

valid_workers = self.valid_workers(ts)
if valid_workers is None and len(self.running) < len(self.workers):
if not self.running:
return None

# If there were no restrictions, `valid_workers()` didn't subset by
# `running`.
valid_workers = self.running
Expand Down Expand Up @@ -8197,7 +8194,7 @@ def _task_to_client_msgs(ts: TaskState) -> dict[str, list[dict[str, Any]]]:

def decide_worker(
ts: TaskState,
all_workers: Iterable[WorkerState],
all_workers: set[WorkerState],
valid_workers: set[WorkerState] | None,
objective: Callable[[WorkerState], Any],
) -> WorkerState | None:
Expand All @@ -8218,12 +8215,13 @@ def decide_worker(
"""
assert all(dts.who_has for dts in ts.dependencies)
if ts.actor:
candidates = set(all_workers)
candidates = all_workers.copy()
else:
candidates = {wws for dts in ts.dependencies for wws in dts.who_has}
candidates &= all_workers
if valid_workers is None:
if not candidates:
candidates = set(all_workers)
candidates = all_workers.copy()
else:
candidates &= valid_workers
if not candidates:
Expand Down
68 changes: 53 additions & 15 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,59 @@ def test_submit_after_failed_worker_sync(loop):
assert total.result() == sum(map(inc, range(10)))


@pytest.mark.slow()
@pytest.mark.parametrize("compute_on_failed", [False, True])
@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"})
async def test_submit_after_failed_worker_async(c, s, a, b, compute_on_failed):
async with Nanny(s.address, nthreads=2) as n:
await c.wait_for_workers(3)

L = c.map(inc, range(10))
await wait(L)

kill_task = asyncio.create_task(n.kill())
compute_addr = n.worker_address if compute_on_failed else a.address
total = c.submit(sum, L, workers=[compute_addr], allow_other_workers=True)
assert await total == sum(range(1, 11))
await kill_task
@pytest.mark.parametrize("when", ["closing", "closed"])
@pytest.mark.parametrize("y_on_failed", [False, True])
@pytest.mark.parametrize("x_on_failed", [False, True])
@gen_cluster(
client=True,
nthreads=[("", 1)] * 2,
config={"distributed.comm.timeouts.connect": "1s"},
)
async def test_submit_after_failed_worker_async(
c, s, a, b, x_on_failed, y_on_failed, when, monkeypatch
):
a_ws = s.workers[a.address]

x = c.submit(
inc,
1,
key="x",
workers=[b.address if x_on_failed else a.address],
allow_other_workers=True,
)
await wait(x)

if when == "closed":
await b.close()
await async_poll_for(lambda: b.address not in s.workers, timeout=5)
elif when == "closing":
orig_remove_worker = s.remove_worker
in_remove_worker = asyncio.Event()
wait_remove_worker = asyncio.Event()

async def remove_worker(*args, **kwargs):
in_remove_worker.set()
await wait_remove_worker.wait()
return await orig_remove_worker(*args, **kwargs)

monkeypatch.setattr(s, "remove_worker", remove_worker)
await b.close()
await in_remove_worker.wait()
assert s.workers[b.address].status.name == "closing"

y = c.submit(
inc,
x,
key="y",
workers=[b.address if y_on_failed else a.address],
allow_other_workers=True,
)
await async_poll_for(lambda: "y" in s.tasks, timeout=5)

if when == "closing":
wait_remove_worker.set()
assert await y == 3
assert s.tasks["y"].who_has == {a_ws}


@gen_cluster(client=True, timeout=60)
Expand Down

0 comments on commit 84e1984

Please sign in to comment.