Skip to content

Commit a787157

Browse files
committed
Add support for worker state callbacks
1 parent 8779372 commit a787157

File tree

4 files changed

+173
-10
lines changed

4 files changed

+173
-10
lines changed

docs/src/_changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ This documents notable changes in DistributedNext.jl. The format is based on
1818
incompatibilities from both libraries being used simultaneously ([#10]).
1919
- [`other_workers()`](@ref) and [`other_procs()`](@ref) were implemented and
2020
exported ([#18]).
21+
- Implemented callback support for workers being added/removed etc ([#17]).
2122

2223
## [v1.0.0] - 2024-12-02
2324

docs/src/index.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ DistributedNext.cluster_cookie()
5252
DistributedNext.cluster_cookie(::Any)
5353
```
5454

55+
## Callbacks
56+
57+
```@docs
58+
DistributedNext.add_worker_added_callback
59+
DistributedNext.remove_worker_added_callback
60+
DistributedNext.add_worker_exiting_callback
61+
DistributedNext.remove_worker_exiting_callback
62+
DistributedNext.add_worker_exited_callback
63+
DistributedNext.remove_worker_exited_callback
64+
```
65+
5566
## Cluster Manager Interface
5667

5768
This interface provides a mechanism to launch and manage Julia workers on different cluster environments.

src/cluster.jl

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,14 @@ function addprocs(manager::ClusterManager; kwargs...)
461461

462462
cluster_mgmt_from_master_check()
463463

464-
lock(worker_lock)
465-
try
466-
addprocs_locked(manager::ClusterManager; kwargs...)
467-
finally
468-
unlock(worker_lock)
464+
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager; kwargs...)
465+
for worker in new_workers
466+
for callback in values(worker_added_callbacks)
467+
callback(worker)
468+
end
469469
end
470+
471+
return new_workers
470472
end
471473

472474
function addprocs_locked(manager::ClusterManager; kwargs...)
@@ -855,13 +857,96 @@ const HDR_COOKIE_LEN=16
855857
const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
856858
const map_sock_wrkr = IdDict()
857859
const map_del_wrkr = Set{Int}()
860+
const worker_added_callbacks = Dict{Any, Base.Callable}()
861+
const worker_exiting_callbacks = Dict{Any, Base.Callable}()
862+
const worker_exited_callbacks = Dict{Any, Base.Callable}()
858863

859864
# whether process is a master or worker in a distributed setup
860865
myrole() = LPROCROLE[]
861866
function myrole!(proctype::Symbol)
862867
LPROCROLE[] = proctype
863868
end
864869

870+
# Callbacks
871+
872+
# We define the callback methods in a loop here and add docstrings for them afterwards
873+
for callback_type in (:added, :exiting, :exited)
874+
let add_name = Symbol(:add_worker_, callback_type, :_callback),
875+
remove_name = Symbol(:remove_worker_, callback_type, :_callback),
876+
dict_name = Symbol(:worker_, callback_type, :_callbacks)
877+
878+
@eval begin
879+
function $add_name(f::Base.Callable; key=nothing)
880+
if !hasmethod(f, Tuple{Int})
881+
throw(ArgumentError("Callback function is invalid, it must be able to accept a single Int argument"))
882+
end
883+
884+
if isnothing(key)
885+
key = Symbol(gensym(), nameof(f))
886+
end
887+
888+
$dict_name[key] = f
889+
return key
890+
end
891+
892+
$remove_name(key) = delete!($dict_name, key)
893+
end
894+
end
895+
end
896+
897+
"""
898+
add_worker_added_callback(f::Base.Callable; key=nothing)
899+
900+
Register a callback to be called on the master process whenever a worker is
901+
added. The callback will be called with the added worker ID,
902+
e.g. `f(w::Int)`. Returns a unique key for the callback.
903+
"""
904+
function add_worker_added_callback end
905+
906+
"""
907+
remove_worker_added_callback(key)
908+
909+
Remove the callback for `key`.
910+
"""
911+
function remove_worker_added_callback end
912+
913+
"""
914+
add_worker_exiting_callback(f::Base.Callable; key=nothing)
915+
916+
Register a callback to be called on the master process immediately before a
917+
worker is removed with [`rmprocs()`](@ref). The callback will be called with the
918+
worker ID, e.g. `f(w::Int)`. Returns a unique key for the callback.
919+
920+
All callbacks will be executed asynchronously and if they don't all finish
921+
before the `callback_timeout` passed to `rmprocs()` then the process will be
922+
removed anyway.
923+
"""
924+
function add_worker_exiting_callback end
925+
926+
"""
927+
remove_worker_exiting_callback(key)
928+
929+
Remove the callback for `key`.
930+
"""
931+
function remove_worker_exiting_callback end
932+
933+
"""
934+
add_worker_exited_callback(f::Base.Callable; key=nothing)
935+
936+
Register a callback to be called on the master process when a worker has exited
937+
for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
938+
segfaulting etc). The callback will be called with the worker ID,
939+
e.g. `f(w::Int)`. Returns a unique key for the callback.
940+
"""
941+
function add_worker_exited_callback end
942+
943+
"""
944+
remove_worker_exited_callback(key)
945+
946+
Remove the callback for `key`.
947+
"""
948+
function remove_worker_exited_callback end
949+
865950
# cluster management related API
866951
"""
867952
myid()
@@ -1048,7 +1133,7 @@ function cluster_mgmt_from_master_check()
10481133
end
10491134

10501135
"""
1051-
rmprocs(pids...; waitfor=typemax(Int))
1136+
rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10521137
10531138
Remove the specified workers. Note that only process 1 can add or remove
10541139
workers.
@@ -1062,6 +1147,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10621147
returned. The user should call [`wait`](@ref) on the task before invoking any other
10631148
parallel calls.
10641149
1150+
The `callback_timeout` specifies how long to wait for any callbacks to execute
1151+
before continuing to remove the workers (see
1152+
[`add_worker_exiting_callback()`](@ref)).
1153+
10651154
# Examples
10661155
```julia-repl
10671156
\$ julia -p 5
@@ -1078,24 +1167,36 @@ julia> workers()
10781167
6
10791168
```
10801169
"""
1081-
function rmprocs(pids...; waitfor=typemax(Int))
1170+
function rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10821171
cluster_mgmt_from_master_check()
10831172

10841173
pids = vcat(pids...)
10851174
if waitfor == 0
1086-
t = @async _rmprocs(pids, typemax(Int))
1175+
t = @async _rmprocs(pids, typemax(Int), callback_timeout)
10871176
yield()
10881177
return t
10891178
else
1090-
_rmprocs(pids, waitfor)
1179+
_rmprocs(pids, waitfor, callback_timeout)
10911180
# return a dummy task object that user code can wait on.
10921181
return @async nothing
10931182
end
10941183
end
10951184

1096-
function _rmprocs(pids, waitfor)
1185+
function _rmprocs(pids, waitfor, callback_timeout)
10971186
lock(worker_lock)
10981187
try
1188+
# Run the callbacks
1189+
callback_tasks = Task[]
1190+
for pid in pids
1191+
for callback in values(worker_exiting_callbacks)
1192+
push!(callback_tasks, Threads.@spawn callback(pid))
1193+
end
1194+
end
1195+
1196+
if timedwait(() -> all(istaskdone.(callback_tasks)), callback_timeout) === :timed_out
1197+
@warn "Some callbacks timed out, continuing to remove workers anyway"
1198+
end
1199+
10991200
rmprocset = Union{LocalProcess, Worker}[]
11001201
for p in pids
11011202
if p == 1
@@ -1241,6 +1342,14 @@ function deregister_worker(pg, pid)
12411342
delete!(pg.refs, id)
12421343
end
12431344
end
1345+
1346+
# Call callbacks on the master
1347+
if myid() == 1
1348+
for callback in values(worker_exited_callbacks)
1349+
callback(pid)
1350+
end
1351+
end
1352+
12441353
return
12451354
end
12461355

test/distributed_exec.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
using DistributedNext, Random, Serialization, Sockets
4+
import DistributedNext
45
import DistributedNext: launch, manage
56

67

@@ -1934,6 +1935,47 @@ include("splitrange.jl")
19341935
end
19351936
end
19361937

1938+
@testset "Worker state callbacks" begin
1939+
if nprocs() > 1
1940+
rmprocs(workers())
1941+
end
1942+
1943+
# Smoke test to ensure that all the callbacks are executed
1944+
added_workers = Int[]
1945+
exiting_workers = Int[]
1946+
exited_workers = Int[]
1947+
added_key = DistributedNext.add_worker_added_callback(pid -> push!(added_workers, pid))
1948+
exiting_key = DistributedNext.add_worker_exiting_callback(pid -> push!(exiting_workers, pid))
1949+
exited_key = DistributedNext.add_worker_exited_callback(pid -> push!(exited_workers, pid))
1950+
1951+
pid = only(addprocs(1))
1952+
@test added_workers == [pid]
1953+
rmprocs(workers())
1954+
@test exiting_workers == [pid]
1955+
@test exited_workers == [pid]
1956+
1957+
# Remove the callbacks
1958+
DistributedNext.remove_worker_added_callback(added_key)
1959+
DistributedNext.remove_worker_exiting_callback(exiting_key)
1960+
DistributedNext.remove_worker_exited_callback(exited_key)
1961+
1962+
# Test that the `callback_timeout` option works
1963+
event = Base.Event()
1964+
callback_task = nothing
1965+
exiting_key = DistributedNext.add_worker_exiting_callback(_ -> (callback_task = current_task(); wait(event)))
1966+
addprocs(1)
1967+
1968+
@test_logs (:warn, r"Some callbacks timed out.+") rmprocs(workers(); callback_timeout=0.5)
1969+
1970+
notify(event)
1971+
wait(callback_task)
1972+
1973+
# Test that the previous callbacks were indeed removed
1974+
@test length(added_workers) == 1
1975+
@test length(exiting_workers) == 1
1976+
@test length(exited_workers) == 1
1977+
end
1978+
19371979
# Run topology tests last after removing all workers, since a given
19381980
# cluster at any time only supports a single topology.
19391981
if nprocs() > 1

0 commit comments

Comments
 (0)