Skip to content

Commit d5d1c90

Browse files
committed
Add support for worker state callbacks
1 parent 802d4f8 commit d5d1c90

File tree

4 files changed

+173
-10
lines changed

4 files changed

+173
-10
lines changed

docs/src/_changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ This documents notable changes in DistributedNext.jl. The format is based on
1616
- A watcher mechanism has been added to detect when both the Distributed stdlib
1717
and DistributedNext may be active and adding workers. This should help prevent
1818
incompatibilities from both libraries being used simultaneously ([#10]).
19+
- Implemented callback support for workers being added/removed etc ([#17]).
1920

2021
## [v1.0.0] - 2024-12-02
2122

docs/src/index.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,17 @@ DistributedNext.cluster_cookie()
5050
DistributedNext.cluster_cookie(::Any)
5151
```
5252

53+
## Callbacks
54+
55+
```@docs
56+
DistributedNext.add_worker_added_callback
57+
DistributedNext.remove_worker_added_callback
58+
DistributedNext.add_worker_exiting_callback
59+
DistributedNext.remove_worker_exiting_callback
60+
DistributedNext.add_worker_exited_callback
61+
DistributedNext.remove_worker_exited_callback
62+
```
63+
5364
## Cluster Manager Interface
5465

5566
This interface provides a mechanism to launch and manage Julia workers on different cluster environments.

src/cluster.jl

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,14 @@ function addprocs(manager::ClusterManager; kwargs...)
461461

462462
cluster_mgmt_from_master_check()
463463

464-
lock(worker_lock)
465-
try
466-
addprocs_locked(manager::ClusterManager; kwargs...)
467-
finally
468-
unlock(worker_lock)
464+
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager; kwargs...)
465+
for worker in new_workers
466+
for callback in values(worker_added_callbacks)
467+
callback(worker)
468+
end
469469
end
470+
471+
return new_workers
470472
end
471473

472474
function addprocs_locked(manager::ClusterManager; kwargs...)
@@ -855,13 +857,96 @@ const HDR_COOKIE_LEN=16
855857
const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
856858
const map_sock_wrkr = IdDict()
857859
const map_del_wrkr = Set{Int}()
860+
const worker_added_callbacks = Dict{Any, Base.Callable}()
861+
const worker_exiting_callbacks = Dict{Any, Base.Callable}()
862+
const worker_exited_callbacks = Dict{Any, Base.Callable}()
858863

859864
# whether process is a master or worker in a distributed setup
860865
myrole() = LPROCROLE[]
861866
function myrole!(proctype::Symbol)
862867
LPROCROLE[] = proctype
863868
end
864869

870+
# Callbacks
871+
872+
# We define the callback methods in a loop here and add docstrings for them afterwards
873+
for callback_type in (:added, :exiting, :exited)
874+
let add_name = Symbol(:add_worker_, callback_type, :_callback),
875+
remove_name = Symbol(:remove_worker_, callback_type, :_callback),
876+
dict_name = Symbol(:worker_, callback_type, :_callbacks)
877+
878+
@eval begin
879+
function $add_name(f::Base.Callable; key=nothing)
880+
if !hasmethod(f, Tuple{Int})
881+
throw(ArgumentError("Callback function is invalid, it must be able to accept a single Int argument"))
882+
end
883+
884+
if isnothing(key)
885+
key = Symbol(gensym(), nameof(f))
886+
end
887+
888+
$dict_name[key] = f
889+
return key
890+
end
891+
892+
$remove_name(key) = delete!($dict_name, key)
893+
end
894+
end
895+
end
896+
897+
"""
898+
add_worker_added_callback(f::Base.Callable; key=nothing)
899+
900+
Register a callback to be called on the master process whenever a worker is
901+
added. The callback will be called with the added worker ID,
902+
e.g. `f(w::Int)`. Returns a unique key for the callback.
903+
"""
904+
function add_worker_added_callback end
905+
906+
"""
907+
remove_worker_added_callback(key)
908+
909+
Remove the callback for `key`.
910+
"""
911+
function remove_worker_added_callback end
912+
913+
"""
914+
add_worker_exiting_callback(f::Base.Callable; key=nothing)
915+
916+
Register a callback to be called on the master process immediately before a
917+
worker is removed with [`rmprocs()`](@ref). The callback will be called with the
918+
worker ID, e.g. `f(w::Int)`. Returns a unique key for the callback.
919+
920+
All callbacks will be executed asynchronously and if they don't all finish
921+
before the `callback_timeout` passed to `rmprocs()` then the process will be
922+
removed anyway.
923+
"""
924+
function add_worker_exiting_callback end
925+
926+
"""
927+
remove_worker_exiting_callback(key)
928+
929+
Remove the callback for `key`.
930+
"""
931+
function remove_worker_exiting_callback end
932+
933+
"""
934+
add_worker_exited_callback(f::Base.Callable; key=nothing)
935+
936+
Register a callback to be called on the master process when a worker has exited
937+
for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
938+
segfaulting etc). The callback will be called with the worker ID,
939+
e.g. `f(w::Int)`. Returns a unique key for the callback.
940+
"""
941+
function add_worker_exited_callback end
942+
943+
"""
944+
remove_worker_exited_callback(key)
945+
946+
Remove the callback for `key`.
947+
"""
948+
function remove_worker_exited_callback end
949+
865950
# cluster management related API
866951
"""
867952
myid()
@@ -1025,7 +1110,7 @@ function cluster_mgmt_from_master_check()
10251110
end
10261111

10271112
"""
1028-
rmprocs(pids...; waitfor=typemax(Int))
1113+
rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10291114
10301115
Remove the specified workers. Note that only process 1 can add or remove
10311116
workers.
@@ -1039,6 +1124,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10391124
returned. The user should call [`wait`](@ref) on the task before invoking any other
10401125
parallel calls.
10411126
1127+
The `callback_timeout` specifies how long to wait for any callbacks to execute
1128+
before continuing to remove the workers (see
1129+
[`add_worker_exiting_callback()`](@ref)).
1130+
10421131
# Examples
10431132
```julia-repl
10441133
\$ julia -p 5
@@ -1055,24 +1144,36 @@ julia> workers()
10551144
6
10561145
```
10571146
"""
1058-
function rmprocs(pids...; waitfor=typemax(Int))
1147+
function rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10591148
cluster_mgmt_from_master_check()
10601149

10611150
pids = vcat(pids...)
10621151
if waitfor == 0
1063-
t = @async _rmprocs(pids, typemax(Int))
1152+
t = @async _rmprocs(pids, typemax(Int), callback_timeout)
10641153
yield()
10651154
return t
10661155
else
1067-
_rmprocs(pids, waitfor)
1156+
_rmprocs(pids, waitfor, callback_timeout)
10681157
# return a dummy task object that user code can wait on.
10691158
return @async nothing
10701159
end
10711160
end
10721161

1073-
function _rmprocs(pids, waitfor)
1162+
function _rmprocs(pids, waitfor, callback_timeout)
10741163
lock(worker_lock)
10751164
try
1165+
# Run the callbacks
1166+
callback_tasks = Task[]
1167+
for pid in pids
1168+
for callback in values(worker_exiting_callbacks)
1169+
push!(callback_tasks, Threads.@spawn callback(pid))
1170+
end
1171+
end
1172+
1173+
if timedwait(() -> all(istaskdone.(callback_tasks)), callback_timeout) === :timed_out
1174+
@warn "Some callbacks timed out, continuing to remove workers anyway"
1175+
end
1176+
10761177
rmprocset = Union{LocalProcess, Worker}[]
10771178
for p in pids
10781179
if p == 1
@@ -1218,6 +1319,14 @@ function deregister_worker(pg, pid)
12181319
delete!(pg.refs, id)
12191320
end
12201321
end
1322+
1323+
# Call callbacks on the master
1324+
if myid() == 1
1325+
for callback in values(worker_exited_callbacks)
1326+
callback(pid)
1327+
end
1328+
end
1329+
12211330
return
12221331
end
12231332

test/distributed_exec.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
using DistributedNext, Random, Serialization, Sockets
4+
import DistributedNext
45
import DistributedNext: launch, manage
56

67

@@ -1927,6 +1928,47 @@ include("splitrange.jl")
19271928
end
19281929
end
19291930

1931+
@testset "Worker state callbacks" begin
1932+
if nprocs() > 1
1933+
rmprocs(workers())
1934+
end
1935+
1936+
# Smoke test to ensure that all the callbacks are executed
1937+
added_workers = Int[]
1938+
exiting_workers = Int[]
1939+
exited_workers = Int[]
1940+
added_key = DistributedNext.add_worker_added_callback(pid -> push!(added_workers, pid))
1941+
exiting_key = DistributedNext.add_worker_exiting_callback(pid -> push!(exiting_workers, pid))
1942+
exited_key = DistributedNext.add_worker_exited_callback(pid -> push!(exited_workers, pid))
1943+
1944+
pid = only(addprocs(1))
1945+
@test added_workers == [pid]
1946+
rmprocs(workers())
1947+
@test exiting_workers == [pid]
1948+
@test exited_workers == [pid]
1949+
1950+
# Remove the callbacks
1951+
DistributedNext.remove_worker_added_callback(added_key)
1952+
DistributedNext.remove_worker_exiting_callback(exiting_key)
1953+
DistributedNext.remove_worker_exited_callback(exited_key)
1954+
1955+
# Test that the `callback_timeout` option works
1956+
event = Base.Event()
1957+
callback_task = nothing
1958+
exiting_key = DistributedNext.add_worker_exiting_callback(_ -> (callback_task = current_task(); wait(event)))
1959+
addprocs(1)
1960+
1961+
@test_logs (:warn, r"Some callbacks timed out.+") rmprocs(workers(); callback_timeout=0.5)
1962+
1963+
notify(event)
1964+
wait(callback_task)
1965+
1966+
# Test that the previous callbacks were indeed removed
1967+
@test length(added_workers) == 1
1968+
@test length(exiting_workers) == 1
1969+
@test length(exited_workers) == 1
1970+
end
1971+
19301972
# Run topology tests last after removing all workers, since a given
19311973
# cluster at any time only supports a single topology.
19321974
if nprocs() > 1

0 commit comments

Comments
 (0)