Skip to content

Commit

Permalink
[ModelDB] Add distributed token (#349)
Browse files Browse the repository at this point in the history
- Add distributed token to NerlnetAPP and ApiServer
- remove redundant import from nerl_model_db
  • Loading branch information
leondavi authored Jun 8, 2024
1 parent bd9b9b5 commit 1eb8229
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 26 deletions.
16 changes: 11 additions & 5 deletions src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ start_link(ARGS) ->
%% distributedBehaviorFunc is the special behavior of the worker regrading the distributed system e.g. federated client/server
init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , ClientPid , WorkerStatsEts , W2WPid}) ->
nerl_tools:setup_logger(?MODULE),
{ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctionalityCodes, LearningRate , Epochs,
OptimizerType, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs} = WorkerArgs,
{ModelID , ModelType , ModelArgs , LayersSizes,
LayersTypes, LayersFunctionalityCodes, LearningRate , Epochs,
OptimizerType, OptimizerArgs , LossMethod , DistributedSystemType ,
DistributedSystemToken, DistributedSystemArgs} = WorkerArgs,
GenWorkerEts = ets:new(generic_worker,[set, public]),
put(generic_worker_ets, GenWorkerEts),
put(client_pid, ClientPid),
Expand All @@ -71,6 +73,7 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData
ets:insert(GenWorkerEts,{optimizer, OptimizerType}),
ets:insert(GenWorkerEts,{optimizer_args, OptimizerArgs}),
ets:insert(GenWorkerEts,{distributed_system_args, DistributedSystemArgs}),
ets:insert(GenWorkerEts,{distributed_system_token, DistributedSystemToken}),
ets:insert(GenWorkerEts,{distributed_system_type, DistributedSystemType}),
ets:insert(GenWorkerEts,{controller_message_q, []}), %% TODO Deprecated
ets:insert(GenWorkerEts,{handshake_done, false}),
Expand Down Expand Up @@ -166,22 +169,25 @@ idle(cast, _Param, State = #workerGeneric_state{myName = _MyName}) ->
%% Got nan or inf from loss function - Error, loss function too big for double
wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, postBatchFunc = PostBatchFunc}) ->
stats:increment_by_value(get(worker_stats_ets), nan_loss_count, 1),
gen_statem:cast(get(client_pid),{loss, MyName , SourceName ,nan , TrainTime ,BatchID}),
WorkerToken = ets:lookup_element(get(generic_worker_ets), distributed_system_token, ?ETS_KEYVAL_VAL_IDX),
gen_statem:cast(get(client_pid),{loss, MyName , SourceName ,nan , TrainTime, WorkerToken ,BatchID}),
DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients
PostBatchFunc(),
{next_state, NextState, State#workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}};


wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc, postBatchFunc = PostBatchFunc}) ->
BatchTimeStamp = erlang:system_time(nanosecond),
gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}),
WorkerToken = ets:lookup_element(get(generic_worker_ets), distributed_system_token, ?ETS_KEYVAL_VAL_IDX),
gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , WorkerToken, BatchID , BatchTimeStamp}),
DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients
PostBatchFunc(),
{next_state, NextState, State#workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}};

wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) ->
BatchTimeStamp = erlang:system_time(nanosecond),
gen_statem:cast(get(client_pid),{predictRes,MyName, SourceName, {PredNerlTensor, PredNerlTensorType}, TimeNif , BatchID , BatchTimeStamp}),
WorkerToken = ets:lookup_element(get(generic_worker_ets), distributed_system_token, ?ETS_KEYVAL_VAL_IDX),
gen_statem:cast(get(client_pid),{predictRes,MyName, SourceName, {PredNerlTensor, PredNerlTensorType}, TimeNif , WorkerToken, BatchID , BatchTimeStamp}),
Update = DistributedBehaviorFunc(post_predict, {get(generic_worker_ets),DistributedWorkerData}),
if Update ->
{next_state, update, State#workerGeneric_state{nextState=NextState}};
Expand Down
8 changes: 4 additions & 4 deletions src_erl/NerlnetApp/src/Client/clientStatem.erl
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,12 @@ training(cast, _In = {predict}, State = #client_statem_state{myName = MyName, et
{next_state, training, State#client_statem_state{etsRef = EtsRef}};


training(cast, In = {loss, WorkerName ,SourceName ,LossTensor ,TimeNIF ,BatchID ,BatchTS}, State = #client_statem_state{myName = MyName,etsRef = EtsRef}) ->
training(cast, In = {loss, WorkerName ,SourceName ,LossTensor ,TimeNIF , WorkerToken,BatchID ,BatchTS}, State = #client_statem_state{myName = MyName,etsRef = EtsRef}) ->
ClientStatsEts = get(client_stats_ets),
stats:increment_messages_received(ClientStatsEts),
stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)),
{RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX),
MessageBody = {WorkerName , SourceName , LossTensor , TimeNIF , BatchID , BatchTS},
MessageBody = {WorkerName , SourceName , LossTensor , TimeNIF , WorkerToken, BatchID , BatchTS},
nerl_tools:http_router_request(RouterHost, RouterPort, [?MAIN_SERVER_ATOM], atom_to_list(lossFunction), MessageBody), %% Change lossFunction atom to lossValue
stats:increment_messages_sent(ClientStatsEts),
stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)),
Expand Down Expand Up @@ -408,13 +408,13 @@ predict(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef
{keep_state, State}
end;

predict(cast, In = {predictRes,WorkerName, SourceName ,{PredictNerlTensor, NetlTensorType} , TimeTook , BatchID , BatchTS}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) ->
predict(cast, In = {predictRes,WorkerName, SourceName ,{PredictNerlTensor, NetlTensorType} , TimeTook , WorkerToken, BatchID , BatchTS}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) ->
ClientStatsEts = get(client_stats_ets),
stats:increment_messages_received(ClientStatsEts),
stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)),

{RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX),
MessageBody = {WorkerName, SourceName, {PredictNerlTensor , NetlTensorType}, TimeTook, BatchID, BatchTS}, %% SHOULD INCLUDE TYPE?
MessageBody = {WorkerName, SourceName, {PredictNerlTensor , NetlTensorType}, TimeTook, WorkerToken, BatchID, BatchTS}, %% SHOULD INCLUDE TYPE?
nerl_tools:http_router_request(RouterHost, RouterPort, [?MAIN_SERVER_ATOM], atom_to_list(predictRes), MessageBody),

stats:increment_messages_sent(ClientStatsEts),
Expand Down
2 changes: 1 addition & 1 deletion src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ create_workers(ClientName, ClientEtsRef , ShaToModelArgsMap , EtsStats) ->
W2wComPid = w2wCom:start_link({WorkerName, MyClientPid}), % TODO Switch to monitor instead of link

WorkerArgs = {ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctions, LearningRate , Epochs,
Optimizer, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs},
Optimizer, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemToken, DistributedSystemArgs},
WorkerPid = workerGeneric:start_link({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , MyClientPid , WorkerStatsETS , W2wComPid}),
gen_server:cast(W2wComPid, {update_gen_worker_pid, WorkerPid}),
ets:insert(WorkersETS, {WorkerName, {WorkerPid, WorkerArgs}}),
Expand Down
14 changes: 10 additions & 4 deletions src_erl/NerlnetApp/src/MainServer/mainGenserver.erl
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,14 @@ handle_cast({lossFunction,Body}, State = #main_genserver_state{myName = MyName})
stats:increment_messages_received(StatsEts),
try
case binary_to_term(Body) of
{WorkerName , SourceName , {LossNerlTensor , LossNerlTensorType} , TimeNIF , BatchID , BatchTS} ->
{WorkerName , SourceName , {LossNerlTensor , LossNerlTensorType} , TimeNIF , WorkerToken, BatchID , BatchTS} ->
Key = atom_to_list(WorkerName) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ atom_to_list(SourceName) ++
?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ integer_to_list(BatchID) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++
integer_to_list(BatchTS) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ float_to_list(TimeNIF) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++
integer_to_list(BatchTS) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ float_to_list(TimeNIF) ++
?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ WorkerToken ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++
atom_to_list(LossNerlTensorType),
% data is encoded in key with separators as follows:
% WorkerName + SourceName + BatchID + BatchTS + TimeNIF + WorkerToken + LossNerlTensorType
store_phase_result_data_to_send_ets(Key, binary_to_list(LossNerlTensor));
_ELSE ->
?LOG_ERROR("~p Wrong loss function pattern received from client and its worker ~p", [MyName, Body])
Expand All @@ -348,11 +351,14 @@ handle_cast({predictRes,Body}, State) ->
_BatchSize = ets:lookup_element(get(main_server_ets), batch_size, ?DATA_IDX),
stats:increment_messages_received(StatsEts),
try
{WorkerName, SourceName, {NerlTensor, NerlTensorType}, TimeNIF , BatchID, BatchTS} = binary_to_term(Body),
{WorkerName, SourceName, {NerlTensor, NerlTensorType}, TimeNIF , WorkerToken, BatchID, BatchTS} = binary_to_term(Body),
Key = atom_to_list(WorkerName) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ atom_to_list(SourceName) ++
?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ integer_to_list(BatchID) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++
integer_to_list(BatchTS) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ float_to_list(TimeNIF) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++
integer_to_list(BatchTS) ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ float_to_list(TimeNIF) ++
?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++ WorkerToken ++ ?PHASE_RES_VALUES_IN_KEY_SEPARATOR ++
atom_to_list(NerlTensorType),
% data is encoded in key with separators as follows:
% WorkerName + SourceName + BatchID + BatchTS + TimeNIF + WorkerToken + NerlTensorType
store_phase_result_data_to_send_ets(Key, binary_to_list(NerlTensor))
catch Err:E ->
?LOG_ERROR(?LOG_HEADER++"Error receiving predict result ~p",[{Err,E}])
Expand Down
7 changes: 7 additions & 0 deletions src_py/apiServer/NerlComDB.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
################################################
# Nerlnet - 2024 GPL-3.0 license
# Authors: Ohad Adi, Noa Shapira, David Leon #
################################################

from networkComponents import MAIN_SERVER_STR , NetworkComponents
from abc import ABC


class EntityComDB(ABC): # Abstract Class
def __init__(self):
# based on stats.erl
Expand Down
6 changes: 4 additions & 2 deletions src_py/apiServer/apiServer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
################################################
# Nerlnet - 2023 GPL-3.0 license
# Authors: Haran Cohen, David Leon, Dor Yerchi #
# Nerlnet - 2024 GPL-3.0 license
# Authors: Noa Shapira, Ohad Adi, David Leon
# Guy Perets, Haran Cohen, Dor Yerchi
################################################

import time
import threading
import sys
Expand Down
5 changes: 5 additions & 0 deletions src_py/apiServer/apiServerHelp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
################################################
# Nerlnet - 2024 GPL-3.0 license
# Authors: Noa Shapira, Ohad Adi, David Leon
# Guy Perets, Haran Cohen, Dor Yerchi
################################################

API_SERVER_HELP_STR = """
__________NERLNET CHECKLIST__________
Expand Down
16 changes: 12 additions & 4 deletions src_py/apiServer/decoderHttpMainServer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@

################################################
# Nerlnet - 2024 GPL-3.0 license
# Authors: Ohad Adi, Noa Shapira, David Leon
# Guy Perets
################################################

import numpy as np
from decoderHttpMainServerDefs import *
from definitions import NERLTENSOR_TYPE_LIST
Expand Down Expand Up @@ -65,24 +71,26 @@ def parse_key_string(key_string: str) -> tuple:
BATCH_ID_IDX = 2
BATCH_TS_IDX = 3
DURATION_IDX = 4 # TimeNIF
NERLTENSOR_TYPE_IDX = 5
WORKER_DISTRIBUTED_TOKEN_IDX = 5
NERLTENSOR_TYPE_IDX = 6

definitions_list = key_string.split(SEP_ENTITY_HASH_STATS)
worker_name = definitions_list[WORKER_NAME_IDX]
source_name = definitions_list[SOURCE_NAME_IDX]
batch_id = definitions_list[BATCH_ID_IDX]
batch_ts = definitions_list[BATCH_TS_IDX]
duration = definitions_list[DURATION_IDX]
distributed_token = definitions_list[WORKER_DISTRIBUTED_TOKEN_IDX]
nerltensor_type = definitions_list[NERLTENSOR_TYPE_IDX]

return worker_name, source_name, batch_id, batch_ts, duration, nerltensor_type
return worker_name, source_name, batch_id, batch_ts, duration, distributed_token, nerltensor_type


def decode_phase_result_data_json_from_main_server(input_json_dict : dict) -> list:
decoded_data = []
DIMS_LENGTH = 3
for key_string, nerltensor in input_json_dict.items():
worker_name, source_name, batch_id, batch_ts, duration, nerltensor_type = parse_key_string(key_string)
worker_name, source_name, batch_id, batch_ts, duration, distributed_token, nerltensor_type = parse_key_string(key_string)
duration = int(float(duration)) # from here duration is int in micro seconds

# nerltensor to numpy tensor conversion
Expand All @@ -104,5 +112,5 @@ def decode_phase_result_data_json_from_main_server(input_json_dict : dict) -> li
np_tensor = np_tensor[DIMS_LENGTH:]
np_tensor = np_tensor.reshape(dims) # reshaped

decoded_data.append((worker_name, source_name, duration, batch_id, batch_ts, np_tensor))
decoded_data.append((worker_name, source_name, duration, batch_id, batch_ts, distributed_token, np_tensor))
return decoded_data
4 changes: 2 additions & 2 deletions src_py/apiServer/experiment_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def process_experiment_phase_data(self):
assert len(self.raw_data_buffer) == 1, "Expecting only one raw_data in buffer of a single phase"
list_of_decoded_data = decode_phase_result_data_json_from_main_server(self.raw_data_buffer[0])
for decoded_data in list_of_decoded_data:
worker_name, source_name, duration, batch_id, batch_ts, np_tensor = decoded_data
worker_name, source_name, duration, batch_id, batch_ts, distributed_token, np_tensor = decoded_data
client_name = self.network_componenets.get_client_name_by_worker_name(worker_name)
self.nerl_model_db.get_client(client_name).get_worker(worker_name).create_batch(batch_id, source_name, np_tensor, duration, batch_ts)
self.nerl_model_db.get_client(client_name).get_worker(worker_name).create_batch(batch_id, source_name, np_tensor, duration, distributed_token, batch_ts)

self.clean_raw_data_buffer()

Expand Down
18 changes: 14 additions & 4 deletions src_py/apiServer/nerl_model_db.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
################################################
# Nerlnet - 2024 GPL-3.0 license
# Authors: Ohad Adi, Noa Shapira, David Leon
################################################

from logger import *
import numpy as np

class BatchDB():
def __init__(self, batch_id, source_name, tensor_data, duration, batch_timestamp):
def __init__(self, batch_id, source_name, tensor_data, duration, distributed_token, batch_timestamp):
self.batch_id = batch_id
self.source_name = source_name
self.tensor_data = tensor_data
self.duration = duration
self.batch_timestamp = batch_timestamp
self.distributed_token = distributed_token

def get_source_name(self):
return self.source_name
Expand All @@ -17,19 +22,24 @@ def get_batch_id(self):

def get_tensor_data(self):
return self.tensor_data

def get_distributed_token(self):
return self.distributed_token


class WorkerModelDB():
def __init__(self, worker_name):
self.batches_dict = {}
self.batches_ts_dict = {}
self.warn_override = False
self.worker_name = worker_name

def create_batch(self, batch_id, source_name, tensor_data, duration, batch_timestamp):
def create_batch(self, batch_id, source_name, tensor_data, duration, distributed_token, batch_timestamp):
if batch_id in self.batches_dict:
if not self.warn_override:
LOG_WARNING(f"Override batches from batch id: {batch_id} in worker {self.worker_name} in source {source_name}.")
self.warn_override = True
self.batches_dict[(source_name, batch_id)] = BatchDB(batch_id, source_name, tensor_data, duration, batch_timestamp)
self.batches_dict[(source_name, batch_id)] = BatchDB(batch_id, source_name, tensor_data, duration, distributed_token, batch_timestamp)
self.batches_ts_dict[batch_timestamp] = self.batches_dict[(source_name, batch_id)]

def get_batch(self, source_name, batch_id):
Expand Down

0 comments on commit 1eb8229

Please sign in to comment.