@@ -461,12 +461,14 @@ function addprocs(manager::ClusterManager; kwargs...)
461
461
462
462
cluster_mgmt_from_master_check ()
463
463
464
- lock ( worker_lock)
465
- try
466
- addprocs_locked (manager :: ClusterManager ; kwargs ... )
467
- finally
468
- unlock (worker_lock)
464
+ new_workers = @ lock worker_lock addprocs_locked (manager :: ClusterManager ; kwargs ... )
465
+ for worker in new_workers
466
+ for callback in values (worker_added_callbacks )
467
+ callback (worker)
468
+ end
469
469
end
470
+
471
+ return new_workers
470
472
end
471
473
472
474
function addprocs_locked (manager:: ClusterManager ; kwargs... )
@@ -855,13 +857,96 @@ const HDR_COOKIE_LEN=16
855
857
const map_pid_wrkr = Dict {Int, Union{Worker, LocalProcess}} ()
856
858
const map_sock_wrkr = IdDict ()
857
859
const map_del_wrkr = Set {Int} ()
860
+ const worker_added_callbacks = Dict {Any, Base.Callable} ()
861
+ const worker_exiting_callbacks = Dict {Any, Base.Callable} ()
862
+ const worker_exited_callbacks = Dict {Any, Base.Callable} ()
858
863
859
864
# whether process is a master or worker in a distributed setup
860
865
myrole () = LPROCROLE[]
861
866
function myrole! (proctype:: Symbol )
862
867
LPROCROLE[] = proctype
863
868
end
864
869
870
+ # Callbacks
871
+
872
+ # We define the callback methods in a loop here and add docstrings for them afterwards
873
+ for callback_type in (:added , :exiting , :exited )
874
+ let add_name = Symbol (:add_worker_ , callback_type, :_callback ),
875
+ remove_name = Symbol (:remove_worker_ , callback_type, :_callback ),
876
+ dict_name = Symbol (:worker_ , callback_type, :_callbacks )
877
+
878
+ @eval begin
879
+ function $add_name (f:: Base.Callable ; key= nothing )
880
+ if ! hasmethod (f, Tuple{Int})
881
+ throw (ArgumentError (" Callback function is invalid, it must be able to accept a single Int argument" ))
882
+ end
883
+
884
+ if isnothing (key)
885
+ key = Symbol (gensym (), nameof (f))
886
+ end
887
+
888
+ $ dict_name[key] = f
889
+ return key
890
+ end
891
+
892
+ $ remove_name (key) = delete! ($ dict_name, key)
893
+ end
894
+ end
895
+ end
896
+
897
+ """
898
+ add_worker_added_callback(f::Base.Callable; key=nothing)
899
+
900
+ Register a callback to be called on the master process whenever a worker is
901
+ added. The callback will be called with the added worker ID,
902
+ e.g. `f(w::Int)`. Returns a unique key for the callback.
903
+ """
904
+ function add_worker_added_callback end
905
+
906
+ """
907
+ remove_worker_added_callback(key)
908
+
909
+ Remove the callback for `key`.
910
+ """
911
+ function remove_worker_added_callback end
912
+
913
+ """
914
+ add_worker_exiting_callback(f::Base.Callable; key=nothing)
915
+
916
+ Register a callback to be called on the master process immediately before a
917
+ worker is removed with [`rmprocs()`](@ref). The callback will be called with the
918
+ worker ID, e.g. `f(w::Int)`. Returns a unique key for the callback.
919
+
920
+ All callbacks will be executed asynchronously and if they don't all finish
921
+ before the `callback_timeout` passed to `rmprocs()` then the process will be
922
+ removed anyway.
923
+ """
924
+ function add_worker_exiting_callback end
925
+
926
+ """
927
+ remove_worker_exiting_callback(key)
928
+
929
+ Remove the callback for `key`.
930
+ """
931
+ function remove_worker_exiting_callback end
932
+
933
+ """
934
+ add_worker_exited_callback(f::Base.Callable; key=nothing)
935
+
936
+ Register a callback to be called on the master process when a worker has exited
937
+ for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
938
+ segfaulting etc). The callback will be called with the worker ID,
939
+ e.g. `f(w::Int)`. Returns a unique key for the callback.
940
+ """
941
+ function add_worker_exited_callback end
942
+
943
+ """
944
+ remove_worker_exited_callback(key)
945
+
946
+ Remove the callback for `key`.
947
+ """
948
+ function remove_worker_exited_callback end
949
+
865
950
# cluster management related API
866
951
"""
867
952
myid()
@@ -1025,7 +1110,7 @@ function cluster_mgmt_from_master_check()
1025
1110
end
1026
1111
1027
1112
"""
1028
- rmprocs(pids...; waitfor=typemax(Int))
1113
+ rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10 )
1029
1114
1030
1115
Remove the specified workers. Note that only process 1 can add or remove
1031
1116
workers.
@@ -1039,6 +1124,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
1039
1124
returned. The user should call [`wait`](@ref) on the task before invoking any other
1040
1125
parallel calls.
1041
1126
1127
+ The `callback_timeout` specifies how long to wait for any callbacks to execute
1128
+ before continuing to remove the workers (see
1129
+ [`add_worker_exiting_callback()`](@ref)).
1130
+
1042
1131
# Examples
1043
1132
```julia-repl
1044
1133
\$ julia -p 5
@@ -1055,24 +1144,36 @@ julia> workers()
1055
1144
6
1056
1145
```
1057
1146
"""
1058
- function rmprocs (pids... ; waitfor= typemax (Int))
1147
+ function rmprocs (pids... ; waitfor= typemax (Int), callback_timeout = 10 )
1059
1148
cluster_mgmt_from_master_check ()
1060
1149
1061
1150
pids = vcat (pids... )
1062
1151
if waitfor == 0
1063
- t = @async _rmprocs (pids, typemax (Int))
1152
+ t = @async _rmprocs (pids, typemax (Int), callback_timeout )
1064
1153
yield ()
1065
1154
return t
1066
1155
else
1067
- _rmprocs (pids, waitfor)
1156
+ _rmprocs (pids, waitfor, callback_timeout )
1068
1157
# return a dummy task object that user code can wait on.
1069
1158
return @async nothing
1070
1159
end
1071
1160
end
1072
1161
1073
- function _rmprocs (pids, waitfor)
1162
+ function _rmprocs (pids, waitfor, callback_timeout )
1074
1163
lock (worker_lock)
1075
1164
try
1165
+ # Run the callbacks
1166
+ callback_tasks = Task[]
1167
+ for pid in pids
1168
+ for callback in values (worker_exiting_callbacks)
1169
+ push! (callback_tasks, Threads. @spawn callback (pid))
1170
+ end
1171
+ end
1172
+
1173
+ if timedwait (() -> all (istaskdone .(callback_tasks)), callback_timeout) === :timed_out
1174
+ @warn " Some callbacks timed out, continuing to remove workers anyway"
1175
+ end
1176
+
1076
1177
rmprocset = Union{LocalProcess, Worker}[]
1077
1178
for p in pids
1078
1179
if p == 1
@@ -1218,6 +1319,14 @@ function deregister_worker(pg, pid)
1218
1319
delete! (pg. refs, id)
1219
1320
end
1220
1321
end
1322
+
1323
+ # Call callbacks on the master
1324
+ if myid () == 1
1325
+ for callback in values (worker_exited_callbacks)
1326
+ callback (pid)
1327
+ end
1328
+ end
1329
+
1221
1330
return
1222
1331
end
1223
1332
0 commit comments