Skip to content

Commit

Permalink
[TORCH] FED DEBUG
Browse files Browse the repository at this point in the history
  • Loading branch information
GuyPerets106 committed Aug 2, 2024
1 parent 5918950 commit 9c33742
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceNam
case length(EndStreamWaitingList) of
0 -> ok;
_ ->
io:format("Removing streams from waiting list ~p~n",[EndStreamWaitingList]),
Func = fun(StreamName) ->
stream_handler(end_stream, train, StreamName, DistributedBehaviorFunc),
CurrentEndStreamWaitingList = ets:lookup_element(get(generic_worker_ets), end_streams_waiting_list, ?ETS_KEYVAL_VAL_IDX),
Expand All @@ -220,6 +221,7 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S
case length(EndStreamWaitingList) of
0 -> ok;
_ ->
io:format("Removing streams from waiting list ~p~n",[EndStreamWaitingList]),
Func = fun(StreamName) ->
stream_handler(end_stream, train, StreamName, DistributedBehaviorFunc),
CurrentEndStreamWaitingList = ets:lookup_element(get(generic_worker_ets), end_streams_waiting_list, ?ETS_KEYVAL_VAL_IDX),
Expand All @@ -230,12 +232,12 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S
end,
{next_state, NextState, State};

wait(cast, {end_stream , StreamName}, State = #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = _DistributedBehaviorFunc}) ->
wait(cast, {end_stream , StreamName}, State = #workerGeneric_state{myName = MyName, distributedBehaviorFunc = _DistributedBehaviorFunc}) ->
%logger:notice("Waiting, next state - idle"),
CurrentEndStreamWaitingList = ets:lookup_element(get(generic_worker_ets), end_streams_waiting_list, ?ETS_KEYVAL_VAL_IDX),
NewEndStreamWaitingList = CurrentEndStreamWaitingList ++ [StreamName],
ets:update_element(get(generic_worker_ets), end_streams_waiting_list, {?ETS_KEYVAL_VAL_IDX, NewEndStreamWaitingList}),
% io:format("@wait ~p got end stream from ~p~n",[MyName, StreamName]),
io:format("@wait ~p got end stream from ~p, added to waiting list...~n",[MyName, StreamName]),
{next_state, wait, State};

wait(cast, {post_train_update, Data}, State = #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = DistributedBehaviorFunc, postBatchFunc = PostBatchFunc}) ->
Expand Down

0 comments on commit 9c33742

Please sign in to comment.