From 7b58660fd5defe87aa198dd5e8c7e3e7d57e4059 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 7 May 2024 10:57:00 +0000 Subject: [PATCH 01/52] [W2WComm] Updates --- .../src/Bridge/onnWorkers/w2wCom.erl | 14 +++++++----- .../onnWorkers/workerFederatedClient.erl | 22 ++++++++++--------- .../src/Bridge/onnWorkers/workerGeneric.erl | 4 ++-- .../src/Client/clientStateHandler.erl | 9 ++------ .../NerlnetApp/src/Client/clientStatem.erl | 15 +++++++------ 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index cdfd020e..d65152b9 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -14,21 +14,23 @@ start_link(Args = {WorkerName, _ClientStatemPid}) -> {ok,Gen_Server_Pid} = gen_server:start_link({local, WorkerName}, ?MODULE, Args, []), Gen_Server_Pid. -init({WorkerName, ClientStatemPid}) -> +init({WorkerName, MyClientPid}) -> InboxQueue = queue:new(), W2wEts = ets:new(w2w_ets, [set]), put(worker_name, WorkerName), - put(client_statem_pid, ClientStatemPid), + put(client_statem_pid, MyClientPid), + % TODO Send init message to client with the {WorkerName , W2WCOMM_PID} put(w2w_ets, W2wEts), ets:insert(W2wEts, {inbox_queue, InboxQueue}), {ok, []}. -% Messages are of the form: {FromWorkerName, Data} +% Received messages are of the form: {worker_to_worker_msg, FromWorkerName, ThisWorkerName, Data} handle_cast({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, Data}, State) -> case get(worker_name) of ThisWorkerName -> ok; _ -> throw({error, "The provided worker name is not this worker"}) end, + % Saved messages are of the form: {FromWorkerName, , Data} Message = {FromWorkerName, Data}, add_msg_to_inbox_queue(Message), io:format("Worker ~p received message from ~p: ~p~n", [ThisWorkerName, FromWorkerName, Data]), %TODO remove @@ -62,9 +64,9 @@ add_msg_to_inbox_queue(Message) -> InboxQueueUpdated = queue:in(Message, InboxQueue), ets:insert(W2WEts, {inbox_queue, InboxQueueUpdated}). -send_message(FromWorkerName, ToWorkerName, Data) -> - Msg = {?W2WCOM_ATOM, FromWorkerName, ToWorkerName, Data}, - MyClient = client_name, % TODO +send_message(FromWorker, TargetWorker, Data) -> + Msg = {?W2WCOM_ATOM, FromWorker, TargetWorker, Data}, + MyClient = get(client_statem_pid), gen_server:cast(MyClient, Msg). diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 90a80bc8..b2dc5e6f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -63,16 +63,17 @@ init({GenWorkerEts, WorkerData}) -> ets:insert(FedratedClientEts, {sync_max_count, SyncMaxCount}), ets:insert(FedratedClientEts, {sync_count, SyncMaxCount}), ets:insert(FedratedClientEts, {server_update, false}), - io:format("finished init in ~p~n",[MyName]). + io:format("finished init in ~p~n",[MyName]). %% TODO REMOVE pre_idle({GenWorkerEts, _WorkerData}) -> ThisEts = get_this_client_ets(GenWorkerEts), %% send to server that this worker is part of the federated workers - ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + _ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), % No longer needed? MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - gen_statem:cast(ClientPID,{custom_worker_message,{MyName, ServerName}}), - io:format("sent ~p init message: ~p~n",[ServerName, {MyName, ServerName}]). + % gen_statem:cast(ClientPID,{custom_worker_message,{MyName, ServerName}}), + w2wCom:send_message(MyName, ServerName, {MyName, ServerName}), %% ****** NEW - TEST NEEDED ****** + io:format("@pre_idle: Worker ~p updates federated server ~p~n",[MyName , ServerName]). post_idle({_GenWorkerEts, _WorkerData}) -> ok. @@ -93,18 +94,19 @@ post_train({GenWorkerEts, _WorkerData}) -> if SyncCount == 0 -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), Weights = nerlNIF:call_to_get_weights(GenWorkerEts, ModelID), - ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + _ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), % No longer needed? + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), % io:format("Worker ~p entering update and got weights ~p~n",[MyName, Weights]), ets:update_counter(ThisEts, sync_count, MaxSyncCount), - % io:format("Worker ~p entering update~n",[MyName]), - gen_statem:cast(ClientPID, {update, {MyName, ServerName, Weights}}), - _ToUpdate = true; + io:format("@post_train: Worker ~p updates federated server ~p~n",[MyName , ServerName]), + w2wCom:send_message(MyName, ServerName , Weights), %% ****** NEW - TEST NEEDED ****** + % gen_statem:cast(ClientPID, {update, {MyName, ServerName, Weights}}), + _ToUpdate = true; % ? true -> ets:update_counter(ThisEts, sync_count, -1), - _ToUpdate = false + _ToUpdate = false % ? end. %% nothing? diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 89d3f19c..e56cf7a9 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -70,7 +70,7 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData ets:insert(GenWorkerEts,{distributed_system_args, DistributedSystemArgs}), ets:insert(GenWorkerEts,{distributed_system_type, DistributedSystemType}), ets:insert(GenWorkerEts,{controller_message_q, []}), %% empty Queue TODO Deprecated - % Worker to Worker communication gen_server + % Worker to Worker communication module - this is a gen_server W2wComPid = w2wCom:start_link({WorkerName, ClientPid}), put(w2wcom_pid, W2wComPid), @@ -174,7 +174,7 @@ wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneri wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), - ToUpdate = DistributedBehaviorFunc(post_train, {get(generic_worker_ets),DistributedWorkerData}), + ToUpdate = DistributedBehaviorFunc(post_train, {get(generic_worker_ets),DistributedWorkerData}), %% Change to W2WComm if ToUpdate -> {next_state, update, State#workerGeneric_state{nextState=NextState}}; true -> {next_state, NextState, State} end; diff --git a/src_erl/NerlnetApp/src/Client/clientStateHandler.erl b/src_erl/NerlnetApp/src/Client/clientStateHandler.erl index e1765a4d..e9c6d09d 100644 --- a/src_erl/NerlnetApp/src/Client/clientStateHandler.erl +++ b/src_erl/NerlnetApp/src/Client/clientStateHandler.erl @@ -19,13 +19,8 @@ init(Req0, [Action,Client_StateM_Pid]) -> {ok,Body,_} = cowboy_req:read_body(Req0), %% io:format("client state_handler got body:~p~n",[Body]), case Action of - custom_worker_message -> - case binary_to_term(Body) of - {To, custom_worker_message, Data} -> %% handshake - gen_statem:cast(Client_StateM_Pid,{custom_worker_message,Data}); - {From, update, Data} -> %% updating weights - gen_statem:cast(Client_StateM_Pid,{update,Data}) - end; + worker_to_worker_msg -> {worker_to_worker_msg , From , To , Data} = binary_to_term(Body), + gen_statem:cast(Client_StateM_Pid,{worker_to_worker_msg , From , To , Data}); batch -> gen_statem:cast(Client_StateM_Pid,{sample,Body}); idle -> gen_statem:cast(Client_StateM_Pid,{idle}); training -> gen_statem:cast(Client_StateM_Pid,{training}); diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 3b3b0334..e881ad22 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -143,21 +143,22 @@ waitforWorkers(cast, EventContent, State = #client_statem_state{myName = MyName} %% initiating workers when they include federated workers. init stage == handshake between federated worker client and server %% TODO: make custom_worker_message in all states to send messages from workers to entities (not just client) -idle(cast, In = {custom_worker_message, {From, To}}, State = #client_statem_state{etsRef = EtsRef}) -> +idle(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerOfThisClient = ets:member(EtsRef, To), + WorkerOfThisClient = ets:member(EtsRef, ToWorker), if WorkerOfThisClient -> - TargetWorkerPID = ets:lookup_element(EtsRef, To, ?WORKER_PID_IDX), - gen_statem:cast(TargetWorkerPID,{post_idle,From}), + % Extract W2WPID from Ets + TargetWorkerW2WPID = ets:lookup_element(EtsRef, ToWorker, ?WORKER_PID_IDX), + gen_statem:cast(TargetWorkerW2WPID,{worker_to_worker_msg, FromWorker, ToWorker, Data}), stats:increment_messages_sent(ClientStatsEts); true -> %% send to FedServer that worker From is connecting to it - DestClient = maps:get(To, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), - MessageBody = {DestClient, custom_worker_message, {From, To}}, + DestClient = maps:get(ToWorker, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), + MessageBody = {worker_to_worker_msg, FromWorker, ToWorker, Data}, {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), - nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(custom_worker_message), term_to_binary(MessageBody)), + nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(worker_to_worker_msg), term_to_binary(MessageBody)), stats:increment_messages_sent(ClientStatsEts), stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) end, From 1829d85d1348b15843d69ff556ff57fd2a55ed23 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Sat, 11 May 2024 16:23:10 +0000 Subject: [PATCH 02/52] [W2W_Com] Added handle_w2w_msg function in clientStatem at all states --- .../src/Bridge/onnWorkers/w2wCom.erl | 2 +- .../src/Bridge/onnWorkers/workerGeneric.erl | 5 +- .../NerlnetApp/src/Client/clientStatem.erl | 75 ++++++++++++------- .../src/Client/clientWorkersFunctions.erl | 4 +- 4 files changed, 53 insertions(+), 33 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index d65152b9..792264b3 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -62,7 +62,7 @@ add_msg_to_inbox_queue(Message) -> W2WEts = get(w2w_ets), {_, InboxQueue} = ets:lookup(W2WEts, inbox_queue), InboxQueueUpdated = queue:in(Message, InboxQueue), - ets:insert(W2WEts, {inbox_queue, InboxQueueUpdated}). + ets:update_element(W2WEts, inbox_queue, {inbox_queue, InboxQueueUpdated}). send_message(FromWorker, TargetWorker, Data) -> Msg = {?W2WCOM_ATOM, FromWorker, TargetWorker, Data}, diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index fd52fc34..4a7f2259 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -36,7 +36,8 @@ start_link(ARGS) -> %{ok,Pid} = gen_statem:start_link({local, element(1, ARGS)}, ?MODULE, ARGS, []), %% name this machine by unique name {ok,Pid} = gen_statem:start_link(?MODULE, ARGS, []), - Pid. + W2W_Pid = get(w2wcom_pid), + {Pid , W2W_Pid}. %%%=================================================================== %%% gen_statem callbacks @@ -49,7 +50,7 @@ start_link(ARGS) -> init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , ClientPid , WorkerStatsEts}) -> nerl_tools:setup_logger(?MODULE), {ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctionalityCodes, LearningRate , Epochs, - OptimizerType, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs} = WorkerArgs, + OptimizerType, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs} = WorkerArgs, GenWorkerEts = ets:new(generic_worker,[set]), put(generic_worker_ets, GenWorkerEts), put(client_pid, ClientPid), diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index e881ad22..14cd8b31 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -26,6 +26,7 @@ -define(ETS_KV_VAL_IDX, 2). % key value pairs --> value index is 2 -define(WORKER_PID_IDX, 1). +-define(W2W_PID_IDX, 2). -define(SERVER, ?MODULE). %% client ETS table: {WorkerName, WorkerPid, WorkerArgs, TimingTuple} @@ -147,21 +148,7 @@ idle(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #cli ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerOfThisClient = ets:member(EtsRef, ToWorker), - if WorkerOfThisClient -> - % Extract W2WPID from Ets - TargetWorkerW2WPID = ets:lookup_element(EtsRef, ToWorker, ?WORKER_PID_IDX), - gen_statem:cast(TargetWorkerW2WPID,{worker_to_worker_msg, FromWorker, ToWorker, Data}), - stats:increment_messages_sent(ClientStatsEts); - true -> - %% send to FedServer that worker From is connecting to it - DestClient = maps:get(ToWorker, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), - MessageBody = {worker_to_worker_msg, FromWorker, ToWorker, Data}, - {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), - nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(worker_to_worker_msg), term_to_binary(MessageBody)), - stats:increment_messages_sent(ClientStatsEts), - stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) - end, + handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data), {keep_state, State}; idle(cast, _In = {statistics}, State = #client_statem_state{ myName = MyName, etsRef = EtsRef}) -> @@ -226,20 +213,27 @@ training(cast, MessageIn = {update, {From, To, Data}}, State = #client_statem_st %% This is a generic way to move data from worker to worker %% TODO fix variables names to make it more generic %% federated server sends AvgWeights to workers -training(cast, InMessage = {custom_worker_message, WorkersList, WeightsTensor}, State = #client_statem_state{etsRef = EtsRef}) -> +% training(cast, InMessage = {custom_worker_message, WorkersList, WeightsTensor}, State = #client_statem_state{etsRef = EtsRef}) -> +% ClientStatsEts = get(client_stats_ets), +% stats:increment_messages_received(ClientStatsEts), +% stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(InMessage)), +% Func = fun(WorkerName) -> +% DestClient = maps:get(WorkerName, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), +% MessageBody = term_to_binary({DestClient, update, {_FedServer = "server", WorkerName, WeightsTensor}}), % TODO - fix client should not be aware of the data of custom worker message + +% {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), +% nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(custom_worker_message), MessageBody), +% stats:increment_messages_sent(ClientStatsEts), +% stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) +% end, +% lists:foreach(Func, WorkersList), % can be optimized with broadcast instead of unicast +% {keep_state, State}; + +training(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(InMessage)), - Func = fun(WorkerName) -> - DestClient = maps:get(WorkerName, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), - MessageBody = term_to_binary({DestClient, update, {_FedServer = "server", WorkerName, WeightsTensor}}), % TODO - fix client should not be aware of the data of custom worker message - - {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), - nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(custom_worker_message), MessageBody), - stats:increment_messages_sent(ClientStatsEts), - stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) - end, - lists:foreach(Func, WorkersList), % can be optimized with broadcast instead of unicast + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data), {keep_state, State}; % TODO Validate this state - sample and empty list @@ -336,6 +330,13 @@ predict(cast,_In = {training}, State = #client_statem_state{myName = MyName}) -> ?LOG_ERROR("client ~p got training request in predict state",[MyName]), {next_state, predict, State#client_statem_state{nextState = predict}}; +predict(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data), + {keep_state, State}; + %% The source sends message to main server that it has finished %% The main server updates its' clients to move to state 'idle' predict(cast, In = {idle}, State = #client_statem_state{etsRef = EtsRef , myName = _MyName}) -> @@ -404,4 +405,22 @@ create_encoded_stats_str(ListStatsEts) -> %% |w1&bytes_sent:6.0:float#bad_messages:0:int....| ?API_SERVER_ENTITY_SEPERATOR ++ atom_to_list(WorkerName) ++ ?WORKER_SEPERATOR ++ WorkerEncStatsStr end, - lists:flatten(lists:map(Func , ListStatsEts)). \ No newline at end of file + lists:flatten(lists:map(Func , ListStatsEts)). + +handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data) -> + ClientStatsEts = get(client_stats_ets), + WorkerOfThisClient = ets:member(EtsRef, ToWorker), + if WorkerOfThisClient -> + % Extract W2WPID from Ets + TargetWorkerW2WPID = ets:lookup_element(get(workers_ets), ToWorker, ?W2W_PID_IDX), + gen_statem:cast(TargetWorkerW2WPID,{worker_to_worker_msg, FromWorker, ToWorker, Data}), + stats:increment_messages_sent(ClientStatsEts); + true -> + %% Send to the correct client + DestClient = maps:get(ToWorker, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), + MessageBody = {worker_to_worker_msg, FromWorker, ToWorker, Data}, + {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), + nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(worker_to_worker_msg), term_to_binary(MessageBody)), + stats:increment_messages_sent(ClientStatsEts), + stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) + end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl index ffcbf9cf..a027a1c6 100644 --- a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl +++ b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl @@ -47,8 +47,8 @@ create_workers(ClientName, ClientEtsRef , ShaToModelArgsMap , EtsStats) -> WorkerArgs = {ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctions, LearningRate , Epochs, Optimizer, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs}, - WorkerPid = workerGeneric:start_link({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , MyClientPid , WorkerStatsETS}), - ets:insert(WorkersETS, {WorkerName, {WorkerPid, WorkerArgs}}), + {WorkerPid , W2W_Pid} = workerGeneric:start_link({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , MyClientPid , WorkerStatsETS}), + ets:insert(WorkersETS, {WorkerName, {WorkerPid, W2W_Pid, WorkerArgs}}), ets:insert(EtsStats, {WorkerName, WorkerStatsETS}), WorkerName From b034461618514e8470a3a0d26666d98ac1552091 Mon Sep 17 00:00:00 2001 From: GuyPErets106 Date: Thu, 16 May 2024 19:41:05 +0000 Subject: [PATCH 03/52] [W2W] Fed Exp Jsons --- .../conn_fed_synt_1d_2c_2r_1s_4w_1ws.json | 9 ++ .../dc_fed_synt_1d_2c_2r_1s_4w_1ws.json | 145 ++++++++++++++++++ .../exp_fed_synt_1d_2c_2r_1s_4w_1ws.json | 39 +++++ 3 files changed, 193 insertions(+) create mode 100644 inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json create mode 100644 inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json create mode 100644 inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json diff --git a/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json new file mode 100644 index 00000000..d2c7418c --- /dev/null +++ b/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -0,0 +1,9 @@ +{ + "connectionsMap": + { + "r1":["mainServer", "r2"], + "r2":["r3", "s1"], + "r3":["r4", "c1"], + "r4":["r1", "c2"] + } +} diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json new file mode 100644 index 00000000..1a789aa0 --- /dev/null +++ b/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -0,0 +1,145 @@ +{ + "nerlnetSettings": { + "frequency": "100", + "batchSize": "50" + }, + "mainServer": { + "port": "8900", + "args": "" + }, + "apiServer": { + "port": "8901", + "args": "" + }, + "devices": [ + { + "name": "c0vm0", + "ipv4": "10.0.0.5", + "entities": "mainServer,c1,c2,r1,r2,s1,apiServer" + } + ], + "routers": [ + { + "name": "r1", + "port": "8905", + "policy": "0" + }, + { + "name": "r2", + "port": "8906", + "policy": "0" + } + ], + "sources": [ + { + "name": "s1", + "port": "8904", + "frequency": "100", + "policy": "0", + "epochs": "1", + "type": "0" + } + ], + "clients": [ + { + "name": "c1", + "port": "8902", + "workers": "w1,w2,ws" + }, + { + "name": "c2", + "port": "8903", + "workers": "w3,w4" + } + ], + "workers": [ + { + "name": "w1", + "model_sha": "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74" + }, + { + "name": "w2", + "model_sha": "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74" + }, + { + "name": "ws", + "model_sha": "24cfe345509ff1d121e437fc0baf3fb8feba88dda87db11b7c9c7aaff065c40b" + }, + { + "name": "w3", + "model_sha": "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74" + }, + { + "name": "w4", + "model_sha": "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74" + } + ], + "model_sha": { + "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,10,5,3,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "1", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "none", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + }, + "24cfe345509ff1d121e437fc0baf3fb8feba88dda87db11b7c9c7aaff065c40b": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,10,5,3,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "2", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "none", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + } + } +} \ No newline at end of file diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json new file mode 100644 index 00000000..f6d9433d --- /dev/null +++ b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -0,0 +1,39 @@ +{ + "experimentName": "synthetic_3_gausians", + "batchSize": 100, + "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", + "numOfFeatures": "5", + "numOfLabels": "3", + "headersNames": "Norm(0:1),Norm(4:1),Norm(10:3)", + "Phases": + [ + { + "phaseName": "training_phase", + "phaseType": "training", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "0", + "numOfBatches": "300", + "workers": "w1,w2,w3,w4" + } + ] + }, + { + "phaseName": "prediction_phase", + "phaseType": "prediction", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "30000", + "numOfBatches": "200", + "workers": "w1,w2,w3,w4" + } + ] + } + ] + } + + \ No newline at end of file From 7e9b71690d46443638079ea82e67d85900976acc Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Thu, 16 May 2024 21:17:59 +0000 Subject: [PATCH 04/52] [W2W] WIP --- .../conn_fed_synt_1d_2c_2r_1s_4w_1ws.json | 6 +- .../dc_fed_synt_1d_2c_2r_1s_4w_1ws.json | 4 +- .../exp_fed_synt_1d_2c_2r_1s_4w_1ws.json | 2 +- .../src/Bridge/onnWorkers/w2wCom.erl | 34 +++++++++- .../onnWorkers/workerFederatedClient.erl | 62 ++++++++++++++++--- .../onnWorkers/workerFederatedServer.erl | 50 ++++++++++----- .../src/Client/clientWorkersFunctions.erl | 2 +- 7 files changed, 126 insertions(+), 34 deletions(-) diff --git a/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json index d2c7418c..db39f106 100644 --- a/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json +++ b/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -1,9 +1,7 @@ { "connectionsMap": { - "r1":["mainServer", "r2"], - "r2":["r3", "s1"], - "r3":["r4", "c1"], - "r4":["r1", "c2"] + "r1":["mainServer", "r2" , "c2"], + "r2":["r1", "s1" , "c1"] } } diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json index 1a789aa0..5c79934d 100644 --- a/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json +++ b/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -103,7 +103,7 @@ "_doc_infraType": " opennn:0 | wolfengine:1 |", "distributedSystemType": "1", "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", - "distributedSystemArgs": "none", + "distributedSystemArgs": "SyncMaxCount=5", "_doc_distributedSystemArgs": "String", "distributedSystemToken": "9922u", "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" @@ -136,7 +136,7 @@ "_doc_infraType": " opennn:0 | wolfengine:1 |", "distributedSystemType": "2", "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", - "distributedSystemArgs": "none", + "distributedSystemArgs": "SyncMaxCount=5", "_doc_distributedSystemArgs": "String", "distributedSystemToken": "9922u", "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json index f6d9433d..cd7501ed 100644 --- a/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json +++ b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -1,6 +1,6 @@ { "experimentName": "synthetic_3_gausians", - "batchSize": 100, + "batchSize": 50, "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", "numOfFeatures": "5", "numOfLabels": "3", diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index 792264b3..16157808 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -5,7 +5,9 @@ -export([start_link/1]). -export([init/1, handle_cast/2, handle_call/3]). --export([send_message/3, get_all_messages/0]). % methods that are used by worker +-export([send_message/3, get_all_messages/0 , sync_inbox/0]). % methods that are used by worker + +-define(SYNC_INBOX_TIMEOUT, 30000). % 30 seconds %% @doc Spawns the server and registers the local name (unique) -spec(start_link(args) -> @@ -56,6 +58,8 @@ handle_call(_Call, _From, State) -> get_all_messages() -> W2WEts = get(w2w_ets), {_, InboxQueue} = ets:lookup(W2WEts, inbox_queue), + NewEmptyQueue = queue:new(), + ets:update_element(W2WEts, inbox_queue, {inbox_queue, NewEmptyQueue}), InboxQueue. add_msg_to_inbox_queue(Message) -> @@ -69,5 +73,29 @@ send_message(FromWorker, TargetWorker, Data) -> MyClient = get(client_statem_pid), gen_server:cast(MyClient, Msg). - - \ No newline at end of file +is_inbox_empty() -> + W2WEts = get(w2w_ets), + {_ , InboxQueue} = ets:lookup(W2WEts, inbox_queue), + queue:len(InboxQueue) == 0. + + +% Think about better alternative to this method + +timeout(Timeout) -> + receive + stop -> ok; + _ -> timeout(Timeout) + after Timeout -> throw("Timeout reached") + end. + +sync_inbox() -> + TimeoutPID = spawn(fun() -> timeout(?SYNC_INBOX_TIMEOUT) end), + sync_inbox(TimeoutPID). + +sync_inbox(TimeoutPID) -> + timer:sleep(10), % 10 ms + IsInboxEmpty = is_inbox_empty(), + if + IsInboxEmpty -> sync_inbox(TimeoutPID); + true -> TimeoutPID ! stop + end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index b2dc5e6f..5f903333 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -4,9 +4,11 @@ -include("/usr/local/lib/nerlnet-lib/NErlNet/src_erl/NerlnetApp/src/nerl_tools.hrl"). -include("workerDefinitions.hrl"). +-include("w2wCom.hrl"). -define(WORKER_FEDERATED_CLIENT_ETS_FIELDS, [my_name, client_pid, server_name, sync_max_count, sync_count]). -define(FEDERATED_CLIENT_ETS_KEY_IN_GENWORKER_ETS, fedrated_client_ets). +-define(DEFAULT_SYNC_MAX_COUNT_ARG, 1). % %% Federated mode % wait(cast, {loss, {LOSS_FUNC,Time_NIF}}, State = #workerGeneric_state{clientPid = ClientPid,ackClient = AckClient, myName = MyName, nextState = NextState, count = Count, countLimit = CountLimit, modelId = Mid}) -> @@ -35,7 +37,6 @@ % end. %% Data = -record(workerFederatedClient, {syncCount, syncMaxCount, serverAddr}). - controller(FuncName, {GenWorkerEts, WorkerData}) -> case FuncName of init -> init({GenWorkerEts, WorkerData}); @@ -51,28 +52,70 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> get_this_client_ets(GenWorkerEts) -> ets:lookup_element(GenWorkerEts, federated_client_ets, ?ETS_KEYVAL_VAL_IDX). +parse_args(Args) -> + ArgsList = string:split(Args, "," , all), + Func = fun(Arg) -> + [Key, Val] = string:split(Arg, "="), + {Key, Val} + end, + lists:map(Func, ArgsList). % Returns list of tuples [{Key, Val}, ...] + +sync_max_count_init(FedClientEts , ArgsList) -> + case lists:keyfind("sync_max_count", 1, ArgsList) of + false -> Val = ?DEFAULT_SYNC_MAX_COUNT_ARG; + {_, Val} -> list_to_integer(Val) + end, + ets:insert(FedClientEts, {sync_max_count, Val}). + %% handshake with workers / server init({GenWorkerEts, WorkerData}) -> % create an ets for this client and save it to generic worker ets FedratedClientEts = ets:new(federated_client,[set]), ets:insert(GenWorkerEts, {federated_client_ets, FedratedClientEts}), - {SyncMaxCount, MyName, ServerName} = WorkerData, + io:format("@FedClient: ~p~n",[WorkerData]), + {MyName, Args, Token} = WorkerData, + ArgsList = parse_args(Args), + sync_max_count_init(FedratedClientEts, ArgsList), % create fields in this ets + ets:insert(FedratedClientEts, {my_token, Token}), ets:insert(FedratedClientEts, {my_name, MyName}), - ets:insert(FedratedClientEts, {server_name, ServerName}), - ets:insert(FedratedClientEts, {sync_max_count, SyncMaxCount}), - ets:insert(FedratedClientEts, {sync_count, SyncMaxCount}), + ets:insert(FedratedClientEts, {server_name, none}), % update later + ets:insert(FedratedClientEts, {sync_count, 0}), ets:insert(FedratedClientEts, {server_update, false}), + ets:insert(FedratedClientEts, {handshake_done, false}), io:format("finished init in ~p~n",[MyName]). %% TODO REMOVE +handshake(EtsRef) -> + w2wCom:sync_inbox(), + InboxQueue = w2wCom:get_all_messages(), + MessagesList = queue:to_list(InboxQueue), + %% Throw exception if there is more than 1 message in the queue or if its empty + Func = + fun({?W2WCOM_ATOM, FromServer, MyName, {handshake, ServerToken}}) -> + ets:insert(EtsRef, {server_name, FromServer}), + ets:insert(EtsRef, {token , ServerToken}), + MyToken = ets:lookup_element(EtsRef, my_token, ?ETS_KEYVAL_VAL_IDX), + if + ServerToken =/= MyToken -> not_my_server; + true -> w2wCom:send_message(MyName, FromServer, {handshake, MyToken}) , + ets:update_element(EtsRef, handshake_done, true) + end + end, + lists:foreach(Func, MessagesList), + % Check if handshake is done + HandshakeDone = ets:lookup_element(EtsRef, handshake_done, ?ETS_KEYVAL_VAL_IDX), + if HandshakeDone -> ok; + true -> handshake(EtsRef) + end. + pre_idle({GenWorkerEts, _WorkerData}) -> ThisEts = get_this_client_ets(GenWorkerEts), - %% send to server that this worker is part of the federated workers _ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), % No longer needed? MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - % gen_statem:cast(ClientPID,{custom_worker_message,{MyName, ServerName}}), - w2wCom:send_message(MyName, ServerName, {MyName, ServerName}), %% ****** NEW - TEST NEEDED ****** + % Waiting for handshake from server + handshake(ThisEts), + io:format("@pre_idle: Worker ~p updates federated server ~p~n",[MyName , ServerName]). post_idle({_GenWorkerEts, _WorkerData}) -> ok. @@ -131,4 +174,5 @@ update({GenWorkerEts, NerlTensorWeights}) -> % receive _ -> non % after 1 -> worker_event_polling(T-1) % end -% end. \ No newline at end of file +% end. + diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 877d534a..7d56c3a1 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -11,6 +11,7 @@ -define(ETS_TYPE_IDX, 2). -define(ETS_WEIGHTS_AND_BIAS_NERLTENSOR_IDX, 3). -define(ETS_NERLTENSOR_TYPE_IDX, 2). +-define(DEFAULT_SYNC_MAX_COUNT_ARG, 1). controller(FuncName, {GenWorkerEts, WorkerData}) -> @@ -28,19 +29,40 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> get_this_server_ets(GenWorkerEts) -> ets:lookup_element(GenWorkerEts, federated_server_ets, ?ETS_KEYVAL_VAL_IDX). +parse_args(Args) -> + ArgsList = string:split(Args, "," , all), + Func = fun(Arg) -> + [Key, Val] = string:split(Arg, "="), + {Key, Val} + end, + lists:map(Func, ArgsList). % Returns list of tuples [{Key, Val}, ...] + +sync_max_count_init(FedServerEts , ArgsList) -> + case lists:keyfind("sync_max_count", 1, ArgsList) of + false -> Val = ?DEFAULT_SYNC_MAX_COUNT_ARG; + {_, Val} -> list_to_integer(Val) + end, + ets:insert(FedServerEts, {sync_max_count, Val}). + %% handshake with workers / server init({GenWorkerEts, WorkerData}) -> - Type = float, % update from data - {SyncMaxCount, MyName, WorkersNamesList} = WorkerData, FederatedServerEts = ets:new(federated_server,[set]), + {MyName, Args, Token} = WorkerData, + ArgsList = parse_args(Args), + sync_max_count_init(FederatedServerEts, ArgsList), ets:insert(GenWorkerEts, {federated_server_ets, FederatedServerEts}), - ets:insert(FederatedServerEts, {workers, [MyName]}), %% start with only self in list, get others in network thru handshake - ets:insert(FederatedServerEts, {sync_max_count, SyncMaxCount}), - ets:insert(FederatedServerEts, {sync_count, SyncMaxCount}), + ets:insert(FederatedServerEts, {fed_clients, []}), + ets:insert(FederatedServerEts, {sync_count, 0}), ets:insert(FederatedServerEts, {my_name, MyName}), - ets:insert(FederatedServerEts, {nerltensor_type, Type}). + ets:insert(FederatedServerEts, {token , Token}). + + -pre_idle({GenWorkerEts, WorkerName}) -> ok. +pre_idle({_GenWorkerEts, _WorkerName}) -> + % Extract all workers in nerlnet network + % Send handshake message to all workers + % Wait for all workers to send handshake message back + timer:sleep(500) % 0.5 second post_idle({GenWorkerEts, WorkerName}) -> ThisEts = get_this_server_ets(GenWorkerEts), @@ -49,10 +71,10 @@ post_idle({GenWorkerEts, WorkerName}) -> ets:insert(ThisEts, {workers, Workers++[WorkerName]}). %% Send updated weights if set -pre_train({GenWorkerEts, WorkerData}) -> ok. +pre_train({_GenWorkerEts, _WorkerData}) -> ok. %% calculate avg of weights when set -post_train({GenWorkerEts, WorkerData}) -> +post_train({GenWorkerEts, _WorkerData}) -> ThisEts = get_this_server_ets(GenWorkerEts), SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), if SyncCount == 0 -> @@ -63,10 +85,10 @@ post_train({GenWorkerEts, WorkerData}) -> gen_statem:cast(ClientPID, {update, {MyName, MyName, Weights}}), MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), ets:update_counter(ThisEts, sync_count, MaxSyncCount), - ToUpdate = true; + _ToUpdate = true; true -> ets:update_counter(ThisEts, sync_count, -1), - ToUpdate = false + _ToUpdate = false end. % ThisEts = get_this_server_ets(GenWorkerEts), % Weights = generate_avg_weights(ThisEts), @@ -74,15 +96,15 @@ post_train({GenWorkerEts, WorkerData}) -> % gen_statem:cast({update, Weights}). %TODO complete send to all workers in lists:foreach %% nothing? -pre_predict({GenWorkerEts, WorkerData}) -> ok. +pre_predict({_GenWorkerEts, _WorkerData}) -> ok. %% nothing? -post_predict({GenWorkerEts, WorkerData}) -> ok. +post_predict({_GenWorkerEts, _WorkerData}) -> ok. %% FedServer keeps an ets list of tuples: {WorkerName, worker, WeightsAndBiasNerlTensor} %% in update get weights of clients, if got from all => avg and send back update({GenWorkerEts, WorkerData}) -> - {WorkerName, Me, NerlTensorWeights} = WorkerData, + {WorkerName, _Me, NerlTensorWeights} = WorkerData, ThisEts = get_this_server_ets(GenWorkerEts), %% update weights in ets ets:insert(ThisEts, {WorkerName, worker, NerlTensorWeights}), diff --git a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl index a027a1c6..f75b3a50 100644 --- a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl +++ b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl @@ -19,7 +19,7 @@ case DistributedSystemType of %% Parse args eg. batch_sync_count ?DC_DISTRIBUTED_SYSTEM_TYPE_FEDSERVERAVG_IDX_STR -> DistributedBehaviorFunc = fun workerFederatedServer:controller/2, - DistributedWorkerData = {_ServerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken, _WorkersNamesList = []} + DistributedWorkerData = {_ServerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken} end, {DistributedBehaviorFunc , DistributedWorkerData}. From d01c874e34331cd27e3718ee104508b18527bf73 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Thu, 16 May 2024 22:45:42 +0000 Subject: [PATCH 05/52] [W2W] WIP --- .../src/Bridge/onnWorkers/w2wCom.erl | 7 +- .../onnWorkers/workerFederatedClient.erl | 87 +++---------------- .../onnWorkers/workerFederatedServer.erl | 45 +++++++--- .../src/Bridge/onnWorkers/workerGeneric.erl | 54 ++---------- .../NerlnetApp/src/Client/clientStatem.erl | 4 +- .../src/Client/clientWorkersFunctions.erl | 8 +- 6 files changed, 62 insertions(+), 143 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index 16157808..9efb0f3e 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -79,17 +79,16 @@ is_inbox_empty() -> queue:len(InboxQueue) == 0. -% Think about better alternative to this method -timeout(Timeout) -> +timeout_throw(Timeout) -> receive stop -> ok; - _ -> timeout(Timeout) + _ -> timeout_throw(Timeout) after Timeout -> throw("Timeout reached") end. sync_inbox() -> - TimeoutPID = spawn(fun() -> timeout(?SYNC_INBOX_TIMEOUT) end), + TimeoutPID = spawn(fun() -> timeout_throw(?SYNC_INBOX_TIMEOUT) end), sync_inbox(TimeoutPID). sync_inbox(TimeoutPID) -> diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 5f903333..9a530713 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -10,33 +10,6 @@ -define(FEDERATED_CLIENT_ETS_KEY_IN_GENWORKER_ETS, fedrated_client_ets). -define(DEFAULT_SYNC_MAX_COUNT_ARG, 1). -% %% Federated mode -% wait(cast, {loss, {LOSS_FUNC,Time_NIF}}, State = #workerGeneric_state{clientPid = ClientPid,ackClient = AckClient, myName = MyName, nextState = NextState, count = Count, countLimit = CountLimit, modelId = Mid}) -> -% % {LOSS_FUNC,_TimeCpp} = LossAndTime, -% if Count == CountLimit -> -% % Get weights -% Ret_weights = nerlNIF:call_to_get_weights(Mid), -% % Ret_weights_tuple = niftest:call_to_get_weights(Mid), -% % {Weights,Bias,Biases_sizes_list,Wheights_sizes_list} = Ret_weights_tuple, - -% % ListToSend = [Weights,Bias,Biases_sizes_list,Wheights_sizes_list], - -% % Send weights and loss value -% gen_statem:cast(ClientPid,{loss, federated_weights, MyName, LOSS_FUNC, Ret_weights}), %% TODO Add Time and Time_NIF to the cast -% checkAndAck(MyName,ClientPid,AckClient), -% % Reset count and go to state train -% {next_state, NextState, State#workerNN_state{ackClient = 0, count = 0}}; - -% true -> -% %% Send back the loss value -% gen_statem:cast(ClientPid,{loss, MyName, LOSS_FUNC,Time_NIF/1000}), %% TODO Add Time and Time_NIF to the cast -% checkAndAck(MyName,ClientPid,AckClient), - - -% {next_state, NextState, State#workerNN_state{ackClient = 0, count = Count + 1}} -% end. - -%% Data = -record(workerFederatedClient, {syncCount, syncMaxCount, serverAddr}). controller(FuncName, {GenWorkerEts, WorkerData}) -> case FuncName of init -> init({GenWorkerEts, WorkerData}); @@ -45,8 +18,7 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> pre_train -> pre_train({GenWorkerEts, WorkerData}); post_train -> post_train({GenWorkerEts, WorkerData}); pre_predict -> pre_predict({GenWorkerEts, WorkerData}); - post_predict-> post_predict({GenWorkerEts, WorkerData}); - update -> update({GenWorkerEts, WorkerData}) + post_predict-> post_predict({GenWorkerEts, WorkerData}) end. get_this_client_ets(GenWorkerEts) -> @@ -83,13 +55,12 @@ init({GenWorkerEts, WorkerData}) -> ets:insert(FedratedClientEts, {sync_count, 0}), ets:insert(FedratedClientEts, {server_update, false}), ets:insert(FedratedClientEts, {handshake_done, false}), - io:format("finished init in ~p~n",[MyName]). %% TODO REMOVE + spawn(fun() -> handshake(FedratedClientEts) end). handshake(EtsRef) -> w2wCom:sync_inbox(), InboxQueue = w2wCom:get_all_messages(), MessagesList = queue:to_list(InboxQueue), - %% Throw exception if there is more than 1 message in the queue or if its empty Func = fun({?W2WCOM_ATOM, FromServer, MyName, {handshake, ServerToken}}) -> ets:insert(EtsRef, {server_name, FromServer}), @@ -108,48 +79,30 @@ handshake(EtsRef) -> true -> handshake(EtsRef) end. -pre_idle({GenWorkerEts, _WorkerData}) -> - ThisEts = get_this_client_ets(GenWorkerEts), - _ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), % No longer needed? - MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), - ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - % Waiting for handshake from server - handshake(ThisEts), - - io:format("@pre_idle: Worker ~p updates federated server ~p~n",[MyName , ServerName]). +pre_idle({_GenWorkerEts, _WorkerData}) -> ok. post_idle({_GenWorkerEts, _WorkerData}) -> ok. -%% set weights from fedserver -pre_train({_GenWorkerEts, _WorkerData}) -> ok. - % ThisEts = get_this_client_ets(GenWorkerEts), - % ToUpdate = ets:lookup_element(ThisEts, server_update, ?ETS_KEYVAL_VAL_IDX), - % if ToUpdate -> - % ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - % nerlNIF:call_to_set_weights(ModelID, Weights); - % true -> nothing - % end. +% After SyncMaxCount , sync_inbox to get the updated model from FedServer +pre_train({GenWorkerEts, NerlTensorWeights}) -> + ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + nerlNIF:call_to_set_weights(ModelID, NerlTensorWeights). %% every countLimit batches, send updated weights post_train({GenWorkerEts, _WorkerData}) -> ThisEts = get_this_client_ets(GenWorkerEts), SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), - if SyncCount == 0 -> + MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), + if SyncCount == MaxSyncCount -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), Weights = nerlNIF:call_to_get_weights(GenWorkerEts, ModelID), - _ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), % No longer needed? ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), - MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), - % io:format("Worker ~p entering update and got weights ~p~n",[MyName, Weights]), - ets:update_counter(ThisEts, sync_count, MaxSyncCount), io:format("@post_train: Worker ~p updates federated server ~p~n",[MyName , ServerName]), w2wCom:send_message(MyName, ServerName , Weights), %% ****** NEW - TEST NEEDED ****** - % gen_statem:cast(ClientPID, {update, {MyName, ServerName, Weights}}), - _ToUpdate = true; % ? + ets:update_element(ThisEts, sync_count, {?ETS_KEYVAL_VAL_IDX , 0}); true -> - ets:update_counter(ThisEts, sync_count, -1), - _ToUpdate = false % ? + ets:update_counter(ThisEts, sync_count, 1) end. %% nothing? @@ -158,21 +111,3 @@ pre_predict({_GenWorkerEts, WorkerData}) -> WorkerData. %% nothing? post_predict(Data) -> Data. -%% gets weights from federated server -update({GenWorkerEts, NerlTensorWeights}) -> - % ThisEts = get_this_client_ets(GenWorkerEts), - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - nerlNIF:call_to_set_weights(ModelID, NerlTensorWeights). - % io:format("updated weights in worker ~p~n",[ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX)]). - -%%------------------------------------------ -% worker_event_polling(0) -> ?LOG_ERROR("worker event polling takes too long!"); -% worker_event_polling(Weights) -> -% if length(Weights) == 1 -> Weights; -% length(Weights) > 1 -> ?LOG_ERROR("more than 1 messages pending!"); -% true -> %% wait for info to update -% receive _ -> non -% after 1 -> worker_event_polling(T-1) -% end -% end. - diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 7d56c3a1..0eba6fa4 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -3,15 +3,19 @@ -export([controller/2]). -include("workerDefinitions.hrl"). +-include("w2wCom.hrl"). -import(nerlNIF,[nerltensor_scalar_multiplication_nif/3]). -import(nerlTensor,[sum_nerltensors_lists/2]). +-import(w2wCom,[send_message/3, get_all_messages/0, is_inbox_empty/0]). + -define(ETS_WID_IDX, 1). -define(ETS_TYPE_IDX, 2). -define(ETS_WEIGHTS_AND_BIAS_NERLTENSOR_IDX, 3). -define(ETS_NERLTENSOR_TYPE_IDX, 2). -define(DEFAULT_SYNC_MAX_COUNT_ARG, 1). +-define(HANDSHAKE_TIMEOUT, 2000). % 2 seconds controller(FuncName, {GenWorkerEts, WorkerData}) -> @@ -47,28 +51,49 @@ sync_max_count_init(FedServerEts , ArgsList) -> %% handshake with workers / server init({GenWorkerEts, WorkerData}) -> FederatedServerEts = ets:new(federated_server,[set]), - {MyName, Args, Token} = WorkerData, + {MyName, Args, Token , WorkersList} = WorkerData, ArgsList = parse_args(Args), sync_max_count_init(FederatedServerEts, ArgsList), ets:insert(GenWorkerEts, {federated_server_ets, FederatedServerEts}), + ets:insert(FederatedServerEts, {workers, WorkersList}), ets:insert(FederatedServerEts, {fed_clients, []}), ets:insert(FederatedServerEts, {sync_count, 0}), ets:insert(FederatedServerEts, {my_name, MyName}), - ets:insert(FederatedServerEts, {token , Token}). + ets:insert(FederatedServerEts, {token , Token}), + put(fed_server_ets, FederatedServerEts). -pre_idle({_GenWorkerEts, _WorkerName}) -> +pre_idle({_GenWorkerEts, _WorkerName}) -> ok. + + +post_idle({_GenWorkerEts, _WorkerName}) -> % Extract all workers in nerlnet network % Send handshake message to all workers % Wait for all workers to send handshake message back - timer:sleep(500) % 0.5 second - -post_idle({GenWorkerEts, WorkerName}) -> - ThisEts = get_this_server_ets(GenWorkerEts), - io:format("adding worker ~p to fed workers~n",[WorkerName]), - Workers = ets:lookup_element(ThisEts, workers, ?ETS_KEYVAL_VAL_IDX), - ets:insert(ThisEts, {workers, Workers++[WorkerName]}). + FedServerEts = get(fed_server_ets), + FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + WorkersList = ets:lookup_element(FedServerEts, workers, ?ETS_KEYVAL_VAL_IDX), + MyToken = ets:lookup_element(FedServerEts, token, ?ETS_KEYVAL_VAL_IDX), + Func = fun(FedClient) -> + w2wCom:send_message(FedClient, FedServerName, {handshake, MyToken}) + end, + lists:foreach(Func, WorkersList), + timer:sleep(?HANDSHAKE_TIMEOUT), + IsEmpty = w2wCom:is_inbox_empty(), + if IsEmpty == true -> + throw("Handshake failed, none of the workers responded in time"); + true -> ok + end, + InboxQueue = w2wCom:get_all_messages(), + MessagesList = queue:to_list(InboxQueue), + MsgFunc = + fun({?W2WCOM_ATOM, FromWorker, _MyName, {handshake, _WorkerToken}}) -> + FedWorkers = ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX), + ets:update_element(FedServerEts, fed_clients, {?ETS_KEYVAL_VAL_IDX , [FromWorker] ++ FedWorkers}) + end, + lists:foreach(MsgFunc, MessagesList), + io:format("Handshake done~n"). %% Send updated weights if set pre_train({_GenWorkerEts, _WorkerData}) -> ok. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 4a7f2259..3816a01c 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -176,10 +176,8 @@ wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneri wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), - ToUpdate = DistributedBehaviorFunc(post_train, {get(generic_worker_ets),DistributedWorkerData}), %% Change to W2WComm - if ToUpdate -> {next_state, update, State#workerGeneric_state{nextState=NextState}}; - true -> {next_state, NextState, State} - end; + DistributedBehaviorFunc(post_train, {get(generic_worker_ets),DistributedWorkerData}), %% Change to W2WComm + {next_state, NextState, State}; wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> BatchTimeStamp = erlang:system_time(nanosecond), @@ -219,48 +217,6 @@ wait(cast, Data, State) -> worker_controller_message_queue(Data), {keep_state, State}. -%% treated runaway message in nerlNIF:call_to_fet_weights -% update(info, Data, State) -> -% ?LOG_NOTICE(?LOG_HEADER++"Worker ~p got data thru info: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), -% ?LOG_INFO("Worker ets is: ~p",[ets:match_object(get(generic_worker_ets), {'$0', '$1'})]), -% {keep_state, State}; - -%% TODO FIX CONTROLLER -update(cast, {update, _From, NerltensorWeights}, State = #workerGeneric_state{distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> - ?LOG_WARNING("************* Unrecognized update method , next state: ~p **************" , [NextState]), - DistributedBehaviorFunc(update, {get(generic_worker_ets), NerltensorWeights}), - {next_state, NextState, State}; - -%% Worker updates its' client that it is available (in idle state) -update(cast, {idle}, State = #workerGeneric_state{myName = MyName}) -> - update_client_avilable_worker(MyName), - {next_state, idle, State#workerGeneric_state{nextState = idle}}; - - -%% TODO Guy MOVE THIS FUNCTION TO CONTROLLER -update(cast, Data, State = #workerGeneric_state{distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> - % io:format("worker ~p got ~p~n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), - case Data of - %% FedClient update avg weights - {update, "server", _Me, NerltensorWeights} -> - DistributedBehaviorFunc(update, {get(generic_worker_ets), NerltensorWeights}), - % io:format("worker ~p updated model and going to ~p state~n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), NextState]), - {next_state, NextState, State}; - %% FedServer get weights from clients - {update, WorkerName, Me, NerlTensorWeights} -> - StillUpdate = DistributedBehaviorFunc(update, {get(generic_worker_ets), {WorkerName, Me, NerlTensorWeights}}), - if StillUpdate -> - % io:format("worker ~p in update waiting to go to ~p state~n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), NextState]), - {keep_state, State#workerGeneric_state{nextState=NextState}}; - true -> - {next_state, NextState, State#workerGeneric_state{}} - end; - %% got sample from source. discard and add missed count TODO: add to Q - {sample, _Tensor} -> - %%ets:update_counter(get(generic_worker_ets), missedBatches, 1), - {keep_state, State} - end. - %% State train train(cast, {sample, BatchID ,{<<>>, _Type}}, State) -> @@ -288,8 +244,9 @@ train(cast, {set_weights,Ret_weights_list}, State = #workerGeneric_state{modelID {next_state, train, State}; -train(cast, {idle}, State = #workerGeneric_state{myName = MyName}) -> +train(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), + DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), {next_state, idle, State}; train(cast, Data, State) -> @@ -314,8 +271,9 @@ predict(cast, {sample , SourceName , BatchID , {PredictBatchTensor, Type}}, Stat _Pid = spawn(fun()-> nerlNIF:call_to_predict(ModelId , {PredictBatchTensor, Type} , CurrPID , BatchID, SourceName) end), {next_state, wait, State#workerGeneric_state{nextState = predict , currentBatchID = BatchID}}; -predict(cast, {idle}, State = #workerGeneric_state{myName = MyName}) -> +predict(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), + DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), predict}), {next_state, idle, State}; predict(cast, Data, State) -> diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 14cd8b31..258fb833 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -77,8 +77,8 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh ClientStatsEts = stats:generate_stats_ets(), %% client stats ets inside ets_stats ets:insert(EtsStats, {MyName, ClientStatsEts}), put(ets_stats, EtsStats), - ets:insert(EtsRef, {workerToClient, WorkerToClientMap}), - ets:insert(EtsRef, {workersNames, ClientWorkers}), + ets:insert(EtsRef, {workerToClient, WorkerToClientMap}), % All workers in the network (map to their client) + ets:insert(EtsRef, {workersNames, ClientWorkers}), % All THIS Client's workers ets:insert(EtsRef, {nerlnetGraph, NerlnetGraph}), ets:insert(EtsRef, {myName, MyName}), MyWorkersToShaMap = maps:filter(fun(Worker , _SHA) -> lists:member(Worker , ClientWorkers) end , WorkerShaMap), diff --git a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl index f75b3a50..c7d0038d 100644 --- a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl +++ b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl @@ -8,7 +8,7 @@ -export([create_workers/4]). -export([get_worker_pid/2 , get_worker_stats_ets/2 , get_workers_names/1]). -get_distributed_worker_behavior(DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken) -> +get_distributed_worker_behavior(ClientEtsRef, DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken) -> case DistributedSystemType of ?DC_DISTRIBUTED_SYSTEM_TYPE_NONE_IDX_STR -> DistributedBehaviorFunc = fun workerNN:controller/2, @@ -18,8 +18,10 @@ case DistributedSystemType of DistributedWorkerData = {_WorkerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken}; %% Parse args eg. batch_sync_count ?DC_DISTRIBUTED_SYSTEM_TYPE_FEDSERVERAVG_IDX_STR -> + WorkersMap = ets:lookup_element(ClientEtsRef, workerToClient, ?DATA_IDX), + WorkersList = [Worker || {Worker, _Val} <- maps:to_list(WorkersMap)], DistributedBehaviorFunc = fun workerFederatedServer:controller/2, - DistributedWorkerData = {_ServerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken} + DistributedWorkerData = {_ServerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken , _WorkersList = WorkersList} end, {DistributedBehaviorFunc , DistributedWorkerData}. @@ -43,7 +45,7 @@ create_workers(ClientName, ClientEtsRef , ShaToModelArgsMap , EtsStats) -> MyClientPid = self(), % TODO add documentation about this case of % move this case to module called client_controller - {DistributedBehaviorFunc , DistributedWorkerData} = get_distributed_worker_behavior(DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken), + {DistributedBehaviorFunc , DistributedWorkerData} = get_distributed_worker_behavior(ClientEtsRef, DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken), WorkerArgs = {ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctions, LearningRate , Epochs, Optimizer, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs}, From 76636d81bea733618819b6b0408a9f4eced7020f Mon Sep 17 00:00:00 2001 From: GuyPErets106 Date: Fri, 17 May 2024 14:07:59 +0000 Subject: [PATCH 06/52] [W2W] Updated Jsons --- .../dc_AEC_1d_2c_1s_4r_4w.json | 67 +++++++++++++++++++ .../Workers/worker_ae_classifier.json | 33 +++++++++ .../Workers/worker_fed_client.json | 33 +++++++++ .../Workers/worker_fed_server.json | 33 +++++++++ 4 files changed, 166 insertions(+) create mode 100644 inputJsonsFiles/DistributedConfig/dc_AEC_1d_2c_1s_4r_4w.json create mode 100644 inputJsonsFiles/Workers/worker_ae_classifier.json create mode 100644 inputJsonsFiles/Workers/worker_fed_client.json create mode 100644 inputJsonsFiles/Workers/worker_fed_server.json diff --git a/inputJsonsFiles/DistributedConfig/dc_AEC_1d_2c_1s_4r_4w.json b/inputJsonsFiles/DistributedConfig/dc_AEC_1d_2c_1s_4r_4w.json new file mode 100644 index 00000000..753ac010 --- /dev/null +++ b/inputJsonsFiles/DistributedConfig/dc_AEC_1d_2c_1s_4r_4w.json @@ -0,0 +1,67 @@ +{ + "nerlnetSettings": { + "frequency": "200", + "batchSize": "100" + }, + "mainServer": { + "port": "8081", + "args": "" + }, + "apiServer": { + "port": "8082", + "args": "" + }, + "devices": [ + { + "name": "pc1", + "ipv4": "10.211.55.3", + "entities": "c1,c2,r2,r1,r3,r4,s1,apiServer,mainServer" + } + ], + "routers": [ + { + "name": "r1", + "port": "8086", + "policy": "0" + }, + { + "name": "r2", + "port": "8087", + "policy": "0" + }, + { + "name": "r3", + "port": "8088", + "policy": "0" + }, + { + "name": "r4", + "port": "8089", + "policy": "0" + } + ], + "sources": [ + { + "name": "s1", + "port": "8085", + "frequency": "200", + "policy": "0", + "epochs": "1", + "type": "0" + } + ], + "clients": [ + { + "name": "c1", + "port": "8083", + "workers": "" + }, + { + "name": "c2", + "port": "8084", + "workers": "" + } + ], + "workers": [], + "model_sha": {} +} \ No newline at end of file diff --git a/inputJsonsFiles/Workers/worker_ae_classifier.json b/inputJsonsFiles/Workers/worker_ae_classifier.json new file mode 100644 index 00000000..96d59122 --- /dev/null +++ b/inputJsonsFiles/Workers/worker_ae_classifier.json @@ -0,0 +1,33 @@ +{ + "modelType": "9", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "11,6,4,6,11", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,3", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", + "layers_functions": "1,7,7,7,11", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "0", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "none", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "none", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" +} \ No newline at end of file diff --git a/inputJsonsFiles/Workers/worker_fed_client.json b/inputJsonsFiles/Workers/worker_fed_client.json new file mode 100644 index 00000000..28964195 --- /dev/null +++ b/inputJsonsFiles/Workers/worker_fed_client.json @@ -0,0 +1,33 @@ +{ + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,10,5,3,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "1", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "none", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" +} \ No newline at end of file diff --git a/inputJsonsFiles/Workers/worker_fed_server.json b/inputJsonsFiles/Workers/worker_fed_server.json new file mode 100644 index 00000000..d9eb7758 --- /dev/null +++ b/inputJsonsFiles/Workers/worker_fed_server.json @@ -0,0 +1,33 @@ +{ + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,10,5,3,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "2", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "none", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" +} \ No newline at end of file From 359f5ef193de6d7133a21cf825d3ad7eee512317 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Sun, 19 May 2024 19:57:06 +0000 Subject: [PATCH 07/52] [W2W] Independent Exp Works --- .../dc_fed_synt_1d_2c_2r_1s_4w_1ws.json | 8 +- .../exp_fed_synt_1d_2c_2r_1s_4w_1ws.json | 77 ++++---- .../exp_test_synt_1d_2c_1s_4r_4w new.json | 2 +- src_erl/NerlnetApp/src/Bridge/nerlNIF.erl | 13 +- .../src/Bridge/onnWorkers/w2wCom.erl | 90 ++++++--- .../onnWorkers/workerFederatedClient.erl | 106 ++++++---- .../onnWorkers/workerFederatedServer.erl | 186 +++++++++--------- .../src/Bridge/onnWorkers/workerGeneric.erl | 43 ++-- .../src/Client/clientStateHandler.erl | 1 - .../NerlnetApp/src/Client/clientStatem.erl | 69 +++---- .../src/Client/clientWorkersFunctions.erl | 9 +- src_erl/NerlnetApp/src/nerlnetApp_app.erl | 3 +- src_py/apiServer/apiServer.py | 4 +- 13 files changed, 325 insertions(+), 286 deletions(-) diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json index 5c79934d..7814c903 100644 --- a/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json +++ b/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -1,7 +1,7 @@ { "nerlnetSettings": { "frequency": "100", - "batchSize": "50" + "batchSize": "100" }, "mainServer": { "port": "8900", @@ -34,7 +34,7 @@ { "name": "s1", "port": "8904", - "frequency": "100", + "frequency": "200", "policy": "0", "epochs": "1", "type": "0" @@ -80,7 +80,7 @@ "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", "modelArgs": "", "_doc_modelArgs": "Extra arguments to model", - "layersSizes": "5,10,5,3,3", + "layersSizes": "5,2,2,2,3", "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", "layerTypesList": "1,3,3,3,5", "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", @@ -113,7 +113,7 @@ "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", "modelArgs": "", "_doc_modelArgs": "Extra arguments to model", - "layersSizes": "5,10,5,3,3", + "layersSizes": "5,2,2,2,3", "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", "layerTypesList": "1,3,3,3,5", "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json index cd7501ed..34df8a56 100644 --- a/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json +++ b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -1,39 +1,40 @@ { - "experimentName": "synthetic_3_gausians", - "batchSize": 50, - "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", - "numOfFeatures": "5", - "numOfLabels": "3", - "headersNames": "Norm(0:1),Norm(4:1),Norm(10:3)", - "Phases": - [ - { - "phaseName": "training_phase", - "phaseType": "training", - "sourcePieces": - [ - { - "sourceName": "s1", - "startingSample": "0", - "numOfBatches": "300", - "workers": "w1,w2,w3,w4" - } - ] - }, - { - "phaseName": "prediction_phase", - "phaseType": "prediction", - "sourcePieces": - [ - { - "sourceName": "s1", - "startingSample": "30000", - "numOfBatches": "200", - "workers": "w1,w2,w3,w4" - } - ] - } - ] - } - - \ No newline at end of file + "experimentName": "synthetic_3_gausians", + "batchSize": 100, + "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", + "numOfFeatures": "5", + "numOfLabels": "3", + "headersNames": "Norm(0:1),Norm(4:1),Norm(10:3)", + "Phases": + [ + { + "phaseName": "training_phase", + "phaseType": "training", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "0", + "numOfBatches": "500", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + } + ] + }, + { + "phaseName": "prediction_phase", + "phaseType": "prediction", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "50000", + "numOfBatches": "500", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + } + ] + } + ] +} + diff --git a/inputJsonsFiles/experimentsFlow/exp_test_synt_1d_2c_1s_4r_4w new.json b/inputJsonsFiles/experimentsFlow/exp_test_synt_1d_2c_1s_4r_4w new.json index ae583eec..f05f5f01 100644 --- a/inputJsonsFiles/experimentsFlow/exp_test_synt_1d_2c_1s_4r_4w new.json +++ b/inputJsonsFiles/experimentsFlow/exp_test_synt_1d_2c_1s_4r_4w new.json @@ -36,6 +36,6 @@ } ] } - ] + ] } diff --git a/src_erl/NerlnetApp/src/Bridge/nerlNIF.erl b/src_erl/NerlnetApp/src/Bridge/nerlNIF.erl index f5250170..823de1e4 100644 --- a/src_erl/NerlnetApp/src/Bridge/nerlNIF.erl +++ b/src_erl/NerlnetApp/src/Bridge/nerlNIF.erl @@ -3,7 +3,7 @@ -include("nerlTensor.hrl"). -export([init/0,nif_preload/0,get_active_models_ids_list/0, train_nif/3,update_nerlworker_train_params_nif/6,call_to_train/5,predict_nif/3,call_to_predict/5,get_weights_nif/1,printTensor/2]). --export([call_to_get_weights/2,call_to_set_weights/2]). +-export([call_to_get_weights/1,call_to_set_weights/2]). -export([decode_nif/2, nerltensor_binary_decode/2]). -export([encode_nif/2, nerltensor_encode/5, nerltensor_conversion/2, get_all_binary_types/0, get_all_nerltensor_list_types/0]). -export([erl_type_conversion/1]). @@ -77,21 +77,22 @@ call_to_predict(ModelID, {BatchTensor, Type}, WorkerPid, BatchID , SourceName)-> gen_statem:cast(WorkerPid,{predictRes, nan, BatchID , SourceName}) end. -call_to_get_weights(ThisEts, ModelID)-> +% This function calls to get_weights_nif() and waits for the result using receive block +% Returns {NerlTensorWeights , BinaryType} +call_to_get_weights(ModelID)-> try ?LOG_INFO("Calling get weights in model ~p~n",{ModelID}), _RetVal = get_weights_nif(ModelID), - recv_call_loop(ThisEts) + recv_call_loop() catch Err:E -> ?LOG_ERROR("Couldnt get weights from worker~n~p~n",{Err,E}), [] end. %% sometimes the receive loop gets OTP calls that its not supposed to in high freq. wait for nerktensor of weights -recv_call_loop(ThisEts) -> +recv_call_loop() -> receive {'$gen_cast', _Any} -> ?LOG_WARNING("Missed batch in call of get_weigths"), - ets:update_counter(ThisEts, missedBatches, 1), - recv_call_loop(ThisEts); + recv_call_loop(); NerlTensorWeights -> NerlTensorWeights end. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index 9efb0f3e..08a75d41 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -5,9 +5,11 @@ -export([start_link/1]). -export([init/1, handle_cast/2, handle_call/3]). --export([send_message/3, get_all_messages/0 , sync_inbox/0]). % methods that are used by worker +-export([send_message/4, get_all_messages/1 , sync_inbox/1]). % methods that are used by worker +-define(ETS_KEYVAL_VAL_IDX, 2). -define(SYNC_INBOX_TIMEOUT, 30000). % 30 seconds +-define(DEFAULT_SYNC_INBOX_BUSY_WAITING_SLEEP, 5). % 5 milliseconds %% @doc Spawns the server and registers the local name (unique) -spec(start_link(args) -> @@ -26,8 +28,28 @@ init({WorkerName, MyClientPid}) -> ets:insert(W2wEts, {inbox_queue, InboxQueue}), {ok, []}. +handle_cast({update_gen_worker_pid, GenWorkerPid}, State) -> + put(gen_worker_pid, GenWorkerPid), + {noreply, State}; + +handle_cast(Msg, State) -> + io:format("@w2wCom: Wrong message received ~p~n", [Msg]), + {noreply, State}. + +% This handler also triggers the state machine during training state +handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {post_train_update, Data}}, _From, State) -> + case get(worker_name) of + ThisWorkerName -> ok; + _ -> throw({error, "The provided worker name is not this worker"}) + end, + % Saved messages are of the form: {FromWorkerName, , Data} + Message = {FromWorkerName, Data}, + add_msg_to_inbox_queue(Message), + gen_server:cast(get(gen_worker_pid), {post_train_update}), + {reply, {ok, post_train_update}, State}; + % Received messages are of the form: {worker_to_worker_msg, FromWorkerName, ThisWorkerName, Data} -handle_cast({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, Data}, State) -> +handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, Data}, _From, State) -> case get(worker_name) of ThisWorkerName -> ok; _ -> throw({error, "The provided worker name is not this worker"}) @@ -35,49 +57,53 @@ handle_cast({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, Data}, State) -> % Saved messages are of the form: {FromWorkerName, , Data} Message = {FromWorkerName, Data}, add_msg_to_inbox_queue(Message), - io:format("Worker ~p received message from ~p: ~p~n", [ThisWorkerName, FromWorkerName, Data]), %TODO remove - {noreply, State}; + {reply, {ok, "Message received"}, State}; % Token messages are tupe of: {FromWorkerName, Token, Data} -handle_cast({?W2WCOM_TOKEN_CAST_ATOM, FromWorkerName, ThisWorkerName, Token, Data}, State) -> +handle_call({?W2WCOM_TOKEN_CAST_ATOM, FromWorkerName, ThisWorkerName, Token, Data}, _From, State) -> case get(worker_name) of ThisWorkerName -> ok; _ -> throw({error, "The provided worker name is not this worker"}) end, Message = {FromWorkerName, Token, Data}, add_msg_to_inbox_queue(Message), - io:format("Worker ~p received token message from ~p: ~p~n", [ThisWorkerName, FromWorkerName, Data]), %TODO remove - {noreply, State}; + {reply, {ok, "Message received"}, State}; -handle_cast(_Msg, State) -> - {noreply, State}. -handle_call(_Call, _From, State) -> - {noreply, State}. +handle_call({is_inbox_empty}, _From, State) -> + W2WEts = get(w2w_ets), + InboxQueue = ets:lookup_element(W2WEts, inbox_queue, ?ETS_KEYVAL_VAL_IDX), + IsInboxEmpty = queue:len(InboxQueue) == 0, + {reply, {ok, IsInboxEmpty}, State}; -get_all_messages() -> +handle_call({get_inbox_queue}, _From, State) -> W2WEts = get(w2w_ets), - {_, InboxQueue} = ets:lookup(W2WEts, inbox_queue), NewEmptyQueue = queue:new(), - ets:update_element(W2WEts, inbox_queue, {inbox_queue, NewEmptyQueue}), + InboxQueue = ets:lookup_element(W2WEts, inbox_queue, ?ETS_KEYVAL_VAL_IDX), + ets:update_element(W2WEts, inbox_queue, {?ETS_KEYVAL_VAL_IDX, NewEmptyQueue}), + {reply, {ok, InboxQueue}, State}; + +handle_call({get_client_pid}, _From, State) -> + {reply, {ok, get(client_statem_pid)}, State}; + +handle_call(_Call, _From, State) -> + {noreply, State}. + +get_all_messages(W2WPid) -> % Returns the InboxQueue and flush it + {ok , InboxQueue} = gen_server:call(W2WPid, {get_inbox_queue}), InboxQueue. -add_msg_to_inbox_queue(Message) -> +add_msg_to_inbox_queue(Message) -> % Only w2wCom process executes this function W2WEts = get(w2w_ets), - {_, InboxQueue} = ets:lookup(W2WEts, inbox_queue), + InboxQueue = ets:lookup_element(W2WEts, inbox_queue, ?ETS_KEYVAL_VAL_IDX), InboxQueueUpdated = queue:in(Message, InboxQueue), - ets:update_element(W2WEts, inbox_queue, {inbox_queue, InboxQueueUpdated}). - -send_message(FromWorker, TargetWorker, Data) -> - Msg = {?W2WCOM_ATOM, FromWorker, TargetWorker, Data}, - MyClient = get(client_statem_pid), - gen_server:cast(MyClient, Msg). + ets:update_element(W2WEts, inbox_queue, {?ETS_KEYVAL_VAL_IDX, InboxQueueUpdated}). -is_inbox_empty() -> - W2WEts = get(w2w_ets), - {_ , InboxQueue} = ets:lookup(W2WEts, inbox_queue), - queue:len(InboxQueue) == 0. +send_message(W2WPid, FromWorker, TargetWorker, Data) -> + Msg = {?W2WCOM_ATOM, FromWorker, TargetWorker, Data}, + {ok, MyClient} = gen_server:call(W2WPid, {get_client_pid}), + gen_statem:cast(MyClient, Msg). timeout_throw(Timeout) -> @@ -87,14 +113,14 @@ timeout_throw(Timeout) -> after Timeout -> throw("Timeout reached") end. -sync_inbox() -> +sync_inbox(W2WPid) -> TimeoutPID = spawn(fun() -> timeout_throw(?SYNC_INBOX_TIMEOUT) end), - sync_inbox(TimeoutPID). + sync_inbox(TimeoutPID , W2WPid). -sync_inbox(TimeoutPID) -> - timer:sleep(10), % 10 ms - IsInboxEmpty = is_inbox_empty(), +sync_inbox(TimeoutPID, W2WPid) -> + timer:sleep(?DEFAULT_SYNC_INBOX_BUSY_WAITING_SLEEP), + {ok , IsInboxEmpty} = gen_server:call(W2WPid, {is_inbox_empty}), if - IsInboxEmpty -> sync_inbox(TimeoutPID); + IsInboxEmpty -> sync_inbox(TimeoutPID, W2WPid); true -> TimeoutPID ! stop end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 9a530713..94da7180 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -6,6 +6,8 @@ -include("workerDefinitions.hrl"). -include("w2wCom.hrl"). +-import(nerlNIF, [call_to_get_weights/2, call_to_set_weights/2]). + -define(WORKER_FEDERATED_CLIENT_ETS_FIELDS, [my_name, client_pid, server_name, sync_max_count, sync_count]). -define(FEDERATED_CLIENT_ETS_KEY_IN_GENWORKER_ETS, fedrated_client_ets). -define(DEFAULT_SYNC_MAX_COUNT_ARG, 1). @@ -33,21 +35,21 @@ parse_args(Args) -> lists:map(Func, ArgsList). % Returns list of tuples [{Key, Val}, ...] sync_max_count_init(FedClientEts , ArgsList) -> - case lists:keyfind("sync_max_count", 1, ArgsList) of - false -> Val = ?DEFAULT_SYNC_MAX_COUNT_ARG; - {_, Val} -> list_to_integer(Val) + case lists:keyfind("SyncMaxCount", 1, ArgsList) of + false -> ValInt = ?DEFAULT_SYNC_MAX_COUNT_ARG; + {_, Val} -> ValInt = list_to_integer(Val) % Val is a list (string) in the json so needs to be converted end, - ets:insert(FedClientEts, {sync_max_count, Val}). + ets:insert(FedClientEts, {sync_max_count, ValInt}). -%% handshake with workers / server +%% handshake with workers / server at the end of init init({GenWorkerEts, WorkerData}) -> % create an ets for this client and save it to generic worker ets - FedratedClientEts = ets:new(federated_client,[set]), + FedratedClientEts = ets:new(federated_client,[set, public]), ets:insert(GenWorkerEts, {federated_client_ets, FedratedClientEts}), - io:format("@FedClient: ~p~n",[WorkerData]), {MyName, Args, Token} = WorkerData, ArgsList = parse_args(Args), sync_max_count_init(FedratedClientEts, ArgsList), + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), % create fields in this ets ets:insert(FedratedClientEts, {my_token, Token}), ets:insert(FedratedClientEts, {my_name, MyName}), @@ -55,38 +57,68 @@ init({GenWorkerEts, WorkerData}) -> ets:insert(FedratedClientEts, {sync_count, 0}), ets:insert(FedratedClientEts, {server_update, false}), ets:insert(FedratedClientEts, {handshake_done, false}), + ets:insert(FedratedClientEts, {handshake_wait, false}), + ets:insert(FedratedClientEts, {w2wcom_pid, W2WPid}), spawn(fun() -> handshake(FedratedClientEts) end). -handshake(EtsRef) -> - w2wCom:sync_inbox(), - InboxQueue = w2wCom:get_all_messages(), - MessagesList = queue:to_list(InboxQueue), - Func = - fun({?W2WCOM_ATOM, FromServer, MyName, {handshake, ServerToken}}) -> - ets:insert(EtsRef, {server_name, FromServer}), - ets:insert(EtsRef, {token , ServerToken}), - MyToken = ets:lookup_element(EtsRef, my_token, ?ETS_KEYVAL_VAL_IDX), - if - ServerToken =/= MyToken -> not_my_server; - true -> w2wCom:send_message(MyName, FromServer, {handshake, MyToken}) , - ets:update_element(EtsRef, handshake_done, true) - end - end, - lists:foreach(Func, MessagesList), - % Check if handshake is done - HandshakeDone = ets:lookup_element(EtsRef, handshake_done, ?ETS_KEYVAL_VAL_IDX), - if HandshakeDone -> ok; - true -> handshake(EtsRef) - end. +handshake(FedClientEts) -> + W2WPid = ets:lookup_element(FedClientEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox(W2WPid), + InboxQueue = w2wCom:get_all_messages(W2WPid), + MessagesList = queue:to_list(InboxQueue), + Func = + fun({FedServer , {handshake, ServerToken}}) -> + ets:insert(FedClientEts, {server_name, FedServer}), + ets:insert(FedClientEts, {my_token , ServerToken}), + MyToken = ets:lookup_element(FedClientEts, my_token, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedClientEts, my_name, ?ETS_KEYVAL_VAL_IDX), + if + ServerToken =/= MyToken -> not_my_server; + true -> w2wCom:send_message(W2WPid, MyName, FedServer, {handshake, MyToken}), + ets:update_element(FedClientEts, handshake_wait, {?ETS_KEYVAL_VAL_IDX, true}) + end + end, + lists:foreach(Func, MessagesList). pre_idle({_GenWorkerEts, _WorkerData}) -> ok. -post_idle({_GenWorkerEts, _WorkerData}) -> ok. +post_idle({GenWorkerEts, _WorkerData}) -> + FedClientEts = get_this_client_ets(GenWorkerEts), + W2WPid = ets:lookup_element(FedClientEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + Token = ets:lookup_element(FedClientEts, my_token, ?ETS_KEYVAL_VAL_IDX), + HandshakeWait = ets:lookup_element(FedClientEts, handshake_wait, ?ETS_KEYVAL_VAL_IDX), + case HandshakeWait of + true -> HandshakeDone = ets:lookup_element(FedClientEts, handshake_done, ?ETS_KEYVAL_VAL_IDX), + case HandshakeDone of + false -> + w2wCom:sync_inbox(W2WPid), + InboxQueue = w2wCom:get_all_messages(W2WPid), + ets:update_element(FedClientEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}), + [{_FedServer, {handshake_done, Token}}] = queue:to_list(InboxQueue); + true -> ok + end; + false -> post_idle({GenWorkerEts, _WorkerData}) % busy waiting until handshake is done + end. + + % After SyncMaxCount , sync_inbox to get the updated model from FedServer -pre_train({GenWorkerEts, NerlTensorWeights}) -> - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - nerlNIF:call_to_set_weights(ModelID, NerlTensorWeights). +pre_train({GenWorkerEts, _NerlTensorWeights}) -> + ThisEts = get_this_client_ets(GenWorkerEts), + SyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_count, ?ETS_KEYVAL_VAL_IDX), + MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), + if SyncCount == MaxSyncCount -> + W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox(W2WPid), % waiting for server to average the weights and send it + WorkerName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), + InboxQueue = w2wCom:get_all_messages(W2WPid), + [UpdateWeightsMsg] = queue:to_list(InboxQueue), + {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, + ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + nerlNIF:call_to_set_weights(ModelID, UpdatedWeights), + ets:update_element(ThisEts, sync_count, {?ETS_KEYVAL_VAL_IDX , 0}); + true -> ets:update_counter(ThisEts, sync_count, 1) + end. %% every countLimit batches, send updated weights post_train({GenWorkerEts, _WorkerData}) -> @@ -95,14 +127,12 @@ post_train({GenWorkerEts, _WorkerData}) -> MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), if SyncCount == MaxSyncCount -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - Weights = nerlNIF:call_to_get_weights(GenWorkerEts, ModelID), + Weights = nerlNIF:call_to_get_weights(ModelID), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), - io:format("@post_train: Worker ~p updates federated server ~p~n",[MyName , ServerName]), - w2wCom:send_message(MyName, ServerName , Weights), %% ****** NEW - TEST NEEDED ****** - ets:update_element(ThisEts, sync_count, {?ETS_KEYVAL_VAL_IDX , 0}); - true -> - ets:update_counter(ThisEts, sync_count, 1) + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:send_message(W2WPid, MyName, ServerName , {post_train_update, Weights}); %% ****** NEW - TEST NEEDED ****** + true -> ok end. %% nothing? diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 0eba6fa4..ca6e0459 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -5,7 +5,7 @@ -include("workerDefinitions.hrl"). -include("w2wCom.hrl"). --import(nerlNIF,[nerltensor_scalar_multiplication_nif/3]). +-import(nerlNIF,[nerltensor_scalar_multiplication_nif/3, call_to_get_weights/1, call_to_set_weights/2]). -import(nerlTensor,[sum_nerltensors_lists/2]). -import(w2wCom,[send_message/3, get_all_messages/0, is_inbox_empty/0]). @@ -26,10 +26,11 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> pre_train -> pre_train({GenWorkerEts, WorkerData}); post_train -> post_train({GenWorkerEts, WorkerData}); pre_predict -> pre_predict({GenWorkerEts, WorkerData}); - post_predict -> post_predict({GenWorkerEts, WorkerData}); - update -> update({GenWorkerEts, WorkerData}) + post_predict -> post_predict({GenWorkerEts, WorkerData}) end. + +% After adding put(Ets) to init this function is not needed get_this_server_ets(GenWorkerEts) -> ets:lookup_element(GenWorkerEts, federated_server_ets, ?ETS_KEYVAL_VAL_IDX). @@ -41,25 +42,29 @@ parse_args(Args) -> end, lists:map(Func, ArgsList). % Returns list of tuples [{Key, Val}, ...] -sync_max_count_init(FedServerEts , ArgsList) -> - case lists:keyfind("sync_max_count", 1, ArgsList) of - false -> Val = ?DEFAULT_SYNC_MAX_COUNT_ARG; - {_, Val} -> list_to_integer(Val) +sync_max_count_init(FedClientEts , ArgsList) -> + case lists:keyfind("SyncMaxCount", 1, ArgsList) of + false -> ValInt = ?DEFAULT_SYNC_MAX_COUNT_ARG; + {_, Val} -> ValInt = list_to_integer(Val) end, - ets:insert(FedServerEts, {sync_max_count, Val}). + ets:insert(FedClientEts, {sync_max_count, ValInt}). %% handshake with workers / server init({GenWorkerEts, WorkerData}) -> FederatedServerEts = ets:new(federated_server,[set]), {MyName, Args, Token , WorkersList} = WorkerData, + BroadcastWorkers = WorkersList -- [MyName], ArgsList = parse_args(Args), sync_max_count_init(FederatedServerEts, ArgsList), ets:insert(GenWorkerEts, {federated_server_ets, FederatedServerEts}), - ets:insert(FederatedServerEts, {workers, WorkersList}), + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + ets:insert(FederatedServerEts, {w2wcom_pid, W2WPid}), + ets:insert(FederatedServerEts, {broadcast_workers_list, BroadcastWorkers}), ets:insert(FederatedServerEts, {fed_clients, []}), ets:insert(FederatedServerEts, {sync_count, 0}), ets:insert(FederatedServerEts, {my_name, MyName}), ets:insert(FederatedServerEts, {token , Token}), + ets:insert(FederatedServerEts, {weights_list, []}), put(fed_server_ets, FederatedServerEts). @@ -67,104 +72,91 @@ init({GenWorkerEts, WorkerData}) -> pre_idle({_GenWorkerEts, _WorkerName}) -> ok. -post_idle({_GenWorkerEts, _WorkerName}) -> - % Extract all workers in nerlnet network - % Send handshake message to all workers - % Wait for all workers to send handshake message back - FedServerEts = get(fed_server_ets), - FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - WorkersList = ets:lookup_element(FedServerEts, workers, ?ETS_KEYVAL_VAL_IDX), - MyToken = ets:lookup_element(FedServerEts, token, ?ETS_KEYVAL_VAL_IDX), - Func = fun(FedClient) -> - w2wCom:send_message(FedClient, FedServerName, {handshake, MyToken}) - end, - lists:foreach(Func, WorkersList), - timer:sleep(?HANDSHAKE_TIMEOUT), - IsEmpty = w2wCom:is_inbox_empty(), - if IsEmpty == true -> - throw("Handshake failed, none of the workers responded in time"); - true -> ok - end, - InboxQueue = w2wCom:get_all_messages(), - MessagesList = queue:to_list(InboxQueue), - MsgFunc = - fun({?W2WCOM_ATOM, FromWorker, _MyName, {handshake, _WorkerToken}}) -> - FedWorkers = ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX), - ets:update_element(FedServerEts, fed_clients, {?ETS_KEYVAL_VAL_IDX , [FromWorker] ++ FedWorkers}) - end, - lists:foreach(MsgFunc, MessagesList), - io:format("Handshake done~n"). +% Extract all workers in nerlnet network +% Send handshake message to all workers +% Wait for all workers to send handshake message back +post_idle({GenWorkerEts, _WorkerName}) -> + HandshakeDone = ets:lookup_element(GenWorkerEts, handshake_done, ?ETS_KEYVAL_VAL_IDX), + case HandshakeDone of + false -> + FedServerEts = get(fed_server_ets), + FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + WorkersList = ets:lookup_element(FedServerEts, broadcast_workers_list, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(FedServerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + MyToken = ets:lookup_element(FedServerEts, token, ?ETS_KEYVAL_VAL_IDX), + Func = fun(FedClient) -> + w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake, MyToken}) + end, + lists:foreach(Func, WorkersList), + timer:sleep(?HANDSHAKE_TIMEOUT), + InboxQueue = w2wCom:get_all_messages(W2WPid), + IsEmpty = queue:len(InboxQueue) == 0, + if IsEmpty == true -> + throw("Handshake failed, none of the workers responded in time"); + true -> ok + end, + MessagesList = queue:to_list(InboxQueue), + MsgFunc = + fun({FedClient, {handshake, _Token}}) -> + FedClients = ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX), + ets:update_element(FedServerEts, fed_clients, {?ETS_KEYVAL_VAL_IDX , [FedClient] ++ FedClients}), + w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake_done, MyToken}) + end, + lists:foreach(MsgFunc, MessagesList), + ets:update_element(GenWorkerEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}), + io:format("**************** @FedServer Handshake done ****************~n"); + true -> ok + end. %% Send updated weights if set pre_train({_GenWorkerEts, _WorkerData}) -> ok. -%% calculate avg of weights when set -post_train({GenWorkerEts, _WorkerData}) -> +% 1. get weights from all workers +% 2. average them +% 3. set new weights to model +% 4. send new weights to all workers +post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerData = [] ThisEts = get_this_server_ets(GenWorkerEts), - SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), - if SyncCount == 0 -> - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - Weights = nerlNIF:call_to_get_weights(GenWorkerEts, ModelID), - ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), - gen_statem:cast(ClientPID, {update, {MyName, MyName, Weights}}), - MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), - ets:update_counter(ThisEts, sync_count, MaxSyncCount), - _ToUpdate = true; - true -> - ets:update_counter(ThisEts, sync_count, -1), - _ToUpdate = false + FedServerEts = get(fed_server_ets), + NumOfWorkers = length(ets:lookup_element(ThisEts, broadcast_workers_list, ?ETS_KEYVAL_VAL_IDX)), + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + InboxQueue = w2wCom:get_all_messages(W2WPid), + MessagesList = queue:to_list(InboxQueue), + ReceivedWeights = [WorkersWeights || {_WorkerName, {WorkersWeights, _BinaryType}} <- MessagesList], + CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX), + TotalWorkersWeights = CurrWorkersWeightsList ++ ReceivedWeights, + case length(TotalWorkersWeights) == NumOfWorkers of + true -> + ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + {CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID), + FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + AllWorkersWeightsList = TotalWorkersWeights ++ [CurrentModelWeights], + AvgWeightsNerlTensor = generate_avg_weights(AllWorkersWeightsList, BinaryType), + nerlNIF:call_to_set_weights(ModelID, AvgWeightsNerlTensor), %% update self weights to new model + Func = fun(FedClient) -> + FedServerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:send_message(W2WPid, FedServerName, FedClient, {update_weights, AvgWeightsNerlTensor}) + end, + WorkersList = ets:lookup_element(ThisEts, broadcast_workers_list, ?ETS_KEYVAL_VAL_IDX), + lists:foreach(Func, WorkersList), + ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, []}); + false -> ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, TotalWorkersWeights}) end. - % ThisEts = get_this_server_ets(GenWorkerEts), - % Weights = generate_avg_weights(ThisEts), - - % gen_statem:cast({update, Weights}). %TODO complete send to all workers in lists:foreach - + %% nothing? pre_predict({_GenWorkerEts, _WorkerData}) -> ok. %% nothing? post_predict({_GenWorkerEts, _WorkerData}) -> ok. -%% FedServer keeps an ets list of tuples: {WorkerName, worker, WeightsAndBiasNerlTensor} -%% in update get weights of clients, if got from all => avg and send back -update({GenWorkerEts, WorkerData}) -> - {WorkerName, _Me, NerlTensorWeights} = WorkerData, - ThisEts = get_this_server_ets(GenWorkerEts), - %% update weights in ets - ets:insert(ThisEts, {WorkerName, worker, NerlTensorWeights}), - - %% check if there are queued messages, and treat them accordingly - MessageQueue = ets:lookup_element(GenWorkerEts, controller_message_q, ?ETS_KEYVAL_VAL_IDX), - % io:format("MessageQ=~p~n",[MessageQueue]), - [ets:insert(ThisEts, {WorkerName, worker, NerlTensorWeights}) || {Action, WorkerName, To, NerlTensorWeights} <- MessageQueue, Action == update], - % reset q - ets:delete(GenWorkerEts, controller_message_q), - ets:insert(GenWorkerEts, {controller_message_q, []}), - - %% check if got all weights of workers - WorkersList = ets:lookup_element(ThisEts, workers, ?ETS_KEYVAL_VAL_IDX), - GotWorkers = [ element(?ETS_WID_IDX, Attr) || Attr <- ets:tab2list(ThisEts), element(?ETS_TYPE_IDX, Attr) == worker], - % io:format("My workers=~p, have vectors from=~p~n",[WorkersList,GotWorkers]), - WaitingFor = WorkersList -- GotWorkers, - - if WaitingFor == [] -> - AvgWeightsNerlTensor = generate_avg_weights(ThisEts), - % io:format("AvgWeights = ~p~n",[AvgWeightsNerlTensor]), - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - nerlNIF:call_to_set_weights(ModelID, AvgWeightsNerlTensor), %% update self weights to new model - [ets:delete(ThisEts, OldWorkerName) || OldWorkerName <- WorkersList ],%% delete old tensors for next aggregation phase - ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - gen_statem:cast(ClientPID, {custom_worker_message, WorkersList, AvgWeightsNerlTensor}), - false; - true -> true end. %% return StillUpdate = true - - -generate_avg_weights(FedEts) -> - BinaryType = ets:lookup_element(FedEts, nerltensor_type, ?ETS_NERLTENSOR_TYPE_IDX), - ListOfWorkersNerlTensors = [ element(?TENSOR_DATA_IDX, element(?ETS_WEIGHTS_AND_BIAS_NERLTENSOR_IDX, Attr)) || Attr <- ets:tab2list(FedEts), element(?ETS_TYPE_IDX, Attr) == worker], - % io:format("Tensors to sum = ~p~n",[ListOfWorkersNerlTensors]), - NerlTensors = length(ListOfWorkersNerlTensors), - [FinalSumNerlTensor] = nerlTensor:sum_nerltensors_lists(ListOfWorkersNerlTensors, BinaryType), + +generate_avg_weights(AllWorkersWeightsList, BinaryType) -> + % io:format("AllWorkersWeightsList = ~p~n",[AllWorkersWeightsList]), + NumNerlTensors = length(AllWorkersWeightsList), + if + NumNerlTensors > 1 -> [FinalSumNerlTensor] = nerlTensor:sum_nerltensors_lists(AllWorkersWeightsList, BinaryType); + true -> FinalSumNerlTensor = hd(AllWorkersWeightsList) + end, % io:format("Summed = ~p~n",[FinalSumNerlTensor]), - nerlNIF:nerltensor_scalar_multiplication_nif(FinalSumNerlTensor, BinaryType, 1.0/NerlTensors). + nerlNIF:nerltensor_scalar_multiplication_nif(FinalSumNerlTensor, BinaryType, 1.0/NumNerlTensors). diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 3816a01c..aebc7784 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -22,7 +22,7 @@ -export([init/1, format_status/2, state_name/3, handle_event/4, terminate/3, code_change/4, callback_mode/0]). %% States functions --export([idle/3, train/3, predict/3, wait/3, update/3]). +-export([idle/3, train/3, predict/3, wait/3]). %% ackClient :: need to tell mainserver that worker is safe and going to new state after wait state @@ -36,8 +36,7 @@ start_link(ARGS) -> %{ok,Pid} = gen_statem:start_link({local, element(1, ARGS)}, ?MODULE, ARGS, []), %% name this machine by unique name {ok,Pid} = gen_statem:start_link(?MODULE, ARGS, []), - W2W_Pid = get(w2wcom_pid), - {Pid , W2W_Pid}. + Pid. %%%=================================================================== %%% gen_statem callbacks @@ -47,16 +46,17 @@ start_link(ARGS) -> %% @doc Whenever a gen_statem is started using gen_statem:start/[3,4] or %% gen_statem:start_link/[3,4], this function is called by the new process to initialize. %% distributedBehaviorFunc is the special behavior of the worker regrading the distributed system e.g. federated client/server -init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , ClientPid , WorkerStatsEts}) -> +init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , ClientPid , WorkerStatsEts , W2WPid}) -> nerl_tools:setup_logger(?MODULE), {ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctionalityCodes, LearningRate , Epochs, OptimizerType, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs} = WorkerArgs, - GenWorkerEts = ets:new(generic_worker,[set]), + GenWorkerEts = ets:new(generic_worker,[set, public]), put(generic_worker_ets, GenWorkerEts), put(client_pid, ClientPid), put(worker_stats_ets , WorkerStatsEts), SourceBatchesEts = ets:new(source_batches,[set]), put(source_batches_ets, SourceBatchesEts), + ets:insert(GenWorkerEts,{w2wcom_pid, W2WPid}), ets:insert(GenWorkerEts,{worker_name, WorkerName}), ets:insert(GenWorkerEts,{model_id, ModelID}), ets:insert(GenWorkerEts,{model_type, ModelType}), @@ -71,10 +71,10 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData ets:insert(GenWorkerEts,{optimizer_args, OptimizerArgs}), ets:insert(GenWorkerEts,{distributed_system_args, DistributedSystemArgs}), ets:insert(GenWorkerEts,{distributed_system_type, DistributedSystemType}), - ets:insert(GenWorkerEts,{controller_message_q, []}), %% empty Queue TODO Deprecated + ets:insert(GenWorkerEts,{controller_message_q, []}), %% TODO Deprecated + ets:insert(GenWorkerEts,{handshake_done, false}), % Worker to Worker communication module - this is a gen_server - W2wComPid = w2wCom:start_link({WorkerName, ClientPid}), - put(w2wcom_pid, W2wComPid), + Res = nerlNIF:new_nerlworker_nif(ModelID , ModelType, ModelArgs, LayersSizes, LayersTypes, LayersFunctionalityCodes, LearningRate, Epochs, OptimizerType, OptimizerArgs, LossMethod , DistributedSystemType , DistributedSystemArgs), @@ -86,7 +86,7 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData ?LOG_ERROR("Failed to create worker ~p\n",[WorkerName]), exit(nif_failed_to_create) end, - + DistributedBehaviorFunc(pre_idle,{GenWorkerEts, DistributedWorkerData}), {ok, idle, #workerGeneric_state{myName = WorkerName , modelID = ModelID , distributedBehaviorFunc = DistributedBehaviorFunc , distributedWorkerData = DistributedWorkerData}}. %% @private @@ -139,26 +139,18 @@ code_change(_OldVsn, StateName, State = #workerGeneric_state{}, _Extra) -> %% State idle -%% Event from clientStatem -idle(cast, {pre_idle}, State = #workerGeneric_state{myName = _MyName,distributedBehaviorFunc = DistributedBehaviorFunc}) -> - DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), empty}), - {next_state, idle, State}; - -%% Event from clientStatem -idle(cast, {post_idle, From}, State = #workerGeneric_state{myName = _MyName,distributedBehaviorFunc = DistributedBehaviorFunc}) -> - DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), From}), - {next_state, idle, State}; - % Go from idle to train -idle(cast, {training}, State = #workerGeneric_state{myName = MyName}) -> +idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> worker_controller_empty_message_queue(), + DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), train}), update_client_avilable_worker(MyName), {next_state, train, State#workerGeneric_state{lastPhase = train}}; % Go from idle to predict -idle(cast, {predict}, State = #workerGeneric_state{myName = MyName}) -> +idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> worker_controller_empty_message_queue(), update_client_avilable_worker(MyName), + DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), {next_state, predict, State#workerGeneric_state{lastPhase = predict}}; idle(cast, _Param, State) -> @@ -173,10 +165,10 @@ wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneri {next_state, NextState, State}; -wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> +wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc}) -> BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), - DistributedBehaviorFunc(post_train, {get(generic_worker_ets),DistributedWorkerData}), %% Change to W2WComm + DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients {next_state, NextState, State}; wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> @@ -238,11 +230,13 @@ train(cast, {sample, SourceName ,BatchID ,{NerlTensorOfSamples, NerlTensorType}} %% TODO: implement send model and weights by demand (Tensor / XML) train(cast, {set_weights,Ret_weights_list}, State = #workerGeneric_state{modelID = ModelId}) -> %% Set weights - %io:format("####sending new weights to workers####~n"), nerlNIF:call_to_set_weights(ModelId, Ret_weights_list), %% TODO wrong usage %logger:notice("####end set weights train####~n"), {next_state, train, State}; +train(cast, {post_train_update} ,State = #workerGeneric_state{distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), + {next_state, train, State}; train(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), @@ -267,7 +261,6 @@ predict(cast, {sample , SourceName , BatchID , {PredictBatchTensor, Type}}, Stat DistributedBehaviorFunc(pre_predict, {get(generic_worker_ets),DistributedWorkerData}), WorkersStatsEts = get(worker_stats_ets), stats:increment_by_value(WorkersStatsEts , batches_received_predict , 1), - %% io:format("Pred Tensor: ~p~n",[nerlNIF:nerltensor_conversion({PredictBatchTensor , Type} , nerlNIF:erl_type_conversion(Type))]), _Pid = spawn(fun()-> nerlNIF:call_to_predict(ModelId , {PredictBatchTensor, Type} , CurrPID , BatchID, SourceName) end), {next_state, wait, State#workerGeneric_state{nextState = predict , currentBatchID = BatchID}}; diff --git a/src_erl/NerlnetApp/src/Client/clientStateHandler.erl b/src_erl/NerlnetApp/src/Client/clientStateHandler.erl index e9c6d09d..07141d54 100644 --- a/src_erl/NerlnetApp/src/Client/clientStateHandler.erl +++ b/src_erl/NerlnetApp/src/Client/clientStateHandler.erl @@ -17,7 +17,6 @@ init(Req0, [Action,Client_StateM_Pid]) -> {ok,Body,_} = cowboy_req:read_body(Req0), -%% io:format("client state_handler got body:~p~n",[Body]), case Action of worker_to_worker_msg -> {worker_to_worker_msg , From , To , Data} = binary_to_term(Body), gen_statem:cast(Client_StateM_Pid,{worker_to_worker_msg , From , To , Data}); diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 258fb833..7d75845b 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -72,7 +72,7 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh inets:start(), ?LOG_INFO("Client ~p is connected to: ~p~n",[MyName, [digraph:vertex(NerlnetGraph,Vertex) || Vertex <- digraph:out_neighbours(NerlnetGraph,MyName)]]), % nerl_tools:start_connection([digraph:vertex(NerlnetGraph,Vertex) || Vertex <- digraph:out_neighbours(NerlnetGraph,MyName)]), - EtsRef = ets:new(client_data, [set]), %% client_data is responsible for functional attributes + EtsRef = ets:new(client_data, [set, public]), %% client_data is responsible for functional attributes EtsStats = ets:new(ets_stats, [set]), %% ets_stats is responsible for holding all the ets stats (client + workers) ClientStatsEts = stats:generate_stats_ets(), %% client stats ets inside ets_stats ets:insert(EtsStats, {MyName, ClientStatsEts}), @@ -84,13 +84,14 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh MyWorkersToShaMap = maps:filter(fun(Worker , _SHA) -> lists:member(Worker , ClientWorkers) end , WorkerShaMap), ets:insert(EtsRef, {workers_to_sha_map, MyWorkersToShaMap}), ets:insert(EtsRef, {sha_to_models_map , ShaToModelArgsMap}), + ets:insert(EtsRef, {w2wcom_pids, #{}}), {MyRouterHost,MyRouterPort} = nerl_tools:getShortPath(MyName,?MAIN_SERVER_ATOM, NerlnetGraph), ets:insert(EtsRef, {my_router,{MyRouterHost,MyRouterPort}}), - clientWorkersFunctions:create_workers(MyName , EtsRef , ShaToModelArgsMap , EtsStats), %% send pre_idle signal to workers WorkersNames = clientWorkersFunctions:get_workers_names(EtsRef), - [gen_statem:cast(clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), {pre_idle}) || WorkerName <- WorkersNames], + Pids = [clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName) || WorkerName <- WorkersNames], + [gen_statem:cast(WorkerPid, {pre_idle}) || WorkerPid <- Pids], % update dictionary WorkersEts = ets:lookup_element(EtsRef , workers_ets , ?DATA_IDX), @@ -125,6 +126,13 @@ waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state _-> {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = NewWaitforWorkers}} end; +waitforWorkers(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data), + {keep_state, State}; + waitforWorkers(cast, In = {NewState}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -155,12 +163,10 @@ idle(cast, _In = {statistics}, State = #client_statem_state{ myName = MyName, et EtsStats = get(ets_stats), ClientStatsEts = get(client_stats_ets), ClientStatsEncStr = stats:encode_ets_to_http_bin_str(ClientStatsEts), - %ClientStatsToSend = atom_to_list(MyName) ++ ?API_SERVER_WITHIN_ENTITY_SEPERATOR ++ ClientStatsEncStr ++ ?API_SERVER_ENTITY_SEPERATOR, stats:increment_messages_received(ClientStatsEts), ListStatsEts = ets:tab2list(EtsStats) -- [{MyName , ClientStatsEts}], WorkersStatsEncStr = create_encoded_stats_str(ListStatsEts), DataToSend = ClientStatsEncStr ++ WorkersStatsEncStr, - %% io:format("DataToSend: ~p~n",[DataToSend]), StatsBody = {MyName , DataToSend}, {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), nerl_tools:http_router_request(RouterHost, RouterPort, [?MAIN_SERVER_ATOM], atom_to_list(statistics), StatsBody), @@ -170,7 +176,8 @@ idle(cast, _In = {statistics}, State = #client_statem_state{ myName = MyName, et idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), MessageToCast = {training}, + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + MessageToCast = {training}, cast_message_to_workers(EtsRef, MessageToCast), {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = clientWorkersFunctions:get_workers_names(EtsRef), nextState = training}}; @@ -210,25 +217,6 @@ training(cast, MessageIn = {update, {From, To, Data}}, State = #client_statem_st {keep_state, State}; -%% This is a generic way to move data from worker to worker -%% TODO fix variables names to make it more generic -%% federated server sends AvgWeights to workers -% training(cast, InMessage = {custom_worker_message, WorkersList, WeightsTensor}, State = #client_statem_state{etsRef = EtsRef}) -> -% ClientStatsEts = get(client_stats_ets), -% stats:increment_messages_received(ClientStatsEts), -% stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(InMessage)), -% Func = fun(WorkerName) -> -% DestClient = maps:get(WorkerName, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), -% MessageBody = term_to_binary({DestClient, update, {_FedServer = "server", WorkerName, WeightsTensor}}), % TODO - fix client should not be aware of the data of custom worker message - -% {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), -% nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(custom_worker_message), MessageBody), -% stats:increment_messages_sent(ClientStatsEts), -% stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) -% end, -% lists:foreach(Func, WorkersList), % can be optimized with broadcast instead of unicast -% {keep_state, State}; - training(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -409,18 +397,21 @@ create_encoded_stats_str(ListStatsEts) -> handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data) -> ClientStatsEts = get(client_stats_ets), - WorkerOfThisClient = ets:member(EtsRef, ToWorker), - if WorkerOfThisClient -> - % Extract W2WPID from Ets - TargetWorkerW2WPID = ets:lookup_element(get(workers_ets), ToWorker, ?W2W_PID_IDX), - gen_statem:cast(TargetWorkerW2WPID,{worker_to_worker_msg, FromWorker, ToWorker, Data}), - stats:increment_messages_sent(ClientStatsEts); - true -> - %% Send to the correct client - DestClient = maps:get(ToWorker, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), - MessageBody = {worker_to_worker_msg, FromWorker, ToWorker, Data}, - {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), - nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(worker_to_worker_msg), term_to_binary(MessageBody)), - stats:increment_messages_sent(ClientStatsEts), - stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) + WorkersOfThisClient = ets:lookup_element(EtsRef, workersNames, ?DATA_IDX), + WorkerOfThisClient = lists:member(ToWorker, WorkersOfThisClient), + case WorkerOfThisClient of + true -> + % Extract W2WPID from Ets + W2WPidsMap = ets:lookup_element(EtsRef, w2wcom_pids, ?DATA_IDX), + TargetWorkerW2WPID = maps:get(ToWorker, W2WPidsMap), + {ok, _Reply} = gen_server:call(TargetWorkerW2WPID, {worker_to_worker_msg, FromWorker, ToWorker, Data}), + stats:increment_messages_sent(ClientStatsEts); + _ -> + %% Send to the correct client + DestClient = maps:get(ToWorker, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), + MessageBody = {worker_to_worker_msg, FromWorker, ToWorker, Data}, + {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), + nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(worker_to_worker_msg), MessageBody), + stats:increment_messages_sent(ClientStatsEts), + stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl index c7d0038d..d859fea0 100644 --- a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl +++ b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl @@ -46,12 +46,17 @@ create_workers(ClientName, ClientEtsRef , ShaToModelArgsMap , EtsStats) -> % TODO add documentation about this case of % move this case to module called client_controller {DistributedBehaviorFunc , DistributedWorkerData} = get_distributed_worker_behavior(ClientEtsRef, DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken), + W2wComPid = w2wCom:start_link({WorkerName, MyClientPid}), % TODO Switch to monitor instead of link WorkerArgs = {ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctions, LearningRate , Epochs, Optimizer, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs}, - {WorkerPid , W2W_Pid} = workerGeneric:start_link({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , MyClientPid , WorkerStatsETS}), - ets:insert(WorkersETS, {WorkerName, {WorkerPid, W2W_Pid, WorkerArgs}}), + WorkerPid = workerGeneric:start_link({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , MyClientPid , WorkerStatsETS , W2wComPid}), + gen_server:cast(W2wComPid, {update_gen_worker_pid, WorkerPid}), + ets:insert(WorkersETS, {WorkerName, {WorkerPid, WorkerArgs}}), ets:insert(EtsStats, {WorkerName, WorkerStatsETS}), + W2WPidMap = ets:lookup_element(ClientEtsRef, w2wcom_pids, ?DATA_IDX), + W2WPidMapNew = maps:put(WorkerName, W2wComPid, W2WPidMap), + ets:update_element(ClientEtsRef, w2wcom_pids, {?DATA_IDX, W2WPidMapNew}), WorkerName end, diff --git a/src_erl/NerlnetApp/src/nerlnetApp_app.erl b/src_erl/NerlnetApp/src/nerlnetApp_app.erl index 4c9836ff..bb058afe 100644 --- a/src_erl/NerlnetApp/src/nerlnetApp_app.erl +++ b/src_erl/NerlnetApp/src/nerlnetApp_app.erl @@ -245,7 +245,8 @@ createClientsAndWorkers() -> {"/clientTraining",clientStateHandler, [training,ClientStatemPid]}, {"/clientIdle",clientStateHandler, [idle,ClientStatemPid]}, {"/clientPredict",clientStateHandler, [predict,ClientStatemPid]}, - {"/batch",clientStateHandler, [batch,ClientStatemPid]} + {"/batch",clientStateHandler, [batch,ClientStatemPid]}, + {"/worker_to_worker_msg",clientStateHandler, [worker_to_worker_msg,ClientStatemPid]} ]} ]), init_cowboy_start_clear(Client, {DeviceName, Port},NerlClientDispatch) diff --git a/src_py/apiServer/apiServer.py b/src_py/apiServer/apiServer.py index 89267411..ca83815f 100644 --- a/src_py/apiServer/apiServer.py +++ b/src_py/apiServer/apiServer.py @@ -210,11 +210,11 @@ def list_datasets(self): repo_csv_files = [file for file in files if file.endswith('.csv')] datasets[repo["id"]] = repo_csv_files for i , (repo_name , files) in enumerate(datasets.items()): - LOG_INFO(f'{i}. {repo_name}: {files}') + print(f'{i}. {repo_name}: {files}') except utils._errors.RepositoryNotFoundError: LOG_INFO(f"Failed to find the repository '{repo}'. Check your '{HF_DATA_REPO_PATHS_JSON}' file or network access.") - def download_dataset(self, repo_idx : int | list[int], download_dir_path : str = DEFAULT_NERLNET_TMP_DATA_DIR): + def download_dataset(self, repo_idx : int, download_dir_path : str = DEFAULT_NERLNET_TMP_DATA_DIR): with open(HF_DATA_REPO_PATHS_JSON) as file: repo_ids = json.load(file) try: From 42868f89069fe68443005776d4074cbf96ef7422 Mon Sep 17 00:00:00 2001 From: GuyPErets106 Date: Sun, 19 May 2024 20:02:52 +0000 Subject: [PATCH 08/52] [W2W] Added DC Dist_Fed Json --- .../dc_fed_dist_2d_3c_2s_3r_6w.json | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..a0235157 --- /dev/null +++ b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,176 @@ +{ + "nerlnetSettings": { + "frequency": "100", + "batchSize": "100" + }, + "mainServer": { + "port": "8900", + "args": "" + }, + "apiServer": { + "port": "8901", + "args": "" + }, + "devices": [ + { + "name": "c0vm0", + "ipv4": "10.0.0.5", + "entities": "mainServer,c1,c2,r1,r2,s1,apiServer" + }, + { + "name": "c0vm7", + "ipv4": "10.0.0.12", + "entities": "c3,r3,s2" + } + ], + "routers": [ + { + "name": "r1", + "port": "8905", + "policy": "0" + }, + { + "name": "r2", + "port": "8906", + "policy": "0" + }, + { + "name": "r3", + "port": "8901", + "policy": "0" + } + ], + "sources": [ + { + "name": "s1", + "port": "8904", + "frequency": "200", + "policy": "0", + "epochs": "1", + "type": "0" + }, + { + "name": "s2", + "port": "8902", + "frequency": "100", + "policy": "0", + "epochs": "1", + "type": "0" + } + ], + "clients": [ + { + "name": "c1", + "port": "8902", + "workers": "w1,w2,ws" + }, + { + "name": "c2", + "port": "8903", + "workers": "w3,w4" + }, + { + "name": "c3", + "port": "8900", + "workers": "w5,w6" + } + ], + "workers": [ + { + "name": "w1", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w2", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "ws", + "model_sha": "c081daf49b8332585243b68cb828ebc9b947528601a6852688cea0312b3e3914" + }, + { + "name": "w3", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w4", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w5", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w6", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + } + ], + "model_sha": { + "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,2,2,2,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "1", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "SyncMaxCount=5", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + }, + "c081daf49b8332585243b68cb828ebc9b947528601a6852688cea0312b3e3914": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,2,2,2,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "2", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "SyncMaxCount=5", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + } + } +} \ No newline at end of file From 1a67f1ae08f136b0143f3c207686d184fdedf55f Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Sun, 19 May 2024 20:28:38 +0000 Subject: [PATCH 09/52] [W2W] Testing Distributed --- .../conn_fed_dist_2d_3c_2s_3r_6w.json | 8 +++ .../exp_fed_dist_2d_3c_2s_3r_6w.json | 54 +++++++++++++++++++ .../exp_fed_synt_1d_2c_2r_1s_4w_1ws.json | 1 + .../onnWorkers/workerFederatedClient.erl | 3 +- .../onnWorkers/workerFederatedServer.erl | 5 +- .../NerlnetApp/src/Source/sourceStatem.erl | 2 +- 6 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 inputJsonsFiles/ConnectionMap/conn_fed_dist_2d_3c_2s_3r_6w.json create mode 100644 inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json diff --git a/inputJsonsFiles/ConnectionMap/conn_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/ConnectionMap/conn_fed_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..9ffce810 --- /dev/null +++ b/inputJsonsFiles/ConnectionMap/conn_fed_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,8 @@ +{ + "connectionsMap": + { + "r1":["mainServer", "r2" , "c2" , "r3"], + "r2":["r1", "s1" , "c1" , "r3"], + "r3":["r1", "r2" , "s2" , "c3"] + } +} diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..754c0016 --- /dev/null +++ b/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,54 @@ +{ + "experimentName": "synthetic_3_gausians", + "experimentType": "classification", + "batchSize": 100, + "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", + "numOfFeatures": "5", + "numOfLabels": "3", + "headersNames": "Norm(0:1),Norm(4:1),Norm(10:3)", + "Phases": + [ + { + "phaseName": "training_phase", + "phaseType": "training", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "0", + "numOfBatches": "100", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + }, + { + "sourceName": "s2", + "startingSample": "10000", + "numOfBatches": "100", + "workers": "w5,w6", + "nerltensorType": "float" + } + ] + }, + { + "phaseName": "prediction_phase", + "phaseType": "prediction", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "20000", + "numOfBatches": "100", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + }, + { + "sourceName": "s2", + "startingSample": "30000", + "numOfBatches": "100", + "workers": "w5,w6", + "nerltensorType": "float" + } + ] + } +] +} \ No newline at end of file diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json index 34df8a56..7c7c32ff 100644 --- a/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json +++ b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -1,5 +1,6 @@ { "experimentName": "synthetic_3_gausians", + "experimentType": "classification", "batchSize": 100, "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", "numOfFeatures": "5", diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 94da7180..dc3029e5 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -106,11 +106,12 @@ post_idle({GenWorkerEts, _WorkerData}) -> pre_train({GenWorkerEts, _NerlTensorWeights}) -> ThisEts = get_this_client_ets(GenWorkerEts), SyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_count, ?ETS_KEYVAL_VAL_IDX), + WorkerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), if SyncCount == MaxSyncCount -> W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:sync_inbox(W2WPid), % waiting for server to average the weights and send it - WorkerName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), InboxQueue = w2wCom:get_all_messages(W2WPid), [UpdateWeightsMsg] = queue:to_list(InboxQueue), {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index ca6e0459..54de7ed6 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -103,8 +103,7 @@ post_idle({GenWorkerEts, _WorkerName}) -> w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake_done, MyToken}) end, lists:foreach(MsgFunc, MessagesList), - ets:update_element(GenWorkerEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}), - io:format("**************** @FedServer Handshake done ****************~n"); + ets:update_element(GenWorkerEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}); true -> ok end. @@ -118,7 +117,7 @@ pre_train({_GenWorkerEts, _WorkerData}) -> ok. post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerData = [] ThisEts = get_this_server_ets(GenWorkerEts), FedServerEts = get(fed_server_ets), - NumOfWorkers = length(ets:lookup_element(ThisEts, broadcast_workers_list, ?ETS_KEYVAL_VAL_IDX)), + NumOfWorkers = length(ets:lookup_element(ThisEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), InboxQueue = w2wCom:get_all_messages(W2WPid), MessagesList = queue:to_list(InboxQueue), diff --git a/src_erl/NerlnetApp/src/Source/sourceStatem.erl b/src_erl/NerlnetApp/src/Source/sourceStatem.erl index ecc44850..ec5e320d 100644 --- a/src_erl/NerlnetApp/src/Source/sourceStatem.erl +++ b/src_erl/NerlnetApp/src/Source/sourceStatem.erl @@ -124,7 +124,7 @@ idle(cast, {batchList, WorkersList, NumOfBatches, NerlTensorType, Data}, State) ?LOG_NOTICE("Source ~p, workers are: ~p", [MyName, WorkersList]), ?LOG_NOTICE("Source ~p, sample size: ~p", [MyName, SampleSize]), ets:update_element(EtsRef, sample_size, [{?DATA_IDX, SampleSize}]), - ?LOG_INFO("Source ~p updated transmission list, total avilable batches to send: ~p~n",[MyName, NumOfBatches]), + ?LOG_INFO("Source ~p updated transmission list, total available batches to send: ~p~n",[MyName, NumOfBatches]), %% send an ACK to mainserver that the CSV file is ready {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), nerl_tools:http_router_request(RouterHost, RouterPort, [?MAIN_SERVER_ATOM], atom_to_list(dataReady), MyName), From c62b37b62cd49f0f5901b7ec28a542d89c31b14b Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Mon, 20 May 2024 13:49:06 +0000 Subject: [PATCH 10/52] [W2W] Added start/end_stream messages --- .../dc_fed_dist_2d_3c_2s_3r_6w.json | 2 +- .../exp_fed_dist_2d_3c_2s_3r_6w.json | 14 ++--- .../src/Bridge/onnWorkers/w2wCom.erl | 11 ++++ .../onnWorkers/workerFederatedClient.erl | 56 ++++++++++++++----- .../onnWorkers/workerFederatedServer.erl | 25 +++++++-- .../src/Bridge/onnWorkers/workerGeneric.erl | 53 +++++++++++++++--- .../src/Bridge/onnWorkers/workerNN.erl | 8 +++ .../src/Client/clientStateHandler.erl | 4 +- .../NerlnetApp/src/Client/clientStatem.erl | 47 ++++++++++++++-- .../NerlnetApp/src/Source/sourceStatem.erl | 16 ++++++ src_erl/NerlnetApp/src/nerlnetApp_app.erl | 4 +- 11 files changed, 200 insertions(+), 40 deletions(-) diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json index a0235157..0af11814 100644 --- a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json +++ b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json @@ -52,7 +52,7 @@ { "name": "s2", "port": "8902", - "frequency": "100", + "frequency": "200", "policy": "0", "epochs": "1", "type": "0" diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json index 754c0016..83011c96 100644 --- a/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json +++ b/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json @@ -16,14 +16,14 @@ { "sourceName": "s1", "startingSample": "0", - "numOfBatches": "100", + "numOfBatches": "200", "workers": "w1,w2,w3,w4", "nerltensorType": "float" }, { "sourceName": "s2", - "startingSample": "10000", - "numOfBatches": "100", + "startingSample": "20000", + "numOfBatches": "200", "workers": "w5,w6", "nerltensorType": "float" } @@ -36,15 +36,15 @@ [ { "sourceName": "s1", - "startingSample": "20000", - "numOfBatches": "100", + "startingSample": "40000", + "numOfBatches": "50", "workers": "w1,w2,w3,w4", "nerltensorType": "float" }, { "sourceName": "s2", - "startingSample": "30000", - "numOfBatches": "100", + "startingSample": "45000", + "numOfBatches": "50", "workers": "w5,w6", "nerltensorType": "float" } diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index 08a75d41..aca88671 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -48,6 +48,17 @@ handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {post_train_update, D gen_server:cast(get(gen_worker_pid), {post_train_update}), {reply, {ok, post_train_update}, State}; +handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {worker_done, Data}}, _From, State) -> + case get(worker_name) of + ThisWorkerName -> ok; + _ -> throw({error, "The provided worker name is not this worker"}) + end, + % Saved messages are of the form: {FromWorkerName, , Data} + Message = {FromWorkerName, Data}, + add_msg_to_inbox_queue(Message), + gen_server:cast(get(gen_worker_pid), {worker_done}), + {reply, {ok, worker_done}, State}; + % Received messages are of the form: {worker_to_worker_msg, FromWorkerName, ThisWorkerName, Data} handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, Data}, _From, State) -> case get(worker_name) of diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index dc3029e5..b18e0713 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -10,7 +10,7 @@ -define(WORKER_FEDERATED_CLIENT_ETS_FIELDS, [my_name, client_pid, server_name, sync_max_count, sync_count]). -define(FEDERATED_CLIENT_ETS_KEY_IN_GENWORKER_ETS, fedrated_client_ets). --define(DEFAULT_SYNC_MAX_COUNT_ARG, 1). +-define(DEFAULT_SYNC_MAX_COUNT_ARG, 100). controller(FuncName, {GenWorkerEts, WorkerData}) -> case FuncName of @@ -20,7 +20,10 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> pre_train -> pre_train({GenWorkerEts, WorkerData}); post_train -> post_train({GenWorkerEts, WorkerData}); pre_predict -> pre_predict({GenWorkerEts, WorkerData}); - post_predict-> post_predict({GenWorkerEts, WorkerData}) + post_predict -> post_predict({GenWorkerEts, WorkerData}); + start_stream -> start_stream({GenWorkerEts, WorkerData}); + end_stream -> end_stream({GenWorkerEts, WorkerData}); + worker_done -> worker_done({GenWorkerEts, WorkerData}) end. get_this_client_ets(GenWorkerEts) -> @@ -59,6 +62,7 @@ init({GenWorkerEts, WorkerData}) -> ets:insert(FedratedClientEts, {handshake_done, false}), ets:insert(FedratedClientEts, {handshake_wait, false}), ets:insert(FedratedClientEts, {w2wcom_pid, W2WPid}), + ets:insert(FedratedClientEts, {casting_sources, []}), spawn(fun() -> handshake(FedratedClientEts) end). handshake(FedClientEts) -> @@ -80,6 +84,24 @@ handshake(FedClientEts) -> end, lists:foreach(Func, MessagesList). +start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] + SourceName = hd(WorkerData), + ThisEts = get_this_client_ets(GenWorkerEts), + ets:update_element(ThisEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, true}), + CastingSources = ets:lookup_element(ThisEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), + NewCastingSources = CastingSources ++ [SourceName], + ets:update_element(ThisEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}). + % ***** Add SourcesList ***** + +end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] + SourceName = hd(WorkerData), + ThisEts = get_this_client_ets(GenWorkerEts), + ets:update_element(ThisEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, false}), + CastingSources = ets:lookup_element(ThisEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), + NewCastingSources = CastingSources -- [SourceName], + ets:update_element(ThisEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}). + + pre_idle({_GenWorkerEts, _WorkerData}) -> ok. post_idle({GenWorkerEts, _WorkerData}) -> @@ -123,17 +145,22 @@ pre_train({GenWorkerEts, _NerlTensorWeights}) -> %% every countLimit batches, send updated weights post_train({GenWorkerEts, _WorkerData}) -> - ThisEts = get_this_client_ets(GenWorkerEts), - SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), - MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), - if SyncCount == MaxSyncCount -> - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - Weights = nerlNIF:call_to_get_weights(ModelID), - ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), - W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:send_message(W2WPid, MyName, ServerName , {post_train_update, Weights}); %% ****** NEW - TEST NEEDED ****** - true -> ok + CastingSources = ets:lookup_element(get_this_client_ets(GenWorkerEts), casting_sources, ?ETS_KEYVAL_VAL_IDX), + case CastingSources of + [] -> ok; + _ -> + ThisEts = get_this_client_ets(GenWorkerEts), + SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), + MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), + if SyncCount == MaxSyncCount -> + ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + Weights = nerlNIF:call_to_get_weights(ModelID), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:send_message(W2WPid, MyName, ServerName , {post_train_update, Weights}); %% ****** NEW - TEST NEEDED ****** + true -> ok + end end. %% nothing? @@ -142,3 +169,6 @@ pre_predict({_GenWorkerEts, WorkerData}) -> WorkerData. %% nothing? post_predict(Data) -> Data. +worker_done({_GenWorkerEts, _WorkerData}) -> ok. + + diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 54de7ed6..936cc657 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -26,7 +26,10 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> pre_train -> pre_train({GenWorkerEts, WorkerData}); post_train -> post_train({GenWorkerEts, WorkerData}); pre_predict -> pre_predict({GenWorkerEts, WorkerData}); - post_predict -> post_predict({GenWorkerEts, WorkerData}) + post_predict -> post_predict({GenWorkerEts, WorkerData}); + start_stream -> start_stream({GenWorkerEts, WorkerData}); + end_stream -> end_stream({GenWorkerEts, WorkerData}); + worker_done -> worker_done({GenWorkerEts, WorkerData}) end. @@ -61,6 +64,7 @@ init({GenWorkerEts, WorkerData}) -> ets:insert(FederatedServerEts, {w2wcom_pid, W2WPid}), ets:insert(FederatedServerEts, {broadcast_workers_list, BroadcastWorkers}), ets:insert(FederatedServerEts, {fed_clients, []}), + ets:insert(FederatedServerEts, {training_workers , []}), ets:insert(FederatedServerEts, {sync_count, 0}), ets:insert(FederatedServerEts, {my_name, MyName}), ets:insert(FederatedServerEts, {token , Token}), @@ -68,10 +72,21 @@ init({GenWorkerEts, WorkerData}) -> put(fed_server_ets, FederatedServerEts). +start_stream({_GenWorkerEts, _WorkerData}) -> ok. + +end_stream({_GenWorkerEts, _WorkerData}) -> ok. pre_idle({_GenWorkerEts, _WorkerName}) -> ok. +worker_done({GenWorkerEts, WorkerData}) -> + WorkerName = hd(WorkerData), + ThisEts = get_this_server_ets(GenWorkerEts), + TrainingWorkers = ets:lookup_element(ThisEts, training_workers, ?ETS_KEYVAL_VAL_IDX), + UpdatedTrainingWorkers = lists:delete(WorkerName, TrainingWorkers), + ets:update_element(ThisEts, training_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedTrainingWorkers}). + + % Extract all workers in nerlnet network % Send handshake message to all workers % Wait for all workers to send handshake message back @@ -103,6 +118,8 @@ post_idle({GenWorkerEts, _WorkerName}) -> w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake_done, MyToken}) end, lists:foreach(MsgFunc, MessagesList), + UpdatedTrainingWorkers = ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX), + ets:update_element(FedServerEts, training_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedTrainingWorkers}), ets:update_element(GenWorkerEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}); true -> ok end. @@ -117,14 +134,14 @@ pre_train({_GenWorkerEts, _WorkerData}) -> ok. post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerData = [] ThisEts = get_this_server_ets(GenWorkerEts), FedServerEts = get(fed_server_ets), - NumOfWorkers = length(ets:lookup_element(ThisEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), + NumOfTrainingWorkers = length(ets:lookup_element(ThisEts, training_workers, ?ETS_KEYVAL_VAL_IDX)), W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), InboxQueue = w2wCom:get_all_messages(W2WPid), MessagesList = queue:to_list(InboxQueue), ReceivedWeights = [WorkersWeights || {_WorkerName, {WorkersWeights, _BinaryType}} <- MessagesList], CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX), TotalWorkersWeights = CurrWorkersWeightsList ++ ReceivedWeights, - case length(TotalWorkersWeights) == NumOfWorkers of + case length(TotalWorkersWeights) == NumOfTrainingWorkers of % Why not timeout true -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), {CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID), @@ -137,7 +154,7 @@ post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerD W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:send_message(W2WPid, FedServerName, FedClient, {update_weights, AvgWeightsNerlTensor}) end, - WorkersList = ets:lookup_element(ThisEts, broadcast_workers_list, ?ETS_KEYVAL_VAL_IDX), + WorkersList = ets:lookup_element(ThisEts, training_workers, ?ETS_KEYVAL_VAL_IDX), lists:foreach(Func, WorkersList), ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, []}); false -> ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, TotalWorkersWeights}) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index aebc7784..d7c38300 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -141,6 +141,7 @@ code_change(_OldVsn, StateName, State = #workerGeneric_state{}, _Extra) -> % Go from idle to train idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("@idle got training , Worker ~p is going to state idle...~n",[MyName]), worker_controller_empty_message_queue(), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), train}), update_client_avilable_worker(MyName), @@ -148,13 +149,14 @@ idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distribute % Go from idle to predict idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("@idle got predict , Worker ~p is going to state idle...~n",[MyName]), worker_controller_empty_message_queue(), update_client_avilable_worker(MyName), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), {next_state, predict, State#workerGeneric_state{lastPhase = predict}}; -idle(cast, _Param, State) -> - % io:fwrite("Same state idle, command: ~p\n",[Param]), +idle(cast, _Param, State = #workerGeneric_state{myName = MyName}) -> + io:format("@idle Worker ~p is going to state idle...~n",[MyName]), {next_state, idle, State}. %% Waiting for receiving results or loss function @@ -181,21 +183,27 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S {next_state, NextState, State} end; -wait(cast, {idle}, State) -> +wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> %logger:notice("Waiting, next state - idle"), + io:format("Worker ~p @wait is going to state idle...~n",[MyName]), + update_client_avilable_worker(MyName), + DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), {next_state, wait, State#workerGeneric_state{nextState = idle}}; wait(cast, {training}, State) -> %logger:notice("Waiting, next state - train"), + io:format("@wait got training , Worker is going to state idle...~n"), % gen_statem:cast(ClientPid,{stateChange,WorkerName}), {next_state, wait, State#workerGeneric_state{nextState = train}}; wait(cast, {predict}, State) -> + io:format("@wait got predict , Worker is going to state idle...~n"), %logger:notice("Waiting, next state - predict"), {next_state, wait, State#workerGeneric_state{nextState = predict}}; %% Worker in wait can't treat incoming message wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase}) -> + io:format("@wait got something , Worker ~p is going to state idle...~n"), case LastPhase of train -> ets:update_counter(get(worker_stats_ets), batches_dropped_train , 1); @@ -206,26 +214,27 @@ wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase}) -> wait(cast, Data, State) -> % logger:notice("worker ~p in wait cant treat message: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), + io:format("@wait got something2 , Worker is going to state idle...~n"), worker_controller_message_queue(Data), {keep_state, State}. %% State train train(cast, {sample, BatchID ,{<<>>, _Type}}, State) -> - ?LOG_WARNING("Empty sample received , batch id: ~p",[BatchID]), + ?LOG_WARNING("Empty sample received , batch id: ~p~n",[BatchID]), WorkerStatsEts = get(worker_stats_ets), stats:increment_by_value(WorkerStatsEts , empty_batches , 1), {next_state, train, State#workerGeneric_state{nextState = train , currentBatchID = BatchID}}; %% Change SampleListTrain to NerlTensor -train(cast, {sample, SourceName ,BatchID ,{NerlTensorOfSamples, NerlTensorType}}, State = #workerGeneric_state{modelID = ModelId, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> +train(cast, {sample, SourceName ,BatchID ,{NerlTensorOfSamples, NerlTensorType}}, State = #workerGeneric_state{modelID = ModelId, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData, myName = MyName}) -> % NerlTensor = nerltensor_conversion({NerlTensorOfSamples, Type}, erl_float), MyPid = self(), - NewWorkerData = DistributedBehaviorFunc(pre_train, {get(generic_worker_ets),DistributedWorkerData}), + DistributedBehaviorFunc(pre_train, {get(generic_worker_ets),DistributedWorkerData}), WorkersStatsEts = get(worker_stats_ets), stats:increment_by_value(WorkersStatsEts , batches_received_train , 1), _Pid = spawn(fun()-> nerlNIF:call_to_train(ModelId , {NerlTensorOfSamples, NerlTensorType} ,MyPid , BatchID , SourceName) end), - {next_state, wait, State#workerGeneric_state{nextState = train, distributedWorkerData = NewWorkerData , currentBatchID = BatchID}}; + {next_state, wait, State#workerGeneric_state{nextState = train, currentBatchID = BatchID}}; %% TODO: implement send model and weights by demand (Tensor / XML) train(cast, {set_weights,Ret_weights_list}, State = #workerGeneric_state{modelID = ModelId}) -> @@ -234,17 +243,31 @@ train(cast, {set_weights,Ret_weights_list}, State = #workerGeneric_state{modelID %logger:notice("####end set weights train####~n"), {next_state, train, State}; -train(cast, {post_train_update} ,State = #workerGeneric_state{distributedBehaviorFunc = DistributedBehaviorFunc}) -> +train(cast, {post_train_update}, State = #workerGeneric_state{distributedBehaviorFunc = DistributedBehaviorFunc}) -> DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), {next_state, train, State}; +train(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(worker_done, {get(generic_worker_ets),[MyName]}), + {next_state, idle, State}; + +train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(start_stream, {get(generic_worker_ets), [SourceName]}), + {next_state, train, State}; + +train(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(end_stream, {get(generic_worker_ets), [SourceName]}), + {next_state, train, State}; + train(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("@train Worker ~p is going to state idle...~n",[MyName]), update_client_avilable_worker(MyName), DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), {next_state, idle, State}; -train(cast, Data, State) -> +train(cast, Data, State = #workerGeneric_state{myName = MyName}) -> % logger:notice("worker ~p in wait cant treat message: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), + io:format("~p Got unknown message in train state: ~p~n",[MyName , Data]), worker_controller_message_queue(Data), {keep_state, State}. @@ -264,6 +287,18 @@ predict(cast, {sample , SourceName , BatchID , {PredictBatchTensor, Type}}, Stat _Pid = spawn(fun()-> nerlNIF:call_to_predict(ModelId , {PredictBatchTensor, Type} , CurrPID , BatchID, SourceName) end), {next_state, wait, State#workerGeneric_state{nextState = predict , currentBatchID = BatchID}}; +predict(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(start_stream, {get(generic_worker_ets), [SourceName]}), + {next_state, train, State}; + +predict(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(end_stream, {get(generic_worker_ets), [SourceName]}), + {next_state, train, State}; + +predict(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(worker_done, {get(generic_worker_ets),[MyName]}), + {next_state, idle, State}; + predict(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), predict}), diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl index 3e3e8be4..b68d8a96 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl @@ -11,6 +11,9 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> post_train -> post_train({GenWorkerEts, WorkerData}); pre_predict -> pre_predict({GenWorkerEts, WorkerData}); post_predict -> post_predict({GenWorkerEts, WorkerData}); + start_stream -> start_stream({GenWorkerEts, WorkerData}); + end_stream -> end_stream({GenWorkerEts, WorkerData}); + worker_done -> worker_done({GenWorkerEts, WorkerData}); update -> update({GenWorkerEts, WorkerData}) end. @@ -30,5 +33,10 @@ post_predict({_GenWorkerEts, _WorkerData}) -> ok. update({_GenWorkerEts, _WorkerData}) -> ok. +start_stream({_GenWorkerEts, _WorkerData}) -> ok. + +end_stream({_GenWorkerEts, _WorkerData}) -> ok. + +worker_done({_GenWorkerEts, _WorkerData}) -> ok. diff --git a/src_erl/NerlnetApp/src/Client/clientStateHandler.erl b/src_erl/NerlnetApp/src/Client/clientStateHandler.erl index 07141d54..5d7a933b 100644 --- a/src_erl/NerlnetApp/src/Client/clientStateHandler.erl +++ b/src_erl/NerlnetApp/src/Client/clientStateHandler.erl @@ -24,7 +24,9 @@ init(Req0, [Action,Client_StateM_Pid]) -> idle -> gen_statem:cast(Client_StateM_Pid,{idle}); training -> gen_statem:cast(Client_StateM_Pid,{training}); predict -> gen_statem:cast(Client_StateM_Pid,{predict}); - statistics -> gen_statem:cast(Client_StateM_Pid,{statistics}) + statistics -> gen_statem:cast(Client_StateM_Pid,{statistics}); + start_stream -> gen_statem:cast(Client_StateM_Pid,{start_stream, Body}); + end_stream -> gen_statem:cast(Client_StateM_Pid,{end_stream, Body}) end, %% reply ACKnowledge to main server for initiating, later send finished initiating http_request from client_stateM diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 7d75845b..100b95ab 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -121,9 +121,11 @@ waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), case NewWaitforWorkers of % TODO Guy here we need to check for keep alive with workers [] -> send_client_is_ready(MyName), % when all workers done their work + io:format("Client ~p is ready~n",[MyName]), stats:increment_messages_sent(ClientStatsEts), {next_state, NextState, State#client_statem_state{waitforWorkers = []}}; - _-> {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = NewWaitforWorkers}} + _ -> io:format("Client ~p is waiting for workers ~p~n",[MyName,NewWaitforWorkers]), + {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = NewWaitforWorkers}} end; waitforWorkers(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> @@ -255,13 +257,32 @@ training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef MessageToCast = {idle}, cast_message_to_workers(EtsRef, MessageToCast), Workers = clientWorkersFunctions:get_workers_names(EtsRef), - ?LOG_INFO("~p setting workers at idle: ~p~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), + ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; training(cast, _In = {predict}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> ?LOG_ERROR("Wrong request , client ~p can't go from training to predict directly", [MyName]), {next_state, training, State#client_statem_state{etsRef = EtsRef}}; +% ************* NEW *************** +training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + gen_statem:cast(WorkerPid, {start_stream, SourceName}), + {keep_state, State}; + +training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), + {keep_state, State}; + training(cast, In = {loss, WorkerName ,SourceName ,LossTensor ,TimeNIF ,BatchID ,BatchTS}, State = #client_statem_state{myName = MyName,etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -298,6 +319,24 @@ predict(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef}) end, {next_state, predict, State#client_statem_state{etsRef = EtsRef}}; +predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + gen_statem:cast(WorkerPid, {start_stream, SourceName}), + {keep_state, State}; + +predict(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), + {keep_state, State}; + predict(cast, In = {predictRes,WorkerName, SourceName ,{PredictNerlTensor, NetlTensorType} , TimeTook , BatchID , BatchTS}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -380,8 +419,8 @@ send_client_is_ready(MyName) -> cast_message_to_workers(EtsRef, Msg) -> ClientStatsEts = get(client_stats_ets), Workers = ets:lookup_element(EtsRef, workersNames, ?ETS_KV_VAL_IDX), - Func = fun(WorkerKey) -> - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef, WorkerKey), % WorkerKey is the worker name + Func = fun(WorkerName) -> + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef, WorkerName), gen_statem:cast(WorkerPid, Msg), stats:increment_messages_sent(ClientStatsEts) end, diff --git a/src_erl/NerlnetApp/src/Source/sourceStatem.erl b/src_erl/NerlnetApp/src/Source/sourceStatem.erl index ec5e320d..559a088b 100644 --- a/src_erl/NerlnetApp/src/Source/sourceStatem.erl +++ b/src_erl/NerlnetApp/src/Source/sourceStatem.erl @@ -365,12 +365,28 @@ transmitter(TimeInterval_ms, SourceEtsRef, SourcePid ,ClientWorkerPairs, Batches ets:insert(TransmitterEts, {batches_skipped, 0}), ets:insert(TransmitterEts, {current_batch_id, 0}), TransmissionStart = erlang:timestamp(), + % Message to all workrers : "start_stream" + {RouterHost, RouterPort} = ets:lookup_element(TransmitterEts, my_router, ?DATA_IDX), + FuncStart = fun({ClientName, WorkerName}) -> + ToSend = {MyName, ClientName, WorkerName}, + io:format("~p sending start_stream to ~p of worker ~p~n",[MyName, ClientName, WorkerName]), + nerl_tools:http_router_request(RouterHost, RouterPort, [ClientName], atom_to_list(start_stream), ToSend) + end, + lists:foreach(FuncStart, ClientWorkerPairs), case Method of ?SOURCE_POLICY_CASTING_ATOM -> send_method_casting(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend); ?SOURCE_POLICY_ROUNDROBIN_ATOM -> send_method_round_robin(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend); ?SOURCE_POLICY_RANDOM_ATOM -> send_method_random(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend); _Default -> send_method_casting(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend) end, + io:format("GOT HEREEEEE~n"), + % Message to workers : "end_stream" + FuncEnd = fun({ClientName, WorkerName}) -> + ToSend = {MyName, ClientName, WorkerName}, + io:format("~p sending end_stream to ~p of worker ~p~n",[MyName, ClientName, WorkerName]), + nerl_tools:http_router_request(RouterHost, RouterPort, [ClientName], atom_to_list(end_stream), ToSend) + end, + lists:foreach(FuncEnd, ClientWorkerPairs), TransmissionTimeTook_sec = timer:now_diff(erlang:timestamp(), TransmissionStart) / 1000000, ErrorBatches = ets:lookup_element(TransmitterEts, batches_issue, ?DATA_IDX), SkippedBatches = ets:lookup_element(TransmitterEts, batches_skipped, ?DATA_IDX), diff --git a/src_erl/NerlnetApp/src/nerlnetApp_app.erl b/src_erl/NerlnetApp/src/nerlnetApp_app.erl index bb058afe..8225aabb 100644 --- a/src_erl/NerlnetApp/src/nerlnetApp_app.erl +++ b/src_erl/NerlnetApp/src/nerlnetApp_app.erl @@ -246,7 +246,9 @@ createClientsAndWorkers() -> {"/clientIdle",clientStateHandler, [idle,ClientStatemPid]}, {"/clientPredict",clientStateHandler, [predict,ClientStatemPid]}, {"/batch",clientStateHandler, [batch,ClientStatemPid]}, - {"/worker_to_worker_msg",clientStateHandler, [worker_to_worker_msg,ClientStatemPid]} + {"/worker_to_worker_msg",clientStateHandler, [worker_to_worker_msg,ClientStatemPid]}, + {"/start_stream", clientStateHandler, [start_stream, ClientStatemPid]}, + {"/end_stream", clientStateHandler, [end_stream, ClientStatemPid]} ]} ]), init_cowboy_start_clear(Client, {DeviceName, Port},NerlClientDispatch) From d6cd1c3b174545dadb7f2a5e4564ba07e53bf839 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Mon, 20 May 2024 14:48:56 +0000 Subject: [PATCH 11/52] [W2W] WIP --- .../src/Bridge/onnWorkers/workerGeneric.erl | 20 +++++++++---------- .../NerlnetApp/src/Client/clientStatem.erl | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index d7c38300..c8b12a47 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -149,8 +149,8 @@ idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distribute % Go from idle to predict idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("@idle got predict , Worker ~p is going to state idle...~n",[MyName]), - worker_controller_empty_message_queue(), + io:format("@idle got predict , Worker ~p is going to state predict...~n",[MyName]), + % worker_controller_empty_message_queue(), update_client_avilable_worker(MyName), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), {next_state, predict, State#workerGeneric_state{lastPhase = predict}}; @@ -202,8 +202,8 @@ wait(cast, {predict}, State) -> {next_state, wait, State#workerGeneric_state{nextState = predict}}; %% Worker in wait can't treat incoming message -wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase}) -> - io:format("@wait got something , Worker ~p is going to state idle...~n"), +wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase, myName= MyName}) -> + io:format("@wait got something , Worker ~p is going to state idle...~n",[MyName]), case LastPhase of train -> ets:update_counter(get(worker_stats_ets), batches_dropped_train , 1); @@ -288,16 +288,16 @@ predict(cast, {sample , SourceName , BatchID , {PredictBatchTensor, Type}}, Stat {next_state, wait, State#workerGeneric_state{nextState = predict , currentBatchID = BatchID}}; predict(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - DistributedBehaviorFunc(start_stream, {get(generic_worker_ets), [SourceName]}), - {next_state, train, State}; + % DistributedBehaviorFunc(start_stream, {get(generic_worker_ets), [SourceName]}), + {next_state, predict, State}; predict(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - DistributedBehaviorFunc(end_stream, {get(generic_worker_ets), [SourceName]}), - {next_state, train, State}; + % DistributedBehaviorFunc(end_stream, {get(generic_worker_ets), [SourceName]}), + {next_state, predict, State}; predict(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - DistributedBehaviorFunc(worker_done, {get(generic_worker_ets),[MyName]}), - {next_state, idle, State}; + % DistributedBehaviorFunc(worker_done, {get(generic_worker_ets),[MyName]}), + {next_state, predict, State}; predict(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 100b95ab..20994d6d 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -184,6 +184,7 @@ idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRe {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = clientWorkersFunctions:get_workers_names(EtsRef), nextState = training}}; idle(cast, In = {predict}, State = #client_statem_state{etsRef = EtsRef}) -> + io:format("Client sending workers to predict state...~n"), ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), @@ -367,7 +368,6 @@ predict(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = # %% The source sends message to main server that it has finished %% The main server updates its' clients to move to state 'idle' predict(cast, In = {idle}, State = #client_statem_state{etsRef = EtsRef , myName = _MyName}) -> - MsgToCast = {idle}, ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), From 895ff6ea753d13b8fae4bf9d05ec8aba07276aaa Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Mon, 20 May 2024 15:04:38 +0000 Subject: [PATCH 12/52] [W2W] WIP --- .../onnWorkers/workerFederatedClient.erl | 41 +++++++++++-------- .../onnWorkers/workerFederatedServer.erl | 1 + 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index b18e0713..55557dc1 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -96,10 +96,12 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] SourceName = hd(WorkerData), ThisEts = get_this_client_ets(GenWorkerEts), - ets:update_element(ThisEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, false}), CastingSources = ets:lookup_element(ThisEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), NewCastingSources = CastingSources -- [SourceName], - ets:update_element(ThisEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}). + case NewCastingSources of + [] -> ets:update_element(ThisEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, false}); + _ -> ets:update_element(ThisEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}) + end. pre_idle({_GenWorkerEts, _WorkerData}) -> ok. @@ -126,21 +128,26 @@ post_idle({GenWorkerEts, _WorkerData}) -> % After SyncMaxCount , sync_inbox to get the updated model from FedServer pre_train({GenWorkerEts, _NerlTensorWeights}) -> - ThisEts = get_this_client_ets(GenWorkerEts), - SyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_count, ?ETS_KEYVAL_VAL_IDX), - WorkerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), - ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), - if SyncCount == MaxSyncCount -> - W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:sync_inbox(W2WPid), % waiting for server to average the weights and send it - InboxQueue = w2wCom:get_all_messages(W2WPid), - [UpdateWeightsMsg] = queue:to_list(InboxQueue), - {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - nerlNIF:call_to_set_weights(ModelID, UpdatedWeights), - ets:update_element(ThisEts, sync_count, {?ETS_KEYVAL_VAL_IDX , 0}); - true -> ets:update_counter(ThisEts, sync_count, 1) + StreamOccuring = ets:lookup_element(get_this_client_ets(GenWorkerEts), stream_occuring, ?ETS_KEYVAL_VAL_IDX), + case StreamOccuring of + true -> + ThisEts = get_this_client_ets(GenWorkerEts), + SyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_count, ?ETS_KEYVAL_VAL_IDX), + WorkerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), + if SyncCount == MaxSyncCount -> + W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox(W2WPid), % waiting for server to average the weights and send it + InboxQueue = w2wCom:get_all_messages(W2WPid), + [UpdateWeightsMsg] = queue:to_list(InboxQueue), + {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, + ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + nerlNIF:call_to_set_weights(ModelID, UpdatedWeights), + ets:update_element(ThisEts, sync_count, {?ETS_KEYVAL_VAL_IDX , 0}); + true -> ets:update_counter(ThisEts, sync_count, 1) + end; + false -> ok end. %% every countLimit batches, send updated weights diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 936cc657..59a9e37a 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -84,6 +84,7 @@ worker_done({GenWorkerEts, WorkerData}) -> ThisEts = get_this_server_ets(GenWorkerEts), TrainingWorkers = ets:lookup_element(ThisEts, training_workers, ?ETS_KEYVAL_VAL_IDX), UpdatedTrainingWorkers = lists:delete(WorkerName, TrainingWorkers), + io:format("Worker ~p Done, UpdatedTrainingWorkers = ~p~n", [WorkerName, UpdatedTrainingWorkers]), ets:update_element(ThisEts, training_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedTrainingWorkers}). From dc7fe281d271505493190bb9d7d594e6552a27be Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Mon, 20 May 2024 15:35:19 +0000 Subject: [PATCH 13/52] [W2W] Distributed Exp WIP --- .../onnWorkers/workerFederatedClient.erl | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 55557dc1..8ca2cc00 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -47,23 +47,24 @@ sync_max_count_init(FedClientEts , ArgsList) -> %% handshake with workers / server at the end of init init({GenWorkerEts, WorkerData}) -> % create an ets for this client and save it to generic worker ets - FedratedClientEts = ets:new(federated_client,[set, public]), - ets:insert(GenWorkerEts, {federated_client_ets, FedratedClientEts}), + FederatedClientEts = ets:new(federated_client,[set, public]), + ets:insert(GenWorkerEts, {federated_client_ets, FederatedClientEts}), {MyName, Args, Token} = WorkerData, ArgsList = parse_args(Args), - sync_max_count_init(FedratedClientEts, ArgsList), + sync_max_count_init(FederatedClientEts, ArgsList), W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), % create fields in this ets - ets:insert(FedratedClientEts, {my_token, Token}), - ets:insert(FedratedClientEts, {my_name, MyName}), - ets:insert(FedratedClientEts, {server_name, none}), % update later - ets:insert(FedratedClientEts, {sync_count, 0}), - ets:insert(FedratedClientEts, {server_update, false}), - ets:insert(FedratedClientEts, {handshake_done, false}), - ets:insert(FedratedClientEts, {handshake_wait, false}), - ets:insert(FedratedClientEts, {w2wcom_pid, W2WPid}), - ets:insert(FedratedClientEts, {casting_sources, []}), - spawn(fun() -> handshake(FedratedClientEts) end). + ets:insert(FederatedClientEts, {my_token, Token}), + ets:insert(FederatedClientEts, {my_name, MyName}), + ets:insert(FederatedClientEts, {server_name, none}), % update later + ets:insert(FederatedClientEts, {sync_count, 0}), + ets:insert(FederatedClientEts, {server_update, false}), + ets:insert(FederatedClientEts, {handshake_done, false}), + ets:insert(FederatedClientEts, {handshake_wait, false}), + ets:insert(FederatedClientEts, {w2wcom_pid, W2WPid}), + ets:insert(FederatedClientEts, {casting_sources, []}), + ets:insert(FederatedClientEts, {stream_occuring, false}), + spawn(fun() -> handshake(FederatedClientEts) end). handshake(FedClientEts) -> W2WPid = ets:lookup_element(FedClientEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), @@ -128,7 +129,8 @@ post_idle({GenWorkerEts, _WorkerData}) -> % After SyncMaxCount , sync_inbox to get the updated model from FedServer pre_train({GenWorkerEts, _NerlTensorWeights}) -> - StreamOccuring = ets:lookup_element(get_this_client_ets(GenWorkerEts), stream_occuring, ?ETS_KEYVAL_VAL_IDX), + ThisEts = get_this_client_ets(GenWorkerEts), + StreamOccuring = ets:lookup_element(ThisEts, stream_occuring, ?ETS_KEYVAL_VAL_IDX), case StreamOccuring of true -> ThisEts = get_this_client_ets(GenWorkerEts), From 8d8d48186a392cc0ff6d791d2ba73780c40bd96b Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Mon, 20 May 2024 16:46:35 +0000 Subject: [PATCH 14/52] [W2W] WIP --- .../onnWorkers/workerFederatedServer.erl | 4 +- .../src/Bridge/onnWorkers/workerGeneric.erl | 2 + .../NerlnetApp/src/Client/clientStatem.erl | 57 +++++++++++++------ .../NerlnetApp/src/Source/sourceStatem.erl | 1 - 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 59a9e37a..01e13a20 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -142,14 +142,14 @@ post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerD ReceivedWeights = [WorkersWeights || {_WorkerName, {WorkersWeights, _BinaryType}} <- MessagesList], CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX), TotalWorkersWeights = CurrWorkersWeightsList ++ ReceivedWeights, - case length(TotalWorkersWeights) == NumOfTrainingWorkers of % Why not timeout + case length(TotalWorkersWeights) == NumOfTrainingWorkers of % ? Why not timeout true -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), {CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID), FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), AllWorkersWeightsList = TotalWorkersWeights ++ [CurrentModelWeights], AvgWeightsNerlTensor = generate_avg_weights(AllWorkersWeightsList, BinaryType), - nerlNIF:call_to_set_weights(ModelID, AvgWeightsNerlTensor), %% update self weights to new model + nerlNIF:call_to_set_weights(ModelID, AvgWeightsNerlTensor), %% update self weights to new model Func = fun(FedClient) -> FedServerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index c8b12a47..e65bac0b 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -248,6 +248,7 @@ train(cast, {post_train_update}, State = #workerGeneric_state{distributedBehavio {next_state, train, State}; train(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("Worker ~p got worker_done~n",[MyName]), DistributedBehaviorFunc(worker_done, {get(generic_worker_ets),[MyName]}), {next_state, idle, State}; @@ -256,6 +257,7 @@ train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = M {next_state, train, State}; train(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("Worker ~p got end_stream~n",[MyName]), DistributedBehaviorFunc(end_stream, {get(generic_worker_ets), [SourceName]}), {next_state, train, State}; diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 20994d6d..f6a3f155 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -85,6 +85,7 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh ets:insert(EtsRef, {workers_to_sha_map, MyWorkersToShaMap}), ets:insert(EtsRef, {sha_to_models_map , ShaToModelArgsMap}), ets:insert(EtsRef, {w2wcom_pids, #{}}), + ets:insert(EtsRef, {all_workers_done, false}), {MyRouterHost,MyRouterPort} = nerl_tools:getShortPath(MyName,?MAIN_SERVER_ATOM, NerlnetGraph), ets:insert(EtsRef, {my_router,{MyRouterHost,MyRouterPort}}), clientWorkersFunctions:create_workers(MyName , EtsRef , ShaToModelArgsMap , EtsStats), @@ -100,6 +101,7 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh put(client_data, EtsRef), put(ets_stats, EtsStats), put(client_stats_ets , ClientStatsEts), + put(my_pid , self()), {ok, idle, #client_statem_state{myName= MyName, etsRef = EtsRef}}. @@ -181,6 +183,7 @@ idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRe stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), MessageToCast = {training}, cast_message_to_workers(EtsRef, MessageToCast), + ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, false}), {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = clientWorkersFunctions:get_workers_names(EtsRef), nextState = training}}; idle(cast, In = {predict}, State = #client_statem_state{etsRef = EtsRef}) -> @@ -251,20 +254,6 @@ training(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef} true -> ?LOG_ERROR("Given worker ~p isn't found in client ~p",[WorkerName, ClientName]) end, {next_state, training, State#client_statem_state{etsRef = EtsRef}}; -training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> - ClientStatsEts = get(client_stats_ets), - stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - MessageToCast = {idle}, - cast_message_to_workers(EtsRef, MessageToCast), - Workers = clientWorkersFunctions:get_workers_names(EtsRef), - ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), - {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; - -training(cast, _In = {predict}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> - ?LOG_ERROR("Wrong request , client ~p can't go from training to predict directly", [MyName]), - {next_state, training, State#client_statem_state{etsRef = EtsRef}}; - % ************* NEW *************** training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), @@ -277,12 +266,44 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ClientStatsEts = get(client_stats_ets), + WorkersOfThisClient = ets:lookup_element(EtsRef, workersNames, ?DATA_IDX), + NumOfTrainingWorkers = ets:lookup_element(EtsRef, num_of_training_workers, ?DATA_IDX), + WorkerOfThisClient = lists:member(WorkerName, WorkersOfThisClient), + if WorkerOfThisClient -> + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), + UpdatedNumOfTrainingWorkers = NumOfTrainingWorkers - 1, + ets:update_element(EtsRef, num_of_training_workers, {?DATA_IDX, UpdatedNumOfTrainingWorkers}), + case UpdatedNumOfTrainingWorkers of + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + _ -> ok end; + true -> ok + end, + {keep_state, State}; + + +training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), - gen_statem:cast(WorkerPid, {end_stream, SourceName}), - {keep_state, State}; + MessageToCast = {idle}, + WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX), + case WorkersDone of + true -> cast_message_to_workers(EtsRef, MessageToCast), + Workers = clientWorkersFunctions:get_workers_names(EtsRef), + ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), + {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; + false -> gen_statem:cast(get(my_pid) , {idle}), + {next_state, training, State#client_statem_state{etsRef = EtsRef}} + end; + +training(cast, _In = {predict}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> + ?LOG_ERROR("Wrong request , client ~p can't go from training to predict directly", [MyName]), + {next_state, training, State#client_statem_state{etsRef = EtsRef}}; + training(cast, In = {loss, WorkerName ,SourceName ,LossTensor ,TimeNIF ,BatchID ,BatchTS}, State = #client_statem_state{myName = MyName,etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), @@ -448,6 +469,8 @@ handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data) -> _ -> %% Send to the correct client DestClient = maps:get(ToWorker, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), + % ClientName = ets:lookup_element(EtsRef, myName , ?DATA_IDX), + % io:format("Client ~p passing w2w_msg {~p --> ~p} to ~p: Data ~p~n",[ClientName, FromWorker, ToWorker, DestClient,Data]), MessageBody = {worker_to_worker_msg, FromWorker, ToWorker, Data}, {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(worker_to_worker_msg), MessageBody), diff --git a/src_erl/NerlnetApp/src/Source/sourceStatem.erl b/src_erl/NerlnetApp/src/Source/sourceStatem.erl index 559a088b..cadedcf9 100644 --- a/src_erl/NerlnetApp/src/Source/sourceStatem.erl +++ b/src_erl/NerlnetApp/src/Source/sourceStatem.erl @@ -379,7 +379,6 @@ transmitter(TimeInterval_ms, SourceEtsRef, SourcePid ,ClientWorkerPairs, Batches ?SOURCE_POLICY_RANDOM_ATOM -> send_method_random(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend); _Default -> send_method_casting(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend) end, - io:format("GOT HEREEEEE~n"), % Message to workers : "end_stream" FuncEnd = fun({ClientName, WorkerName}) -> ToSend = {MyName, ClientName, WorkerName}, From 8c603cecf71a7c2b63fef9197f560a0021e0042e Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Mon, 20 May 2024 17:18:12 +0000 Subject: [PATCH 15/52] [W2W] WIP --- src_erl/NerlnetApp/src/Client/clientStatem.erl | 12 ++++++++---- .../NerlnetApp/src/Client/clientWorkersFunctions.erl | 3 +++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index f6a3f155..333f1f61 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -86,6 +86,7 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh ets:insert(EtsRef, {sha_to_models_map , ShaToModelArgsMap}), ets:insert(EtsRef, {w2wcom_pids, #{}}), ets:insert(EtsRef, {all_workers_done, false}), + ets:insert(EtsRef, {num_of_fed_servers, 0}), {MyRouterHost,MyRouterPort} = nerl_tools:getShortPath(MyName,?MAIN_SERVER_ATOM, NerlnetGraph), ets:insert(EtsRef, {my_router,{MyRouterHost,MyRouterPort}}), clientWorkersFunctions:create_workers(MyName , EtsRef , ShaToModelArgsMap , EtsStats), @@ -93,7 +94,9 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh WorkersNames = clientWorkersFunctions:get_workers_names(EtsRef), Pids = [clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName) || WorkerName <- WorkersNames], [gen_statem:cast(WorkerPid, {pre_idle}) || WorkerPid <- Pids], - + NumOfFedServers = ets:lookup_element(EtsRef, num_of_fed_servers, ?DATA_IDX), % When non-federated exp this value is 0 + ets:insert(EtsRef, {num_of_training_workers, length(ClientWorkers) - NumOfFedServers}), + ets:insert(EtsRef, {training_workers, 0}), % will be updated in idle -> training % update dictionary WorkersEts = ets:lookup_element(EtsRef , workers_ets , ?DATA_IDX), put(workers_ets, WorkersEts), @@ -155,7 +158,6 @@ waitforWorkers(cast, EventContent, State = #client_statem_state{myName = MyName} %% initiating workers when they include federated workers. init stage == handshake between federated worker client and server -%% TODO: make custom_worker_message in all states to send messages from workers to entities (not just client) idle(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -184,6 +186,8 @@ idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRe MessageToCast = {training}, cast_message_to_workers(EtsRef, MessageToCast), ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, false}), + NumOfTrainingWorkers = ets:lookup_element(EtsRef, num_of_training_workers, ?DATA_IDX), + ets:update_element(EtsRef, training_workers, {?DATA_IDX, NumOfTrainingWorkers}), {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = clientWorkersFunctions:get_workers_names(EtsRef), nextState = training}}; idle(cast, In = {predict}, State = #client_statem_state{etsRef = EtsRef}) -> @@ -268,7 +272,7 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E {SourceName, _ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), WorkersOfThisClient = ets:lookup_element(EtsRef, workersNames, ?DATA_IDX), - NumOfTrainingWorkers = ets:lookup_element(EtsRef, num_of_training_workers, ?DATA_IDX), + NumOfTrainingWorkers = ets:lookup_element(EtsRef, training_workers, ?DATA_IDX), WorkerOfThisClient = lists:member(WorkerName, WorkersOfThisClient), if WorkerOfThisClient -> stats:increment_messages_received(ClientStatsEts), @@ -276,7 +280,7 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), gen_statem:cast(WorkerPid, {end_stream, SourceName}), UpdatedNumOfTrainingWorkers = NumOfTrainingWorkers - 1, - ets:update_element(EtsRef, num_of_training_workers, {?DATA_IDX, UpdatedNumOfTrainingWorkers}), + ets:update_element(EtsRef, training_workers, {?DATA_IDX, UpdatedNumOfTrainingWorkers}), case UpdatedNumOfTrainingWorkers of 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); _ -> ok end; diff --git a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl index d859fea0..1333c781 100644 --- a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl +++ b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl @@ -21,6 +21,9 @@ case DistributedSystemType of WorkersMap = ets:lookup_element(ClientEtsRef, workerToClient, ?DATA_IDX), WorkersList = [Worker || {Worker, _Val} <- maps:to_list(WorkersMap)], DistributedBehaviorFunc = fun workerFederatedServer:controller/2, + NumOfFedServers = ets:lookup_element(ClientEtsRef, num_of_fed_servers, ?DATA_IDX), + UpdatedNumOfFedServers = NumOfFedServers + 1, + ets:update_element(ClientEtsRef, num_of_fed_servers, {?DATA_IDX, UpdatedNumOfFedServers}), DistributedWorkerData = {_ServerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken , _WorkersList = WorkersList} end, {DistributedBehaviorFunc , DistributedWorkerData}. From c9f5b05586ee485a4608aebe0788ee9acf7c98e2 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Mon, 20 May 2024 17:41:48 +0000 Subject: [PATCH 16/52] [W2W] WIP --- src_erl/NerlnetApp/src/Client/clientStatem.erl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 333f1f61..847666a8 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -269,17 +269,19 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = {keep_state, State}; training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> - {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + {SourceName, ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), WorkersOfThisClient = ets:lookup_element(EtsRef, workersNames, ?DATA_IDX), NumOfTrainingWorkers = ets:lookup_element(EtsRef, training_workers, ?DATA_IDX), - WorkerOfThisClient = lists:member(WorkerName, WorkersOfThisClient), + io:format("Client ~p received end_stream to worker ~p , remaining training workers ~p~n",[ClientName, WorkerName, NumOfTrainingWorkers]), + WorkerOfThisClient = lists:member(list_to_atom(WorkerName), WorkersOfThisClient), if WorkerOfThisClient -> stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), gen_statem:cast(WorkerPid, {end_stream, SourceName}), UpdatedNumOfTrainingWorkers = NumOfTrainingWorkers - 1, + io:format("UpdatedNumOfTrainingWorkers = ~p~n",[UpdatedNumOfTrainingWorkers]), ets:update_element(EtsRef, training_workers, {?DATA_IDX, UpdatedNumOfTrainingWorkers}), case UpdatedNumOfTrainingWorkers of 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); From 41dbd1bc6d9ed4478a0ce8b04e0ba5e1fcbbaf82 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Mon, 20 May 2024 22:04:00 +0000 Subject: [PATCH 17/52] [W2W] WIP --- .../Bridge/onnWorkers/workerFederatedClient.erl | 8 +++++++- .../Bridge/onnWorkers/workerFederatedServer.erl | 3 +++ .../src/Bridge/onnWorkers/workerGeneric.erl | 4 +++- src_erl/NerlnetApp/src/Client/clientStatem.erl | 14 +++++++++----- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 8ca2cc00..1158d66d 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -99,8 +99,13 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S ThisEts = get_this_client_ets(GenWorkerEts), CastingSources = ets:lookup_element(ThisEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), NewCastingSources = CastingSources -- [SourceName], + io:format("NewCastingSources = ~p~n", [NewCastingSources]), case NewCastingSources of - [] -> ets:update_element(ThisEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, false}); + [] -> ets:update_element(ThisEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, false}), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:send_message(W2WPid, MyName, ServerName , {worker_done, []}); _ -> ets:update_element(ThisEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}) end. @@ -131,6 +136,7 @@ post_idle({GenWorkerEts, _WorkerData}) -> pre_train({GenWorkerEts, _NerlTensorWeights}) -> ThisEts = get_this_client_ets(GenWorkerEts), StreamOccuring = ets:lookup_element(ThisEts, stream_occuring, ?ETS_KEYVAL_VAL_IDX), + % io:format("StreamOccuring = ~p~n", [StreamOccuring]), case StreamOccuring of true -> ThisEts = get_this_client_ets(GenWorkerEts), diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 01e13a20..323cfc46 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -136,15 +136,18 @@ post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerD ThisEts = get_this_server_ets(GenWorkerEts), FedServerEts = get(fed_server_ets), NumOfTrainingWorkers = length(ets:lookup_element(ThisEts, training_workers, ?ETS_KEYVAL_VAL_IDX)), + io:format("NumOfTrainingWorkers = ~p~n",[NumOfTrainingWorkers]), W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), InboxQueue = w2wCom:get_all_messages(W2WPid), MessagesList = queue:to_list(InboxQueue), ReceivedWeights = [WorkersWeights || {_WorkerName, {WorkersWeights, _BinaryType}} <- MessagesList], CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX), TotalWorkersWeights = CurrWorkersWeightsList ++ ReceivedWeights, + io:format("Num of TotalWorkersWeights = ~p~n",[length(TotalWorkersWeights)]), case length(TotalWorkersWeights) == NumOfTrainingWorkers of % ? Why not timeout true -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + io:format("Averaging model weights...~n"), {CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID), FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), AllWorkersWeightsList = TotalWorkersWeights ++ [CurrentModelWeights], diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index e65bac0b..568bdc2d 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -243,7 +243,8 @@ train(cast, {set_weights,Ret_weights_list}, State = #workerGeneric_state{modelID %logger:notice("####end set weights train####~n"), {next_state, train, State}; -train(cast, {post_train_update}, State = #workerGeneric_state{distributedBehaviorFunc = DistributedBehaviorFunc}) -> +train(cast, {post_train_update}, State = #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("Worker ~p got post_train_update~n",[MyName]), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), {next_state, train, State}; @@ -253,6 +254,7 @@ train(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distri {next_state, idle, State}; train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("Worker ~p got start_stream~n",[MyName]), DistributedBehaviorFunc(start_stream, {get(generic_worker_ets), [SourceName]}), {next_state, train, State}; diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 847666a8..44de04bb 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -86,7 +86,7 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh ets:insert(EtsRef, {sha_to_models_map , ShaToModelArgsMap}), ets:insert(EtsRef, {w2wcom_pids, #{}}), ets:insert(EtsRef, {all_workers_done, false}), - ets:insert(EtsRef, {num_of_fed_servers, 0}), + ets:insert(EtsRef, {num_of_fed_servers, 0}), % Will stay 0 if non-federated {MyRouterHost,MyRouterPort} = nerl_tools:getShortPath(MyName,?MAIN_SERVER_ATOM, NerlnetGraph), ets:insert(EtsRef, {my_router,{MyRouterHost,MyRouterPort}}), clientWorkersFunctions:create_workers(MyName , EtsRef , ShaToModelArgsMap , EtsStats), @@ -95,8 +95,8 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh Pids = [clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName) || WorkerName <- WorkersNames], [gen_statem:cast(WorkerPid, {pre_idle}) || WorkerPid <- Pids], NumOfFedServers = ets:lookup_element(EtsRef, num_of_fed_servers, ?DATA_IDX), % When non-federated exp this value is 0 - ets:insert(EtsRef, {num_of_training_workers, length(ClientWorkers) - NumOfFedServers}), - ets:insert(EtsRef, {training_workers, 0}), % will be updated in idle -> training + ets:insert(EtsRef, {num_of_training_workers, length(ClientWorkers) - NumOfFedServers}), % This number will not change + ets:insert(EtsRef, {training_workers, 0}), % will be updated in idle -> training & end_stream % update dictionary WorkersEts = ets:lookup_element(EtsRef , workers_ets , ?DATA_IDX), put(workers_ets, WorkersEts), @@ -179,6 +179,7 @@ idle(cast, _In = {statistics}, State = #client_statem_state{ myName = MyName, et stats:increment_messages_sent(ClientStatsEts), {next_state, idle, State}; +% Main Server triggers this state idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -187,7 +188,7 @@ idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRe cast_message_to_workers(EtsRef, MessageToCast), ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, false}), NumOfTrainingWorkers = ets:lookup_element(EtsRef, num_of_training_workers, ?DATA_IDX), - ets:update_element(EtsRef, training_workers, {?DATA_IDX, NumOfTrainingWorkers}), + ets:update_element(EtsRef, training_workers, {?DATA_IDX, NumOfTrainingWorkers}), % Reset the number of training workers {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = clientWorkersFunctions:get_workers_names(EtsRef), nextState = training}}; idle(cast, In = {predict}, State = #client_statem_state{etsRef = EtsRef}) -> @@ -268,6 +269,7 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; +% ************* NEW *************** training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), @@ -279,7 +281,8 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), - gen_statem:cast(WorkerPid, {end_stream, SourceName}), + io:format("Worker ~p Pid is ~p~n",[WorkerName, WorkerPid]), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), % WHY THIS IS NOT WORKING???? UpdatedNumOfTrainingWorkers = NumOfTrainingWorkers - 1, io:format("UpdatedNumOfTrainingWorkers = ~p~n",[UpdatedNumOfTrainingWorkers]), ets:update_element(EtsRef, training_workers, {?DATA_IDX, UpdatedNumOfTrainingWorkers}), @@ -297,6 +300,7 @@ training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), MessageToCast = {idle}, WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX), + io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]), case WorkersDone of true -> cast_message_to_workers(EtsRef, MessageToCast), Workers = clientWorkersFunctions:get_workers_names(EtsRef), From 898002824df0d807e782651efdc99df1da9bc9b6 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 10:39:07 +0000 Subject: [PATCH 18/52] [W2W] WIP --- .../src/Bridge/onnWorkers/w2wCom.erl | 52 ++++++++----- .../src/Bridge/onnWorkers/w2wCom.hrl | 9 ++- .../onnWorkers/workerFederatedClient.erl | 73 +++++++++-------- .../onnWorkers/workerFederatedServer.erl | 78 ++++++++++++------- .../src/Bridge/onnWorkers/workerGeneric.erl | 63 +++++++++++---- .../src/Bridge/onnWorkers/workerNN.erl | 2 - .../NerlnetApp/src/Client/clientStatem.erl | 34 ++++---- .../NerlnetApp/src/Source/sourceStatem.erl | 14 ++-- 8 files changed, 199 insertions(+), 126 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index aca88671..d62c3fcd 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -5,18 +5,20 @@ -export([start_link/1]). -export([init/1, handle_cast/2, handle_call/3]). --export([send_message/4, get_all_messages/1 , sync_inbox/1]). % methods that are used by worker +-export([send_message/4, send_message_with_event/5, get_all_messages/1 , sync_inbox/1, sync_inbox_no_limit/1]). % methods that are used by worker --define(ETS_KEYVAL_VAL_IDX, 2). --define(SYNC_INBOX_TIMEOUT, 30000). % 30 seconds --define(DEFAULT_SYNC_INBOX_BUSY_WAITING_SLEEP, 5). % 5 milliseconds + +setup_logger(Module) -> + logger:set_handler_config(default, formatter, {logger_formatter, #{}}), + logger:set_module_level(Module, all). %% @doc Spawns the server and registers the local name (unique) -spec(start_link(args) -> - {ok, Pid :: pid()} | ignore | {error, Reason :: term()}). + {ok, Pid :: pid()} | ignore | {error, Reason :: term()}). start_link(Args = {WorkerName, _ClientStatemPid}) -> - {ok,Gen_Server_Pid} = gen_server:start_link({local, WorkerName}, ?MODULE, Args, []), - Gen_Server_Pid. + setup_logger(?MODULE), + {ok,Gen_Server_Pid} = gen_server:start_link({local, WorkerName}, ?MODULE, Args, []), + Gen_Server_Pid. init({WorkerName, MyClientPid}) -> InboxQueue = queue:new(), @@ -36,28 +38,22 @@ handle_cast(Msg, State) -> io:format("@w2wCom: Wrong message received ~p~n", [Msg]), {noreply, State}. -% This handler also triggers the state machine during training state -handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {post_train_update, Data}}, _From, State) -> +handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {msg_with_event, Event, Data}}, _From, State) -> case get(worker_name) of ThisWorkerName -> ok; _ -> throw({error, "The provided worker name is not this worker"}) end, - % Saved messages are of the form: {FromWorkerName, , Data} - Message = {FromWorkerName, Data}, - add_msg_to_inbox_queue(Message), - gen_server:cast(get(gen_worker_pid), {post_train_update}), - {reply, {ok, post_train_update}, State}; - -handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {worker_done, Data}}, _From, State) -> - case get(worker_name) of - ThisWorkerName -> ok; - _ -> throw({error, "The provided worker name is not this worker"}) + GenWorkerPid = get(gen_worker_pid), + case Event of + post_train_update -> gen_statem:cast(GenWorkerPid, {post_train_update, Data}); + worker_done -> gen_statem:cast(GenWorkerPid, {worker_done, Data}); + start_stream -> gen_statem:cast(GenWorkerPid, {start_stream, Data}); % Data is [SourceName] + end_stream -> gen_statem:cast(GenWorkerPid, {end_stream, Data}) % Data is [SourceName] end, % Saved messages are of the form: {FromWorkerName, , Data} Message = {FromWorkerName, Data}, add_msg_to_inbox_queue(Message), - gen_server:cast(get(gen_worker_pid), {worker_done}), - {reply, {ok, worker_done}, State}; + {reply, {ok, Event}, State}; % Received messages are of the form: {worker_to_worker_msg, FromWorkerName, ThisWorkerName, Data} handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, Data}, _From, State) -> @@ -116,6 +112,16 @@ send_message(W2WPid, FromWorker, TargetWorker, Data) -> {ok, MyClient} = gen_server:call(W2WPid, {get_client_pid}), gen_statem:cast(MyClient, Msg). +send_message_with_event(W2WPid, FromWorker, TargetWorker, Event, Data) -> + ValidEvent = lists:member(Event, ?SUPPORTED_EVENTS), + if ValidEvent -> ok; + true -> ?LOG_ERROR("Event ~p is not supported!!",[Event]), + throw({error, "The provided event is not supported"}) + end, + Msg = {?W2WCOM_ATOM, FromWorker, TargetWorker, {msg_with_event, Event, Data}}, + {ok, MyClient} = gen_server:call(W2WPid, {get_client_pid}), + gen_statem:cast(MyClient, Msg). + timeout_throw(Timeout) -> receive @@ -128,6 +134,10 @@ sync_inbox(W2WPid) -> TimeoutPID = spawn(fun() -> timeout_throw(?SYNC_INBOX_TIMEOUT) end), sync_inbox(TimeoutPID , W2WPid). +sync_inbox_no_limit(W2WPid) -> + TimeoutPID = spawn(fun() -> timeout_throw(?SYNC_INBOX_TIMEOUT_NO_LIMIT) end), + sync_inbox(TimeoutPID , W2WPid). + sync_inbox(TimeoutPID, W2WPid) -> timer:sleep(?DEFAULT_SYNC_INBOX_BUSY_WAITING_SLEEP), {ok , IsInboxEmpty} = gen_server:call(W2WPid, {is_inbox_empty}), diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl index a76172fa..1fd26762 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl @@ -1,4 +1,11 @@ +-include_lib("kernel/include/logger.hrl"). +-include("workerDefinitions.hrl"). + -define(W2WCOM_INBOX_Q_ATOM, worker_to_worker_inbox_queue). -define(W2WCOM_ATOM, worker_to_worker_msg). --define(W2WCOM_TOKEN_CAST_ATOM, worker_to_worker_token_cast). \ No newline at end of file +-define(W2WCOM_TOKEN_CAST_ATOM, worker_to_worker_token_cast). +-define(SYNC_INBOX_TIMEOUT, 30000). % 30 seconds +-define(SYNC_INBOX_TIMEOUT_NO_LIMIT, 36000000). % 36000 seconds = 10 hours , no limit +-define(DEFAULT_SYNC_INBOX_BUSY_WAITING_SLEEP, 5). % 5 milliseconds +-define(SUPPORTED_EVENTS , [post_train_update, worker_done, start_stream, end_stream]). \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 1158d66d..3405b46d 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -3,7 +3,6 @@ -export([controller/2]). -include("/usr/local/lib/nerlnet-lib/NErlNet/src_erl/NerlnetApp/src/nerl_tools.hrl"). --include("workerDefinitions.hrl"). -include("w2wCom.hrl"). -import(nerlNIF, [call_to_get_weights/2, call_to_set_weights/2]). @@ -14,16 +13,15 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> case FuncName of - init -> init({GenWorkerEts, WorkerData}); - pre_idle -> pre_idle({GenWorkerEts, WorkerData}); - post_idle -> post_idle({GenWorkerEts, WorkerData}); - pre_train -> pre_train({GenWorkerEts, WorkerData}); - post_train -> post_train({GenWorkerEts, WorkerData}); - pre_predict -> pre_predict({GenWorkerEts, WorkerData}); - post_predict -> post_predict({GenWorkerEts, WorkerData}); - start_stream -> start_stream({GenWorkerEts, WorkerData}); - end_stream -> end_stream({GenWorkerEts, WorkerData}); - worker_done -> worker_done({GenWorkerEts, WorkerData}) + init -> init({GenWorkerEts, WorkerData}); + pre_idle -> pre_idle({GenWorkerEts, WorkerData}); + post_idle -> post_idle({GenWorkerEts, WorkerData}); + pre_train -> pre_train({GenWorkerEts, WorkerData}); + post_train -> post_train({GenWorkerEts, WorkerData}); + pre_predict -> pre_predict({GenWorkerEts, WorkerData}); + post_predict -> post_predict({GenWorkerEts, WorkerData}); + start_stream -> start_stream({GenWorkerEts, WorkerData}); + end_stream -> end_stream({GenWorkerEts, WorkerData}) end. get_this_client_ets(GenWorkerEts) -> @@ -56,7 +54,7 @@ init({GenWorkerEts, WorkerData}) -> % create fields in this ets ets:insert(FederatedClientEts, {my_token, Token}), ets:insert(FederatedClientEts, {my_name, MyName}), - ets:insert(FederatedClientEts, {server_name, none}), % update later + ets:insert(FederatedClientEts, {server_name, []}), % update later ets:insert(FederatedClientEts, {sync_count, 0}), ets:insert(FederatedClientEts, {server_update, false}), ets:insert(FederatedClientEts, {handshake_done, false}), @@ -85,28 +83,36 @@ handshake(FedClientEts) -> end, lists:foreach(Func, MessagesList). -start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] - SourceName = hd(WorkerData), - ThisEts = get_this_client_ets(GenWorkerEts), - ets:update_element(ThisEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, true}), - CastingSources = ets:lookup_element(ThisEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), - NewCastingSources = CastingSources ++ [SourceName], - ets:update_element(ThisEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}). - % ***** Add SourcesList ***** +start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName, State] + [_SourceName, State] = WorkerData, + case State of + train -> + ThisEts = get_this_client_ets(GenWorkerEts), + MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), + case length(CastingSources) of % Send to server an updater after got start_stream from the first source + 1 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, []); + _ -> ok + end; + predict -> ok + end. end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] - SourceName = hd(WorkerData), - ThisEts = get_this_client_ets(GenWorkerEts), - CastingSources = ets:lookup_element(ThisEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), - NewCastingSources = CastingSources -- [SourceName], - io:format("NewCastingSources = ~p~n", [NewCastingSources]), - case NewCastingSources of - [] -> ets:update_element(ThisEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, false}), - ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), - W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:send_message(W2WPid, MyName, ServerName , {worker_done, []}); - _ -> ets:update_element(ThisEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}) + [_SourceName, State] = WorkerData, + case State of + train -> + ThisEts = get_this_client_ets(GenWorkerEts), + MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), + case length(CastingSources) of % Send to server an updater after got start_stream from the first source + 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, []); + _ -> ok + end; + predict -> ok end. @@ -146,7 +152,7 @@ pre_train({GenWorkerEts, _NerlTensorWeights}) -> MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), if SyncCount == MaxSyncCount -> W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:sync_inbox(W2WPid), % waiting for server to average the weights and send it + w2wCom:sync_inbox_no_limit(W2WPid), % waiting for server to average the weights and send it InboxQueue = w2wCom:get_all_messages(W2WPid), [UpdateWeightsMsg] = queue:to_list(InboxQueue), {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, @@ -184,6 +190,5 @@ pre_predict({_GenWorkerEts, WorkerData}) -> WorkerData. %% nothing? post_predict(Data) -> Data. -worker_done({_GenWorkerEts, _WorkerData}) -> ok. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 323cfc46..c174f5ee 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -2,7 +2,6 @@ -export([controller/2]). --include("workerDefinitions.hrl"). -include("w2wCom.hrl"). -import(nerlNIF,[nerltensor_scalar_multiplication_nif/3, call_to_get_weights/1, call_to_set_weights/2]). @@ -20,16 +19,15 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> case FuncName of - init -> init({GenWorkerEts, WorkerData}); - pre_idle -> pre_idle({GenWorkerEts, WorkerData}); - post_idle -> post_idle({GenWorkerEts, WorkerData}); - pre_train -> pre_train({GenWorkerEts, WorkerData}); - post_train -> post_train({GenWorkerEts, WorkerData}); - pre_predict -> pre_predict({GenWorkerEts, WorkerData}); - post_predict -> post_predict({GenWorkerEts, WorkerData}); - start_stream -> start_stream({GenWorkerEts, WorkerData}); - end_stream -> end_stream({GenWorkerEts, WorkerData}); - worker_done -> worker_done({GenWorkerEts, WorkerData}) + init -> init({GenWorkerEts, WorkerData}); + pre_idle -> pre_idle({GenWorkerEts, WorkerData}); + post_idle -> post_idle({GenWorkerEts, WorkerData}); + pre_train -> pre_train({GenWorkerEts, WorkerData}); + post_train -> post_train({GenWorkerEts, WorkerData}); + pre_predict -> pre_predict({GenWorkerEts, WorkerData}); + post_predict -> post_predict({GenWorkerEts, WorkerData}); + start_stream -> start_stream({GenWorkerEts, WorkerData}); + end_stream -> end_stream({GenWorkerEts, WorkerData}) end. @@ -64,7 +62,7 @@ init({GenWorkerEts, WorkerData}) -> ets:insert(FederatedServerEts, {w2wcom_pid, W2WPid}), ets:insert(FederatedServerEts, {broadcast_workers_list, BroadcastWorkers}), ets:insert(FederatedServerEts, {fed_clients, []}), - ets:insert(FederatedServerEts, {training_workers , []}), + ets:insert(FederatedServerEts, {active_workers , []}), ets:insert(FederatedServerEts, {sync_count, 0}), ets:insert(FederatedServerEts, {my_name, MyName}), ets:insert(FederatedServerEts, {token , Token}), @@ -72,20 +70,48 @@ init({GenWorkerEts, WorkerData}) -> put(fed_server_ets, FederatedServerEts). -start_stream({_GenWorkerEts, _WorkerData}) -> ok. +start_stream({GenWorkerEts, _WorkerData}) -> + FedServerEts = get_this_server_ets(GenWorkerEts), + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox(W2WPid), + InboxQueue = w2wCom:get_all_messages(W2WPid), + [Message] = queue:to_list(InboxQueue), + {FromFedClient , [_SourceName]} = Message, + ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], + ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), + LengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), + case length(UpdatedActiveWorkers) of + LengthFedClients -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), + Data = {MyName, ClientName, MyName}, % Mimic source behavior to register as an active worker for the client + gen_server:cast(ClientPid, {start_stream, term_to_binary(Data)}); + _ -> ok + end. -end_stream({_GenWorkerEts, _WorkerData}) -> ok. -pre_idle({_GenWorkerEts, _WorkerName}) -> ok. +end_stream({GenWorkerEts, _WorkerData}) -> + FedServerEts = get_this_server_ets(GenWorkerEts), + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox(W2WPid), + InboxQueue = w2wCom:get_all_messages(W2WPid), + [Message] = queue:to_list(InboxQueue), + {FromFedClient , [_SourceName]} = Message, + ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + UpdatedActiveWorkers = ActiveWorkers -- [FromFedClient], + ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), + case length(UpdatedActiveWorkers) of + 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), + Data = {MyName, ClientName, MyName}, % Mimic source behavior to register as an active worker for the client + gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); + _ -> ok + end. -worker_done({GenWorkerEts, WorkerData}) -> - WorkerName = hd(WorkerData), - ThisEts = get_this_server_ets(GenWorkerEts), - TrainingWorkers = ets:lookup_element(ThisEts, training_workers, ?ETS_KEYVAL_VAL_IDX), - UpdatedTrainingWorkers = lists:delete(WorkerName, TrainingWorkers), - io:format("Worker ~p Done, UpdatedTrainingWorkers = ~p~n", [WorkerName, UpdatedTrainingWorkers]), - ets:update_element(ThisEts, training_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedTrainingWorkers}). +pre_idle({_GenWorkerEts, _WorkerName}) -> ok. % Extract all workers in nerlnet network @@ -119,8 +145,6 @@ post_idle({GenWorkerEts, _WorkerName}) -> w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake_done, MyToken}) end, lists:foreach(MsgFunc, MessagesList), - UpdatedTrainingWorkers = ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX), - ets:update_element(FedServerEts, training_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedTrainingWorkers}), ets:update_element(GenWorkerEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}); true -> ok end. @@ -135,16 +159,14 @@ pre_train({_GenWorkerEts, _WorkerData}) -> ok. post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerData = [] ThisEts = get_this_server_ets(GenWorkerEts), FedServerEts = get(fed_server_ets), - NumOfTrainingWorkers = length(ets:lookup_element(ThisEts, training_workers, ?ETS_KEYVAL_VAL_IDX)), - io:format("NumOfTrainingWorkers = ~p~n",[NumOfTrainingWorkers]), W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), InboxQueue = w2wCom:get_all_messages(W2WPid), MessagesList = queue:to_list(InboxQueue), ReceivedWeights = [WorkersWeights || {_WorkerName, {WorkersWeights, _BinaryType}} <- MessagesList], CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX), TotalWorkersWeights = CurrWorkersWeightsList ++ ReceivedWeights, - io:format("Num of TotalWorkersWeights = ~p~n",[length(TotalWorkersWeights)]), - case length(TotalWorkersWeights) == NumOfTrainingWorkers of % ? Why not timeout + NumOfActiveWorkers = length(ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX)), + case length(TotalWorkersWeights) == NumOfActiveWorkers of % ? Why not timeout true -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), io:format("Averaging model weights...~n"), diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 568bdc2d..011276c8 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -161,17 +161,32 @@ idle(cast, _Param, State = #workerGeneric_state{myName = MyName}) -> %% Waiting for receiving results or loss function %% Got nan or inf from loss function - Error, loss function too big for double -wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState}) -> +wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc}) -> stats:increment_by_value(get(worker_stats_ets), nan_loss_count, 1), gen_statem:cast(get(client_pid),{loss, MyName , SourceName ,nan , TrainTime ,BatchID}), - {next_state, NextState, State}; + DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients + UpdatedNextState = + case NextState of + end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), + update_client_avilable_worker(MyName), + idle; + _ -> train + end, + {next_state, UpdatedNextState, State}; wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc}) -> BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients - {next_state, NextState, State}; + UpdatedNextState = + case NextState of + end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), + update_client_avilable_worker(MyName), + idle; + _ -> train + end, + {next_state, UpdatedNextState, State}; wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> BatchTimeStamp = erlang:system_time(nanosecond), @@ -183,12 +198,20 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S {next_state, NextState, State} end; -wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> +wait(cast, {end_stream , _Data}, State= #workerGeneric_state{myName = MyName}) -> + %logger:notice("Waiting, next state - idle"), + io:format("Worker ~p @wait got end_stream~n",[MyName]), + {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; + +wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> %logger:notice("Waiting, next state - idle"), io:format("Worker ~p @wait is going to state idle...~n",[MyName]), - update_client_avilable_worker(MyName), DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), - {next_state, wait, State#workerGeneric_state{nextState = idle}}; + case NextState of + end_stream -> {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; + _ -> update_client_avilable_worker(MyName), + {next_state, wait, State#workerGeneric_state{nextState = idle}} + end; wait(cast, {training}, State) -> %logger:notice("Waiting, next state - train"), @@ -230,7 +253,7 @@ train(cast, {sample, BatchID ,{<<>>, _Type}}, State) -> train(cast, {sample, SourceName ,BatchID ,{NerlTensorOfSamples, NerlTensorType}}, State = #workerGeneric_state{modelID = ModelId, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData, myName = MyName}) -> % NerlTensor = nerltensor_conversion({NerlTensorOfSamples, Type}, erl_float), MyPid = self(), - DistributedBehaviorFunc(pre_train, {get(generic_worker_ets),DistributedWorkerData}), + DistributedBehaviorFunc(pre_train, {get(generic_worker_ets),DistributedWorkerData}), % Here the model can be updated by the federated server WorkersStatsEts = get(worker_stats_ets), stats:increment_by_value(WorkersStatsEts , batches_received_train , 1), _Pid = spawn(fun()-> nerlNIF:call_to_train(ModelId , {NerlTensorOfSamples, NerlTensorType} ,MyPid , BatchID , SourceName) end), @@ -255,12 +278,12 @@ train(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distri train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> io:format("Worker ~p got start_stream~n",[MyName]), - DistributedBehaviorFunc(start_stream, {get(generic_worker_ets), [SourceName]}), + stream_handler(start_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; train(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> io:format("Worker ~p got end_stream~n",[MyName]), - DistributedBehaviorFunc(end_stream, {get(generic_worker_ets), [SourceName]}), + stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; train(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> @@ -292,12 +315,14 @@ predict(cast, {sample , SourceName , BatchID , {PredictBatchTensor, Type}}, Stat {next_state, wait, State#workerGeneric_state{nextState = predict , currentBatchID = BatchID}}; predict(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - % DistributedBehaviorFunc(start_stream, {get(generic_worker_ets), [SourceName]}), - {next_state, predict, State}; + io:format("Worker ~p got start_stream~n",[MyName]), + stream_handler(start_stream, predict, SourceName, DistributedBehaviorFunc), + {next_state, train, State}; predict(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - % DistributedBehaviorFunc(end_stream, {get(generic_worker_ets), [SourceName]}), - {next_state, predict, State}; + io:format("Worker ~p got end_stream~n",[MyName]), + stream_handler(end_stream, predict, SourceName, DistributedBehaviorFunc), + {next_state, train, State}; predict(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> % DistributedBehaviorFunc(worker_done, {get(generic_worker_ets),[MyName]}), @@ -322,3 +347,15 @@ worker_controller_message_queue(ReceiveData) -> worker_controller_empty_message_queue() -> ets:update_element(get(generic_worker_ets), controller_message_q, {?ETS_KEYVAL_VAL_IDX , []}). + +stream_handler(StreamPhase , ModelPhase , SourceName , DistributedBehaviorFunc) -> + GenWorkerEts = get(generic_worker_ets), + ets:update_element(GenWorkerEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, true}), + CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), + NewCastingSources = + case StreamPhase of + start_stream -> CastingSources ++ [SourceName]; + end_stream -> CastingSources -- [SourceName] + end, + ets:update_element(GenWorkerEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}), + DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [SourceName , ModelPhase]}). \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl index b68d8a96..69018d6e 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl @@ -13,7 +13,6 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> post_predict -> post_predict({GenWorkerEts, WorkerData}); start_stream -> start_stream({GenWorkerEts, WorkerData}); end_stream -> end_stream({GenWorkerEts, WorkerData}); - worker_done -> worker_done({GenWorkerEts, WorkerData}); update -> update({GenWorkerEts, WorkerData}) end. @@ -37,6 +36,5 @@ start_stream({_GenWorkerEts, _WorkerData}) -> ok. end_stream({_GenWorkerEts, _WorkerData}) -> ok. -worker_done({_GenWorkerEts, _WorkerData}) -> ok. diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 44de04bb..5997d428 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -97,6 +97,7 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh NumOfFedServers = ets:lookup_element(EtsRef, num_of_fed_servers, ?DATA_IDX), % When non-federated exp this value is 0 ets:insert(EtsRef, {num_of_training_workers, length(ClientWorkers) - NumOfFedServers}), % This number will not change ets:insert(EtsRef, {training_workers, 0}), % will be updated in idle -> training & end_stream + ets:insert(EtsRef, {active_workers_sources_list, []}), % update dictionary WorkersEts = ets:lookup_element(EtsRef , workers_ets , ?DATA_IDX), put(workers_ets, WorkersEts), @@ -187,8 +188,6 @@ idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRe MessageToCast = {training}, cast_message_to_workers(EtsRef, MessageToCast), ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, false}), - NumOfTrainingWorkers = ets:lookup_element(EtsRef, num_of_training_workers, ?DATA_IDX), - ets:update_element(EtsRef, training_workers, {?DATA_IDX, NumOfTrainingWorkers}), % Reset the number of training workers {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = clientWorkersFunctions:get_workers_names(EtsRef), nextState = training}}; idle(cast, In = {predict}, State = #client_statem_state{etsRef = EtsRef}) -> @@ -262,6 +261,8 @@ training(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef} % ************* NEW *************** training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, SourceName}]}), ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), @@ -273,24 +274,17 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), - WorkersOfThisClient = ets:lookup_element(EtsRef, workersNames, ?DATA_IDX), - NumOfTrainingWorkers = ets:lookup_element(EtsRef, training_workers, ?DATA_IDX), - io:format("Client ~p received end_stream to worker ~p , remaining training workers ~p~n",[ClientName, WorkerName, NumOfTrainingWorkers]), - WorkerOfThisClient = lists:member(list_to_atom(WorkerName), WorkersOfThisClient), - if WorkerOfThisClient -> - stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), - io:format("Worker ~p Pid is ~p~n",[WorkerName, WorkerPid]), - gen_statem:cast(WorkerPid, {end_stream, SourceName}), % WHY THIS IS NOT WORKING???? - UpdatedNumOfTrainingWorkers = NumOfTrainingWorkers - 1, - io:format("UpdatedNumOfTrainingWorkers = ~p~n",[UpdatedNumOfTrainingWorkers]), - ets:update_element(EtsRef, training_workers, {?DATA_IDX, UpdatedNumOfTrainingWorkers}), - case UpdatedNumOfTrainingWorkers of - 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); - _ -> ok end; - true -> ok - end, + ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), + UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], + ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), + io:format("Client ~p received end_stream to worker ~p , remaining training workers ~p~n",[ClientName, WorkerName , UpdatedListOfActiveWorkerSources]), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), % WHY THIS IS NOT WORKING???? + case length(UpdatedListOfActiveWorkerSources) of + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + _ -> ok end, {keep_state, State}; diff --git a/src_erl/NerlnetApp/src/Source/sourceStatem.erl b/src_erl/NerlnetApp/src/Source/sourceStatem.erl index cadedcf9..4325d279 100644 --- a/src_erl/NerlnetApp/src/Source/sourceStatem.erl +++ b/src_erl/NerlnetApp/src/Source/sourceStatem.erl @@ -365,11 +365,11 @@ transmitter(TimeInterval_ms, SourceEtsRef, SourcePid ,ClientWorkerPairs, Batches ets:insert(TransmitterEts, {batches_skipped, 0}), ets:insert(TransmitterEts, {current_batch_id, 0}), TransmissionStart = erlang:timestamp(), - % Message to all workrers : "start_stream" + % Message to all workrers : "start_stream" , TRANSFER TO FUNCTIONS {RouterHost, RouterPort} = ets:lookup_element(TransmitterEts, my_router, ?DATA_IDX), - FuncStart = fun({ClientName, WorkerName}) -> - ToSend = {MyName, ClientName, WorkerName}, - io:format("~p sending start_stream to ~p of worker ~p~n",[MyName, ClientName, WorkerName]), + FuncStart = fun({ClientName, WorkerNameStr}) -> + ToSend = {MyName, ClientName, list_to_atom(WorkerNameStr)}, + io:format("~p sending start_stream to ~p of worker ~p~n",[MyName, ClientName, WorkerNameStr]), nerl_tools:http_router_request(RouterHost, RouterPort, [ClientName], atom_to_list(start_stream), ToSend) end, lists:foreach(FuncStart, ClientWorkerPairs), @@ -380,9 +380,9 @@ transmitter(TimeInterval_ms, SourceEtsRef, SourcePid ,ClientWorkerPairs, Batches _Default -> send_method_casting(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend) end, % Message to workers : "end_stream" - FuncEnd = fun({ClientName, WorkerName}) -> - ToSend = {MyName, ClientName, WorkerName}, - io:format("~p sending end_stream to ~p of worker ~p~n",[MyName, ClientName, WorkerName]), + FuncEnd = fun({ClientName, WorkerNameStr}) -> + ToSend = {MyName, ClientName, list_to_atom(WorkerNameStr)}, + io:format("~p sending end_stream to ~p of worker ~p~n",[MyName, ClientName, WorkerNameStr]), nerl_tools:http_router_request(RouterHost, RouterPort, [ClientName], atom_to_list(end_stream), ToSend) end, lists:foreach(FuncEnd, ClientWorkerPairs), From 4c78db84374a6f37151ce9399909cd56c992dcbe Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 10:45:56 +0000 Subject: [PATCH 19/52] [W2W] Fixed bugs --- src_erl/NerlnetApp/src/Client/clientStatem.erl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 5997d428..c23ef6cd 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -266,7 +266,7 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; @@ -280,7 +280,7 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E io:format("Client ~p received end_stream to worker ~p , remaining training workers ~p~n",[ClientName, WorkerName , UpdatedListOfActiveWorkerSources]), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), gen_statem:cast(WorkerPid, {end_stream, SourceName}), % WHY THIS IS NOT WORKING???? case length(UpdatedListOfActiveWorkerSources) of 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); @@ -350,7 +350,7 @@ predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; @@ -359,7 +359,7 @@ predict(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = Et ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , list_to_atom(WorkerName)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), gen_statem:cast(WorkerPid, {end_stream, SourceName}), {keep_state, State}; From aa377a3dc3a994f054aaf30e7e92e952b32a84c1 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 10:47:18 +0000 Subject: [PATCH 20/52] [W2W] Fixed bugs --- src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 011276c8..df69653a 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -73,6 +73,8 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData ets:insert(GenWorkerEts,{distributed_system_type, DistributedSystemType}), ets:insert(GenWorkerEts,{controller_message_q, []}), %% TODO Deprecated ets:insert(GenWorkerEts,{handshake_done, false}), + ets:insert(GenWorkerEts,{casting_sources, []}), + ets:insert(GenWorkerEts,{stream_occuring, false}), % Worker to Worker communication module - this is a gen_server From f8aed9026a69e8ecddc6bf050f6689658cc61428 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 10:51:20 +0000 Subject: [PATCH 21/52] [W2W] Fixed bugs --- .../onnWorkers/workerFederatedServer.erl | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index c174f5ee..dcf9577f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -75,10 +75,14 @@ start_stream({GenWorkerEts, _WorkerData}) -> W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:sync_inbox(W2WPid), InboxQueue = w2wCom:get_all_messages(W2WPid), - [Message] = queue:to_list(InboxQueue), - {FromFedClient , [_SourceName]} = Message, - ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], + MessagesList = queue:to_list(InboxQueue), + Func = fun({FromFedClient , _SourceName}) -> + ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], + ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}) + end, + lists:foreach(Func, MessagesList), + UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), LengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), case length(UpdatedActiveWorkers) of @@ -96,10 +100,14 @@ end_stream({GenWorkerEts, _WorkerData}) -> W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:sync_inbox(W2WPid), InboxQueue = w2wCom:get_all_messages(W2WPid), - [Message] = queue:to_list(InboxQueue), - {FromFedClient , [_SourceName]} = Message, - ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - UpdatedActiveWorkers = ActiveWorkers -- [FromFedClient], + MessagesList = queue:to_list(InboxQueue), + Func = fun({FromFedClient , _SourceName}) -> + ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], + ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}) + end, + lists:foreach(Func, MessagesList), + UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), case length(UpdatedActiveWorkers) of 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), From 08bd7cbcf52abdd2b3c1ec8de24963be2adb67a9 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 10:52:39 +0000 Subject: [PATCH 22/52] [W2W] Fixed bugs --- src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 1 + 1 file changed, 1 insertion(+) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index df69653a..719e3729 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -56,6 +56,7 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData put(worker_stats_ets , WorkerStatsEts), SourceBatchesEts = ets:new(source_batches,[set]), put(source_batches_ets, SourceBatchesEts), + ets:insert(GenWorkerEts,{client_pid, ClientPid}), ets:insert(GenWorkerEts,{w2wcom_pid, W2WPid}), ets:insert(GenWorkerEts,{worker_name, WorkerName}), ets:insert(GenWorkerEts,{model_id, ModelID}), From 9e3bf0265f48c0103fb1bc2a3bbf53f5077e7bcb Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 10:54:57 +0000 Subject: [PATCH 23/52] [W2W] Fixed bugs --- .../src/Bridge/onnWorkers/workerFederatedServer.erl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index dcf9577f..fa9454e6 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -112,8 +112,8 @@ end_stream({GenWorkerEts, _WorkerData}) -> case length(UpdatedActiveWorkers) of 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), - Data = {MyName, ClientName, MyName}, % Mimic source behavior to register as an active worker for the client + % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), + Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); _ -> ok end. From daaf969fc49b71c838133cd3ca869fab52dd76a0 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 10:55:59 +0000 Subject: [PATCH 24/52] [W2W] Fixed bugs --- .../src/Bridge/onnWorkers/workerFederatedServer.erl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index fa9454e6..78365bea 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -88,8 +88,8 @@ start_stream({GenWorkerEts, _WorkerData}) -> case length(UpdatedActiveWorkers) of LengthFedClients -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), - Data = {MyName, ClientName, MyName}, % Mimic source behavior to register as an active worker for the client + % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), + Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client gen_server:cast(ClientPid, {start_stream, term_to_binary(Data)}); _ -> ok end. From 37ddd83ef2d4f2e480d69481ccaf63a1c32455cd Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 18:56:36 +0000 Subject: [PATCH 25/52] [W2W] Fixed bugs --- .../src/Bridge/onnWorkers/w2wCom.erl | 1 + .../onnWorkers/workerFederatedClient.erl | 6 +- .../onnWorkers/workerFederatedServer.erl | 94 +++++++++++-------- .../src/Bridge/onnWorkers/workerGeneric.erl | 6 +- .../NerlnetApp/src/Client/clientStatem.erl | 4 +- .../NerlnetApp/src/Source/sourceStatem.erl | 2 - 6 files changed, 65 insertions(+), 48 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index d62c3fcd..470f7ba9 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -44,6 +44,7 @@ handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {msg_with_event, Even _ -> throw({error, "The provided worker name is not this worker"}) end, GenWorkerPid = get(gen_worker_pid), + io:format("~p got message with event ~p from ~p~n", [ThisWorkerName, Event, FromWorkerName]), case Event of post_train_update -> gen_statem:cast(GenWorkerPid, {post_train_update, Data}); worker_done -> gen_statem:cast(GenWorkerPid, {worker_done, Data}); diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 3405b46d..1ea37b02 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -93,7 +93,8 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source - 1 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, []); + 1 -> io:format("~p sending start_stream msg to ~p~n",[MyName, ServerName]), + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, []); _ -> ok end; predict -> ok @@ -109,7 +110,8 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source - 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, []); + 0 -> io:format("~p sending end_stream msg to ~p~n",[MyName, ServerName]), + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, []); _ -> ok end; predict -> ok diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 78365bea..3cf0bf6b 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -72,50 +72,66 @@ init({GenWorkerEts, WorkerData}) -> start_stream({GenWorkerEts, _WorkerData}) -> FedServerEts = get_this_server_ets(GenWorkerEts), - W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:sync_inbox(W2WPid), - InboxQueue = w2wCom:get_all_messages(W2WPid), - MessagesList = queue:to_list(InboxQueue), - Func = fun({FromFedClient , _SourceName}) -> - ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], - ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}) - end, - lists:foreach(Func, MessagesList), - UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), - LengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), - case length(UpdatedActiveWorkers) of - LengthFedClients -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), - Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client - gen_server:cast(ClientPid, {start_stream, term_to_binary(Data)}); - _ -> ok + CurrUpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + CurrLengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), + case length(CurrUpdatedActiveWorkers) of + CurrLengthFedClients -> ok; + _Else -> + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox(W2WPid), + InboxQueue = w2wCom:get_all_messages(W2WPid), + MessagesList = queue:to_list(InboxQueue), + io:format("@FedServer MessagesList = ~p~n",[MessagesList]), + Func = fun({FromFedClient , _SourceName}) -> + ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], + ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}) + end, + lists:foreach(Func, MessagesList), + UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), + LengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), + io:format("@FedServer start_stream UpdatedActiveWorkers = ~p , Num Of Fed Clients = ~p~n",[UpdatedActiveWorkers, LengthFedClients]), + case length(UpdatedActiveWorkers) of + LengthFedClients -> io:format("*****Federated Server Is Done With start_stream *****~n"), + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), + Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client + gen_server:cast(ClientPid, {start_stream, term_to_binary(Data)}); + _ -> start_stream({GenWorkerEts, []}) % If not all messages were received when inbox was synced + end end. end_stream({GenWorkerEts, _WorkerData}) -> FedServerEts = get_this_server_ets(GenWorkerEts), - W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:sync_inbox(W2WPid), - InboxQueue = w2wCom:get_all_messages(W2WPid), - MessagesList = queue:to_list(InboxQueue), - Func = fun({FromFedClient , _SourceName}) -> - ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], - ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}) - end, - lists:foreach(Func, MessagesList), - UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), - case length(UpdatedActiveWorkers) of - 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), - Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client - gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); - _ -> ok + CurrUpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + case CurrUpdatedActiveWorkers of + [] -> ok; % if there are no active workers, no need to do anything + _Else -> + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox(W2WPid), + InboxQueue = w2wCom:get_all_messages(W2WPid), + MessagesList = queue:to_list(InboxQueue), + Func = fun({FromFedClient , _SourceName}) -> + ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + UpdatedActiveWorkers = ActiveWorkers -- [FromFedClient], + ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}) + end, + lists:foreach(Func, MessagesList), + UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), + io:format("@FedServer end_stream UpdatedActiveWorkers = ~p~n",[UpdatedActiveWorkers]), + case length(UpdatedActiveWorkers) of + 0 -> io:format("*****Federated Server Is Done With end_stream *****~n"), + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), + Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client + gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); + _ -> end_stream({GenWorkerEts, []}) % If not all messages were received when inbox was synced + end end. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 719e3729..3f996c29 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -144,7 +144,7 @@ code_change(_OldVsn, StateName, State = #workerGeneric_state{}, _Extra) -> % Go from idle to train idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("@idle got training , Worker ~p is going to state idle...~n",[MyName]), + % io:format("@idle got training , Worker ~p is going to state idle...~n",[MyName]), worker_controller_empty_message_queue(), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), train}), update_client_avilable_worker(MyName), @@ -152,14 +152,14 @@ idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distribute % Go from idle to predict idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("@idle got predict , Worker ~p is going to state predict...~n",[MyName]), + % io:format("@idle got predict , Worker ~p is going to state predict...~n",[MyName]), % worker_controller_empty_message_queue(), update_client_avilable_worker(MyName), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), {next_state, predict, State#workerGeneric_state{lastPhase = predict}}; idle(cast, _Param, State = #workerGeneric_state{myName = MyName}) -> - io:format("@idle Worker ~p is going to state idle...~n",[MyName]), + % io:format("@idle Worker ~p is going to state idle...~n",[MyName]), {next_state, idle, State}. %% Waiting for receiving results or loss function diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index c23ef6cd..e8d52c41 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -130,7 +130,7 @@ waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state io:format("Client ~p is ready~n",[MyName]), stats:increment_messages_sent(ClientStatsEts), {next_state, NextState, State#client_statem_state{waitforWorkers = []}}; - _ -> io:format("Client ~p is waiting for workers ~p~n",[MyName,NewWaitforWorkers]), + _ -> %io:format("Client ~p is waiting for workers ~p~n",[MyName,NewWaitforWorkers]), {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = NewWaitforWorkers}} end; @@ -294,7 +294,7 @@ training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), MessageToCast = {idle}, WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX), - io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]), + % io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]), case WorkersDone of true -> cast_message_to_workers(EtsRef, MessageToCast), Workers = clientWorkersFunctions:get_workers_names(EtsRef), diff --git a/src_erl/NerlnetApp/src/Source/sourceStatem.erl b/src_erl/NerlnetApp/src/Source/sourceStatem.erl index 4325d279..99b492a0 100644 --- a/src_erl/NerlnetApp/src/Source/sourceStatem.erl +++ b/src_erl/NerlnetApp/src/Source/sourceStatem.erl @@ -369,7 +369,6 @@ transmitter(TimeInterval_ms, SourceEtsRef, SourcePid ,ClientWorkerPairs, Batches {RouterHost, RouterPort} = ets:lookup_element(TransmitterEts, my_router, ?DATA_IDX), FuncStart = fun({ClientName, WorkerNameStr}) -> ToSend = {MyName, ClientName, list_to_atom(WorkerNameStr)}, - io:format("~p sending start_stream to ~p of worker ~p~n",[MyName, ClientName, WorkerNameStr]), nerl_tools:http_router_request(RouterHost, RouterPort, [ClientName], atom_to_list(start_stream), ToSend) end, lists:foreach(FuncStart, ClientWorkerPairs), @@ -382,7 +381,6 @@ transmitter(TimeInterval_ms, SourceEtsRef, SourcePid ,ClientWorkerPairs, Batches % Message to workers : "end_stream" FuncEnd = fun({ClientName, WorkerNameStr}) -> ToSend = {MyName, ClientName, list_to_atom(WorkerNameStr)}, - io:format("~p sending end_stream to ~p of worker ~p~n",[MyName, ClientName, WorkerNameStr]), nerl_tools:http_router_request(RouterHost, RouterPort, [ClientName], atom_to_list(end_stream), ToSend) end, lists:foreach(FuncEnd, ClientWorkerPairs), From d82f984dcc4884dfd1feb2b664ea1064c2a6cec4 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 19:01:37 +0000 Subject: [PATCH 26/52] [W2W] Predict phase update --- .../src/Bridge/onnWorkers/workerGeneric.erl | 13 ++----------- src_erl/NerlnetApp/src/Client/clientStatem.erl | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 3f996c29..f314e9d4 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -158,7 +158,7 @@ idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributed DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), {next_state, predict, State#workerGeneric_state{lastPhase = predict}}; -idle(cast, _Param, State = #workerGeneric_state{myName = MyName}) -> +idle(cast, _Param, State = #workerGeneric_state{myName = _MyName}) -> % io:format("@idle Worker ~p is going to state idle...~n",[MyName]), {next_state, idle, State}. @@ -253,7 +253,7 @@ train(cast, {sample, BatchID ,{<<>>, _Type}}, State) -> {next_state, train, State#workerGeneric_state{nextState = train , currentBatchID = BatchID}}; %% Change SampleListTrain to NerlTensor -train(cast, {sample, SourceName ,BatchID ,{NerlTensorOfSamples, NerlTensorType}}, State = #workerGeneric_state{modelID = ModelId, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData, myName = MyName}) -> +train(cast, {sample, SourceName ,BatchID ,{NerlTensorOfSamples, NerlTensorType}}, State = #workerGeneric_state{modelID = ModelId, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData, myName = _MyName}) -> % NerlTensor = nerltensor_conversion({NerlTensorOfSamples, Type}, erl_float), MyPid = self(), DistributedBehaviorFunc(pre_train, {get(generic_worker_ets),DistributedWorkerData}), % Here the model can be updated by the federated server @@ -274,11 +274,6 @@ train(cast, {post_train_update}, State = #workerGeneric_state{myName = MyName, d DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), {next_state, train, State}; -train(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("Worker ~p got worker_done~n",[MyName]), - DistributedBehaviorFunc(worker_done, {get(generic_worker_ets),[MyName]}), - {next_state, idle, State}; - train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> io:format("Worker ~p got start_stream~n",[MyName]), stream_handler(start_stream, train, SourceName, DistributedBehaviorFunc), @@ -327,10 +322,6 @@ predict(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = M stream_handler(end_stream, predict, SourceName, DistributedBehaviorFunc), {next_state, train, State}; -predict(cast, {worker_done}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - % DistributedBehaviorFunc(worker_done, {get(generic_worker_ets),[MyName]}), - {next_state, predict, State}; - predict(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), predict}), diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index e8d52c41..4019548b 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -345,8 +345,11 @@ predict(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef}) end, {next_state, predict, State#client_statem_state{etsRef = EtsRef}}; +% ************* NEW *************** predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, SourceName}]}), ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), @@ -354,13 +357,21 @@ predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; +% ************* NEW *************** predict(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> - {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + {SourceName, ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), + ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), + UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], + ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), + io:format("Client ~p received end_stream to worker ~p , remaining training workers ~p~n",[ClientName, WorkerName , UpdatedListOfActiveWorkerSources]), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), - gen_statem:cast(WorkerPid, {end_stream, SourceName}), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), % WHY THIS IS NOT WORKING???? + case length(UpdatedListOfActiveWorkerSources) of + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + _ -> ok end, {keep_state, State}; predict(cast, In = {predictRes,WorkerName, SourceName ,{PredictNerlTensor, NetlTensorType} , TimeTook , BatchID , BatchTS}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> From adcf220ce1baec1c289827b279fb2546ff7cb1f1 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 19:42:09 +0000 Subject: [PATCH 27/52] [W2W] Fix bug --- .../NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index f314e9d4..b2048915 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -152,7 +152,7 @@ idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distribute % Go from idle to predict idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - % io:format("@idle got predict , Worker ~p is going to state predict...~n",[MyName]), + io:format("@idle got predict , Worker ~p is going to state predict...~n",[MyName]), % worker_controller_empty_message_queue(), update_client_avilable_worker(MyName), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), @@ -208,12 +208,11 @@ wait(cast, {end_stream , _Data}, State= #workerGeneric_state{myName = MyName}) - wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> %logger:notice("Waiting, next state - idle"), - io:format("Worker ~p @wait is going to state idle...~n",[MyName]), DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), case NextState of end_stream -> {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; _ -> update_client_avilable_worker(MyName), - {next_state, wait, State#workerGeneric_state{nextState = idle}} + {next_state, idle, State#workerGeneric_state{nextState = idle}} end; wait(cast, {training}, State) -> @@ -315,12 +314,12 @@ predict(cast, {sample , SourceName , BatchID , {PredictBatchTensor, Type}}, Stat predict(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> io:format("Worker ~p got start_stream~n",[MyName]), stream_handler(start_stream, predict, SourceName, DistributedBehaviorFunc), - {next_state, train, State}; + {next_state, predict, State}; predict(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> io:format("Worker ~p got end_stream~n",[MyName]), stream_handler(end_stream, predict, SourceName, DistributedBehaviorFunc), - {next_state, train, State}; + {next_state, predict, State}; predict(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), From 6f6e787e8aff324818edaa5bb7c4e6e5b9d0b9aa Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 20:32:08 +0000 Subject: [PATCH 28/52] [W2W] Done --- .../dc_fed_dist_2d_3c_2s_3r_6w.json | 12 ++++++------ .../exp_fed_dist_2d_3c_2s_3r_6w.json | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json index 0af11814..0882bc08 100644 --- a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json +++ b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json @@ -111,7 +111,7 @@ "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", "modelArgs": "", "_doc_modelArgs": "Extra arguments to model", - "layersSizes": "5,2,2,2,3", + "layersSizes": "5,16,8,4,3", "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", "layerTypesList": "1,3,3,3,5", "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", @@ -120,7 +120,7 @@ "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", - "lossMethod": "2", + "lossMethod": "6", "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", "lr": "0.01", "_doc_lr": "Positve float", @@ -134,7 +134,7 @@ "_doc_infraType": " opennn:0 | wolfengine:1 |", "distributedSystemType": "1", "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", - "distributedSystemArgs": "SyncMaxCount=5", + "distributedSystemArgs": "SyncMaxCount=50", "_doc_distributedSystemArgs": "String", "distributedSystemToken": "9922u", "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" @@ -144,7 +144,7 @@ "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", "modelArgs": "", "_doc_modelArgs": "Extra arguments to model", - "layersSizes": "5,2,2,2,3", + "layersSizes": "5,16,8,4,3", "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", "layerTypesList": "1,3,3,3,5", "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", @@ -153,7 +153,7 @@ "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", - "lossMethod": "2", + "lossMethod": "6", "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", "lr": "0.01", "_doc_lr": "Positve float", @@ -167,7 +167,7 @@ "_doc_infraType": " opennn:0 | wolfengine:1 |", "distributedSystemType": "2", "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", - "distributedSystemArgs": "SyncMaxCount=5", + "distributedSystemArgs": "SyncMaxCount=50", "_doc_distributedSystemArgs": "String", "distributedSystemToken": "9922u", "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json index 83011c96..b96649d5 100644 --- a/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json +++ b/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json @@ -16,14 +16,14 @@ { "sourceName": "s1", "startingSample": "0", - "numOfBatches": "200", + "numOfBatches": "250", "workers": "w1,w2,w3,w4", "nerltensorType": "float" }, { "sourceName": "s2", - "startingSample": "20000", - "numOfBatches": "200", + "startingSample": "25000", + "numOfBatches": "250", "workers": "w5,w6", "nerltensorType": "float" } @@ -36,15 +36,15 @@ [ { "sourceName": "s1", - "startingSample": "40000", - "numOfBatches": "50", + "startingSample": "50000", + "numOfBatches": "500", "workers": "w1,w2,w3,w4", "nerltensorType": "float" }, { "sourceName": "s2", - "startingSample": "45000", - "numOfBatches": "50", + "startingSample": "50000", + "numOfBatches": "500", "workers": "w5,w6", "nerltensorType": "float" } From 4960724ceeb5e821a1eef7d999ce2aef782e7082 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Tue, 21 May 2024 21:23:00 +0000 Subject: [PATCH 29/52] [W2W] Fixed averaging --- .../src/Bridge/onnWorkers/w2wCom.erl | 1 - .../onnWorkers/workerFederatedClient.erl | 23 +++++---- .../onnWorkers/workerFederatedServer.erl | 29 +++++------ .../src/Bridge/onnWorkers/workerGeneric.erl | 48 +++++++++---------- .../NerlnetApp/src/Client/clientStatem.erl | 14 ++---- 5 files changed, 47 insertions(+), 68 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index 470f7ba9..d62c3fcd 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -44,7 +44,6 @@ handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {msg_with_event, Even _ -> throw({error, "The provided worker name is not this worker"}) end, GenWorkerPid = get(gen_worker_pid), - io:format("~p got message with event ~p from ~p~n", [ThisWorkerName, Event, FromWorkerName]), case Event of post_train_update -> gen_statem:cast(GenWorkerPid, {post_train_update, Data}); worker_done -> gen_statem:cast(GenWorkerPid, {worker_done, Data}); diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 1ea37b02..5071dbb0 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -84,7 +84,7 @@ handshake(FedClientEts) -> lists:foreach(Func, MessagesList). start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName, State] - [_SourceName, State] = WorkerData, + [SourceName, State] = WorkerData, case State of train -> ThisEts = get_this_client_ets(GenWorkerEts), @@ -93,15 +93,15 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source - 1 -> io:format("~p sending start_stream msg to ~p~n",[MyName, ServerName]), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, []); + 1 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, true}), + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, SourceName); _ -> ok end; predict -> ok end. end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] - [_SourceName, State] = WorkerData, + [SourceName, State] = WorkerData, case State of train -> ThisEts = get_this_client_ets(GenWorkerEts), @@ -110,8 +110,8 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source - 0 -> io:format("~p sending end_stream msg to ~p~n",[MyName, ServerName]), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, []); + 0 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, false}), + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, SourceName); _ -> ok end; predict -> ok @@ -144,13 +144,10 @@ post_idle({GenWorkerEts, _WorkerData}) -> pre_train({GenWorkerEts, _NerlTensorWeights}) -> ThisEts = get_this_client_ets(GenWorkerEts), StreamOccuring = ets:lookup_element(ThisEts, stream_occuring, ?ETS_KEYVAL_VAL_IDX), - % io:format("StreamOccuring = ~p~n", [StreamOccuring]), case StreamOccuring of true -> ThisEts = get_this_client_ets(GenWorkerEts), SyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_count, ?ETS_KEYVAL_VAL_IDX), - WorkerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), - ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), if SyncCount == MaxSyncCount -> W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), @@ -168,10 +165,13 @@ pre_train({GenWorkerEts, _NerlTensorWeights}) -> %% every countLimit batches, send updated weights post_train({GenWorkerEts, _WorkerData}) -> - CastingSources = ets:lookup_element(get_this_client_ets(GenWorkerEts), casting_sources, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), + CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), + % io:format("Worker ~p CastingSources ~p~n",[MyName, CastingSources]), case CastingSources of [] -> ok; _ -> + % io:format("Worker ~p Sending weights to server~n",[MyName]), ThisEts = get_this_client_ets(GenWorkerEts), SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), @@ -179,9 +179,8 @@ post_train({GenWorkerEts, _WorkerData}) -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), Weights = nerlNIF:call_to_get_weights(ModelID), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:send_message(W2WPid, MyName, ServerName , {post_train_update, Weights}); %% ****** NEW - TEST NEEDED ****** + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , post_train_update, Weights); true -> ok end end. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 3cf0bf6b..9b617274 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -81,7 +81,6 @@ start_stream({GenWorkerEts, _WorkerData}) -> w2wCom:sync_inbox(W2WPid), InboxQueue = w2wCom:get_all_messages(W2WPid), MessagesList = queue:to_list(InboxQueue), - io:format("@FedServer MessagesList = ~p~n",[MessagesList]), Func = fun({FromFedClient , _SourceName}) -> ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], @@ -91,10 +90,8 @@ start_stream({GenWorkerEts, _WorkerData}) -> UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), LengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), - io:format("@FedServer start_stream UpdatedActiveWorkers = ~p , Num Of Fed Clients = ~p~n",[UpdatedActiveWorkers, LengthFedClients]), case length(UpdatedActiveWorkers) of - LengthFedClients -> io:format("*****Federated Server Is Done With start_stream *****~n"), - ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + LengthFedClients -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client @@ -122,10 +119,8 @@ end_stream({GenWorkerEts, _WorkerData}) -> lists:foreach(Func, MessagesList), UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), - io:format("@FedServer end_stream UpdatedActiveWorkers = ~p~n",[UpdatedActiveWorkers]), case length(UpdatedActiveWorkers) of - 0 -> io:format("*****Federated Server Is Done With end_stream *****~n"), - ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client @@ -180,20 +175,18 @@ pre_train({_GenWorkerEts, _WorkerData}) -> ok. % 2. average them % 3. set new weights to model % 4. send new weights to all workers -post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerData = [] +post_train({GenWorkerEts, WeightsTensor}) -> ThisEts = get_this_server_ets(GenWorkerEts), FedServerEts = get(fed_server_ets), - W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - InboxQueue = w2wCom:get_all_messages(W2WPid), - MessagesList = queue:to_list(InboxQueue), - ReceivedWeights = [WorkersWeights || {_WorkerName, {WorkersWeights, _BinaryType}} <- MessagesList], CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX), - TotalWorkersWeights = CurrWorkersWeightsList ++ ReceivedWeights, + {WorkerWeights, _BinaryType} = WeightsTensor, + TotalWorkersWeights = CurrWorkersWeightsList ++ [WorkerWeights], NumOfActiveWorkers = length(ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX)), - case length(TotalWorkersWeights) == NumOfActiveWorkers of % ? Why not timeout - true -> + % io:format("NumOfActiveWorkers = ~p~n",[NumOfActiveWorkers]), + case length(TotalWorkersWeights) of + NumOfActiveWorkers -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - io:format("Averaging model weights...~n"), + % io:format("Averaging model weights...~n"), {CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID), FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), AllWorkersWeightsList = TotalWorkersWeights ++ [CurrentModelWeights], @@ -204,10 +197,10 @@ post_train({GenWorkerEts, WorkerData}) when length(WorkerData) == 0 -> % WorkerD W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:send_message(W2WPid, FedServerName, FedClient, {update_weights, AvgWeightsNerlTensor}) end, - WorkersList = ets:lookup_element(ThisEts, training_workers, ?ETS_KEYVAL_VAL_IDX), + WorkersList = ets:lookup_element(ThisEts, active_workers, ?ETS_KEYVAL_VAL_IDX), lists:foreach(Func, WorkersList), ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, []}); - false -> ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, TotalWorkersWeights}) + _ -> ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, TotalWorkersWeights}) end. %% nothing? diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index b2048915..76b3dd5f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -152,7 +152,6 @@ idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distribute % Go from idle to predict idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("@idle got predict , Worker ~p is going to state predict...~n",[MyName]), % worker_controller_empty_message_queue(), update_client_avilable_worker(MyName), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), @@ -201,9 +200,8 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S {next_state, NextState, State} end; -wait(cast, {end_stream , _Data}, State= #workerGeneric_state{myName = MyName}) -> +wait(cast, {end_stream , _Data}, State= #workerGeneric_state{myName = _MyName}) -> %logger:notice("Waiting, next state - idle"), - io:format("Worker ~p @wait got end_stream~n",[MyName]), {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> @@ -217,18 +215,15 @@ wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehav wait(cast, {training}, State) -> %logger:notice("Waiting, next state - train"), - io:format("@wait got training , Worker is going to state idle...~n"), % gen_statem:cast(ClientPid,{stateChange,WorkerName}), {next_state, wait, State#workerGeneric_state{nextState = train}}; wait(cast, {predict}, State) -> - io:format("@wait got predict , Worker is going to state idle...~n"), %logger:notice("Waiting, next state - predict"), {next_state, wait, State#workerGeneric_state{nextState = predict}}; %% Worker in wait can't treat incoming message -wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase, myName= MyName}) -> - io:format("@wait got something , Worker ~p is going to state idle...~n",[MyName]), +wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase, myName= _MyName}) -> case LastPhase of train -> ets:update_counter(get(worker_stats_ets), batches_dropped_train , 1); @@ -239,7 +234,6 @@ wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase, myNa wait(cast, Data, State) -> % logger:notice("worker ~p in wait cant treat message: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), - io:format("@wait got something2 , Worker is going to state idle...~n"), worker_controller_message_queue(Data), {keep_state, State}. @@ -268,23 +262,19 @@ train(cast, {set_weights,Ret_weights_list}, State = #workerGeneric_state{modelID %logger:notice("####end set weights train####~n"), {next_state, train, State}; -train(cast, {post_train_update}, State = #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("Worker ~p got post_train_update~n",[MyName]), - DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), +train(cast, {post_train_update , Weights}, State = #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(post_train, {get(generic_worker_ets), Weights}), {next_state, train, State}; -train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("Worker ~p got start_stream~n",[MyName]), +train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> stream_handler(start_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; -train(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("Worker ~p got end_stream~n",[MyName]), +train(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; train(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("@train Worker ~p is going to state idle...~n",[MyName]), update_client_avilable_worker(MyName), DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), {next_state, idle, State}; @@ -311,13 +301,11 @@ predict(cast, {sample , SourceName , BatchID , {PredictBatchTensor, Type}}, Stat _Pid = spawn(fun()-> nerlNIF:call_to_predict(ModelId , {PredictBatchTensor, Type} , CurrPID , BatchID, SourceName) end), {next_state, wait, State#workerGeneric_state{nextState = predict , currentBatchID = BatchID}}; -predict(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("Worker ~p got start_stream~n",[MyName]), +predict(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> stream_handler(start_stream, predict, SourceName, DistributedBehaviorFunc), {next_state, predict, State}; -predict(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("Worker ~p got end_stream~n",[MyName]), +predict(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> stream_handler(end_stream, predict, SourceName, DistributedBehaviorFunc), {next_state, predict, State}; @@ -345,10 +333,18 @@ stream_handler(StreamPhase , ModelPhase , SourceName , DistributedBehaviorFunc) GenWorkerEts = get(generic_worker_ets), ets:update_element(GenWorkerEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, true}), CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), - NewCastingSources = - case StreamPhase of - start_stream -> CastingSources ++ [SourceName]; - end_stream -> CastingSources -- [SourceName] - end, - ets:update_element(GenWorkerEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}), + IsSource = case string:substr(atom_to_list(SourceName), 1, 1) of % Could be a FedServer sending to himself + "s" -> true; + _ -> false + end, + case IsSource of + true -> + NewCastingSources = + case StreamPhase of + start_stream -> CastingSources ++ [SourceName]; + end_stream -> CastingSources -- [SourceName] + end, + ets:update_element(GenWorkerEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}); + false -> ok + end, DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [SourceName , ModelPhase]}). \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 4019548b..20598a66 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -127,7 +127,6 @@ waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), case NewWaitforWorkers of % TODO Guy here we need to check for keep alive with workers [] -> send_client_is_ready(MyName), % when all workers done their work - io:format("Client ~p is ready~n",[MyName]), stats:increment_messages_sent(ClientStatsEts), {next_state, NextState, State#client_statem_state{waitforWorkers = []}}; _ -> %io:format("Client ~p is waiting for workers ~p~n",[MyName,NewWaitforWorkers]), @@ -191,7 +190,6 @@ idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRe {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = clientWorkersFunctions:get_workers_names(EtsRef), nextState = training}}; idle(cast, In = {predict}, State = #client_statem_state{etsRef = EtsRef}) -> - io:format("Client sending workers to predict state...~n"), ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), @@ -258,7 +256,6 @@ training(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef} true -> ?LOG_ERROR("Given worker ~p isn't found in client ~p",[WorkerName, ClientName]) end, {next_state, training, State#client_statem_state{etsRef = EtsRef}}; -% ************* NEW *************** training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), @@ -270,18 +267,16 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; -% ************* NEW *************** training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> - {SourceName, ClientName, WorkerName} = binary_to_term(Data), + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), - io:format("Client ~p received end_stream to worker ~p , remaining training workers ~p~n",[ClientName, WorkerName , UpdatedListOfActiveWorkerSources]), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), - gen_statem:cast(WorkerPid, {end_stream, SourceName}), % WHY THIS IS NOT WORKING???? + gen_statem:cast(WorkerPid, {end_stream, SourceName}), case length(UpdatedListOfActiveWorkerSources) of 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); _ -> ok end, @@ -345,7 +340,6 @@ predict(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef}) end, {next_state, predict, State#client_statem_state{etsRef = EtsRef}}; -% ************* NEW *************** predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), @@ -357,14 +351,12 @@ predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; -% ************* NEW *************** predict(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> - {SourceName, ClientName, WorkerName} = binary_to_term(Data), + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), - io:format("Client ~p received end_stream to worker ~p , remaining training workers ~p~n",[ClientName, WorkerName , UpdatedListOfActiveWorkerSources]), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), From d5d1c9d35abb5275cf7f17e2e65dcb3fa6eb027d Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 08:53:55 +0000 Subject: [PATCH 30/52] [W2W] WIP --- .../Bridge/onnWorkers/workerFederatedClient.erl | 1 + .../src/Bridge/onnWorkers/workerGeneric.erl | 16 ++++------------ 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 5071dbb0..a2f9c97f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -152,6 +152,7 @@ pre_train({GenWorkerEts, _NerlTensorWeights}) -> if SyncCount == MaxSyncCount -> W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:sync_inbox_no_limit(W2WPid), % waiting for server to average the weights and send it + io:format("@~p Updated weights received from server~n",[ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX)]), InboxQueue = w2wCom:get_all_messages(W2WPid), [UpdateWeightsMsg] = queue:to_list(InboxQueue), {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 76b3dd5f..a722cea1 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -333,18 +333,10 @@ stream_handler(StreamPhase , ModelPhase , SourceName , DistributedBehaviorFunc) GenWorkerEts = get(generic_worker_ets), ets:update_element(GenWorkerEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, true}), CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), - IsSource = case string:substr(atom_to_list(SourceName), 1, 1) of % Could be a FedServer sending to himself - "s" -> true; - _ -> false - end, - case IsSource of - true -> - NewCastingSources = - case StreamPhase of + NewCastingSources = + case StreamPhase of start_stream -> CastingSources ++ [SourceName]; end_stream -> CastingSources -- [SourceName] - end, - ets:update_element(GenWorkerEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}); - false -> ok - end, + end, + ets:update_element(GenWorkerEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}), DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [SourceName , ModelPhase]}). \ No newline at end of file From 2f508c61a3b4c5cb65fc592546a7d4fa55c18550 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 09:09:31 +0000 Subject: [PATCH 31/52] [W2W] WIP --- .../src/Bridge/onnWorkers/w2wCom.erl | 3 +- .../onnWorkers/workerFederatedClient.erl | 39 ++++++++----------- .../onnWorkers/workerFederatedServer.erl | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index d62c3fcd..5cea17ff 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -45,7 +45,8 @@ handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {msg_with_event, Even end, GenWorkerPid = get(gen_worker_pid), case Event of - post_train_update -> gen_statem:cast(GenWorkerPid, {post_train_update, Data}); + post_train_update -> io:format("~p got post_train_update~n",[ThisWorkerName]), + gen_statem:cast(GenWorkerPid, {post_train_update, Data}); worker_done -> gen_statem:cast(GenWorkerPid, {worker_done, Data}); start_stream -> gen_statem:cast(GenWorkerPid, {start_stream, Data}); % Data is [SourceName] end_stream -> gen_statem:cast(GenWorkerPid, {end_stream, Data}) % Data is [SourceName] diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index a2f9c97f..1fe1461c 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -143,25 +143,19 @@ post_idle({GenWorkerEts, _WorkerData}) -> % After SyncMaxCount , sync_inbox to get the updated model from FedServer pre_train({GenWorkerEts, _NerlTensorWeights}) -> ThisEts = get_this_client_ets(GenWorkerEts), - StreamOccuring = ets:lookup_element(ThisEts, stream_occuring, ?ETS_KEYVAL_VAL_IDX), - case StreamOccuring of - true -> - ThisEts = get_this_client_ets(GenWorkerEts), - SyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_count, ?ETS_KEYVAL_VAL_IDX), - MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), - if SyncCount == MaxSyncCount -> - W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:sync_inbox_no_limit(W2WPid), % waiting for server to average the weights and send it - io:format("@~p Updated weights received from server~n",[ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX)]), - InboxQueue = w2wCom:get_all_messages(W2WPid), - [UpdateWeightsMsg] = queue:to_list(InboxQueue), - {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - nerlNIF:call_to_set_weights(ModelID, UpdatedWeights), - ets:update_element(ThisEts, sync_count, {?ETS_KEYVAL_VAL_IDX , 0}); - true -> ets:update_counter(ThisEts, sync_count, 1) - end; - false -> ok + SyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_count, ?ETS_KEYVAL_VAL_IDX), + MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), + if SyncCount == MaxSyncCount -> + W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox_no_limit(W2WPid), % waiting for server to average the weights and send it + io:format("@~p Updated weights received from server~n",[ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX)]), + InboxQueue = w2wCom:get_all_messages(W2WPid), + [UpdateWeightsMsg] = queue:to_list(InboxQueue), + {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, + ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + nerlNIF:call_to_set_weights(ModelID, UpdatedWeights), + ets:update_element(ThisEts, sync_count, {?ETS_KEYVAL_VAL_IDX , 0}); + true -> ets:update_counter(ThisEts, sync_count, 1) end. %% every countLimit batches, send updated weights @@ -172,16 +166,17 @@ post_train({GenWorkerEts, _WorkerData}) -> case CastingSources of [] -> ok; _ -> - % io:format("Worker ~p Sending weights to server~n",[MyName]), ThisEts = get_this_client_ets(GenWorkerEts), SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), if SyncCount == MaxSyncCount -> + io:format("SyncCount = MaxSyncCount = ~p~n",[SyncCount]), + io:format("Worker ~p Sending weights to server~n",[MyName]), ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - Weights = nerlNIF:call_to_get_weights(ModelID), + WeightsTensor = nerlNIF:call_to_get_weights(ModelID), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , post_train_update, Weights); + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , post_train_update, WeightsTensor); true -> ok end end. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 9b617274..803e19cd 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -182,7 +182,7 @@ post_train({GenWorkerEts, WeightsTensor}) -> {WorkerWeights, _BinaryType} = WeightsTensor, TotalWorkersWeights = CurrWorkersWeightsList ++ [WorkerWeights], NumOfActiveWorkers = length(ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX)), - % io:format("NumOfActiveWorkers = ~p~n",[NumOfActiveWorkers]), + io:format("NumOfActiveWorkers = ~p~n",[NumOfActiveWorkers]), case length(TotalWorkersWeights) of NumOfActiveWorkers -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), From d07e62cc180283ea99478ab5c2a27a4464a7bff2 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 10:54:49 +0000 Subject: [PATCH 32/52] [W2W] WIP --- .../src/Bridge/onnWorkers/w2wCom.erl | 7 +-- .../onnWorkers/workerFederatedClient.erl | 11 ++-- .../onnWorkers/workerFederatedServer.erl | 59 ++++++++----------- .../NerlnetApp/src/Client/clientStatem.erl | 4 +- 4 files changed, 36 insertions(+), 45 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index 5cea17ff..94b5c692 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -45,11 +45,10 @@ handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {msg_with_event, Even end, GenWorkerPid = get(gen_worker_pid), case Event of - post_train_update -> io:format("~p got post_train_update~n",[ThisWorkerName]), - gen_statem:cast(GenWorkerPid, {post_train_update, Data}); + post_train_update -> gen_statem:cast(GenWorkerPid, {post_train_update, Data}); worker_done -> gen_statem:cast(GenWorkerPid, {worker_done, Data}); - start_stream -> gen_statem:cast(GenWorkerPid, {start_stream, Data}); % Data is [SourceName] - end_stream -> gen_statem:cast(GenWorkerPid, {end_stream, Data}) % Data is [SourceName] + start_stream -> gen_statem:cast(GenWorkerPid, {start_stream, Data}); + end_stream -> gen_statem:cast(GenWorkerPid, {end_stream, Data}) end, % Saved messages are of the form: {FromWorkerName, , Data} Message = {FromWorkerName, Data}, diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 1fe1461c..639ce06d 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -84,7 +84,7 @@ handshake(FedClientEts) -> lists:foreach(Func, MessagesList). start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName, State] - [SourceName, State] = WorkerData, + [_SourceName, State] = WorkerData, case State of train -> ThisEts = get_this_client_ets(GenWorkerEts), @@ -94,14 +94,15 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source 1 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, true}), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, SourceName); + io:format("FedWorker ~p sending start_stream to server ~p~n",[MyName, ServerName]), + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName); % Server gets FedWorkerName instead of SourceName _ -> ok end; predict -> ok end. end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] - [SourceName, State] = WorkerData, + [_SourceName, State] = WorkerData, case State of train -> ThisEts = get_this_client_ets(GenWorkerEts), @@ -111,10 +112,10 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source 0 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, false}), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, SourceName); + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName); _ -> ok end; - predict -> ok + predict -> ok end. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 803e19cd..716ffaf1 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -70,24 +70,17 @@ init({GenWorkerEts, WorkerData}) -> put(fed_server_ets, FederatedServerEts). -start_stream({GenWorkerEts, _WorkerData}) -> +start_stream({GenWorkerEts, WorkerData}) -> + [WorkerName , _ModelPhase] = WorkerData, FedServerEts = get_this_server_ets(GenWorkerEts), - CurrUpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), CurrLengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), - case length(CurrUpdatedActiveWorkers) of + case length(CurrActiveWorkers) of CurrLengthFedClients -> ok; _Else -> - W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:sync_inbox(W2WPid), - InboxQueue = w2wCom:get_all_messages(W2WPid), - MessagesList = queue:to_list(InboxQueue), - Func = fun({FromFedClient , _SourceName}) -> - ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - UpdatedActiveWorkers = ActiveWorkers ++ [FromFedClient], - ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}) - end, - lists:foreach(Func, MessagesList), - UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + io:format("FedServer got start_stream from ~p~n",[WorkerName]), + ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + UpdatedActiveWorkers = ActiveWorkers ++ [WorkerName], ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), LengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), case length(UpdatedActiveWorkers) of @@ -96,36 +89,30 @@ start_stream({GenWorkerEts, _WorkerData}) -> % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client gen_server:cast(ClientPid, {start_stream, term_to_binary(Data)}); - _ -> start_stream({GenWorkerEts, []}) % If not all messages were received when inbox was synced + _ -> ok end end. -end_stream({GenWorkerEts, _WorkerData}) -> +end_stream({GenWorkerEts, WorkerData}) -> + [WorkerName , _ModelPhase] = WorkerData, FedServerEts = get_this_server_ets(GenWorkerEts), - CurrUpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - case CurrUpdatedActiveWorkers of + CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + case CurrActiveWorkers of [] -> ok; % if there are no active workers, no need to do anything _Else -> - W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:sync_inbox(W2WPid), - InboxQueue = w2wCom:get_all_messages(W2WPid), - MessagesList = queue:to_list(InboxQueue), - Func = fun({FromFedClient , _SourceName}) -> - ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - UpdatedActiveWorkers = ActiveWorkers -- [FromFedClient], - ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}) - end, - lists:foreach(Func, MessagesList), - UpdatedActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + io:format("ActiveWorkers = ~p , got end stream from ~p removing it..~n",[ActiveWorkers, WorkerName]), + UpdatedActiveWorkers = ActiveWorkers -- [WorkerName], ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), case length(UpdatedActiveWorkers) of - 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), - Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client - gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); - _ -> end_stream({GenWorkerEts, []}) % If not all messages were received when inbox was synced + 0 -> io:format("GOT HEREEEE~n"), + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), + Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client + gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); + _ -> io:format("ActiveWorkers = ~p~n",[UpdatedActiveWorkers]) end end. @@ -164,6 +151,7 @@ post_idle({GenWorkerEts, _WorkerName}) -> w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake_done, MyToken}) end, lists:foreach(MsgFunc, MessagesList), + io:format("After handshake , FedClients = ~p~n",[ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)]), ets:update_element(GenWorkerEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}); true -> ok end. @@ -195,6 +183,7 @@ post_train({GenWorkerEts, WeightsTensor}) -> Func = fun(FedClient) -> FedServerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + io:format("Sending updated weights to ~p~n",[FedClient]), w2wCom:send_message(W2WPid, FedServerName, FedClient, {update_weights, AvgWeightsNerlTensor}) end, WorkersList = ets:lookup_element(ThisEts, active_workers, ?ETS_KEYVAL_VAL_IDX), diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 20598a66..5a6ed2db 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -128,6 +128,7 @@ waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state case NewWaitforWorkers of % TODO Guy here we need to check for keep alive with workers [] -> send_client_is_ready(MyName), % when all workers done their work stats:increment_messages_sent(ClientStatsEts), + io:format("Client ~p is ready~n", [MyName]), {next_state, NextState, State#client_statem_state{waitforWorkers = []}}; _ -> %io:format("Client ~p is waiting for workers ~p~n",[MyName,NewWaitforWorkers]), {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = NewWaitforWorkers}} @@ -276,7 +277,7 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), - gen_statem:cast(WorkerPid, {end_stream, SourceName}), + gen_statem:cast(WorkerPid, {end_stream, [SourceName]}), case length(UpdatedListOfActiveWorkerSources) of 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); _ -> ok end, @@ -463,6 +464,7 @@ create_encoded_stats_str(ListStatsEts) -> lists:flatten(lists:map(Func , ListStatsEts)). handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data) -> + io:format("~p sent w2w_msg to ~p~n",[FromWorker, ToWorker]), ClientStatsEts = get(client_stats_ets), WorkersOfThisClient = ets:lookup_element(EtsRef, workersNames, ?DATA_IDX), WorkerOfThisClient = lists:member(ToWorker, WorkersOfThisClient), From 32352feac47878e1c5fda69e848d53e3c45dd8bf Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 11:25:55 +0000 Subject: [PATCH 33/52] [W2W] WIP --- .../src/Bridge/onnWorkers/workerFederatedClient.erl | 4 ---- .../src/Bridge/onnWorkers/workerFederatedServer.erl | 5 +---- src_erl/NerlnetApp/src/Client/clientStatem.erl | 1 - 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 639ce06d..1e795ef2 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -94,7 +94,6 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source 1 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, true}), - io:format("FedWorker ~p sending start_stream to server ~p~n",[MyName, ServerName]), w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName); % Server gets FedWorkerName instead of SourceName _ -> ok end; @@ -149,7 +148,6 @@ pre_train({GenWorkerEts, _NerlTensorWeights}) -> if SyncCount == MaxSyncCount -> W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:sync_inbox_no_limit(W2WPid), % waiting for server to average the weights and send it - io:format("@~p Updated weights received from server~n",[ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX)]), InboxQueue = w2wCom:get_all_messages(W2WPid), [UpdateWeightsMsg] = queue:to_list(InboxQueue), {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, @@ -171,8 +169,6 @@ post_train({GenWorkerEts, _WorkerData}) -> SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), if SyncCount == MaxSyncCount -> - io:format("SyncCount = MaxSyncCount = ~p~n",[SyncCount]), - io:format("Worker ~p Sending weights to server~n",[MyName]), ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), WeightsTensor = nerlNIF:call_to_get_weights(ModelID), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 716ffaf1..9867cb13 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -78,7 +78,6 @@ start_stream({GenWorkerEts, WorkerData}) -> case length(CurrActiveWorkers) of CurrLengthFedClients -> ok; _Else -> - io:format("FedServer got start_stream from ~p~n",[WorkerName]), ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), UpdatedActiveWorkers = ActiveWorkers ++ [WorkerName], ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), @@ -98,6 +97,7 @@ end_stream({GenWorkerEts, WorkerData}) -> [WorkerName , _ModelPhase] = WorkerData, FedServerEts = get_this_server_ets(GenWorkerEts), CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + io:format("FedServer got end_stream , CurrActiveWorkers = ~p~n",[CurrActiveWorkers]), case CurrActiveWorkers of [] -> ok; % if there are no active workers, no need to do anything _Else -> @@ -151,7 +151,6 @@ post_idle({GenWorkerEts, _WorkerName}) -> w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake_done, MyToken}) end, lists:foreach(MsgFunc, MessagesList), - io:format("After handshake , FedClients = ~p~n",[ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)]), ets:update_element(GenWorkerEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}); true -> ok end. @@ -170,7 +169,6 @@ post_train({GenWorkerEts, WeightsTensor}) -> {WorkerWeights, _BinaryType} = WeightsTensor, TotalWorkersWeights = CurrWorkersWeightsList ++ [WorkerWeights], NumOfActiveWorkers = length(ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX)), - io:format("NumOfActiveWorkers = ~p~n",[NumOfActiveWorkers]), case length(TotalWorkersWeights) of NumOfActiveWorkers -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), @@ -183,7 +181,6 @@ post_train({GenWorkerEts, WeightsTensor}) -> Func = fun(FedClient) -> FedServerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - io:format("Sending updated weights to ~p~n",[FedClient]), w2wCom:send_message(W2WPid, FedServerName, FedClient, {update_weights, AvgWeightsNerlTensor}) end, WorkersList = ets:lookup_element(ThisEts, active_workers, ?ETS_KEYVAL_VAL_IDX), diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 5a6ed2db..3e939f13 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -464,7 +464,6 @@ create_encoded_stats_str(ListStatsEts) -> lists:flatten(lists:map(Func , ListStatsEts)). handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data) -> - io:format("~p sent w2w_msg to ~p~n",[FromWorker, ToWorker]), ClientStatsEts = get(client_stats_ets), WorkersOfThisClient = ets:lookup_element(EtsRef, workersNames, ?DATA_IDX), WorkerOfThisClient = lists:member(ToWorker, WorkersOfThisClient), From efe20f3e25740c7968e85fd135fb992c5b4d254f Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 14:18:22 +0000 Subject: [PATCH 34/52] [W2W] WIP --- .../src/Bridge/onnWorkers/workerFederatedClient.erl | 12 ++++++++---- .../src/Bridge/onnWorkers/workerFederatedServer.erl | 6 +++--- .../src/Bridge/onnWorkers/workerGeneric.erl | 5 ++++- src_erl/NerlnetApp/src/Client/clientStatem.erl | 9 +++++---- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 1e795ef2..ba618331 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -94,7 +94,8 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source 1 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, true}), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName); % Server gets FedWorkerName instead of SourceName + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName), % Server gets FedWorkerName instead of SourceName + io:format("~p sent start_stream to ~p~n",[MyName , ServerName]); _ -> ok end; predict -> ok @@ -111,7 +112,8 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), case length(CastingSources) of % Send to server an updater after got start_stream from the first source 0 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, false}), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName); + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName), + io:format("~p sent end_stream to ~p~n",[MyName , ServerName]); _ -> ok end; predict -> ok @@ -148,6 +150,7 @@ pre_train({GenWorkerEts, _NerlTensorWeights}) -> if SyncCount == MaxSyncCount -> W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:sync_inbox_no_limit(W2WPid), % waiting for server to average the weights and send it + io:format("~p done syncing inbox~n",[ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX)]), InboxQueue = w2wCom:get_all_messages(W2WPid), [UpdateWeightsMsg] = queue:to_list(InboxQueue), {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, @@ -163,7 +166,7 @@ post_train({GenWorkerEts, _WorkerData}) -> CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), % io:format("Worker ~p CastingSources ~p~n",[MyName, CastingSources]), case CastingSources of - [] -> ok; + [] -> io:format("~p done training...~n",[MyName]), ok; _ -> ThisEts = get_this_client_ets(GenWorkerEts), SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), @@ -173,7 +176,8 @@ post_train({GenWorkerEts, _WorkerData}) -> WeightsTensor = nerlNIF:call_to_get_weights(ModelID), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , post_train_update, WeightsTensor); + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , post_train_update, WeightsTensor), + io:format("~p sent post_train_update to ~p~n",[MyName , ServerName]); true -> ok end end. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 9867cb13..2328df46 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -97,7 +97,7 @@ end_stream({GenWorkerEts, WorkerData}) -> [WorkerName , _ModelPhase] = WorkerData, FedServerEts = get_this_server_ets(GenWorkerEts), CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - io:format("FedServer got end_stream , CurrActiveWorkers = ~p~n",[CurrActiveWorkers]), + io:format("FedServer got end_stream from ~p, CurrActiveWorkers = ~p~n",[WorkerName, CurrActiveWorkers]), case CurrActiveWorkers of [] -> ok; % if there are no active workers, no need to do anything _Else -> @@ -112,7 +112,7 @@ end_stream({GenWorkerEts, WorkerData}) -> % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); - _ -> io:format("ActiveWorkers = ~p~n",[UpdatedActiveWorkers]) + _ -> ok end end. @@ -172,7 +172,7 @@ post_train({GenWorkerEts, WeightsTensor}) -> case length(TotalWorkersWeights) of NumOfActiveWorkers -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - % io:format("Averaging model weights...~n"), + io:format("Averaging model weights...~n"), {CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID), FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), AllWorkersWeightsList = TotalWorkersWeights ++ [CurrentModelWeights], diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index a722cea1..b757c871 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -184,6 +184,7 @@ wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceNam UpdatedNextState = case NextState of end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), + io:format("@wait --> loss SENDING STATE CHANGE~n"), update_client_avilable_worker(MyName), idle; _ -> train @@ -210,6 +211,7 @@ wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehav case NextState of end_stream -> {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; _ -> update_client_avilable_worker(MyName), + io:format("@wait --> idle SENDING STATE CHANGE~n"), {next_state, idle, State#workerGeneric_state{nextState = idle}} end; @@ -270,7 +272,8 @@ train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = _ stream_handler(start_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; -train(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> +train(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("@worker ~p got end_stream from ~p~n",[MyName, SourceName]), stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 3e939f13..e60d9c09 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -270,6 +270,7 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + io:format("@client ~p got end stream from ~p~n",[WorkerName, SourceName]), ClientStatsEts = get(client_stats_ets), ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], @@ -292,10 +293,10 @@ training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX), % io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]), case WorkersDone of - true -> cast_message_to_workers(EtsRef, MessageToCast), - Workers = clientWorkersFunctions:get_workers_names(EtsRef), - ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), - {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; + true -> cast_message_to_workers(EtsRef, MessageToCast), + Workers = clientWorkersFunctions:get_workers_names(EtsRef), + ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), + {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; false -> gen_statem:cast(get(my_pid) , {idle}), {next_state, training, State#client_statem_state{etsRef = EtsRef}} end; From 7f05668f23ff7e8aa1b7854d3e1c821f5be9de7c Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 15:00:06 +0000 Subject: [PATCH 35/52] [W2W] Added worker_done --- .../onnWorkers/workerFederatedClient.erl | 4 +++- .../onnWorkers/workerFederatedServer.erl | 6 +++--- .../src/Bridge/onnWorkers/workerGeneric.erl | 6 ++---- .../NerlnetApp/src/Client/clientStatem.erl | 21 ++++++++++++------- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index ba618331..9b9ace4f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -102,7 +102,7 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of end. end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] - [_SourceName, State] = WorkerData, + [SourceName, State] = WorkerData, case State of train -> ThisEts = get_this_client_ets(GenWorkerEts), @@ -113,6 +113,8 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S case length(CastingSources) of % Send to server an updater after got start_stream from the first source 0 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, false}), w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName), + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + gen_statem:cast(ClientPid, {worker_done, {MyName, SourceName}}), io:format("~p sent end_stream to ~p~n",[MyName , ServerName]); _ -> ok end; diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 2328df46..de8d9818 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -97,9 +97,11 @@ end_stream({GenWorkerEts, WorkerData}) -> [WorkerName , _ModelPhase] = WorkerData, FedServerEts = get_this_server_ets(GenWorkerEts), CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), io:format("FedServer got end_stream from ~p, CurrActiveWorkers = ~p~n",[WorkerName, CurrActiveWorkers]), case CurrActiveWorkers of - [] -> ok; % if there are no active workers, no need to do anything + [] -> gen_statem:cast(ClientPid, {worker_done, {MyName, MyName}}); _Else -> ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), io:format("ActiveWorkers = ~p , got end stream from ~p removing it..~n",[ActiveWorkers, WorkerName]), @@ -107,8 +109,6 @@ end_stream({GenWorkerEts, WorkerData}) -> ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), case length(UpdatedActiveWorkers) of 0 -> io:format("GOT HEREEEE~n"), - ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index b757c871..11a2c9f3 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -183,10 +183,8 @@ wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceNam DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients UpdatedNextState = case NextState of - end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), - io:format("@wait --> loss SENDING STATE CHANGE~n"), - update_client_avilable_worker(MyName), - idle; + end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), + train; _ -> train end, {next_state, UpdatedNextState, State}; diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index e60d9c09..0cecf4d3 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -121,7 +121,7 @@ format_status(_Opt, [_PDict, _StateName, _State]) -> Status = some_term, Status. %% ==============STATES================= waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state{myName = MyName,waitforWorkers = WaitforWorkers,nextState = NextState, etsRef = _EtsRef}) -> - NewWaitforWorkers = WaitforWorkers--[WorkerName], + NewWaitforWorkers = WaitforWorkers -- [WorkerName], ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), @@ -268,22 +268,27 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; -training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> +training(cast, _In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), io:format("@client ~p got end stream from ~p~n",[WorkerName, SourceName]), ClientStatsEts = get(client_stats_ets), - ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), - UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], - ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), gen_statem:cast(WorkerPid, {end_stream, [SourceName]}), - case length(UpdatedListOfActiveWorkerSources) of - 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); - _ -> ok end, {keep_state, State}; +training(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {WorkerName, SourceName} = binary_to_term(Data), + io:format("Client got worker_done from ~p~n",[WorkerName]), + ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), + UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], + ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), + case length(UpdatedListOfActiveWorkerSources) of + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + _ -> ok + end, + {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef}}; training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), From 41cf5362e1a89e7dd22627d4be0c587ce91dbaa1 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 15:01:07 +0000 Subject: [PATCH 36/52] [W2W] Fixed var --- src_erl/NerlnetApp/src/Client/clientStatem.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 0cecf4d3..63183a43 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -268,7 +268,7 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; -training(cast, _In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> +training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), io:format("@client ~p got end stream from ~p~n",[WorkerName, SourceName]), ClientStatsEts = get(client_stats_ets), From da5b977f630b20018b516f44c389c9b96a922c49 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 15:04:05 +0000 Subject: [PATCH 37/52] [W2W] Fixed var --- src_erl/NerlnetApp/src/Client/clientStatem.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 63183a43..50b32ffc 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -279,7 +279,7 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E {keep_state, State}; training(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = EtsRef}) -> - {WorkerName, SourceName} = binary_to_term(Data), + {WorkerName, SourceName} = Data, io:format("Client got worker_done from ~p~n",[WorkerName]), ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], From 064970bba608e3a5438c1b700da5103154d4baf9 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 15:11:06 +0000 Subject: [PATCH 38/52] [W2W] Test --- src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 11a2c9f3..8efc03d9 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -181,6 +181,7 @@ wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceNam BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients + io:format("GOT HEREEEEEEEEEEEEEEEEEEEE~n"), UpdatedNextState = case NextState of end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), @@ -206,6 +207,7 @@ wait(cast, {end_stream , _Data}, State= #workerGeneric_state{myName = _MyName}) wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> %logger:notice("Waiting, next state - idle"), DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), + io:format("SHOULDNT BE HERE~n"), case NextState of end_stream -> {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; _ -> update_client_avilable_worker(MyName), From eef1d02ffb7a30a4438e578a46dc621c84659060 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 15:15:51 +0000 Subject: [PATCH 39/52] [W2W] Test --- .../NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl | 2 +- src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 9b9ace4f..046bb6b4 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -115,7 +115,7 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName), ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), gen_statem:cast(ClientPid, {worker_done, {MyName, SourceName}}), - io:format("~p sent end_stream to ~p~n",[MyName , ServerName]); + io:format("~p sent end_stream to ~p and client~n",[MyName , ServerName]); _ -> ok end; predict -> ok diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 8efc03d9..bd9a5a73 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -209,7 +209,8 @@ wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehav DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), io:format("SHOULDNT BE HERE~n"), case NextState of - end_stream -> {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; + end_stream -> io:format("@wait --> idle , NextState = end_stream~n"), + {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; _ -> update_client_avilable_worker(MyName), io:format("@wait --> idle SENDING STATE CHANGE~n"), {next_state, idle, State#workerGeneric_state{nextState = idle}} From 4aec2d8f8ba735734586105a42815693658b6925 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 15:19:00 +0000 Subject: [PATCH 40/52] [W2W] Test --- src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index bd9a5a73..a51ac043 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -273,7 +273,7 @@ train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = _ stream_handler(start_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; -train(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> +train(cast, {end_stream , [SourceName]}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> io:format("@worker ~p got end_stream from ~p~n",[MyName, SourceName]), stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; From b3845eceff59ce272593a82717bebbf77d29bbc1 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 15:38:13 +0000 Subject: [PATCH 41/52] [W2W] WIP --- src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 2 +- src_erl/NerlnetApp/src/Client/clientStatem.erl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index a51ac043..c37741f6 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -181,7 +181,6 @@ wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceNam BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients - io:format("GOT HEREEEEEEEEEEEEEEEEEEEE~n"), UpdatedNextState = case NextState of end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), @@ -202,6 +201,7 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S wait(cast, {end_stream , _Data}, State= #workerGeneric_state{myName = _MyName}) -> %logger:notice("Waiting, next state - idle"), + % ******** WERE MISSING THE SOURCE NAME HERE ******** {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 50b32ffc..fe263255 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -270,7 +270,7 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), - io:format("@client ~p got end stream from ~p~n",[WorkerName, SourceName]), + io:format("@client: Worker ~p got end_stream from ~p~n",[WorkerName, SourceName]), ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), From 70a357377f1a3f336323052ecca7fc34c17fda6b Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 20:43:20 +0000 Subject: [PATCH 42/52] [W2W] WIP --- .../onnWorkers/workerFederatedClient.erl | 30 ++++---- .../onnWorkers/workerFederatedServer.erl | 71 +++++++++---------- .../src/Bridge/onnWorkers/workerGeneric.erl | 59 ++++++--------- .../NerlnetApp/src/Client/clientStatem.erl | 24 ++++--- 4 files changed, 83 insertions(+), 101 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 046bb6b4..ec99896f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -60,7 +60,7 @@ init({GenWorkerEts, WorkerData}) -> ets:insert(FederatedClientEts, {handshake_done, false}), ets:insert(FederatedClientEts, {handshake_wait, false}), ets:insert(FederatedClientEts, {w2wcom_pid, W2WPid}), - ets:insert(FederatedClientEts, {casting_sources, []}), + ets:insert(FederatedClientEts, {active_streams, []}), ets:insert(FederatedClientEts, {stream_occuring, false}), spawn(fun() -> handshake(FederatedClientEts) end). @@ -91,31 +91,27 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), - case length(CastingSources) of % Send to server an updater after got start_stream from the first source - 1 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, true}), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName), % Server gets FedWorkerName instead of SourceName - io:format("~p sent start_stream to ~p~n",[MyName , ServerName]); + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source + 1 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName), % Server gets FedWorkerName instead of SourceName + io:format("~p sent START_stream to ~p~n",[MyName , ServerName]); _ -> ok end; predict -> ok end. end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] - [SourceName, State] = WorkerData, + [_SourceName, State] = WorkerData, case State of train -> ThisEts = get_this_client_ets(GenWorkerEts), MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), - case length(CastingSources) of % Send to server an updater after got start_stream from the first source - 0 -> ets:update_element(ThisEts, stream_occuring, {?ETS_KEYVAL_VAL_IDX, false}), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName), - ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - gen_statem:cast(ClientPid, {worker_done, {MyName, SourceName}}), - io:format("~p sent end_stream to ~p and client~n",[MyName , ServerName]); + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source + 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName), + io:format("~p sent END_stream to ~p~n",[MyName , ServerName]); _ -> ok end; predict -> ok @@ -165,9 +161,9 @@ pre_train({GenWorkerEts, _NerlTensorWeights}) -> %% every countLimit batches, send updated weights post_train({GenWorkerEts, _WorkerData}) -> MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), - CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), - % io:format("Worker ~p CastingSources ~p~n",[MyName, CastingSources]), - case CastingSources of + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + % io:format("Worker ~p ActiveStreams ~p~n",[MyName, ActiveStreams]), + case ActiveStreams of [] -> io:format("~p done training...~n",[MyName]), ok; _ -> ThisEts = get_this_client_ets(GenWorkerEts), diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index de8d9818..61d91f3e 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -70,51 +70,48 @@ init({GenWorkerEts, WorkerData}) -> put(fed_server_ets, FederatedServerEts). -start_stream({GenWorkerEts, WorkerData}) -> - [WorkerName , _ModelPhase] = WorkerData, +start_stream({GenWorkerEts, _WorkerData}) -> + % [_FedWorkerName , _ModelPhase] = WorkerData, FedServerEts = get_this_server_ets(GenWorkerEts), - CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - CurrLengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), + CurrActiveWorkers = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + NumOfFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), case length(CurrActiveWorkers) of - CurrLengthFedClients -> ok; + NumOfFedClients -> ok; _Else -> - ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - UpdatedActiveWorkers = ActiveWorkers ++ [WorkerName], - ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), - LengthFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), - case length(UpdatedActiveWorkers) of - LengthFedClients -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), - Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client - gen_server:cast(ClientPid, {start_stream, term_to_binary(Data)}); + case length(CurrActiveWorkers) of + 1 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client + gen_server:cast(ClientPid, {start_stream, term_to_binary(Data)}); _ -> ok end end. -end_stream({GenWorkerEts, WorkerData}) -> - [WorkerName , _ModelPhase] = WorkerData, - FedServerEts = get_this_server_ets(GenWorkerEts), - CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - io:format("FedServer got end_stream from ~p, CurrActiveWorkers = ~p~n",[WorkerName, CurrActiveWorkers]), - case CurrActiveWorkers of - [] -> gen_statem:cast(ClientPid, {worker_done, {MyName, MyName}}); - _Else -> - ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), - io:format("ActiveWorkers = ~p , got end stream from ~p removing it..~n",[ActiveWorkers, WorkerName]), - UpdatedActiveWorkers = ActiveWorkers -- [WorkerName], - ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), - case length(UpdatedActiveWorkers) of - 0 -> io:format("GOT HEREEEE~n"), - % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), - Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client - gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); - _ -> ok - end - end. +end_stream({_GenWorkerEts, _WorkerData}) -> ok. % All happens in GenWorker stream_handler + +% end_stream({GenWorkerEts, WorkerData}) -> +% [WorkerName , _ModelPhase] = WorkerData, +% FedServerEts = get_this_server_ets(GenWorkerEts), +% CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), +% ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), +% MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), +% io:format("FedServer got end_stream from ~p, CurrActiveWorkers = ~p~n",[WorkerName, CurrActiveWorkers]), +% case CurrActiveWorkers of +% [] -> gen_statem:cast(ClientPid, {worker_done, {MyName, MyName}}); +% _Else -> +% ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), +% io:format("ActiveWorkers = ~p , got end stream from ~p removing it..~n",[ActiveWorkers, WorkerName]), +% UpdatedActiveWorkers = ActiveWorkers -- [WorkerName], +% ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), +% case length(UpdatedActiveWorkers) of +% 0 -> io:format("GOT HEREEEE~n"), +% % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), +% Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client +% gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); +% _ -> ok +% end +% end. pre_idle({_GenWorkerEts, _WorkerName}) -> ok. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index c37741f6..32385399 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -74,7 +74,7 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData ets:insert(GenWorkerEts,{distributed_system_type, DistributedSystemType}), ets:insert(GenWorkerEts,{controller_message_q, []}), %% TODO Deprecated ets:insert(GenWorkerEts,{handshake_done, false}), - ets:insert(GenWorkerEts,{casting_sources, []}), + ets:insert(GenWorkerEts,{active_streams, []}), ets:insert(GenWorkerEts,{stream_occuring, false}), % Worker to Worker communication module - this is a gen_server @@ -167,27 +167,14 @@ wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneri stats:increment_by_value(get(worker_stats_ets), nan_loss_count, 1), gen_statem:cast(get(client_pid),{loss, MyName , SourceName ,nan , TrainTime ,BatchID}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients - UpdatedNextState = - case NextState of - end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), - update_client_avilable_worker(MyName), - idle; - _ -> train - end, - {next_state, UpdatedNextState, State}; + {next_state, NextState, State}; wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc}) -> BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients - UpdatedNextState = - case NextState of - end_stream -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), - train; - _ -> train - end, - {next_state, UpdatedNextState, State}; + {next_state, NextState, State}; wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> BatchTimeStamp = erlang:system_time(nanosecond), @@ -199,22 +186,18 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S {next_state, NextState, State} end; -wait(cast, {end_stream , _Data}, State= #workerGeneric_state{myName = _MyName}) -> +wait(cast, {end_stream , Data}, State= #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> %logger:notice("Waiting, next state - idle"), - % ******** WERE MISSING THE SOURCE NAME HERE ******** + stream_handler(end_stream, wait, Data, DistributedBehaviorFunc), {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; -wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> + +% CANNOT HAPPEN +wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> %logger:notice("Waiting, next state - idle"), DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), - io:format("SHOULDNT BE HERE~n"), - case NextState of - end_stream -> io:format("@wait --> idle , NextState = end_stream~n"), - {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; - _ -> update_client_avilable_worker(MyName), - io:format("@wait --> idle SENDING STATE CHANGE~n"), - {next_state, idle, State#workerGeneric_state{nextState = idle}} - end; + update_client_avilable_worker(MyName), + {next_state, idle, State#workerGeneric_state{nextState = idle}}; wait(cast, {training}, State) -> %logger:notice("Waiting, next state - train"), @@ -273,8 +256,7 @@ train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = _ stream_handler(start_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; -train(cast, {end_stream , [SourceName]}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("@worker ~p got end_stream from ~p~n",[MyName, SourceName]), +train(cast, {end_stream , [SourceName]}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), {next_state, train, State}; @@ -335,12 +317,17 @@ worker_controller_empty_message_queue() -> stream_handler(StreamPhase , ModelPhase , SourceName , DistributedBehaviorFunc) -> GenWorkerEts = get(generic_worker_ets), - ets:update_element(GenWorkerEts, stream_occuring , {?ETS_KEYVAL_VAL_IDX, true}), - CastingSources = ets:lookup_element(GenWorkerEts, casting_sources, ?ETS_KEYVAL_VAL_IDX), - NewCastingSources = + MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + NewActiveStreams = case StreamPhase of - start_stream -> CastingSources ++ [SourceName]; - end_stream -> CastingSources -- [SourceName] + start_stream -> ActiveStreams ++ [SourceName]; + end_stream -> ActiveStreams -- [SourceName] end, - ets:update_element(GenWorkerEts, casting_sources, {?ETS_KEYVAL_VAL_IDX, NewCastingSources}), - DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [SourceName , ModelPhase]}). \ No newline at end of file + ets:update_element(GenWorkerEts, active_streams, {?ETS_KEYVAL_VAL_IDX, NewActiveStreams}), + DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [SourceName , ModelPhase]}), + case length(NewActiveStreams) of % Send to client an update after done with training phase + 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + gen_statem:cast(ClientPid, {worker_done, {MyName, SourceName}}); + _ -> ok + end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index fe263255..9276a046 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -259,13 +259,14 @@ training(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef} training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), - ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), - ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, SourceName}]}), - ClientStatsEts = get(client_stats_ets), - stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), - gen_statem:cast(WorkerPid, {start_stream, SourceName}), + case SourceName of % Only Federated Servers send start_stream with FedServerName == SourceName + WorkerName -> ok; + _ -> ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), + gen_statem:cast(WorkerPid, {start_stream, SourceName}) + end, {keep_state, State}; training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> @@ -285,10 +286,11 @@ training(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), case length(UpdatedListOfActiveWorkerSources) of - 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); - _ -> ok - end, - {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef}}; + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}), + {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef}}; + _ -> {next_state, training, State#client_statem_state{etsRef = EtsRef}} + end; + training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), From c06cada2d29d52437223ecb24e398f22f31e638f Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 21:44:41 +0000 Subject: [PATCH 43/52] [W2W] WIP --- .../Bridge/onnWorkers/workerDefinitions.hrl | 3 +- .../onnWorkers/workerFederatedClient.erl | 15 +++--- .../onnWorkers/workerFederatedServer.erl | 25 +++------ .../src/Bridge/onnWorkers/workerGeneric.erl | 28 +++++----- .../NerlnetApp/src/Client/clientStatem.erl | 54 +++++++++++-------- 5 files changed, 63 insertions(+), 62 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerDefinitions.hrl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerDefinitions.hrl index 588ef3f3..966d54b3 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerDefinitions.hrl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerDefinitions.hrl @@ -2,4 +2,5 @@ -define(ETS_KEYVAL_VAL_IDX, 2). -define(TENSOR_DATA_IDX, 1). --record(workerGeneric_state, {myName , modelID , distributedBehaviorFunc , distributedWorkerData , currentBatchID , nextState , lastPhase}). +-record(workerGeneric_state, {myName , modelID , distributedBehaviorFunc , distributedWorkerData , currentBatchID , nextState , lastPhase, postBatchFunc}). +-define(EMPTY_FUNC, fun() -> ok end). \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index ec99896f..07d85805 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -84,8 +84,9 @@ handshake(FedClientEts) -> lists:foreach(Func, MessagesList). start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName, State] - [_SourceName, State] = WorkerData, - case State of + [_SourceName, ModelPhase] = WorkerData, + FirstMsg = 1, + case ModelPhase of train -> ThisEts = get_this_client_ets(GenWorkerEts), MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), @@ -93,16 +94,16 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source - 1 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName), % Server gets FedWorkerName instead of SourceName - io:format("~p sent START_stream to ~p~n",[MyName , ServerName]); + FirstMsg -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName), % Server gets FedWorkerName instead of SourceName + io:format("~p sent START_stream to ~p~n",[MyName , ServerName]); _ -> ok end; predict -> ok end. end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] - [_SourceName, State] = WorkerData, - case State of + [_SourceName, ModelPhase] = WorkerData, + case ModelPhase of train -> ThisEts = get_this_client_ets(GenWorkerEts), MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), @@ -110,7 +111,7 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source - 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName), + 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName), % Mimic source behavior io:format("~p sent END_stream to ~p~n",[MyName , ServerName]); _ -> ok end; diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 61d91f3e..fdb22a04 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -70,23 +70,12 @@ init({GenWorkerEts, WorkerData}) -> put(fed_server_ets, FederatedServerEts). -start_stream({GenWorkerEts, _WorkerData}) -> - % [_FedWorkerName , _ModelPhase] = WorkerData, +start_stream({GenWorkerEts, WorkerData}) -> + [FedWorkerName , _ModelPhase] = WorkerData, FedServerEts = get_this_server_ets(GenWorkerEts), - CurrActiveWorkers = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), - NumOfFedClients = length(ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX)), - case length(CurrActiveWorkers) of - NumOfFedClients -> ok; - _Else -> - case length(CurrActiveWorkers) of - 1 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), - Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client - gen_server:cast(ClientPid, {start_stream, term_to_binary(Data)}); - _ -> ok - end - end. - + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + gen_server:cast(ClientPid, {start_stream, {worker, MyName, FedWorkerName}}). end_stream({_GenWorkerEts, _WorkerData}) -> ok. % All happens in GenWorker stream_handler @@ -165,7 +154,7 @@ post_train({GenWorkerEts, WeightsTensor}) -> CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX), {WorkerWeights, _BinaryType} = WeightsTensor, TotalWorkersWeights = CurrWorkersWeightsList ++ [WorkerWeights], - NumOfActiveWorkers = length(ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX)), + NumOfActiveWorkers = length(ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX)), case length(TotalWorkersWeights) of NumOfActiveWorkers -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), @@ -180,7 +169,7 @@ post_train({GenWorkerEts, WeightsTensor}) -> W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:send_message(W2WPid, FedServerName, FedClient, {update_weights, AvgWeightsNerlTensor}) end, - WorkersList = ets:lookup_element(ThisEts, active_workers, ?ETS_KEYVAL_VAL_IDX), + WorkersList = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), lists:foreach(Func, WorkersList), ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, []}); _ -> ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, TotalWorkersWeights}) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 32385399..e5bff1e1 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -90,7 +90,7 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData exit(nif_failed_to_create) end, DistributedBehaviorFunc(pre_idle,{GenWorkerEts, DistributedWorkerData}), - {ok, idle, #workerGeneric_state{myName = WorkerName , modelID = ModelID , distributedBehaviorFunc = DistributedBehaviorFunc , distributedWorkerData = DistributedWorkerData}}. + {ok, idle, #workerGeneric_state{myName = WorkerName , modelID = ModelID , distributedBehaviorFunc = DistributedBehaviorFunc , distributedWorkerData = DistributedWorkerData, postBatchFunc = ?EMPTY_FUNC}}. %% @private %% @doc This function is called by a gen_statem when it needs to find out @@ -163,18 +163,20 @@ idle(cast, _Param, State = #workerGeneric_state{myName = _MyName}) -> %% Waiting for receiving results or loss function %% Got nan or inf from loss function - Error, loss function too big for double -wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc}) -> +wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, postBatchFunc = PostBatchFunc}) -> stats:increment_by_value(get(worker_stats_ets), nan_loss_count, 1), gen_statem:cast(get(client_pid),{loss, MyName , SourceName ,nan , TrainTime ,BatchID}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients - {next_state, NextState, State}; + PostBatchFunc(), + {next_state, NextState, State = #workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}}; -wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc}) -> +wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc, postBatchFunc = PostBatchFunc}) -> BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients - {next_state, NextState, State}; + PostBatchFunc(), + {next_state, NextState, State = #workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}}; wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> BatchTimeStamp = erlang:system_time(nanosecond), @@ -186,10 +188,10 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S {next_state, NextState, State} end; -wait(cast, {end_stream , Data}, State= #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> +wait(cast, {end_stream , Data}, State = #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> %logger:notice("Waiting, next state - idle"), - stream_handler(end_stream, wait, Data, DistributedBehaviorFunc), - {next_state, wait, State#workerGeneric_state{nextState = end_stream}}; + Func = fun() -> stream_handler(end_stream, wait, Data, DistributedBehaviorFunc) end, + {next_state, wait, State = #workerGeneric_state{postBatchFunc = Func}}; % CANNOT HAPPEN @@ -315,19 +317,19 @@ worker_controller_message_queue(ReceiveData) -> worker_controller_empty_message_queue() -> ets:update_element(get(generic_worker_ets), controller_message_q, {?ETS_KEYVAL_VAL_IDX , []}). -stream_handler(StreamPhase , ModelPhase , SourceName , DistributedBehaviorFunc) -> +stream_handler(StreamPhase , ModelPhase , StreamName , DistributedBehaviorFunc) -> GenWorkerEts = get(generic_worker_ets), MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), NewActiveStreams = case StreamPhase of - start_stream -> ActiveStreams ++ [SourceName]; - end_stream -> ActiveStreams -- [SourceName] + start_stream -> ActiveStreams ++ [StreamName]; + end_stream -> ActiveStreams -- [StreamName] end, ets:update_element(GenWorkerEts, active_streams, {?ETS_KEYVAL_VAL_IDX, NewActiveStreams}), - DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [SourceName , ModelPhase]}), + DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [StreamName , ModelPhase]}), case length(NewActiveStreams) of % Send to client an update after done with training phase 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - gen_statem:cast(ClientPid, {worker_done, {MyName, SourceName}}); + gen_statem:cast(ClientPid, {worker_done, {MyName, StreamName}}); _ -> ok end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 9276a046..ae579cd6 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -97,7 +97,7 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh NumOfFedServers = ets:lookup_element(EtsRef, num_of_fed_servers, ?DATA_IDX), % When non-federated exp this value is 0 ets:insert(EtsRef, {num_of_training_workers, length(ClientWorkers) - NumOfFedServers}), % This number will not change ets:insert(EtsRef, {training_workers, 0}), % will be updated in idle -> training & end_stream - ets:insert(EtsRef, {active_workers_sources_list, []}), + ets:insert(EtsRef, {active_workers_streams, []}), % update dictionary WorkersEts = ets:lookup_element(EtsRef , workers_ets , ?DATA_IDX), put(workers_ets, WorkersEts), @@ -150,6 +150,7 @@ waitforWorkers(cast, In = {NewState}, State = #client_statem_state{myName = _MyN cast_message_to_workers(EtsRef, {NewState}), %% This function increments the number of sent messages in stats ets {next_state, waitforWorkers, State#client_statem_state{nextState = NewState, waitforWorkers = Workers}}; + waitforWorkers(cast, EventContent, State = #client_statem_state{myName = MyName}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -257,16 +258,22 @@ training(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef} true -> ?LOG_ERROR("Given worker ~p isn't found in client ~p",[WorkerName, ClientName]) end, {next_state, training, State#client_statem_state{etsRef = EtsRef}}; +% This action is used for start_stream triggered from a clients' worker and not source +training(cast, {start_stream , {worker, WorkerName, TargetName}}, State = #client_statem_state{etsRef = EtsRef}) -> + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, TargetName}]}), + {keep_state, State}; + +% This action is used for start_stream triggered from a source per worker training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), - case SourceName of % Only Federated Servers send start_stream with FedServerName == SourceName - WorkerName -> ok; - _ -> ClientStatsEts = get(client_stats_ets), - stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), - gen_statem:cast(WorkerPid, {start_stream, SourceName}) - end, + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, SourceName}]}), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), + gen_statem:cast(WorkerPid, {start_stream, SourceName}), {keep_state, State}; training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> @@ -280,18 +287,19 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E {keep_state, State}; training(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = EtsRef}) -> - {WorkerName, SourceName} = Data, + {WorkerName, StreamName} = Data, io:format("Client got worker_done from ~p~n",[WorkerName]), - ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), - UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], - ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), + ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, StreamName}], + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), case length(UpdatedListOfActiveWorkerSources) of - 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}), - {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef}}; - _ -> {next_state, training, State#client_statem_state{etsRef = EtsRef}} - end; - + 0 -> io:format("All workers sent worker_done~n"), + ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + _ -> ok + end, + {next_state, training, State#client_statem_state{etsRef = EtsRef}}; +% From MainServer training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -304,7 +312,7 @@ training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef Workers = clientWorkersFunctions:get_workers_names(EtsRef), ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; - false -> gen_statem:cast(get(my_pid) , {idle}), + false -> gen_statem:cast(get(my_pid) , {idle}), % Trigger this action until all workers are done {next_state, training, State#client_statem_state{etsRef = EtsRef}} end; @@ -351,8 +359,8 @@ predict(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef}) predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), - ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), - ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, SourceName}]}), + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, SourceName}]}), ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), @@ -363,9 +371,9 @@ predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = predict(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), - ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_sources_list, ?DATA_IDX), + ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], - ets:update_element(EtsRef, active_workers_sources_list, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), From 55070ab621c6b351efb89097e1dfd583b8e1e538 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 21:49:47 +0000 Subject: [PATCH 44/52] [W2W] WIP --- src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index e5bff1e1..8956fca3 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -168,7 +168,7 @@ wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneri gen_statem:cast(get(client_pid),{loss, MyName , SourceName ,nan , TrainTime ,BatchID}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients PostBatchFunc(), - {next_state, NextState, State = #workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}}; + {next_state, NextState, State#workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}}; wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc, postBatchFunc = PostBatchFunc}) -> @@ -176,7 +176,7 @@ wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceNam gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients PostBatchFunc(), - {next_state, NextState, State = #workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}}; + {next_state, NextState, State#workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}}; wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> BatchTimeStamp = erlang:system_time(nanosecond), @@ -191,7 +191,7 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S wait(cast, {end_stream , Data}, State = #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> %logger:notice("Waiting, next state - idle"), Func = fun() -> stream_handler(end_stream, wait, Data, DistributedBehaviorFunc) end, - {next_state, wait, State = #workerGeneric_state{postBatchFunc = Func}}; + {next_state, wait, State#workerGeneric_state{postBatchFunc = Func}}; % CANNOT HAPPEN From 269225f4c98c6e5ce30cd1ea68cc7159d6c9d703 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 21:57:50 +0000 Subject: [PATCH 45/52] [W2W] WIP --- .../NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 8956fca3..293241fa 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -254,12 +254,12 @@ train(cast, {post_train_update , Weights}, State = #workerGeneric_state{myName = DistributedBehaviorFunc(post_train, {get(generic_worker_ets), Weights}), {next_state, train, State}; -train(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - stream_handler(start_stream, train, SourceName, DistributedBehaviorFunc), +train(cast, {start_stream , StreamName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + stream_handler(start_stream, train, StreamName, DistributedBehaviorFunc), {next_state, train, State}; -train(cast, {end_stream , [SourceName]}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - stream_handler(end_stream, train, SourceName, DistributedBehaviorFunc), +train(cast, {end_stream , StreamName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + stream_handler(end_stream, train, StreamName, DistributedBehaviorFunc), {next_state, train, State}; train(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> From f8270e3e4551e715a334ab3b799f673378214f46 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 22:00:44 +0000 Subject: [PATCH 46/52] [W2W] WIP --- src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 293241fa..abfda097 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -258,7 +258,8 @@ train(cast, {start_stream , StreamName}, State = #workerGeneric_state{myName = _ stream_handler(start_stream, train, StreamName, DistributedBehaviorFunc), {next_state, train, State}; -train(cast, {end_stream , StreamName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> +train(cast, {end_stream , StreamName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + io:format("Worker ~p got end stream from ~p~n",[MyName, StreamName]), stream_handler(end_stream, train, StreamName, DistributedBehaviorFunc), {next_state, train, State}; From 2a04137683b3bed916e0eba24db1a9c28b6a5e04 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 22:01:50 +0000 Subject: [PATCH 47/52] [W2W] WIP --- src_erl/NerlnetApp/src/Client/clientStatem.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index ae579cd6..9e2a0528 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -283,7 +283,7 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), - gen_statem:cast(WorkerPid, {end_stream, [SourceName]}), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), {keep_state, State}; training(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = EtsRef}) -> From f4ae9f017cec596cd7dc3aff14b7311c6d5d19f8 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 22:27:10 +0000 Subject: [PATCH 48/52] [W2W] WIP --- .../onnWorkers/workerFederatedServer.erl | 35 ++++++------------- .../src/Bridge/onnWorkers/workerGeneric.erl | 29 +++++++-------- .../NerlnetApp/src/Client/clientStatem.erl | 1 + 3 files changed, 25 insertions(+), 40 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index fdb22a04..53512dde 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -77,30 +77,17 @@ start_stream({GenWorkerEts, WorkerData}) -> MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), gen_server:cast(ClientPid, {start_stream, {worker, MyName, FedWorkerName}}). -end_stream({_GenWorkerEts, _WorkerData}) -> ok. % All happens in GenWorker stream_handler - -% end_stream({GenWorkerEts, WorkerData}) -> -% [WorkerName , _ModelPhase] = WorkerData, -% FedServerEts = get_this_server_ets(GenWorkerEts), -% CurrActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), -% ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), -% MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), -% io:format("FedServer got end_stream from ~p, CurrActiveWorkers = ~p~n",[WorkerName, CurrActiveWorkers]), -% case CurrActiveWorkers of -% [] -> gen_statem:cast(ClientPid, {worker_done, {MyName, MyName}}); -% _Else -> -% ActiveWorkers = ets:lookup_element(FedServerEts, active_workers, ?ETS_KEYVAL_VAL_IDX), -% io:format("ActiveWorkers = ~p , got end stream from ~p removing it..~n",[ActiveWorkers, WorkerName]), -% UpdatedActiveWorkers = ActiveWorkers -- [WorkerName], -% ets:update_element(FedServerEts, active_workers, {?ETS_KEYVAL_VAL_IDX, UpdatedActiveWorkers}), -% case length(UpdatedActiveWorkers) of -% 0 -> io:format("GOT HEREEEE~n"), -% % ClientName = ets:lookup_element(GenWorkerEts, client_name, ?ETS_KEYVAL_VAL_IDX), -% Data = {MyName, MyName, MyName}, % Mimic source behavior to register as an active worker for the client -% gen_server:cast(ClientPid, {end_stream, term_to_binary(Data)}); -% _ -> ok -% end -% end. +end_stream({GenWorkerEts, WorkerData}) -> % Federated server takes the control of popping the stream from the active streams list + [FedWorkerName , _ModelPhase] = WorkerData, + FedServerEts = get_this_server_ets(GenWorkerEts), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + gen_statem:cast(ClientPid, {worker_done, {MyName, FedWorkerName}}), + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + case ActiveStreams of + [] -> ets:update_element(FedServerEts, active_streams, {?ETS_KEYVAL_VAL_IDX, none}); + _ -> ok + end. pre_idle({_GenWorkerEts, _WorkerName}) -> ok. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index abfda097..31171c7a 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -145,7 +145,7 @@ code_change(_OldVsn, StateName, State = #workerGeneric_state{}, _Extra) -> % Go from idle to train idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> % io:format("@idle got training , Worker ~p is going to state idle...~n",[MyName]), - worker_controller_empty_message_queue(), + ets:update_element(get(generic_worker_ets), active_streams, {?ETS_KEYVAL_VAL_IDX, []}), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), train}), update_client_avilable_worker(MyName), {next_state, train, State#workerGeneric_state{lastPhase = train}}; @@ -153,6 +153,7 @@ idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distribute % Go from idle to predict idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> % worker_controller_empty_message_queue(), + ets:update_element(get(generic_worker_ets), active_streams, {?ETS_KEYVAL_VAL_IDX, []}), update_client_avilable_worker(MyName), DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), {next_state, predict, State#workerGeneric_state{lastPhase = predict}}; @@ -222,7 +223,8 @@ wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase, myNa wait(cast, Data, State) -> % logger:notice("worker ~p in wait cant treat message: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), - worker_controller_message_queue(Data), + ?LOG_ERROR("Got unknown message in wait state: ~p~n",[Data]), + throw("Got unknown message in wait state"), {keep_state, State}. @@ -268,10 +270,10 @@ train(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBe DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), {next_state, idle, State}; -train(cast, Data, State = #workerGeneric_state{myName = MyName}) -> +train(cast, Data, State = #workerGeneric_state{myName = _MyName}) -> % logger:notice("worker ~p in wait cant treat message: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), - io:format("~p Got unknown message in train state: ~p~n",[MyName , Data]), - worker_controller_message_queue(Data), + ?LOG_ERROR("Got unknown message in train state: ~p~n",[Data]), + throw("Got unknown message in train state"), {keep_state, State}. %% State predict @@ -304,20 +306,14 @@ predict(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributed {next_state, idle, State}; predict(cast, Data, State) -> - worker_controller_message_queue(Data), + ?LOG_ERROR("Got unknown message in predict state: ~p~n",[Data]), + throw("Got unknown message in predict state"), {next_state, predict, State}. %% Updates the client that worker is available update_client_avilable_worker(MyName) -> gen_statem:cast(get(client_pid),{stateChange,MyName}). -worker_controller_message_queue(ReceiveData) -> - Queue = ets:lookup_element(get(generic_worker_ets), controller_message_q, ?ETS_KEYVAL_VAL_IDX), - ets:update_element(get(generic_worker_ets), controller_message_q, {?ETS_KEYVAL_VAL_IDX , Queue++[ReceiveData]}). - -worker_controller_empty_message_queue() -> - ets:update_element(get(generic_worker_ets), controller_message_q, {?ETS_KEYVAL_VAL_IDX , []}). - stream_handler(StreamPhase , ModelPhase , StreamName , DistributedBehaviorFunc) -> GenWorkerEts = get(generic_worker_ets), MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), @@ -329,8 +325,9 @@ stream_handler(StreamPhase , ModelPhase , StreamName , DistributedBehaviorFunc) end, ets:update_element(GenWorkerEts, active_streams, {?ETS_KEYVAL_VAL_IDX, NewActiveStreams}), DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [StreamName , ModelPhase]}), - case length(NewActiveStreams) of % Send to client an update after done with training phase - 0 -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + UpdatedActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + case UpdatedActiveStreams of % Send to client an update after done with training phase + [] -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), gen_statem:cast(ClientPid, {worker_done, {MyName, StreamName}}); - _ -> ok + _ -> ok end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 9e2a0528..1b7932b1 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -291,6 +291,7 @@ training(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = io:format("Client got worker_done from ~p~n",[WorkerName]), ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, StreamName}], + io:format("~p Sent worker_done with ~p, UpdatedListOfActiveWorkerSources = ~p~n",[WorkerName, StreamName, UpdatedListOfActiveWorkerSources]), ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), case length(UpdatedListOfActiveWorkerSources) of 0 -> io:format("All workers sent worker_done~n"), From 40b1dfa30b9351cbc15901c5460712b2d98bb602 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 22:36:03 +0000 Subject: [PATCH 49/52] [W2W] Test --- .../onnWorkers/workerFederatedClient.erl | 12 ++-- .../onnWorkers/workerFederatedServer.erl | 2 +- .../src/Bridge/onnWorkers/workerGeneric.erl | 3 +- .../NerlnetApp/src/Client/clientStatem.erl | 57 ++++++++++++------- 4 files changed, 43 insertions(+), 31 deletions(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 07d85805..22862ecf 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -94,8 +94,7 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source - FirstMsg -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName), % Server gets FedWorkerName instead of SourceName - io:format("~p sent START_stream to ~p~n",[MyName , ServerName]); + FirstMsg -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName); % Server gets FedWorkerName instead of SourceName _ -> ok end; predict -> ok @@ -111,8 +110,7 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source - 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName), % Mimic source behavior - io:format("~p sent END_stream to ~p~n",[MyName , ServerName]); + 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName); % Mimic source behavior _ -> ok end; predict -> ok @@ -149,7 +147,6 @@ pre_train({GenWorkerEts, _NerlTensorWeights}) -> if SyncCount == MaxSyncCount -> W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), w2wCom:sync_inbox_no_limit(W2WPid), % waiting for server to average the weights and send it - io:format("~p done syncing inbox~n",[ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX)]), InboxQueue = w2wCom:get_all_messages(W2WPid), [UpdateWeightsMsg] = queue:to_list(InboxQueue), {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, @@ -165,7 +162,7 @@ post_train({GenWorkerEts, _WorkerData}) -> ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), % io:format("Worker ~p ActiveStreams ~p~n",[MyName, ActiveStreams]), case ActiveStreams of - [] -> io:format("~p done training...~n",[MyName]), ok; + [] -> ok; _ -> ThisEts = get_this_client_ets(GenWorkerEts), SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), @@ -175,8 +172,7 @@ post_train({GenWorkerEts, _WorkerData}) -> WeightsTensor = nerlNIF:call_to_get_weights(ModelID), ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), - w2wCom:send_message_with_event(W2WPid, MyName, ServerName , post_train_update, WeightsTensor), - io:format("~p sent post_train_update to ~p~n",[MyName , ServerName]); + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , post_train_update, WeightsTensor); true -> ok end end. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 53512dde..86b9fba3 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -145,7 +145,7 @@ post_train({GenWorkerEts, WeightsTensor}) -> case length(TotalWorkersWeights) of NumOfActiveWorkers -> ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - io:format("Averaging model weights...~n"), + % io:format("Averaging model weights...~n"), {CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID), FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), AllWorkersWeightsList = TotalWorkersWeights ++ [CurrentModelWeights], diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index 31171c7a..0fd5ee9f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -260,8 +260,7 @@ train(cast, {start_stream , StreamName}, State = #workerGeneric_state{myName = _ stream_handler(start_stream, train, StreamName, DistributedBehaviorFunc), {next_state, train, State}; -train(cast, {end_stream , StreamName}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> - io:format("Worker ~p got end stream from ~p~n",[MyName, StreamName]), +train(cast, {end_stream , StreamName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> stream_handler(end_stream, train, StreamName, DistributedBehaviorFunc), {next_state, train, State}; diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 1b7932b1..6af840df 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -278,7 +278,6 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), - io:format("@client: Worker ~p got end_stream from ~p~n",[WorkerName, SourceName]), ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), @@ -288,14 +287,11 @@ training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = E training(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = EtsRef}) -> {WorkerName, StreamName} = Data, - io:format("Client got worker_done from ~p~n",[WorkerName]), ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, StreamName}], - io:format("~p Sent worker_done with ~p, UpdatedListOfActiveWorkerSources = ~p~n",[WorkerName, StreamName, UpdatedListOfActiveWorkerSources]), ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), case length(UpdatedListOfActiveWorkerSources) of - 0 -> io:format("All workers sent worker_done~n"), - ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); _ -> ok end, {next_state, training, State#client_statem_state{etsRef = EtsRef}}; @@ -314,7 +310,7 @@ training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; false -> gen_statem:cast(get(my_pid) , {idle}), % Trigger this action until all workers are done - {next_state, training, State#client_statem_state{etsRef = EtsRef}} + {keep_state, State} end; training(cast, _In = {predict}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> @@ -358,6 +354,13 @@ predict(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef}) end, {next_state, predict, State#client_statem_state{etsRef = EtsRef}}; +% This action is used for start_stream triggered from a clients' worker and not source +predict(cast, {start_stream , {worker, WorkerName, TargetName}}, State = #client_statem_state{etsRef = EtsRef}) -> + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, TargetName}]}), + {keep_state, State}; + +% This action is used for start_stream triggered from a source per worker predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), @@ -372,18 +375,40 @@ predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = predict(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> {SourceName, _ClientName, WorkerName} = binary_to_term(Data), ClientStatsEts = get(client_stats_ets), - ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), - UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, SourceName}], - ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), - gen_statem:cast(WorkerPid, {end_stream, SourceName}), % WHY THIS IS NOT WORKING???? + gen_statem:cast(WorkerPid, {end_stream, SourceName}), + {keep_state, State}; + +predict(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {WorkerName, StreamName} = Data, + ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, StreamName}], + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), case length(UpdatedListOfActiveWorkerSources) of - 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); - _ -> ok end, + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + _ -> ok + end, {keep_state, State}; +% From MainServer +predict(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + MessageToCast = {idle}, + WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX), + % io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]), + case WorkersDone of + true -> cast_message_to_workers(EtsRef, MessageToCast), + Workers = clientWorkersFunctions:get_workers_names(EtsRef), + ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), + {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; + false -> gen_statem:cast(get(my_pid) , {idle}), % Trigger this action until all workers are done + {keep_state, State} + end; + predict(cast, In = {predictRes,WorkerName, SourceName ,{PredictNerlTensor, NetlTensorType} , TimeTook , BatchID , BatchTS}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -413,14 +438,6 @@ predict(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = # %% The source sends message to main server that it has finished %% The main server updates its' clients to move to state 'idle' -predict(cast, In = {idle}, State = #client_statem_state{etsRef = EtsRef , myName = _MyName}) -> - MsgToCast = {idle}, - ClientStatsEts = get(client_stats_ets), - stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - cast_message_to_workers(EtsRef, MsgToCast), - Workers = clientWorkersFunctions:get_workers_names(EtsRef), - {next_state, waitforWorkers, State#client_statem_state{nextState = idle, waitforWorkers = Workers, etsRef = EtsRef}}; predict(cast, EventContent, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), From f81736843075d661bd078567f78900c89f98af1c Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 22:46:29 +0000 Subject: [PATCH 50/52] [W2W] Fixed end_stream while state wait --- .../NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 22862ecf..5e5c185f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -113,7 +113,8 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName); % Mimic source behavior _ -> ok end; - predict -> ok + predict -> ok; + wait -> ok end. From a5806380f6263eaa93496fc576cfd7465271004a Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 23:14:23 +0000 Subject: [PATCH 51/52] [W2W] Done --- .../dc_dist_2d_3c_2s_3r_6w.json | 143 ++++++++++++++++++ .../dc_fed_dist_2d_3c_2s_3r_6w.json | 24 +-- .../exp_dist_2d_3c_2s_3r_6w.json | 54 +++++++ .../NerlnetApp/src/Client/clientStatem.erl | 1 - 4 files changed, 209 insertions(+), 13 deletions(-) create mode 100644 inputJsonsFiles/DistributedConfig/dc_dist_2d_3c_2s_3r_6w.json create mode 100644 inputJsonsFiles/experimentsFlow/exp_dist_2d_3c_2s_3r_6w.json diff --git a/inputJsonsFiles/DistributedConfig/dc_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/DistributedConfig/dc_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..958ea78f --- /dev/null +++ b/inputJsonsFiles/DistributedConfig/dc_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,143 @@ +{ + "nerlnetSettings": { + "frequency": "100", + "batchSize": "100" + }, + "mainServer": { + "port": "8900", + "args": "" + }, + "apiServer": { + "port": "8901", + "args": "" + }, + "devices": [ + { + "name": "c0vm0", + "ipv4": "10.0.0.5", + "entities": "mainServer,c1,c2,r1,r2,s1,apiServer" + }, + { + "name": "c0vm1", + "ipv4": "10.0.0.4", + "entities": "c3,r3,s2" + } + ], + "routers": [ + { + "name": "r1", + "port": "8905", + "policy": "0" + }, + { + "name": "r2", + "port": "8906", + "policy": "0" + }, + { + "name": "r3", + "port": "8901", + "policy": "0" + } + ], + "sources": [ + { + "name": "s1", + "port": "8904", + "frequency": "200", + "policy": "0", + "epochs": "1", + "type": "0" + }, + { + "name": "s2", + "port": "8902", + "frequency": "200", + "policy": "0", + "epochs": "1", + "type": "0" + } + ], + "clients": [ + { + "name": "c1", + "port": "8902", + "workers": "w1,w2,ws" + }, + { + "name": "c2", + "port": "8903", + "workers": "w3,w4" + }, + { + "name": "c3", + "port": "8900", + "workers": "w5,w6" + } + ], + "workers": [ + { + "name": "w1", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w2", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "ws", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w3", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w4", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w5", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w6", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + } + ], + "model_sha": { + "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,6,6,4,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,3", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", + "layers_functions": "1,8,8,8,11", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.001", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "0", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "SyncMaxCount=10", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + } + } +} \ No newline at end of file diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json index 0882bc08..bf8bc7b9 100644 --- a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json +++ b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json @@ -18,8 +18,8 @@ "entities": "mainServer,c1,c2,r1,r2,s1,apiServer" }, { - "name": "c0vm7", - "ipv4": "10.0.0.12", + "name": "c0vm1", + "ipv4": "10.0.0.4", "entities": "c3,r3,s2" } ], @@ -45,7 +45,7 @@ "name": "s1", "port": "8904", "frequency": "200", - "policy": "0", + "policy": "1", "epochs": "1", "type": "0" }, @@ -53,7 +53,7 @@ "name": "s2", "port": "8902", "frequency": "200", - "policy": "0", + "policy": "1", "epochs": "1", "type": "0" } @@ -113,14 +113,14 @@ "_doc_modelArgs": "Extra arguments to model", "layersSizes": "5,16,8,4,3", "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", - "layerTypesList": "1,3,3,3,5", + "layerTypesList": "1,3,3,3,3", "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", - "layers_functions": "1,6,6,11,4", + "layers_functions": "1,6,6,11,11", "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", - "lossMethod": "6", + "lossMethod": "2", "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", "lr": "0.01", "_doc_lr": "Positve float", @@ -134,7 +134,7 @@ "_doc_infraType": " opennn:0 | wolfengine:1 |", "distributedSystemType": "1", "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", - "distributedSystemArgs": "SyncMaxCount=50", + "distributedSystemArgs": "SyncMaxCount=10", "_doc_distributedSystemArgs": "String", "distributedSystemToken": "9922u", "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" @@ -146,14 +146,14 @@ "_doc_modelArgs": "Extra arguments to model", "layersSizes": "5,16,8,4,3", "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", - "layerTypesList": "1,3,3,3,5", + "layerTypesList": "1,3,3,3,3", "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", - "layers_functions": "1,6,6,11,4", + "layers_functions": "1,6,6,11,11", "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", - "lossMethod": "6", + "lossMethod": "2", "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", "lr": "0.01", "_doc_lr": "Positve float", @@ -167,7 +167,7 @@ "_doc_infraType": " opennn:0 | wolfengine:1 |", "distributedSystemType": "2", "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", - "distributedSystemArgs": "SyncMaxCount=50", + "distributedSystemArgs": "SyncMaxCount=10", "_doc_distributedSystemArgs": "String", "distributedSystemToken": "9922u", "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" diff --git a/inputJsonsFiles/experimentsFlow/exp_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/experimentsFlow/exp_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..b96649d5 --- /dev/null +++ b/inputJsonsFiles/experimentsFlow/exp_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,54 @@ +{ + "experimentName": "synthetic_3_gausians", + "experimentType": "classification", + "batchSize": 100, + "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", + "numOfFeatures": "5", + "numOfLabels": "3", + "headersNames": "Norm(0:1),Norm(4:1),Norm(10:3)", + "Phases": + [ + { + "phaseName": "training_phase", + "phaseType": "training", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "0", + "numOfBatches": "250", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + }, + { + "sourceName": "s2", + "startingSample": "25000", + "numOfBatches": "250", + "workers": "w5,w6", + "nerltensorType": "float" + } + ] + }, + { + "phaseName": "prediction_phase", + "phaseType": "prediction", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "50000", + "numOfBatches": "500", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + }, + { + "sourceName": "s2", + "startingSample": "50000", + "numOfBatches": "500", + "workers": "w5,w6", + "nerltensorType": "float" + } + ] + } +] +} \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 6af840df..476fdd1e 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -128,7 +128,6 @@ waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state case NewWaitforWorkers of % TODO Guy here we need to check for keep alive with workers [] -> send_client_is_ready(MyName), % when all workers done their work stats:increment_messages_sent(ClientStatsEts), - io:format("Client ~p is ready~n", [MyName]), {next_state, NextState, State#client_statem_state{waitforWorkers = []}}; _ -> %io:format("Client ~p is waiting for workers ~p~n",[MyName,NewWaitforWorkers]), {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = NewWaitforWorkers}} From 35cc41eee0b176a0acd20f54e0c6c677ba2163c9 Mon Sep 17 00:00:00 2001 From: GuyPerets106 Date: Wed, 22 May 2024 23:21:22 +0000 Subject: [PATCH 52/52] [W2W] Done --- .../dc_fed_dist_2d_3c_2s_3r_6w.json | 12 ++++++------ src_erl/NerlnetApp/src/nerlnetApp_app.erl | 2 +- src_py/nerlPlanner/Definitions.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json index bf8bc7b9..ffa02d36 100644 --- a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json +++ b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json @@ -111,18 +111,18 @@ "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", "modelArgs": "", "_doc_modelArgs": "Extra arguments to model", - "layersSizes": "5,16,8,4,3", + "layersSizes": "5,6,6,4,3", "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", "layerTypesList": "1,3,3,3,3", "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", - "layers_functions": "1,6,6,11,11", + "layers_functions": "1,8,8,8,11", "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", "lossMethod": "2", "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", - "lr": "0.01", + "lr": "0.001", "_doc_lr": "Positve float", "epochs": "1", "_doc_epochs": "Positve Integer", @@ -144,18 +144,18 @@ "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", "modelArgs": "", "_doc_modelArgs": "Extra arguments to model", - "layersSizes": "5,16,8,4,3", + "layersSizes": "5,6,6,4,3", "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", "layerTypesList": "1,3,3,3,3", "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", - "layers_functions": "1,6,6,11,11", + "layers_functions": "1,8,8,8,11", "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", "lossMethod": "2", "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", - "lr": "0.01", + "lr": "0.001", "_doc_lr": "Positve float", "epochs": "1", "_doc_epochs": "Positve Integer", diff --git a/src_erl/NerlnetApp/src/nerlnetApp_app.erl b/src_erl/NerlnetApp/src/nerlnetApp_app.erl index 8225aabb..2f228d78 100644 --- a/src_erl/NerlnetApp/src/nerlnetApp_app.erl +++ b/src_erl/NerlnetApp/src/nerlnetApp_app.erl @@ -20,7 +20,7 @@ -behaviour(application). -include("nerl_tools.hrl"). --define(NERLNET_APP_VERSION, "1.4.3"). +-define(NERLNET_APP_VERSION, "1.5.0"). -define(NERLPLANNER_TESTED_VERSION,"1.0.2"). -export([start/2, stop/1]). diff --git a/src_py/nerlPlanner/Definitions.py b/src_py/nerlPlanner/Definitions.py index 11cc21e0..d70512d3 100644 --- a/src_py/nerlPlanner/Definitions.py +++ b/src_py/nerlPlanner/Definitions.py @@ -2,7 +2,7 @@ from logger import * VERSION = "1.0.2" -NERLNET_VERSION_TESTED_WITH = "1.4.2" +NERLNET_VERSION_TESTED_WITH = "1.5.0" NERLNET_TMP_PATH = "/tmp/nerlnet" NERLNET_GRAPHVIZ_OUTPUT_DIR = f"{NERLNET_TMP_PATH}/nerlplanner" NERLNET_GLOBAL_PATH = "/usr/local/lib/nerlnet-lib/NErlNet"