@@ -461,12 +461,14 @@ function addprocs(manager::ClusterManager; kwargs...)
461
461
462
462
cluster_mgmt_from_master_check ()
463
463
464
- lock ( worker_lock)
465
- try
466
- addprocs_locked (manager :: ClusterManager ; kwargs ... )
467
- finally
468
- unlock (worker_lock)
464
+ new_workers = @ lock worker_lock addprocs_locked (manager :: ClusterManager ; kwargs ... )
465
+ for worker in new_workers
466
+ for callback in values (worker_added_callbacks )
467
+ callback (worker)
468
+ end
469
469
end
470
+
471
+ return new_workers
470
472
end
471
473
472
474
function addprocs_locked (manager:: ClusterManager ; kwargs... )
@@ -855,13 +857,96 @@ const HDR_COOKIE_LEN=16
855
857
const map_pid_wrkr = Dict {Int, Union{Worker, LocalProcess}} ()
856
858
const map_sock_wrkr = IdDict ()
857
859
const map_del_wrkr = Set {Int} ()
860
+ const worker_added_callbacks = Dict {Any, Base.Callable} ()
861
+ const worker_exiting_callbacks = Dict {Any, Base.Callable} ()
862
+ const worker_exited_callbacks = Dict {Any, Base.Callable} ()
858
863
859
864
# whether process is a master or worker in a distributed setup
860
865
myrole () = LPROCROLE[]
861
866
function myrole! (proctype:: Symbol )
862
867
LPROCROLE[] = proctype
863
868
end
864
869
870
+ # Callbacks
871
+
872
+ # We define the callback methods in a loop here and add docstrings for them afterwards
873
+ for callback_type in (:added , :exiting , :exited )
874
+ let add_name = Symbol (:add_worker_ , callback_type, :_callback ),
875
+ remove_name = Symbol (:remove_worker_ , callback_type, :_callback ),
876
+ dict_name = Symbol (:worker_ , callback_type, :_callbacks )
877
+
878
+ @eval begin
879
+ function $add_name (f:: Base.Callable ; key= nothing )
880
+ if ! hasmethod (f, Tuple{Int})
881
+ throw (ArgumentError (" Callback function is invalid, it must be able to accept a single Int argument" ))
882
+ end
883
+
884
+ if isnothing (key)
885
+ key = Symbol (gensym (), nameof (f))
886
+ end
887
+
888
+ $ dict_name[key] = f
889
+ return key
890
+ end
891
+
892
+ $ remove_name (key) = delete! ($ dict_name, key)
893
+ end
894
+ end
895
+ end
896
+
897
+ """
898
+ add_worker_added_callback(f::Base.Callable; key=nothing)
899
+
900
+ Register a callback to be called on the master process whenever a worker is
901
+ added. The callback will be called with the added worker ID,
902
+ e.g. `f(w::Int)`. Returns a unique key for the callback.
903
+ """
904
+ function add_worker_added_callback end
905
+
906
+ """
907
+ remove_worker_added_callback(key)
908
+
909
+ Remove the callback for `key`.
910
+ """
911
+ function remove_worker_added_callback end
912
+
913
+ """
914
+ add_worker_exiting_callback(f::Base.Callable; key=nothing)
915
+
916
+ Register a callback to be called on the master process immediately before a
917
+ worker is removed with [`rmprocs()`](@ref). The callback will be called with the
918
+ worker ID, e.g. `f(w::Int)`. Returns a unique key for the callback.
919
+
920
+ All callbacks will be executed asynchronously and if they don't all finish
921
+ before the `callback_timeout` passed to `rmprocs()` then the process will be
922
+ removed anyway.
923
+ """
924
+ function add_worker_exiting_callback end
925
+
926
+ """
927
+ remove_worker_exiting_callback(key)
928
+
929
+ Remove the callback for `key`.
930
+ """
931
+ function remove_worker_exiting_callback end
932
+
933
+ """
934
+ add_worker_exited_callback(f::Base.Callable; key=nothing)
935
+
936
+ Register a callback to be called on the master process when a worker has exited
937
+ for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
938
+ segfaulting etc). The callback will be called with the worker ID,
939
+ e.g. `f(w::Int)`. Returns a unique key for the callback.
940
+ """
941
+ function add_worker_exited_callback end
942
+
943
+ """
944
+ remove_worker_exited_callback(key)
945
+
946
+ Remove the callback for `key`.
947
+ """
948
+ function remove_worker_exited_callback end
949
+
865
950
# cluster management related API
866
951
"""
867
952
myid()
@@ -1048,7 +1133,7 @@ function cluster_mgmt_from_master_check()
1048
1133
end
1049
1134
1050
1135
"""
1051
- rmprocs(pids...; waitfor=typemax(Int))
1136
+ rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10 )
1052
1137
1053
1138
Remove the specified workers. Note that only process 1 can add or remove
1054
1139
workers.
@@ -1062,6 +1147,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
1062
1147
returned. The user should call [`wait`](@ref) on the task before invoking any other
1063
1148
parallel calls.
1064
1149
1150
+ The `callback_timeout` specifies how long to wait for any callbacks to execute
1151
+ before continuing to remove the workers (see
1152
+ [`add_worker_exiting_callback()`](@ref)).
1153
+
1065
1154
# Examples
1066
1155
```julia-repl
1067
1156
\$ julia -p 5
@@ -1078,24 +1167,36 @@ julia> workers()
1078
1167
6
1079
1168
```
1080
1169
"""
1081
- function rmprocs (pids... ; waitfor= typemax (Int))
1170
+ function rmprocs (pids... ; waitfor= typemax (Int), callback_timeout = 10 )
1082
1171
cluster_mgmt_from_master_check ()
1083
1172
1084
1173
pids = vcat (pids... )
1085
1174
if waitfor == 0
1086
- t = @async _rmprocs (pids, typemax (Int))
1175
+ t = @async _rmprocs (pids, typemax (Int), callback_timeout )
1087
1176
yield ()
1088
1177
return t
1089
1178
else
1090
- _rmprocs (pids, waitfor)
1179
+ _rmprocs (pids, waitfor, callback_timeout )
1091
1180
# return a dummy task object that user code can wait on.
1092
1181
return @async nothing
1093
1182
end
1094
1183
end
1095
1184
1096
- function _rmprocs (pids, waitfor)
1185
+ function _rmprocs (pids, waitfor, callback_timeout )
1097
1186
lock (worker_lock)
1098
1187
try
1188
+ # Run the callbacks
1189
+ callback_tasks = Task[]
1190
+ for pid in pids
1191
+ for callback in values (worker_exiting_callbacks)
1192
+ push! (callback_tasks, Threads. @spawn callback (pid))
1193
+ end
1194
+ end
1195
+
1196
+ if timedwait (() -> all (istaskdone .(callback_tasks)), callback_timeout) === :timed_out
1197
+ @warn " Some callbacks timed out, continuing to remove workers anyway"
1198
+ end
1199
+
1099
1200
rmprocset = Union{LocalProcess, Worker}[]
1100
1201
for p in pids
1101
1202
if p == 1
@@ -1241,6 +1342,14 @@ function deregister_worker(pg, pid)
1241
1342
delete! (pg. refs, id)
1242
1343
end
1243
1344
end
1345
+
1346
+ # Call callbacks on the master
1347
+ if myid () == 1
1348
+ for callback in values (worker_exited_callbacks)
1349
+ callback (pid)
1350
+ end
1351
+ end
1352
+
1244
1353
return
1245
1354
end
1246
1355
0 commit comments