Skip to content

Commit 8b04241

Browse files
committed
fixup! Add support for worker state callbacks
1 parent 90f44f6 commit 8b04241

File tree

4 files changed

+94
-40
lines changed

4 files changed

+94
-40
lines changed

docs/src/index.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ DistributedNext.cluster_cookie(::Any)
5555
## Callbacks
5656

5757
```@docs
58-
DistributedNext.add_worker_added_callback
59-
DistributedNext.remove_worker_added_callback
58+
DistributedNext.add_worker_starting_callback
59+
DistributedNext.remove_worker_starting_callback
60+
DistributedNext.add_worker_started_callback
61+
DistributedNext.remove_worker_started_callback
6062
DistributedNext.add_worker_exiting_callback
6163
DistributedNext.remove_worker_exiting_callback
6264
DistributedNext.add_worker_exited_callback

ext/ReviseExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ Revise.is_master_worker(worker::DistributedNextWorker) = worker.id == 1
2323

2424
function __init__()
2525
Revise.register_workers_function(get_workers)
26-
DistributedNext.add_worker_added_callback(pid -> Revise.init_worker(DistributedNextWorker(pid));
27-
key="DistributedNext-integration")
26+
DistributedNext.add_worker_started_callback(pid -> Revise.init_worker(DistributedNextWorker(pid));
27+
key="DistributedNext-integration")
2828
end
2929

3030
end

src/cluster.jl

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -463,23 +463,17 @@ function addprocs(manager::ClusterManager; kwargs...)
463463

464464
cluster_mgmt_from_master_check()
465465

466+
# Call worker-starting callbacks
467+
warning_interval = params[:callback_warning_interval]
468+
_run_callbacks_concurrently("worker-starting", worker_starting_callbacks,
469+
warning_interval, [(manager, kwargs)])
470+
471+
# Add new workers
466472
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager, params)
467473

468-
callback_tasks = Dict{Any, Task}()
469-
for worker in new_workers
470-
for (name, callback) in worker_added_callbacks
471-
callback_tasks[name] = Threads.@spawn callback(worker)
472-
end
473-
end
474-
475-
running_callbacks = () -> ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
476-
while timedwait(() -> isempty(running_callbacks()), params[:callback_warning_interval]) === :timed_out
477-
callbacks_str = join(running_callbacks(), ", ")
478-
@warn "Waiting for these worker-added callbacks to finish: $(callbacks_str)"
479-
end
480-
481-
# Wait on the tasks so that exceptions bubble up
482-
wait.(values(callback_tasks))
474+
# Call worker-started callbacks
475+
_run_callbacks_concurrently("worker-started", worker_started_callbacks,
476+
warning_interval, new_workers)
483477

484478
return new_workers
485479
end
@@ -870,7 +864,8 @@ const HDR_COOKIE_LEN=16
870864
const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
871865
const map_sock_wrkr = IdDict()
872866
const map_del_wrkr = Set{Int}()
873-
const worker_added_callbacks = Dict{Any, Base.Callable}()
867+
const worker_starting_callbacks = Dict{Any, Base.Callable}()
868+
const worker_started_callbacks = Dict{Any, Base.Callable}()
874869
const worker_exiting_callbacks = Dict{Any, Base.Callable}()
875870
const worker_exited_callbacks = Dict{Any, Base.Callable}()
876871

@@ -882,9 +877,29 @@ end
882877

883878
# Callbacks
884879

885-
function _add_callback(f, key, dict)
886-
if !hasmethod(f, Tuple{Int})
887-
throw(ArgumentError("Callback function is invalid, it must be able to accept a single Int argument"))
880+
function _run_callbacks_concurrently(callbacks_name, callbacks_dict, warning_interval, arglist)
881+
callback_tasks = Dict{Any, Task}()
882+
for args in arglist
883+
for (name, callback) in callbacks_dict
884+
callback_tasks[name] = Threads.@spawn callback(args...)
885+
end
886+
end
887+
888+
running_callbacks = () -> ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
889+
while timedwait(() -> isempty(running_callbacks()), warning_interval) === :timed_out
890+
callbacks_str = join(running_callbacks(), ", ")
891+
@warn "Waiting for these $(callbacks_name) callbacks to finish: $(callbacks_str)"
892+
end
893+
894+
# Wait on the tasks so that exceptions bubble up
895+
wait.(values(callback_tasks))
896+
end
897+
898+
function _add_callback(f, key, dict; arg_types=Tuple{Int})
899+
desired_signature = "f(" * join(["::$(t)" for t in arg_types.types], ", ") * ")"
900+
901+
if !hasmethod(f, arg_types)
902+
throw(ArgumentError("Callback function is invalid, it must be able to be called with these argument types: $(desired_signature)"))
888903
elseif haskey(dict, key)
889904
throw(ArgumentError("A callback function with key '$(key)' already exists"))
890905
end
@@ -900,29 +915,58 @@ end
900915
_remove_callback(key, dict) = delete!(dict, key)
901916

902917
"""
903-
add_worker_added_callback(f::Base.Callable; key=nothing)
918+
add_worker_starting_callback(f::Base.Callable; key=nothing)
919+
920+
Register a callback to be called on the master process immediately before new
921+
workers are started. The callback `f` will be called with the `ClusterManager`
922+
instance that is being used and a dictionary of parameters related to adding
923+
workers, i.e. `f(manager, params)`. The `params` dictionary is specific to the
924+
`manager` type. Note that the `LocalManager` and `SSHManager` cluster managers
925+
in DistributedNext are not fully documented yet, see the
926+
[managers.jl](https://github.com/JuliaParallel/DistributedNext.jl/blob/master/src/managers.jl)
927+
file for their definitions.
928+
929+
!!! warning
930+
Adding workers can fail so it is not guaranteed that the workers requested
931+
will exist.
932+
933+
The worker-starting callbacks will be executed concurrently. If one throws an
934+
exception it will not be caught and will bubble up through [`addprocs`](@ref).
935+
936+
Keep in mind that the callbacks will add to the time taken to launch workers; so
937+
try to either keep the callbacks fast to execute, or do the actual work
938+
asynchronously by spawning a task in the callback (beware of race conditions if
939+
you do this).
940+
"""
941+
add_worker_starting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_starting_callbacks;
942+
arg_types=Tuple{ClusterManager, Dict})
943+
944+
remove_worker_starting_callback(key) = _remove_callback(key, worker_starting_callbacks)
945+
946+
"""
947+
add_worker_started_callback(f::Base.Callable; key=nothing)
904948
905949
Register a callback to be called on the master process whenever a worker is
906950
added. The callback will be called with the added worker ID,
907951
e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
908952
not specified.
909953
910-
The worker-added callbacks will be executed concurrently. If one throws an
954+
The worker-started callbacks will be executed concurrently. If one throws an
911955
exception it will not be caught and will bubble up through [`addprocs()`](@ref).
912956
913957
Keep in mind that the callbacks will add to the time taken to launch workers; so
914958
try to either keep the callbacks fast to execute, or do the actual
915959
initialization asynchronously by spawning a task in the callback (beware of race
916960
conditions if you do this).
917961
"""
918-
add_worker_added_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_added_callbacks)
962+
add_worker_started_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_started_callbacks)
919963

920964
"""
921-
remove_worker_added_callback(key)
965+
remove_worker_started_callback(key)
922966
923-
Remove the callback for `key` that was added with [`add_worker_added_callback()`](@ref).
967+
Remove the callback for `key` that was added with [`add_worker_started_callback()`](@ref).
924968
"""
925-
remove_worker_added_callback(key) = _remove_callback(key, worker_added_callbacks)
969+
remove_worker_started_callback(key) = _remove_callback(key, worker_started_callbacks)
926970

927971
"""
928972
add_worker_exiting_callback(f::Base.Callable; key=nothing)

test/distributed_exec.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,40 +1939,47 @@ end
19391939
@testset "Worker state callbacks" begin
19401940
rmprocs(other_workers())
19411941

1942+
# Adding a callback with an invalid signature should fail
1943+
@test_throws ArgumentError DistributedNext.add_worker_started_callback(() -> nothing)
1944+
19421945
# Smoke test to ensure that all the callbacks are executed
1943-
added_workers = Int[]
1946+
starting_managers = []
1947+
started_workers = Int[]
19441948
exiting_workers = Int[]
19451949
exited_workers = Int[]
1946-
added_key = DistributedNext.add_worker_added_callback(pid -> (push!(added_workers, pid); error("foo")))
1950+
starting_key = DistributedNext.add_worker_starting_callback((manager, kwargs) -> push!(starting_managers, manager))
1951+
started_key = DistributedNext.add_worker_started_callback(pid -> (push!(started_workers, pid); error("foo")))
19471952
exiting_key = DistributedNext.add_worker_exiting_callback(pid -> push!(exiting_workers, pid))
19481953
exited_key = DistributedNext.add_worker_exited_callback(pid -> push!(exited_workers, pid))
19491954

1950-
# Test that the worker-added exception bubbles up
1955+
# Test that the worker-started exception bubbles up
19511956
@test_throws TaskFailedException addprocs(1)
19521957

19531958
pid = only(workers())
1954-
@test added_workers == [pid]
1959+
@test only(starting_managers) isa DistributedNext.LocalManager
1960+
@test started_workers == [pid]
19551961
rmprocs(workers())
19561962
@test exiting_workers == [pid]
19571963
@test exited_workers == [pid]
19581964

19591965
# Trying to reset an existing callback should fail
1960-
@test_throws ArgumentError DistributedNext.add_worker_added_callback(Returns(nothing); key=added_key)
1966+
@test_throws ArgumentError DistributedNext.add_worker_started_callback(Returns(nothing); key=started_key)
19611967

19621968
# Remove the callbacks
1963-
DistributedNext.remove_worker_added_callback(added_key)
1969+
DistributedNext.remove_worker_starting_callback(starting_key)
1970+
DistributedNext.remove_worker_started_callback(started_key)
19641971
DistributedNext.remove_worker_exiting_callback(exiting_key)
19651972
DistributedNext.remove_worker_exited_callback(exited_key)
19661973

19671974
# Test that the worker-exiting `callback_timeout` option works and that we
1968-
# get warnings about slow worker-added callbacks.
1975+
# get warnings about slow worker-started callbacks.
19691976
event = Base.Event()
19701977
callback_task = nothing
1971-
added_key = DistributedNext.add_worker_added_callback(_ -> sleep(0.5))
1978+
started_key = DistributedNext.add_worker_started_callback(_ -> sleep(0.5))
19721979
exiting_key = DistributedNext.add_worker_exiting_callback(_ -> (callback_task = current_task(); wait(event)))
19731980

1974-
@test_logs (:warn, r"Waiting for these worker-added callbacks.+") match_mode=:any addprocs(1; callback_warning_interval=0.05)
1975-
DistributedNext.remove_worker_added_callback(added_key)
1981+
@test_logs (:warn, r"Waiting for these worker-started callbacks.+") match_mode=:any addprocs(1; callback_warning_interval=0.05)
1982+
DistributedNext.remove_worker_started_callback(started_key)
19761983

19771984
@test_logs (:warn, r"Some worker-exiting callbacks have not yet finished.+") rmprocs(workers(); callback_timeout=0.5)
19781985
DistributedNext.remove_worker_exiting_callback(exiting_key)
@@ -1981,7 +1988,8 @@ end
19811988
wait(callback_task)
19821989

19831990
# Test that the initial callbacks were indeed removed
1984-
@test length(added_workers) == 1
1991+
@test length(starting_managers) == 1
1992+
@test length(started_workers) == 1
19851993
@test length(exiting_workers) == 1
19861994
@test length(exited_workers) == 1
19871995
end

0 commit comments

Comments
 (0)