diff --git a/inputJsonsFiles/ConnectionMap/conn_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/ConnectionMap/conn_fed_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..9ffce810 --- /dev/null +++ b/inputJsonsFiles/ConnectionMap/conn_fed_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,8 @@ +{ + "connectionsMap": + { + "r1":["mainServer", "r2" , "c2" , "r3"], + "r2":["r1", "s1" , "c1" , "r3"], + "r3":["r1", "r2" , "s2" , "c3"] + } +} diff --git a/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json new file mode 100644 index 00000000..db39f106 --- /dev/null +++ b/inputJsonsFiles/ConnectionMap/conn_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -0,0 +1,7 @@ +{ + "connectionsMap": + { + "r1":["mainServer", "r2" , "c2"], + "r2":["r1", "s1" , "c1"] + } +} diff --git a/inputJsonsFiles/DistributedConfig/dc_AEC_1d_2c_1s_4r_4w.json b/inputJsonsFiles/DistributedConfig/dc_AEC_1d_2c_1s_4r_4w.json new file mode 100644 index 00000000..753ac010 --- /dev/null +++ b/inputJsonsFiles/DistributedConfig/dc_AEC_1d_2c_1s_4r_4w.json @@ -0,0 +1,67 @@ +{ + "nerlnetSettings": { + "frequency": "200", + "batchSize": "100" + }, + "mainServer": { + "port": "8081", + "args": "" + }, + "apiServer": { + "port": "8082", + "args": "" + }, + "devices": [ + { + "name": "pc1", + "ipv4": "10.211.55.3", + "entities": "c1,c2,r2,r1,r3,r4,s1,apiServer,mainServer" + } + ], + "routers": [ + { + "name": "r1", + "port": "8086", + "policy": "0" + }, + { + "name": "r2", + "port": "8087", + "policy": "0" + }, + { + "name": "r3", + "port": "8088", + "policy": "0" + }, + { + "name": "r4", + "port": "8089", + "policy": "0" + } + ], + "sources": [ + { + "name": "s1", + "port": "8085", + "frequency": "200", + "policy": "0", + "epochs": "1", + "type": "0" + } + ], + "clients": [ + { + "name": "c1", + "port": "8083", + "workers": "" + }, + { + "name": "c2", + "port": "8084", + "workers": "" + } + ], + "workers": [], + "model_sha": {} +} \ No newline at end of file diff --git a/inputJsonsFiles/DistributedConfig/dc_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/DistributedConfig/dc_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..958ea78f --- /dev/null +++ b/inputJsonsFiles/DistributedConfig/dc_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,143 @@ +{ + "nerlnetSettings": { + "frequency": "100", + "batchSize": "100" + }, + "mainServer": { + "port": "8900", + "args": "" + }, + "apiServer": { + "port": "8901", + "args": "" + }, + "devices": [ + { + "name": "c0vm0", + "ipv4": "10.0.0.5", + "entities": "mainServer,c1,c2,r1,r2,s1,apiServer" + }, + { + "name": "c0vm1", + "ipv4": "10.0.0.4", + "entities": "c3,r3,s2" + } + ], + "routers": [ + { + "name": "r1", + "port": "8905", + "policy": "0" + }, + { + "name": "r2", + "port": "8906", + "policy": "0" + }, + { + "name": "r3", + "port": "8901", + "policy": "0" + } + ], + "sources": [ + { + "name": "s1", + "port": "8904", + "frequency": "200", + "policy": "0", + "epochs": "1", + "type": "0" + }, + { + "name": "s2", + "port": "8902", + "frequency": "200", + "policy": "0", + "epochs": "1", + "type": "0" + } + ], + "clients": [ + { + "name": "c1", + "port": "8902", + "workers": "w1,w2,ws" + }, + { + "name": "c2", + "port": "8903", + "workers": "w3,w4" + }, + { + "name": "c3", + "port": "8900", + "workers": "w5,w6" + } + ], + "workers": [ + { + "name": "w1", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w2", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "ws", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w3", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w4", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w5", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w6", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + } + ], + "model_sha": { + "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,6,6,4,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,3", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", + "layers_functions": "1,8,8,8,11", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.001", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "0", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "SyncMaxCount=10", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + } + } +} \ No newline at end of file diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..ffa02d36 --- /dev/null +++ b/inputJsonsFiles/DistributedConfig/dc_fed_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,176 @@ +{ + "nerlnetSettings": { + "frequency": "100", + "batchSize": "100" + }, + "mainServer": { + "port": "8900", + "args": "" + }, + "apiServer": { + "port": "8901", + "args": "" + }, + "devices": [ + { + "name": "c0vm0", + "ipv4": "10.0.0.5", + "entities": "mainServer,c1,c2,r1,r2,s1,apiServer" + }, + { + "name": "c0vm1", + "ipv4": "10.0.0.4", + "entities": "c3,r3,s2" + } + ], + "routers": [ + { + "name": "r1", + "port": "8905", + "policy": "0" + }, + { + "name": "r2", + "port": "8906", + "policy": "0" + }, + { + "name": "r3", + "port": "8901", + "policy": "0" + } + ], + "sources": [ + { + "name": "s1", + "port": "8904", + "frequency": "200", + "policy": "1", + "epochs": "1", + "type": "0" + }, + { + "name": "s2", + "port": "8902", + "frequency": "200", + "policy": "1", + "epochs": "1", + "type": "0" + } + ], + "clients": [ + { + "name": "c1", + "port": "8902", + "workers": "w1,w2,ws" + }, + { + "name": "c2", + "port": "8903", + "workers": "w3,w4" + }, + { + "name": "c3", + "port": "8900", + "workers": "w5,w6" + } + ], + "workers": [ + { + "name": "w1", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w2", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "ws", + "model_sha": "c081daf49b8332585243b68cb828ebc9b947528601a6852688cea0312b3e3914" + }, + { + "name": "w3", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w4", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w5", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + }, + { + "name": "w6", + "model_sha": "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896" + } + ], + "model_sha": { + "0771693392e898393c9b2b8235497537b5fbed1fd0c9a5a7ec6aab665d2c1896": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,6,6,4,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,3", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", + "layers_functions": "1,8,8,8,11", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.001", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "1", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "SyncMaxCount=10", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + }, + "c081daf49b8332585243b68cb828ebc9b947528601a6852688cea0312b3e3914": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,6,6,4,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,3", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", + "layers_functions": "1,8,8,8,11", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.001", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "2", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "SyncMaxCount=10", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + } + } +} \ No newline at end of file diff --git a/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json new file mode 100644 index 00000000..7814c903 --- /dev/null +++ b/inputJsonsFiles/DistributedConfig/dc_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -0,0 +1,145 @@ +{ + "nerlnetSettings": { + "frequency": "100", + "batchSize": "100" + }, + "mainServer": { + "port": "8900", + "args": "" + }, + "apiServer": { + "port": "8901", + "args": "" + }, + "devices": [ + { + "name": "c0vm0", + "ipv4": "10.0.0.5", + "entities": "mainServer,c1,c2,r1,r2,s1,apiServer" + } + ], + "routers": [ + { + "name": "r1", + "port": "8905", + "policy": "0" + }, + { + "name": "r2", + "port": "8906", + "policy": "0" + } + ], + "sources": [ + { + "name": "s1", + "port": "8904", + "frequency": "200", + "policy": "0", + "epochs": "1", + "type": "0" + } + ], + "clients": [ + { + "name": "c1", + "port": "8902", + "workers": "w1,w2,ws" + }, + { + "name": "c2", + "port": "8903", + "workers": "w3,w4" + } + ], + "workers": [ + { + "name": "w1", + "model_sha": "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74" + }, + { + "name": "w2", + "model_sha": "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74" + }, + { + "name": "ws", + "model_sha": "24cfe345509ff1d121e437fc0baf3fb8feba88dda87db11b7c9c7aaff065c40b" + }, + { + "name": "w3", + "model_sha": "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74" + }, + { + "name": "w4", + "model_sha": "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74" + } + ], + "model_sha": { + "7c0c5327ad2632a8a1107ed60f03b5bb49fc098332e7b91a12f214d045c6dd74": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,2,2,2,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "1", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "SyncMaxCount=5", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + }, + "24cfe345509ff1d121e437fc0baf3fb8feba88dda87db11b7c9c7aaff065c40b": { + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,2,2,2,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "2", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "SyncMaxCount=5", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" + } + } +} \ No newline at end of file diff --git a/inputJsonsFiles/Workers/worker_ae_classifier.json b/inputJsonsFiles/Workers/worker_ae_classifier.json new file mode 100644 index 00000000..96d59122 --- /dev/null +++ b/inputJsonsFiles/Workers/worker_ae_classifier.json @@ -0,0 +1,33 @@ +{ + "modelType": "9", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "11,6,4,6,11", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,3", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |", + "layers_functions": "1,7,7,7,11", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "0", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "none", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "none", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" +} \ No newline at end of file diff --git a/inputJsonsFiles/Workers/worker_fed_client.json b/inputJsonsFiles/Workers/worker_fed_client.json new file mode 100644 index 00000000..28964195 --- /dev/null +++ b/inputJsonsFiles/Workers/worker_fed_client.json @@ -0,0 +1,33 @@ +{ + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,10,5,3,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "1", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "none", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" +} \ No newline at end of file diff --git a/inputJsonsFiles/Workers/worker_fed_server.json b/inputJsonsFiles/Workers/worker_fed_server.json new file mode 100644 index 00000000..d9eb7758 --- /dev/null +++ b/inputJsonsFiles/Workers/worker_fed_server.json @@ -0,0 +1,33 @@ +{ + "modelType": "0", + "_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image_classification:4 | text_classification:5 | text_generation:6 | auto_association:7 | autoencoder:8 | ae_classifier:9 |", + "modelArgs": "", + "_doc_modelArgs": "Extra arguments to model", + "layersSizes": "5,10,5,3,3", + "_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]", + "layerTypesList": "1,3,3,3,5", + "_doc_LayerTypes": " Default:0 | Scaling:1 | Conv:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Bounding:9 |", + "layers_functions": "1,6,6,11,4", + "_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |", + "_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |", + "_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |", + "_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |", + "lossMethod": "2", + "_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |", + "lr": "0.01", + "_doc_lr": "Positve float", + "epochs": "1", + "_doc_epochs": "Positve Integer", + "optimizer": "5", + "_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |", + "optimizerArgs": "none", + "_doc_optimizerArgs": "String", + "infraType": "0", + "_doc_infraType": " opennn:0 | wolfengine:1 |", + "distributedSystemType": "2", + "_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |", + "distributedSystemArgs": "none", + "_doc_distributedSystemArgs": "String", + "distributedSystemToken": "9922u", + "_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server" +} \ No newline at end of file diff --git a/inputJsonsFiles/experimentsFlow/exp_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/experimentsFlow/exp_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..b96649d5 --- /dev/null +++ b/inputJsonsFiles/experimentsFlow/exp_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,54 @@ +{ + "experimentName": "synthetic_3_gausians", + "experimentType": "classification", + "batchSize": 100, + "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", + "numOfFeatures": "5", + "numOfLabels": "3", + "headersNames": "Norm(0:1),Norm(4:1),Norm(10:3)", + "Phases": + [ + { + "phaseName": "training_phase", + "phaseType": "training", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "0", + "numOfBatches": "250", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + }, + { + "sourceName": "s2", + "startingSample": "25000", + "numOfBatches": "250", + "workers": "w5,w6", + "nerltensorType": "float" + } + ] + }, + { + "phaseName": "prediction_phase", + "phaseType": "prediction", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "50000", + "numOfBatches": "500", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + }, + { + "sourceName": "s2", + "startingSample": "50000", + "numOfBatches": "500", + "workers": "w5,w6", + "nerltensorType": "float" + } + ] + } +] +} \ No newline at end of file diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json b/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json new file mode 100644 index 00000000..b96649d5 --- /dev/null +++ b/inputJsonsFiles/experimentsFlow/exp_fed_dist_2d_3c_2s_3r_6w.json @@ -0,0 +1,54 @@ +{ + "experimentName": "synthetic_3_gausians", + "experimentType": "classification", + "batchSize": 100, + "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", + "numOfFeatures": "5", + "numOfLabels": "3", + "headersNames": "Norm(0:1),Norm(4:1),Norm(10:3)", + "Phases": + [ + { + "phaseName": "training_phase", + "phaseType": "training", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "0", + "numOfBatches": "250", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + }, + { + "sourceName": "s2", + "startingSample": "25000", + "numOfBatches": "250", + "workers": "w5,w6", + "nerltensorType": "float" + } + ] + }, + { + "phaseName": "prediction_phase", + "phaseType": "prediction", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "50000", + "numOfBatches": "500", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + }, + { + "sourceName": "s2", + "startingSample": "50000", + "numOfBatches": "500", + "workers": "w5,w6", + "nerltensorType": "float" + } + ] + } +] +} \ No newline at end of file diff --git a/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json new file mode 100644 index 00000000..7c7c32ff --- /dev/null +++ b/inputJsonsFiles/experimentsFlow/exp_fed_synt_1d_2c_2r_1s_4w_1ws.json @@ -0,0 +1,41 @@ +{ + "experimentName": "synthetic_3_gausians", + "experimentType": "classification", + "batchSize": 100, + "csvFilePath": "/tmp/nerlnet/data/NerlnetData-master/nerlnet/synthetic_norm/synthetic_full.csv", + "numOfFeatures": "5", + "numOfLabels": "3", + "headersNames": "Norm(0:1),Norm(4:1),Norm(10:3)", + "Phases": + [ + { + "phaseName": "training_phase", + "phaseType": "training", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "0", + "numOfBatches": "500", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + } + ] + }, + { + "phaseName": "prediction_phase", + "phaseType": "prediction", + "sourcePieces": + [ + { + "sourceName": "s1", + "startingSample": "50000", + "numOfBatches": "500", + "workers": "w1,w2,w3,w4", + "nerltensorType": "float" + } + ] + } + ] +} + diff --git a/inputJsonsFiles/experimentsFlow/exp_test_synt_1d_2c_1s_4r_4w new.json b/inputJsonsFiles/experimentsFlow/exp_test_synt_1d_2c_1s_4r_4w new.json index ae583eec..f05f5f01 100644 --- a/inputJsonsFiles/experimentsFlow/exp_test_synt_1d_2c_1s_4r_4w new.json +++ b/inputJsonsFiles/experimentsFlow/exp_test_synt_1d_2c_1s_4r_4w new.json @@ -36,6 +36,6 @@ } ] } - ] + ] } diff --git a/src_erl/NerlnetApp/src/Bridge/nerlNIF.erl b/src_erl/NerlnetApp/src/Bridge/nerlNIF.erl index f5250170..823de1e4 100644 --- a/src_erl/NerlnetApp/src/Bridge/nerlNIF.erl +++ b/src_erl/NerlnetApp/src/Bridge/nerlNIF.erl @@ -3,7 +3,7 @@ -include("nerlTensor.hrl"). -export([init/0,nif_preload/0,get_active_models_ids_list/0, train_nif/3,update_nerlworker_train_params_nif/6,call_to_train/5,predict_nif/3,call_to_predict/5,get_weights_nif/1,printTensor/2]). --export([call_to_get_weights/2,call_to_set_weights/2]). +-export([call_to_get_weights/1,call_to_set_weights/2]). -export([decode_nif/2, nerltensor_binary_decode/2]). -export([encode_nif/2, nerltensor_encode/5, nerltensor_conversion/2, get_all_binary_types/0, get_all_nerltensor_list_types/0]). -export([erl_type_conversion/1]). @@ -77,21 +77,22 @@ call_to_predict(ModelID, {BatchTensor, Type}, WorkerPid, BatchID , SourceName)-> gen_statem:cast(WorkerPid,{predictRes, nan, BatchID , SourceName}) end. -call_to_get_weights(ThisEts, ModelID)-> +% This function calls to get_weights_nif() and waits for the result using receive block +% Returns {NerlTensorWeights , BinaryType} +call_to_get_weights(ModelID)-> try ?LOG_INFO("Calling get weights in model ~p~n",{ModelID}), _RetVal = get_weights_nif(ModelID), - recv_call_loop(ThisEts) + recv_call_loop() catch Err:E -> ?LOG_ERROR("Couldnt get weights from worker~n~p~n",{Err,E}), [] end. %% sometimes the receive loop gets OTP calls that its not supposed to in high freq. wait for nerktensor of weights -recv_call_loop(ThisEts) -> +recv_call_loop() -> receive {'$gen_cast', _Any} -> ?LOG_WARNING("Missed batch in call of get_weigths"), - ets:update_counter(ThisEts, missedBatches, 1), - recv_call_loop(ThisEts); + recv_call_loop(); NerlTensorWeights -> NerlTensorWeights end. diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl index cdfd020e..94b5c692 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl @@ -5,67 +5,143 @@ -export([start_link/1]). -export([init/1, handle_cast/2, handle_call/3]). --export([send_message/3, get_all_messages/0]). % methods that are used by worker +-export([send_message/4, send_message_with_event/5, get_all_messages/1 , sync_inbox/1, sync_inbox_no_limit/1]). % methods that are used by worker + + +setup_logger(Module) -> + logger:set_handler_config(default, formatter, {logger_formatter, #{}}), + logger:set_module_level(Module, all). %% @doc Spawns the server and registers the local name (unique) -spec(start_link(args) -> - {ok, Pid :: pid()} | ignore | {error, Reason :: term()}). + {ok, Pid :: pid()} | ignore | {error, Reason :: term()}). start_link(Args = {WorkerName, _ClientStatemPid}) -> - {ok,Gen_Server_Pid} = gen_server:start_link({local, WorkerName}, ?MODULE, Args, []), - Gen_Server_Pid. + setup_logger(?MODULE), + {ok,Gen_Server_Pid} = gen_server:start_link({local, WorkerName}, ?MODULE, Args, []), + Gen_Server_Pid. -init({WorkerName, ClientStatemPid}) -> +init({WorkerName, MyClientPid}) -> InboxQueue = queue:new(), W2wEts = ets:new(w2w_ets, [set]), put(worker_name, WorkerName), - put(client_statem_pid, ClientStatemPid), + put(client_statem_pid, MyClientPid), + % TODO Send init message to client with the {WorkerName , W2WCOMM_PID} put(w2w_ets, W2wEts), ets:insert(W2wEts, {inbox_queue, InboxQueue}), {ok, []}. -% Messages are of the form: {FromWorkerName, Data} -handle_cast({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, Data}, State) -> +handle_cast({update_gen_worker_pid, GenWorkerPid}, State) -> + put(gen_worker_pid, GenWorkerPid), + {noreply, State}; + +handle_cast(Msg, State) -> + io:format("@w2wCom: Wrong message received ~p~n", [Msg]), + {noreply, State}. + +handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {msg_with_event, Event, Data}}, _From, State) -> case get(worker_name) of ThisWorkerName -> ok; _ -> throw({error, "The provided worker name is not this worker"}) end, + GenWorkerPid = get(gen_worker_pid), + case Event of + post_train_update -> gen_statem:cast(GenWorkerPid, {post_train_update, Data}); + worker_done -> gen_statem:cast(GenWorkerPid, {worker_done, Data}); + start_stream -> gen_statem:cast(GenWorkerPid, {start_stream, Data}); + end_stream -> gen_statem:cast(GenWorkerPid, {end_stream, Data}) + end, + % Saved messages are of the form: {FromWorkerName, , Data} Message = {FromWorkerName, Data}, add_msg_to_inbox_queue(Message), - io:format("Worker ~p received message from ~p: ~p~n", [ThisWorkerName, FromWorkerName, Data]), %TODO remove - {noreply, State}; + {reply, {ok, Event}, State}; + +% Received messages are of the form: {worker_to_worker_msg, FromWorkerName, ThisWorkerName, Data} +handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, Data}, _From, State) -> + case get(worker_name) of + ThisWorkerName -> ok; + _ -> throw({error, "The provided worker name is not this worker"}) + end, + % Saved messages are of the form: {FromWorkerName, , Data} + Message = {FromWorkerName, Data}, + add_msg_to_inbox_queue(Message), + {reply, {ok, "Message received"}, State}; % Token messages are tupe of: {FromWorkerName, Token, Data} -handle_cast({?W2WCOM_TOKEN_CAST_ATOM, FromWorkerName, ThisWorkerName, Token, Data}, State) -> +handle_call({?W2WCOM_TOKEN_CAST_ATOM, FromWorkerName, ThisWorkerName, Token, Data}, _From, State) -> case get(worker_name) of ThisWorkerName -> ok; _ -> throw({error, "The provided worker name is not this worker"}) end, Message = {FromWorkerName, Token, Data}, add_msg_to_inbox_queue(Message), - io:format("Worker ~p received token message from ~p: ~p~n", [ThisWorkerName, FromWorkerName, Data]), %TODO remove - {noreply, State}; + {reply, {ok, "Message received"}, State}; -handle_cast(_Msg, State) -> - {noreply, State}. + +handle_call({is_inbox_empty}, _From, State) -> + W2WEts = get(w2w_ets), + InboxQueue = ets:lookup_element(W2WEts, inbox_queue, ?ETS_KEYVAL_VAL_IDX), + IsInboxEmpty = queue:len(InboxQueue) == 0, + {reply, {ok, IsInboxEmpty}, State}; + +handle_call({get_inbox_queue}, _From, State) -> + W2WEts = get(w2w_ets), + NewEmptyQueue = queue:new(), + InboxQueue = ets:lookup_element(W2WEts, inbox_queue, ?ETS_KEYVAL_VAL_IDX), + ets:update_element(W2WEts, inbox_queue, {?ETS_KEYVAL_VAL_IDX, NewEmptyQueue}), + {reply, {ok, InboxQueue}, State}; + +handle_call({get_client_pid}, _From, State) -> + {reply, {ok, get(client_statem_pid)}, State}; handle_call(_Call, _From, State) -> {noreply, State}. -get_all_messages() -> - W2WEts = get(w2w_ets), - {_, InboxQueue} = ets:lookup(W2WEts, inbox_queue), +get_all_messages(W2WPid) -> % Returns the InboxQueue and flush it + {ok , InboxQueue} = gen_server:call(W2WPid, {get_inbox_queue}), InboxQueue. -add_msg_to_inbox_queue(Message) -> +add_msg_to_inbox_queue(Message) -> % Only w2wCom process executes this function W2WEts = get(w2w_ets), - {_, InboxQueue} = ets:lookup(W2WEts, inbox_queue), + InboxQueue = ets:lookup_element(W2WEts, inbox_queue, ?ETS_KEYVAL_VAL_IDX), InboxQueueUpdated = queue:in(Message, InboxQueue), - ets:insert(W2WEts, {inbox_queue, InboxQueueUpdated}). + ets:update_element(W2WEts, inbox_queue, {?ETS_KEYVAL_VAL_IDX, InboxQueueUpdated}). + + +send_message(W2WPid, FromWorker, TargetWorker, Data) -> + Msg = {?W2WCOM_ATOM, FromWorker, TargetWorker, Data}, + {ok, MyClient} = gen_server:call(W2WPid, {get_client_pid}), + gen_statem:cast(MyClient, Msg). + +send_message_with_event(W2WPid, FromWorker, TargetWorker, Event, Data) -> + ValidEvent = lists:member(Event, ?SUPPORTED_EVENTS), + if ValidEvent -> ok; + true -> ?LOG_ERROR("Event ~p is not supported!!",[Event]), + throw({error, "The provided event is not supported"}) + end, + Msg = {?W2WCOM_ATOM, FromWorker, TargetWorker, {msg_with_event, Event, Data}}, + {ok, MyClient} = gen_server:call(W2WPid, {get_client_pid}), + gen_statem:cast(MyClient, Msg). + + +timeout_throw(Timeout) -> + receive + stop -> ok; + _ -> timeout_throw(Timeout) + after Timeout -> throw("Timeout reached") + end. + +sync_inbox(W2WPid) -> + TimeoutPID = spawn(fun() -> timeout_throw(?SYNC_INBOX_TIMEOUT) end), + sync_inbox(TimeoutPID , W2WPid). -send_message(FromWorkerName, ToWorkerName, Data) -> - Msg = {?W2WCOM_ATOM, FromWorkerName, ToWorkerName, Data}, - MyClient = client_name, % TODO - gen_server:cast(MyClient, Msg). +sync_inbox_no_limit(W2WPid) -> + TimeoutPID = spawn(fun() -> timeout_throw(?SYNC_INBOX_TIMEOUT_NO_LIMIT) end), + sync_inbox(TimeoutPID , W2WPid). - - \ No newline at end of file +sync_inbox(TimeoutPID, W2WPid) -> + timer:sleep(?DEFAULT_SYNC_INBOX_BUSY_WAITING_SLEEP), + {ok , IsInboxEmpty} = gen_server:call(W2WPid, {is_inbox_empty}), + if + IsInboxEmpty -> sync_inbox(TimeoutPID, W2WPid); + true -> TimeoutPID ! stop + end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl index a76172fa..1fd26762 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl @@ -1,4 +1,11 @@ +-include_lib("kernel/include/logger.hrl"). +-include("workerDefinitions.hrl"). + -define(W2WCOM_INBOX_Q_ATOM, worker_to_worker_inbox_queue). -define(W2WCOM_ATOM, worker_to_worker_msg). --define(W2WCOM_TOKEN_CAST_ATOM, worker_to_worker_token_cast). \ No newline at end of file +-define(W2WCOM_TOKEN_CAST_ATOM, worker_to_worker_token_cast). +-define(SYNC_INBOX_TIMEOUT, 30000). % 30 seconds +-define(SYNC_INBOX_TIMEOUT_NO_LIMIT, 36000000). % 36000 seconds = 10 hours , no limit +-define(DEFAULT_SYNC_INBOX_BUSY_WAITING_SLEEP, 5). % 5 milliseconds +-define(SUPPORTED_EVENTS , [post_train_update, worker_done, start_stream, end_stream]). \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerDefinitions.hrl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerDefinitions.hrl index 588ef3f3..966d54b3 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerDefinitions.hrl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerDefinitions.hrl @@ -2,4 +2,5 @@ -define(ETS_KEYVAL_VAL_IDX, 2). -define(TENSOR_DATA_IDX, 1). --record(workerGeneric_state, {myName , modelID , distributedBehaviorFunc , distributedWorkerData , currentBatchID , nextState , lastPhase}). +-record(workerGeneric_state, {myName , modelID , distributedBehaviorFunc , distributedWorkerData , currentBatchID , nextState , lastPhase, postBatchFunc}). +-define(EMPTY_FUNC, fun() -> ok end). \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl index 90a80bc8..5e5c185f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedClient.erl @@ -3,108 +3,179 @@ -export([controller/2]). -include("/usr/local/lib/nerlnet-lib/NErlNet/src_erl/NerlnetApp/src/nerl_tools.hrl"). --include("workerDefinitions.hrl"). +-include("w2wCom.hrl"). + +-import(nerlNIF, [call_to_get_weights/2, call_to_set_weights/2]). -define(WORKER_FEDERATED_CLIENT_ETS_FIELDS, [my_name, client_pid, server_name, sync_max_count, sync_count]). -define(FEDERATED_CLIENT_ETS_KEY_IN_GENWORKER_ETS, fedrated_client_ets). - -% %% Federated mode -% wait(cast, {loss, {LOSS_FUNC,Time_NIF}}, State = #workerGeneric_state{clientPid = ClientPid,ackClient = AckClient, myName = MyName, nextState = NextState, count = Count, countLimit = CountLimit, modelId = Mid}) -> -% % {LOSS_FUNC,_TimeCpp} = LossAndTime, -% if Count == CountLimit -> -% % Get weights -% Ret_weights = nerlNIF:call_to_get_weights(Mid), -% % Ret_weights_tuple = niftest:call_to_get_weights(Mid), -% % {Weights,Bias,Biases_sizes_list,Wheights_sizes_list} = Ret_weights_tuple, - -% % ListToSend = [Weights,Bias,Biases_sizes_list,Wheights_sizes_list], - -% % Send weights and loss value -% gen_statem:cast(ClientPid,{loss, federated_weights, MyName, LOSS_FUNC, Ret_weights}), %% TODO Add Time and Time_NIF to the cast -% checkAndAck(MyName,ClientPid,AckClient), -% % Reset count and go to state train -% {next_state, NextState, State#workerNN_state{ackClient = 0, count = 0}}; - -% true -> -% %% Send back the loss value -% gen_statem:cast(ClientPid,{loss, MyName, LOSS_FUNC,Time_NIF/1000}), %% TODO Add Time and Time_NIF to the cast -% checkAndAck(MyName,ClientPid,AckClient), - - -% {next_state, NextState, State#workerNN_state{ackClient = 0, count = Count + 1}} -% end. - -%% Data = -record(workerFederatedClient, {syncCount, syncMaxCount, serverAddr}). +-define(DEFAULT_SYNC_MAX_COUNT_ARG, 100). controller(FuncName, {GenWorkerEts, WorkerData}) -> case FuncName of - init -> init({GenWorkerEts, WorkerData}); - pre_idle -> pre_idle({GenWorkerEts, WorkerData}); - post_idle -> post_idle({GenWorkerEts, WorkerData}); - pre_train -> pre_train({GenWorkerEts, WorkerData}); - post_train -> post_train({GenWorkerEts, WorkerData}); - pre_predict -> pre_predict({GenWorkerEts, WorkerData}); - post_predict-> post_predict({GenWorkerEts, WorkerData}); - update -> update({GenWorkerEts, WorkerData}) + init -> init({GenWorkerEts, WorkerData}); + pre_idle -> pre_idle({GenWorkerEts, WorkerData}); + post_idle -> post_idle({GenWorkerEts, WorkerData}); + pre_train -> pre_train({GenWorkerEts, WorkerData}); + post_train -> post_train({GenWorkerEts, WorkerData}); + pre_predict -> pre_predict({GenWorkerEts, WorkerData}); + post_predict -> post_predict({GenWorkerEts, WorkerData}); + start_stream -> start_stream({GenWorkerEts, WorkerData}); + end_stream -> end_stream({GenWorkerEts, WorkerData}) end. get_this_client_ets(GenWorkerEts) -> ets:lookup_element(GenWorkerEts, federated_client_ets, ?ETS_KEYVAL_VAL_IDX). -%% handshake with workers / server +parse_args(Args) -> + ArgsList = string:split(Args, "," , all), + Func = fun(Arg) -> + [Key, Val] = string:split(Arg, "="), + {Key, Val} + end, + lists:map(Func, ArgsList). % Returns list of tuples [{Key, Val}, ...] + +sync_max_count_init(FedClientEts , ArgsList) -> + case lists:keyfind("SyncMaxCount", 1, ArgsList) of + false -> ValInt = ?DEFAULT_SYNC_MAX_COUNT_ARG; + {_, Val} -> ValInt = list_to_integer(Val) % Val is a list (string) in the json so needs to be converted + end, + ets:insert(FedClientEts, {sync_max_count, ValInt}). + +%% handshake with workers / server at the end of init init({GenWorkerEts, WorkerData}) -> % create an ets for this client and save it to generic worker ets - FedratedClientEts = ets:new(federated_client,[set]), - ets:insert(GenWorkerEts, {federated_client_ets, FedratedClientEts}), - {SyncMaxCount, MyName, ServerName} = WorkerData, + FederatedClientEts = ets:new(federated_client,[set, public]), + ets:insert(GenWorkerEts, {federated_client_ets, FederatedClientEts}), + {MyName, Args, Token} = WorkerData, + ArgsList = parse_args(Args), + sync_max_count_init(FederatedClientEts, ArgsList), + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), % create fields in this ets - ets:insert(FedratedClientEts, {my_name, MyName}), - ets:insert(FedratedClientEts, {server_name, ServerName}), - ets:insert(FedratedClientEts, {sync_max_count, SyncMaxCount}), - ets:insert(FedratedClientEts, {sync_count, SyncMaxCount}), - ets:insert(FedratedClientEts, {server_update, false}), - io:format("finished init in ~p~n",[MyName]). - -pre_idle({GenWorkerEts, _WorkerData}) -> - ThisEts = get_this_client_ets(GenWorkerEts), - %% send to server that this worker is part of the federated workers - ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), - ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - gen_statem:cast(ClientPID,{custom_worker_message,{MyName, ServerName}}), - io:format("sent ~p init message: ~p~n",[ServerName, {MyName, ServerName}]). - -post_idle({_GenWorkerEts, _WorkerData}) -> ok. - -%% set weights from fedserver -pre_train({_GenWorkerEts, _WorkerData}) -> ok. - % ThisEts = get_this_client_ets(GenWorkerEts), - % ToUpdate = ets:lookup_element(ThisEts, server_update, ?ETS_KEYVAL_VAL_IDX), - % if ToUpdate -> - % ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - % nerlNIF:call_to_set_weights(ModelID, Weights); - % true -> nothing - % end. + ets:insert(FederatedClientEts, {my_token, Token}), + ets:insert(FederatedClientEts, {my_name, MyName}), + ets:insert(FederatedClientEts, {server_name, []}), % update later + ets:insert(FederatedClientEts, {sync_count, 0}), + ets:insert(FederatedClientEts, {server_update, false}), + ets:insert(FederatedClientEts, {handshake_done, false}), + ets:insert(FederatedClientEts, {handshake_wait, false}), + ets:insert(FederatedClientEts, {w2wcom_pid, W2WPid}), + ets:insert(FederatedClientEts, {active_streams, []}), + ets:insert(FederatedClientEts, {stream_occuring, false}), + spawn(fun() -> handshake(FederatedClientEts) end). + +handshake(FedClientEts) -> + W2WPid = ets:lookup_element(FedClientEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox(W2WPid), + InboxQueue = w2wCom:get_all_messages(W2WPid), + MessagesList = queue:to_list(InboxQueue), + Func = + fun({FedServer , {handshake, ServerToken}}) -> + ets:insert(FedClientEts, {server_name, FedServer}), + ets:insert(FedClientEts, {my_token , ServerToken}), + MyToken = ets:lookup_element(FedClientEts, my_token, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedClientEts, my_name, ?ETS_KEYVAL_VAL_IDX), + if + ServerToken =/= MyToken -> not_my_server; + true -> w2wCom:send_message(W2WPid, MyName, FedServer, {handshake, MyToken}), + ets:update_element(FedClientEts, handshake_wait, {?ETS_KEYVAL_VAL_IDX, true}) + end + end, + lists:foreach(Func, MessagesList). + +start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName, State] + [_SourceName, ModelPhase] = WorkerData, + FirstMsg = 1, + case ModelPhase of + train -> + ThisEts = get_this_client_ets(GenWorkerEts), + MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source + FirstMsg -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName); % Server gets FedWorkerName instead of SourceName + _ -> ok + end; + predict -> ok + end. -%% every countLimit batches, send updated weights -post_train({GenWorkerEts, _WorkerData}) -> +end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName] + [_SourceName, ModelPhase] = WorkerData, + case ModelPhase of + train -> + ThisEts = get_this_client_ets(GenWorkerEts), + MyName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source + 0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName); % Mimic source behavior + _ -> ok + end; + predict -> ok; + wait -> ok + end. + + +pre_idle({_GenWorkerEts, _WorkerData}) -> ok. + +post_idle({GenWorkerEts, _WorkerData}) -> + FedClientEts = get_this_client_ets(GenWorkerEts), + W2WPid = ets:lookup_element(FedClientEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + Token = ets:lookup_element(FedClientEts, my_token, ?ETS_KEYVAL_VAL_IDX), + HandshakeWait = ets:lookup_element(FedClientEts, handshake_wait, ?ETS_KEYVAL_VAL_IDX), + case HandshakeWait of + true -> HandshakeDone = ets:lookup_element(FedClientEts, handshake_done, ?ETS_KEYVAL_VAL_IDX), + case HandshakeDone of + false -> + w2wCom:sync_inbox(W2WPid), + InboxQueue = w2wCom:get_all_messages(W2WPid), + ets:update_element(FedClientEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}), + [{_FedServer, {handshake_done, Token}}] = queue:to_list(InboxQueue); + true -> ok + end; + false -> post_idle({GenWorkerEts, _WorkerData}) % busy waiting until handshake is done + end. + + + +% After SyncMaxCount , sync_inbox to get the updated model from FedServer +pre_train({GenWorkerEts, _NerlTensorWeights}) -> ThisEts = get_this_client_ets(GenWorkerEts), - SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), - if SyncCount == 0 -> + SyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_count, ?ETS_KEYVAL_VAL_IDX), + MaxSyncCount = ets:lookup_element(get_this_client_ets(GenWorkerEts), sync_max_count, ?ETS_KEYVAL_VAL_IDX), + if SyncCount == MaxSyncCount -> + W2WPid = ets:lookup_element(get_this_client_ets(GenWorkerEts), w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:sync_inbox_no_limit(W2WPid), % waiting for server to average the weights and send it + InboxQueue = w2wCom:get_all_messages(W2WPid), + [UpdateWeightsMsg] = queue:to_list(InboxQueue), + {_FedServer , {update_weights, UpdatedWeights}} = UpdateWeightsMsg, ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - Weights = nerlNIF:call_to_get_weights(GenWorkerEts, ModelID), - ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), - MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), - % io:format("Worker ~p entering update and got weights ~p~n",[MyName, Weights]), - ets:update_counter(ThisEts, sync_count, MaxSyncCount), - % io:format("Worker ~p entering update~n",[MyName]), - gen_statem:cast(ClientPID, {update, {MyName, ServerName, Weights}}), - _ToUpdate = true; - true -> - ets:update_counter(ThisEts, sync_count, -1), - _ToUpdate = false + nerlNIF:call_to_set_weights(ModelID, UpdatedWeights), + ets:update_element(ThisEts, sync_count, {?ETS_KEYVAL_VAL_IDX , 0}); + true -> ets:update_counter(ThisEts, sync_count, 1) + end. + +%% every countLimit batches, send updated weights +post_train({GenWorkerEts, _WorkerData}) -> + MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + % io:format("Worker ~p ActiveStreams ~p~n",[MyName, ActiveStreams]), + case ActiveStreams of + [] -> ok; + _ -> + ThisEts = get_this_client_ets(GenWorkerEts), + SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), + MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), + if SyncCount == MaxSyncCount -> + ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + WeightsTensor = nerlNIF:call_to_get_weights(ModelID), + ServerName = ets:lookup_element(ThisEts, server_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:send_message_with_event(W2WPid, MyName, ServerName , post_train_update, WeightsTensor); + true -> ok + end end. %% nothing? @@ -113,20 +184,5 @@ pre_predict({_GenWorkerEts, WorkerData}) -> WorkerData. %% nothing? post_predict(Data) -> Data. -%% gets weights from federated server -update({GenWorkerEts, NerlTensorWeights}) -> - % ThisEts = get_this_client_ets(GenWorkerEts), - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - nerlNIF:call_to_set_weights(ModelID, NerlTensorWeights). - % io:format("updated weights in worker ~p~n",[ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX)]). - -%%------------------------------------------ -% worker_event_polling(0) -> ?LOG_ERROR("worker event polling takes too long!"); -% worker_event_polling(Weights) -> -% if length(Weights) == 1 -> Weights; -% length(Weights) > 1 -> ?LOG_ERROR("more than 1 messages pending!"); -% true -> %% wait for info to update -% receive _ -> non -% after 1 -> worker_event_polling(T-1) -% end -% end. \ No newline at end of file + + diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl index 877d534a..86b9fba3 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl @@ -2,122 +2,179 @@ -export([controller/2]). --include("workerDefinitions.hrl"). +-include("w2wCom.hrl"). --import(nerlNIF,[nerltensor_scalar_multiplication_nif/3]). +-import(nerlNIF,[nerltensor_scalar_multiplication_nif/3, call_to_get_weights/1, call_to_set_weights/2]). -import(nerlTensor,[sum_nerltensors_lists/2]). +-import(w2wCom,[send_message/3, get_all_messages/0, is_inbox_empty/0]). + -define(ETS_WID_IDX, 1). -define(ETS_TYPE_IDX, 2). -define(ETS_WEIGHTS_AND_BIAS_NERLTENSOR_IDX, 3). -define(ETS_NERLTENSOR_TYPE_IDX, 2). +-define(DEFAULT_SYNC_MAX_COUNT_ARG, 1). +-define(HANDSHAKE_TIMEOUT, 2000). % 2 seconds controller(FuncName, {GenWorkerEts, WorkerData}) -> case FuncName of - init -> init({GenWorkerEts, WorkerData}); - pre_idle -> pre_idle({GenWorkerEts, WorkerData}); - post_idle -> post_idle({GenWorkerEts, WorkerData}); - pre_train -> pre_train({GenWorkerEts, WorkerData}); - post_train -> post_train({GenWorkerEts, WorkerData}); - pre_predict -> pre_predict({GenWorkerEts, WorkerData}); - post_predict -> post_predict({GenWorkerEts, WorkerData}); - update -> update({GenWorkerEts, WorkerData}) + init -> init({GenWorkerEts, WorkerData}); + pre_idle -> pre_idle({GenWorkerEts, WorkerData}); + post_idle -> post_idle({GenWorkerEts, WorkerData}); + pre_train -> pre_train({GenWorkerEts, WorkerData}); + post_train -> post_train({GenWorkerEts, WorkerData}); + pre_predict -> pre_predict({GenWorkerEts, WorkerData}); + post_predict -> post_predict({GenWorkerEts, WorkerData}); + start_stream -> start_stream({GenWorkerEts, WorkerData}); + end_stream -> end_stream({GenWorkerEts, WorkerData}) end. + +% After adding put(Ets) to init this function is not needed get_this_server_ets(GenWorkerEts) -> ets:lookup_element(GenWorkerEts, federated_server_ets, ?ETS_KEYVAL_VAL_IDX). +parse_args(Args) -> + ArgsList = string:split(Args, "," , all), + Func = fun(Arg) -> + [Key, Val] = string:split(Arg, "="), + {Key, Val} + end, + lists:map(Func, ArgsList). % Returns list of tuples [{Key, Val}, ...] + +sync_max_count_init(FedClientEts , ArgsList) -> + case lists:keyfind("SyncMaxCount", 1, ArgsList) of + false -> ValInt = ?DEFAULT_SYNC_MAX_COUNT_ARG; + {_, Val} -> ValInt = list_to_integer(Val) + end, + ets:insert(FedClientEts, {sync_max_count, ValInt}). + %% handshake with workers / server init({GenWorkerEts, WorkerData}) -> - Type = float, % update from data - {SyncMaxCount, MyName, WorkersNamesList} = WorkerData, FederatedServerEts = ets:new(federated_server,[set]), + {MyName, Args, Token , WorkersList} = WorkerData, + BroadcastWorkers = WorkersList -- [MyName], + ArgsList = parse_args(Args), + sync_max_count_init(FederatedServerEts, ArgsList), ets:insert(GenWorkerEts, {federated_server_ets, FederatedServerEts}), - ets:insert(FederatedServerEts, {workers, [MyName]}), %% start with only self in list, get others in network thru handshake - ets:insert(FederatedServerEts, {sync_max_count, SyncMaxCount}), - ets:insert(FederatedServerEts, {sync_count, SyncMaxCount}), + W2WPid = ets:lookup_element(GenWorkerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + ets:insert(FederatedServerEts, {w2wcom_pid, W2WPid}), + ets:insert(FederatedServerEts, {broadcast_workers_list, BroadcastWorkers}), + ets:insert(FederatedServerEts, {fed_clients, []}), + ets:insert(FederatedServerEts, {active_workers , []}), + ets:insert(FederatedServerEts, {sync_count, 0}), ets:insert(FederatedServerEts, {my_name, MyName}), - ets:insert(FederatedServerEts, {nerltensor_type, Type}). + ets:insert(FederatedServerEts, {token , Token}), + ets:insert(FederatedServerEts, {weights_list, []}), + put(fed_server_ets, FederatedServerEts). + + +start_stream({GenWorkerEts, WorkerData}) -> + [FedWorkerName , _ModelPhase] = WorkerData, + FedServerEts = get_this_server_ets(GenWorkerEts), + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + gen_server:cast(ClientPid, {start_stream, {worker, MyName, FedWorkerName}}). + +end_stream({GenWorkerEts, WorkerData}) -> % Federated server takes the control of popping the stream from the active streams list + [FedWorkerName , _ModelPhase] = WorkerData, + FedServerEts = get_this_server_ets(GenWorkerEts), + MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + gen_statem:cast(ClientPid, {worker_done, {MyName, FedWorkerName}}), + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + case ActiveStreams of + [] -> ets:update_element(FedServerEts, active_streams, {?ETS_KEYVAL_VAL_IDX, none}); + _ -> ok + end. -pre_idle({GenWorkerEts, WorkerName}) -> ok. -post_idle({GenWorkerEts, WorkerName}) -> - ThisEts = get_this_server_ets(GenWorkerEts), - io:format("adding worker ~p to fed workers~n",[WorkerName]), - Workers = ets:lookup_element(ThisEts, workers, ?ETS_KEYVAL_VAL_IDX), - ets:insert(ThisEts, {workers, Workers++[WorkerName]}). +pre_idle({_GenWorkerEts, _WorkerName}) -> ok. + + +% Extract all workers in nerlnet network +% Send handshake message to all workers +% Wait for all workers to send handshake message back +post_idle({GenWorkerEts, _WorkerName}) -> + HandshakeDone = ets:lookup_element(GenWorkerEts, handshake_done, ?ETS_KEYVAL_VAL_IDX), + case HandshakeDone of + false -> + FedServerEts = get(fed_server_ets), + FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + WorkersList = ets:lookup_element(FedServerEts, broadcast_workers_list, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(FedServerEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + MyToken = ets:lookup_element(FedServerEts, token, ?ETS_KEYVAL_VAL_IDX), + Func = fun(FedClient) -> + w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake, MyToken}) + end, + lists:foreach(Func, WorkersList), + timer:sleep(?HANDSHAKE_TIMEOUT), + InboxQueue = w2wCom:get_all_messages(W2WPid), + IsEmpty = queue:len(InboxQueue) == 0, + if IsEmpty == true -> + throw("Handshake failed, none of the workers responded in time"); + true -> ok + end, + MessagesList = queue:to_list(InboxQueue), + MsgFunc = + fun({FedClient, {handshake, _Token}}) -> + FedClients = ets:lookup_element(FedServerEts, fed_clients, ?ETS_KEYVAL_VAL_IDX), + ets:update_element(FedServerEts, fed_clients, {?ETS_KEYVAL_VAL_IDX , [FedClient] ++ FedClients}), + w2wCom:send_message(W2WPid, FedServerName, FedClient, {handshake_done, MyToken}) + end, + lists:foreach(MsgFunc, MessagesList), + ets:update_element(GenWorkerEts, handshake_done, {?ETS_KEYVAL_VAL_IDX, true}); + true -> ok + end. %% Send updated weights if set -pre_train({GenWorkerEts, WorkerData}) -> ok. +pre_train({_GenWorkerEts, _WorkerData}) -> ok. -%% calculate avg of weights when set -post_train({GenWorkerEts, WorkerData}) -> +% 1. get weights from all workers +% 2. average them +% 3. set new weights to model +% 4. send new weights to all workers +post_train({GenWorkerEts, WeightsTensor}) -> ThisEts = get_this_server_ets(GenWorkerEts), - SyncCount = ets:lookup_element(ThisEts, sync_count, ?ETS_KEYVAL_VAL_IDX), - if SyncCount == 0 -> - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - Weights = nerlNIF:call_to_get_weights(GenWorkerEts, ModelID), - ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), - gen_statem:cast(ClientPID, {update, {MyName, MyName, Weights}}), - MaxSyncCount = ets:lookup_element(ThisEts, sync_max_count, ?ETS_KEYVAL_VAL_IDX), - ets:update_counter(ThisEts, sync_count, MaxSyncCount), - ToUpdate = true; - true -> - ets:update_counter(ThisEts, sync_count, -1), - ToUpdate = false + FedServerEts = get(fed_server_ets), + CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX), + {WorkerWeights, _BinaryType} = WeightsTensor, + TotalWorkersWeights = CurrWorkersWeightsList ++ [WorkerWeights], + NumOfActiveWorkers = length(ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX)), + case length(TotalWorkersWeights) of + NumOfActiveWorkers -> + ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), + % io:format("Averaging model weights...~n"), + {CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID), + FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX), + AllWorkersWeightsList = TotalWorkersWeights ++ [CurrentModelWeights], + AvgWeightsNerlTensor = generate_avg_weights(AllWorkersWeightsList, BinaryType), + nerlNIF:call_to_set_weights(ModelID, AvgWeightsNerlTensor), %% update self weights to new model + Func = fun(FedClient) -> + FedServerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX), + W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX), + w2wCom:send_message(W2WPid, FedServerName, FedClient, {update_weights, AvgWeightsNerlTensor}) + end, + WorkersList = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + lists:foreach(Func, WorkersList), + ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, []}); + _ -> ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, TotalWorkersWeights}) end. - % ThisEts = get_this_server_ets(GenWorkerEts), - % Weights = generate_avg_weights(ThisEts), - - % gen_statem:cast({update, Weights}). %TODO complete send to all workers in lists:foreach - + %% nothing? -pre_predict({GenWorkerEts, WorkerData}) -> ok. +pre_predict({_GenWorkerEts, _WorkerData}) -> ok. %% nothing? -post_predict({GenWorkerEts, WorkerData}) -> ok. +post_predict({_GenWorkerEts, _WorkerData}) -> ok. -%% FedServer keeps an ets list of tuples: {WorkerName, worker, WeightsAndBiasNerlTensor} -%% in update get weights of clients, if got from all => avg and send back -update({GenWorkerEts, WorkerData}) -> - {WorkerName, Me, NerlTensorWeights} = WorkerData, - ThisEts = get_this_server_ets(GenWorkerEts), - %% update weights in ets - ets:insert(ThisEts, {WorkerName, worker, NerlTensorWeights}), - - %% check if there are queued messages, and treat them accordingly - MessageQueue = ets:lookup_element(GenWorkerEts, controller_message_q, ?ETS_KEYVAL_VAL_IDX), - % io:format("MessageQ=~p~n",[MessageQueue]), - [ets:insert(ThisEts, {WorkerName, worker, NerlTensorWeights}) || {Action, WorkerName, To, NerlTensorWeights} <- MessageQueue, Action == update], - % reset q - ets:delete(GenWorkerEts, controller_message_q), - ets:insert(GenWorkerEts, {controller_message_q, []}), - - %% check if got all weights of workers - WorkersList = ets:lookup_element(ThisEts, workers, ?ETS_KEYVAL_VAL_IDX), - GotWorkers = [ element(?ETS_WID_IDX, Attr) || Attr <- ets:tab2list(ThisEts), element(?ETS_TYPE_IDX, Attr) == worker], - % io:format("My workers=~p, have vectors from=~p~n",[WorkersList,GotWorkers]), - WaitingFor = WorkersList -- GotWorkers, - - if WaitingFor == [] -> - AvgWeightsNerlTensor = generate_avg_weights(ThisEts), - % io:format("AvgWeights = ~p~n",[AvgWeightsNerlTensor]), - ModelID = ets:lookup_element(GenWorkerEts, model_id, ?ETS_KEYVAL_VAL_IDX), - nerlNIF:call_to_set_weights(ModelID, AvgWeightsNerlTensor), %% update self weights to new model - [ets:delete(ThisEts, OldWorkerName) || OldWorkerName <- WorkersList ],%% delete old tensors for next aggregation phase - ClientPID = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), - gen_statem:cast(ClientPID, {custom_worker_message, WorkersList, AvgWeightsNerlTensor}), - false; - true -> true end. %% return StillUpdate = true - - -generate_avg_weights(FedEts) -> - BinaryType = ets:lookup_element(FedEts, nerltensor_type, ?ETS_NERLTENSOR_TYPE_IDX), - ListOfWorkersNerlTensors = [ element(?TENSOR_DATA_IDX, element(?ETS_WEIGHTS_AND_BIAS_NERLTENSOR_IDX, Attr)) || Attr <- ets:tab2list(FedEts), element(?ETS_TYPE_IDX, Attr) == worker], - % io:format("Tensors to sum = ~p~n",[ListOfWorkersNerlTensors]), - NerlTensors = length(ListOfWorkersNerlTensors), - [FinalSumNerlTensor] = nerlTensor:sum_nerltensors_lists(ListOfWorkersNerlTensors, BinaryType), + +generate_avg_weights(AllWorkersWeightsList, BinaryType) -> + % io:format("AllWorkersWeightsList = ~p~n",[AllWorkersWeightsList]), + NumNerlTensors = length(AllWorkersWeightsList), + if + NumNerlTensors > 1 -> [FinalSumNerlTensor] = nerlTensor:sum_nerltensors_lists(AllWorkersWeightsList, BinaryType); + true -> FinalSumNerlTensor = hd(AllWorkersWeightsList) + end, % io:format("Summed = ~p~n",[FinalSumNerlTensor]), - nerlNIF:nerltensor_scalar_multiplication_nif(FinalSumNerlTensor, BinaryType, 1.0/NerlTensors). + nerlNIF:nerltensor_scalar_multiplication_nif(FinalSumNerlTensor, BinaryType, 1.0/NumNerlTensors). diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl index e4cd8f50..0fd5ee9f 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl @@ -22,7 +22,7 @@ -export([init/1, format_status/2, state_name/3, handle_event/4, terminate/3, code_change/4, callback_mode/0]). %% States functions --export([idle/3, train/3, predict/3, wait/3, update/3]). +-export([idle/3, train/3, predict/3, wait/3]). %% ackClient :: need to tell mainserver that worker is safe and going to new state after wait state @@ -46,16 +46,18 @@ start_link(ARGS) -> %% @doc Whenever a gen_statem is started using gen_statem:start/[3,4] or %% gen_statem:start_link/[3,4], this function is called by the new process to initialize. %% distributedBehaviorFunc is the special behavior of the worker regrading the distributed system e.g. federated client/server -init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , ClientPid , WorkerStatsEts}) -> +init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , ClientPid , WorkerStatsEts , W2WPid}) -> nerl_tools:setup_logger(?MODULE), {ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctionalityCodes, LearningRate , Epochs, - OptimizerType, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs} = WorkerArgs, - GenWorkerEts = ets:new(generic_worker,[set]), + OptimizerType, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs} = WorkerArgs, + GenWorkerEts = ets:new(generic_worker,[set, public]), put(generic_worker_ets, GenWorkerEts), put(client_pid, ClientPid), put(worker_stats_ets , WorkerStatsEts), SourceBatchesEts = ets:new(source_batches,[set]), put(source_batches_ets, SourceBatchesEts), + ets:insert(GenWorkerEts,{client_pid, ClientPid}), + ets:insert(GenWorkerEts,{w2wcom_pid, W2WPid}), ets:insert(GenWorkerEts,{worker_name, WorkerName}), ets:insert(GenWorkerEts,{model_id, ModelID}), ets:insert(GenWorkerEts,{model_type, ModelType}), @@ -70,10 +72,12 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData ets:insert(GenWorkerEts,{optimizer_args, OptimizerArgs}), ets:insert(GenWorkerEts,{distributed_system_args, DistributedSystemArgs}), ets:insert(GenWorkerEts,{distributed_system_type, DistributedSystemType}), - ets:insert(GenWorkerEts,{controller_message_q, []}), %% empty Queue TODO Deprecated - % Worker to Worker communication gen_server - W2wComPid = w2wCom:start_link({WorkerName, ClientPid}), - put(w2wcom_pid, W2wComPid), + ets:insert(GenWorkerEts,{controller_message_q, []}), %% TODO Deprecated + ets:insert(GenWorkerEts,{handshake_done, false}), + ets:insert(GenWorkerEts,{active_streams, []}), + ets:insert(GenWorkerEts,{stream_occuring, false}), + % Worker to Worker communication module - this is a gen_server + Res = nerlNIF:new_nerlworker_nif(ModelID , ModelType, ModelArgs, LayersSizes, LayersTypes, LayersFunctionalityCodes, LearningRate, Epochs, OptimizerType, OptimizerArgs, LossMethod , DistributedSystemType , DistributedSystemArgs), @@ -85,8 +89,8 @@ init({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData ?LOG_ERROR("Failed to create worker ~p\n",[WorkerName]), exit(nif_failed_to_create) end, - - {ok, idle, #workerGeneric_state{myName = WorkerName , modelID = ModelID , distributedBehaviorFunc = DistributedBehaviorFunc , distributedWorkerData = DistributedWorkerData}}. + DistributedBehaviorFunc(pre_idle,{GenWorkerEts, DistributedWorkerData}), + {ok, idle, #workerGeneric_state{myName = WorkerName , modelID = ModelID , distributedBehaviorFunc = DistributedBehaviorFunc , distributedWorkerData = DistributedWorkerData, postBatchFunc = ?EMPTY_FUNC}}. %% @private %% @doc This function is called by a gen_statem when it needs to find out @@ -138,47 +142,42 @@ code_change(_OldVsn, StateName, State = #workerGeneric_state{}, _Extra) -> %% State idle -%% Event from clientStatem -idle(cast, {pre_idle}, State = #workerGeneric_state{myName = _MyName,distributedBehaviorFunc = DistributedBehaviorFunc}) -> - DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), empty}), - {next_state, idle, State}; - -%% Event from clientStatem -idle(cast, {post_idle, From}, State = #workerGeneric_state{myName = _MyName,distributedBehaviorFunc = DistributedBehaviorFunc}) -> - DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), From}), - {next_state, idle, State}; - % Go from idle to train -idle(cast, {training}, State = #workerGeneric_state{myName = MyName}) -> - worker_controller_empty_message_queue(), +idle(cast, {training}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + % io:format("@idle got training , Worker ~p is going to state idle...~n",[MyName]), + ets:update_element(get(generic_worker_ets), active_streams, {?ETS_KEYVAL_VAL_IDX, []}), + DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), train}), update_client_avilable_worker(MyName), {next_state, train, State#workerGeneric_state{lastPhase = train}}; % Go from idle to predict -idle(cast, {predict}, State = #workerGeneric_state{myName = MyName}) -> - worker_controller_empty_message_queue(), +idle(cast, {predict}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + % worker_controller_empty_message_queue(), + ets:update_element(get(generic_worker_ets), active_streams, {?ETS_KEYVAL_VAL_IDX, []}), update_client_avilable_worker(MyName), + DistributedBehaviorFunc(post_idle, {get(generic_worker_ets), predict}), {next_state, predict, State#workerGeneric_state{lastPhase = predict}}; -idle(cast, _Param, State) -> - % io:fwrite("Same state idle, command: ~p\n",[Param]), +idle(cast, _Param, State = #workerGeneric_state{myName = _MyName}) -> + % io:format("@idle Worker ~p is going to state idle...~n",[MyName]), {next_state, idle, State}. %% Waiting for receiving results or loss function %% Got nan or inf from loss function - Error, loss function too big for double -wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState}) -> +wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, postBatchFunc = PostBatchFunc}) -> stats:increment_by_value(get(worker_stats_ets), nan_loss_count, 1), gen_statem:cast(get(client_pid),{loss, MyName , SourceName ,nan , TrainTime ,BatchID}), - {next_state, NextState, State}; + DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients + PostBatchFunc(), + {next_state, NextState, State#workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}}; -wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> +wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc, postBatchFunc = PostBatchFunc}) -> BatchTimeStamp = erlang:system_time(nanosecond), gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , BatchID , BatchTimeStamp}), - ToUpdate = DistributedBehaviorFunc(post_train, {get(generic_worker_ets),DistributedWorkerData}), - if ToUpdate -> {next_state, update, State#workerGeneric_state{nextState=NextState}}; - true -> {next_state, NextState, State} - end; + DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients + PostBatchFunc(), + {next_state, NextState, State#workerGeneric_state{postBatchFunc = ?EMPTY_FUNC}}; wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, nextState = NextState, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> BatchTimeStamp = erlang:system_time(nanosecond), @@ -190,9 +189,18 @@ wait(cast, {predictRes, PredNerlTensor, PredNerlTensorType, TimeNif, BatchID , S {next_state, NextState, State} end; -wait(cast, {idle}, State) -> +wait(cast, {end_stream , Data}, State = #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> %logger:notice("Waiting, next state - idle"), - {next_state, wait, State#workerGeneric_state{nextState = idle}}; + Func = fun() -> stream_handler(end_stream, wait, Data, DistributedBehaviorFunc) end, + {next_state, wait, State#workerGeneric_state{postBatchFunc = Func}}; + + +% CANNOT HAPPEN +wait(cast, {idle}, State= #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> + %logger:notice("Waiting, next state - idle"), + DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), + update_client_avilable_worker(MyName), + {next_state, idle, State#workerGeneric_state{nextState = idle}}; wait(cast, {training}, State) -> %logger:notice("Waiting, next state - train"), @@ -204,7 +212,7 @@ wait(cast, {predict}, State) -> {next_state, wait, State#workerGeneric_state{nextState = predict}}; %% Worker in wait can't treat incoming message -wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase}) -> +wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase, myName= _MyName}) -> case LastPhase of train -> ets:update_counter(get(worker_stats_ets), batches_dropped_train , 1); @@ -215,85 +223,56 @@ wait(cast, _BatchData , State = #workerGeneric_state{lastPhase = LastPhase}) -> wait(cast, Data, State) -> % logger:notice("worker ~p in wait cant treat message: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), - worker_controller_message_queue(Data), + ?LOG_ERROR("Got unknown message in wait state: ~p~n",[Data]), + throw("Got unknown message in wait state"), {keep_state, State}. -%% treated runaway message in nerlNIF:call_to_fet_weights -% update(info, Data, State) -> -% ?LOG_NOTICE(?LOG_HEADER++"Worker ~p got data thru info: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), -% ?LOG_INFO("Worker ets is: ~p",[ets:match_object(get(generic_worker_ets), {'$0', '$1'})]), -% {keep_state, State}; - -%% TODO FIX CONTROLLER -update(cast, {update, _From, NerltensorWeights}, State = #workerGeneric_state{distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> - ?LOG_WARNING("************* Unrecognized update method , next state: ~p **************" , [NextState]), - DistributedBehaviorFunc(update, {get(generic_worker_ets), NerltensorWeights}), - {next_state, NextState, State}; - -%% Worker updates its' client that it is available (in idle state) -update(cast, {idle}, State = #workerGeneric_state{myName = MyName}) -> - update_client_avilable_worker(MyName), - {next_state, idle, State#workerGeneric_state{nextState = idle}}; - - -%% TODO Guy MOVE THIS FUNCTION TO CONTROLLER -update(cast, Data, State = #workerGeneric_state{distributedBehaviorFunc = DistributedBehaviorFunc, nextState = NextState}) -> - % io:format("worker ~p got ~p~n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), - case Data of - %% FedClient update avg weights - {update, "server", _Me, NerltensorWeights} -> - DistributedBehaviorFunc(update, {get(generic_worker_ets), NerltensorWeights}), - % io:format("worker ~p updated model and going to ~p state~n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), NextState]), - {next_state, NextState, State}; - %% FedServer get weights from clients - {update, WorkerName, Me, NerlTensorWeights} -> - StillUpdate = DistributedBehaviorFunc(update, {get(generic_worker_ets), {WorkerName, Me, NerlTensorWeights}}), - if StillUpdate -> - % io:format("worker ~p in update waiting to go to ~p state~n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), NextState]), - {keep_state, State#workerGeneric_state{nextState=NextState}}; - true -> - {next_state, NextState, State#workerGeneric_state{}} - end; - %% got sample from source. discard and add missed count TODO: add to Q - {sample, _Tensor} -> - %%ets:update_counter(get(generic_worker_ets), missedBatches, 1), - {keep_state, State} - end. - %% State train train(cast, {sample, BatchID ,{<<>>, _Type}}, State) -> - ?LOG_WARNING("Empty sample received , batch id: ~p",[BatchID]), + ?LOG_WARNING("Empty sample received , batch id: ~p~n",[BatchID]), WorkerStatsEts = get(worker_stats_ets), stats:increment_by_value(WorkerStatsEts , empty_batches , 1), {next_state, train, State#workerGeneric_state{nextState = train , currentBatchID = BatchID}}; %% Change SampleListTrain to NerlTensor -train(cast, {sample, SourceName ,BatchID ,{NerlTensorOfSamples, NerlTensorType}}, State = #workerGeneric_state{modelID = ModelId, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData}) -> +train(cast, {sample, SourceName ,BatchID ,{NerlTensorOfSamples, NerlTensorType}}, State = #workerGeneric_state{modelID = ModelId, distributedBehaviorFunc = DistributedBehaviorFunc, distributedWorkerData = DistributedWorkerData, myName = _MyName}) -> % NerlTensor = nerltensor_conversion({NerlTensorOfSamples, Type}, erl_float), MyPid = self(), - NewWorkerData = DistributedBehaviorFunc(pre_train, {get(generic_worker_ets),DistributedWorkerData}), + DistributedBehaviorFunc(pre_train, {get(generic_worker_ets),DistributedWorkerData}), % Here the model can be updated by the federated server WorkersStatsEts = get(worker_stats_ets), stats:increment_by_value(WorkersStatsEts , batches_received_train , 1), _Pid = spawn(fun()-> nerlNIF:call_to_train(ModelId , {NerlTensorOfSamples, NerlTensorType} ,MyPid , BatchID , SourceName) end), - {next_state, wait, State#workerGeneric_state{nextState = train, distributedWorkerData = NewWorkerData , currentBatchID = BatchID}}; + {next_state, wait, State#workerGeneric_state{nextState = train, currentBatchID = BatchID}}; %% TODO: implement send model and weights by demand (Tensor / XML) train(cast, {set_weights,Ret_weights_list}, State = #workerGeneric_state{modelID = ModelId}) -> %% Set weights - %io:format("####sending new weights to workers####~n"), nerlNIF:call_to_set_weights(ModelId, Ret_weights_list), %% TODO wrong usage %logger:notice("####end set weights train####~n"), {next_state, train, State}; +train(cast, {post_train_update , Weights}, State = #workerGeneric_state{myName = _MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) -> + DistributedBehaviorFunc(post_train, {get(generic_worker_ets), Weights}), + {next_state, train, State}; + +train(cast, {start_stream , StreamName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + stream_handler(start_stream, train, StreamName, DistributedBehaviorFunc), + {next_state, train, State}; + +train(cast, {end_stream , StreamName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + stream_handler(end_stream, train, StreamName, DistributedBehaviorFunc), + {next_state, train, State}; -train(cast, {idle}, State = #workerGeneric_state{myName = MyName}) -> +train(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), + DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), train}), {next_state, idle, State}; -train(cast, Data, State) -> +train(cast, Data, State = #workerGeneric_state{myName = _MyName}) -> % logger:notice("worker ~p in wait cant treat message: ~p\n",[ets:lookup_element(get(generic_worker_ets), worker_name, ?ETS_KEYVAL_VAL_IDX), Data]), - worker_controller_message_queue(Data), + ?LOG_ERROR("Got unknown message in train state: ~p~n",[Data]), + throw("Got unknown message in train state"), {keep_state, State}. %% State predict @@ -309,25 +288,45 @@ predict(cast, {sample , SourceName , BatchID , {PredictBatchTensor, Type}}, Stat DistributedBehaviorFunc(pre_predict, {get(generic_worker_ets),DistributedWorkerData}), WorkersStatsEts = get(worker_stats_ets), stats:increment_by_value(WorkersStatsEts , batches_received_predict , 1), - %% io:format("Pred Tensor: ~p~n",[nerlNIF:nerltensor_conversion({PredictBatchTensor , Type} , nerlNIF:erl_type_conversion(Type))]), _Pid = spawn(fun()-> nerlNIF:call_to_predict(ModelId , {PredictBatchTensor, Type} , CurrPID , BatchID, SourceName) end), {next_state, wait, State#workerGeneric_state{nextState = predict , currentBatchID = BatchID}}; -predict(cast, {idle}, State = #workerGeneric_state{myName = MyName}) -> +predict(cast, {start_stream , SourceName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + stream_handler(start_stream, predict, SourceName, DistributedBehaviorFunc), + {next_state, predict, State}; + +predict(cast, {end_stream , SourceName}, State = #workerGeneric_state{myName = _MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> + stream_handler(end_stream, predict, SourceName, DistributedBehaviorFunc), + {next_state, predict, State}; + +predict(cast, {idle}, State = #workerGeneric_state{myName = MyName , distributedBehaviorFunc = DistributedBehaviorFunc}) -> update_client_avilable_worker(MyName), + DistributedBehaviorFunc(pre_idle, {get(generic_worker_ets), predict}), {next_state, idle, State}; predict(cast, Data, State) -> - worker_controller_message_queue(Data), + ?LOG_ERROR("Got unknown message in predict state: ~p~n",[Data]), + throw("Got unknown message in predict state"), {next_state, predict, State}. %% Updates the client that worker is available update_client_avilable_worker(MyName) -> gen_statem:cast(get(client_pid),{stateChange,MyName}). -worker_controller_message_queue(ReceiveData) -> - Queue = ets:lookup_element(get(generic_worker_ets), controller_message_q, ?ETS_KEYVAL_VAL_IDX), - ets:update_element(get(generic_worker_ets), controller_message_q, {?ETS_KEYVAL_VAL_IDX , Queue++[ReceiveData]}). - -worker_controller_empty_message_queue() -> - ets:update_element(get(generic_worker_ets), controller_message_q, {?ETS_KEYVAL_VAL_IDX , []}). +stream_handler(StreamPhase , ModelPhase , StreamName , DistributedBehaviorFunc) -> + GenWorkerEts = get(generic_worker_ets), + MyName = ets:lookup_element(GenWorkerEts, worker_name, ?ETS_KEYVAL_VAL_IDX), + ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + NewActiveStreams = + case StreamPhase of + start_stream -> ActiveStreams ++ [StreamName]; + end_stream -> ActiveStreams -- [StreamName] + end, + ets:update_element(GenWorkerEts, active_streams, {?ETS_KEYVAL_VAL_IDX, NewActiveStreams}), + DistributedBehaviorFunc(StreamPhase, {GenWorkerEts, [StreamName , ModelPhase]}), + UpdatedActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX), + case UpdatedActiveStreams of % Send to client an update after done with training phase + [] -> ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX), + gen_statem:cast(ClientPid, {worker_done, {MyName, StreamName}}); + _ -> ok + end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl index 3e3e8be4..69018d6e 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/workerNN.erl @@ -11,6 +11,8 @@ controller(FuncName, {GenWorkerEts, WorkerData}) -> post_train -> post_train({GenWorkerEts, WorkerData}); pre_predict -> pre_predict({GenWorkerEts, WorkerData}); post_predict -> post_predict({GenWorkerEts, WorkerData}); + start_stream -> start_stream({GenWorkerEts, WorkerData}); + end_stream -> end_stream({GenWorkerEts, WorkerData}); update -> update({GenWorkerEts, WorkerData}) end. @@ -30,5 +32,9 @@ post_predict({_GenWorkerEts, _WorkerData}) -> ok. update({_GenWorkerEts, _WorkerData}) -> ok. +start_stream({_GenWorkerEts, _WorkerData}) -> ok. + +end_stream({_GenWorkerEts, _WorkerData}) -> ok. + diff --git a/src_erl/NerlnetApp/src/Client/clientStateHandler.erl b/src_erl/NerlnetApp/src/Client/clientStateHandler.erl index e1765a4d..5d7a933b 100644 --- a/src_erl/NerlnetApp/src/Client/clientStateHandler.erl +++ b/src_erl/NerlnetApp/src/Client/clientStateHandler.erl @@ -17,20 +17,16 @@ init(Req0, [Action,Client_StateM_Pid]) -> {ok,Body,_} = cowboy_req:read_body(Req0), -%% io:format("client state_handler got body:~p~n",[Body]), case Action of - custom_worker_message -> - case binary_to_term(Body) of - {To, custom_worker_message, Data} -> %% handshake - gen_statem:cast(Client_StateM_Pid,{custom_worker_message,Data}); - {From, update, Data} -> %% updating weights - gen_statem:cast(Client_StateM_Pid,{update,Data}) - end; + worker_to_worker_msg -> {worker_to_worker_msg , From , To , Data} = binary_to_term(Body), + gen_statem:cast(Client_StateM_Pid,{worker_to_worker_msg , From , To , Data}); batch -> gen_statem:cast(Client_StateM_Pid,{sample,Body}); idle -> gen_statem:cast(Client_StateM_Pid,{idle}); training -> gen_statem:cast(Client_StateM_Pid,{training}); predict -> gen_statem:cast(Client_StateM_Pid,{predict}); - statistics -> gen_statem:cast(Client_StateM_Pid,{statistics}) + statistics -> gen_statem:cast(Client_StateM_Pid,{statistics}); + start_stream -> gen_statem:cast(Client_StateM_Pid,{start_stream, Body}); + end_stream -> gen_statem:cast(Client_StateM_Pid,{end_stream, Body}) end, %% reply ACKnowledge to main server for initiating, later send finished initiating http_request from client_stateM diff --git a/src_erl/NerlnetApp/src/Client/clientStatem.erl b/src_erl/NerlnetApp/src/Client/clientStatem.erl index 3b3b0334..476fdd1e 100644 --- a/src_erl/NerlnetApp/src/Client/clientStatem.erl +++ b/src_erl/NerlnetApp/src/Client/clientStatem.erl @@ -26,6 +26,7 @@ -define(ETS_KV_VAL_IDX, 2). % key value pairs --> value index is 2 -define(WORKER_PID_IDX, 1). +-define(W2W_PID_IDX, 2). -define(SERVER, ?MODULE). %% client ETS table: {WorkerName, WorkerPid, WorkerArgs, TimingTuple} @@ -71,26 +72,32 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh inets:start(), ?LOG_INFO("Client ~p is connected to: ~p~n",[MyName, [digraph:vertex(NerlnetGraph,Vertex) || Vertex <- digraph:out_neighbours(NerlnetGraph,MyName)]]), % nerl_tools:start_connection([digraph:vertex(NerlnetGraph,Vertex) || Vertex <- digraph:out_neighbours(NerlnetGraph,MyName)]), - EtsRef = ets:new(client_data, [set]), %% client_data is responsible for functional attributes + EtsRef = ets:new(client_data, [set, public]), %% client_data is responsible for functional attributes EtsStats = ets:new(ets_stats, [set]), %% ets_stats is responsible for holding all the ets stats (client + workers) ClientStatsEts = stats:generate_stats_ets(), %% client stats ets inside ets_stats ets:insert(EtsStats, {MyName, ClientStatsEts}), put(ets_stats, EtsStats), - ets:insert(EtsRef, {workerToClient, WorkerToClientMap}), - ets:insert(EtsRef, {workersNames, ClientWorkers}), + ets:insert(EtsRef, {workerToClient, WorkerToClientMap}), % All workers in the network (map to their client) + ets:insert(EtsRef, {workersNames, ClientWorkers}), % All THIS Client's workers ets:insert(EtsRef, {nerlnetGraph, NerlnetGraph}), ets:insert(EtsRef, {myName, MyName}), MyWorkersToShaMap = maps:filter(fun(Worker , _SHA) -> lists:member(Worker , ClientWorkers) end , WorkerShaMap), ets:insert(EtsRef, {workers_to_sha_map, MyWorkersToShaMap}), ets:insert(EtsRef, {sha_to_models_map , ShaToModelArgsMap}), + ets:insert(EtsRef, {w2wcom_pids, #{}}), + ets:insert(EtsRef, {all_workers_done, false}), + ets:insert(EtsRef, {num_of_fed_servers, 0}), % Will stay 0 if non-federated {MyRouterHost,MyRouterPort} = nerl_tools:getShortPath(MyName,?MAIN_SERVER_ATOM, NerlnetGraph), ets:insert(EtsRef, {my_router,{MyRouterHost,MyRouterPort}}), - clientWorkersFunctions:create_workers(MyName , EtsRef , ShaToModelArgsMap , EtsStats), %% send pre_idle signal to workers WorkersNames = clientWorkersFunctions:get_workers_names(EtsRef), - [gen_statem:cast(clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), {pre_idle}) || WorkerName <- WorkersNames], - + Pids = [clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName) || WorkerName <- WorkersNames], + [gen_statem:cast(WorkerPid, {pre_idle}) || WorkerPid <- Pids], + NumOfFedServers = ets:lookup_element(EtsRef, num_of_fed_servers, ?DATA_IDX), % When non-federated exp this value is 0 + ets:insert(EtsRef, {num_of_training_workers, length(ClientWorkers) - NumOfFedServers}), % This number will not change + ets:insert(EtsRef, {training_workers, 0}), % will be updated in idle -> training & end_stream + ets:insert(EtsRef, {active_workers_streams, []}), % update dictionary WorkersEts = ets:lookup_element(EtsRef , workers_ets , ?DATA_IDX), put(workers_ets, WorkersEts), @@ -98,6 +105,7 @@ init({MyName,NerlnetGraph, ClientWorkers , WorkerShaMap , WorkerToClientMap , Sh put(client_data, EtsRef), put(ets_stats, EtsStats), put(client_stats_ets , ClientStatsEts), + put(my_pid , self()), {ok, idle, #client_statem_state{myName= MyName, etsRef = EtsRef}}. @@ -113,7 +121,7 @@ format_status(_Opt, [_PDict, _StateName, _State]) -> Status = some_term, Status. %% ==============STATES================= waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state{myName = MyName,waitforWorkers = WaitforWorkers,nextState = NextState, etsRef = _EtsRef}) -> - NewWaitforWorkers = WaitforWorkers--[WorkerName], + NewWaitforWorkers = WaitforWorkers -- [WorkerName], ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), @@ -121,9 +129,17 @@ waitforWorkers(cast, In = {stateChange,WorkerName}, State = #client_statem_state [] -> send_client_is_ready(MyName), % when all workers done their work stats:increment_messages_sent(ClientStatsEts), {next_state, NextState, State#client_statem_state{waitforWorkers = []}}; - _-> {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = NewWaitforWorkers}} + _ -> %io:format("Client ~p is waiting for workers ~p~n",[MyName,NewWaitforWorkers]), + {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = NewWaitforWorkers}} end; +waitforWorkers(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data), + {keep_state, State}; + waitforWorkers(cast, In = {NewState}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -133,6 +149,7 @@ waitforWorkers(cast, In = {NewState}, State = #client_statem_state{myName = _MyN cast_message_to_workers(EtsRef, {NewState}), %% This function increments the number of sent messages in stats ets {next_state, waitforWorkers, State#client_statem_state{nextState = NewState, waitforWorkers = Workers}}; + waitforWorkers(cast, EventContent, State = #client_statem_state{myName = MyName}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -142,48 +159,35 @@ waitforWorkers(cast, EventContent, State = #client_statem_state{myName = MyName} %% initiating workers when they include federated workers. init stage == handshake between federated worker client and server -%% TODO: make custom_worker_message in all states to send messages from workers to entities (not just client) -idle(cast, In = {custom_worker_message, {From, To}}, State = #client_statem_state{etsRef = EtsRef}) -> +idle(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - WorkerOfThisClient = ets:member(EtsRef, To), - if WorkerOfThisClient -> - TargetWorkerPID = ets:lookup_element(EtsRef, To, ?WORKER_PID_IDX), - gen_statem:cast(TargetWorkerPID,{post_idle,From}), - stats:increment_messages_sent(ClientStatsEts); - true -> - %% send to FedServer that worker From is connecting to it - DestClient = maps:get(To, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), - MessageBody = {DestClient, custom_worker_message, {From, To}}, - {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), - nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(custom_worker_message), term_to_binary(MessageBody)), - stats:increment_messages_sent(ClientStatsEts), - stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) - end, + handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data), {keep_state, State}; idle(cast, _In = {statistics}, State = #client_statem_state{ myName = MyName, etsRef = EtsRef}) -> EtsStats = get(ets_stats), ClientStatsEts = get(client_stats_ets), ClientStatsEncStr = stats:encode_ets_to_http_bin_str(ClientStatsEts), - %ClientStatsToSend = atom_to_list(MyName) ++ ?API_SERVER_WITHIN_ENTITY_SEPERATOR ++ ClientStatsEncStr ++ ?API_SERVER_ENTITY_SEPERATOR, stats:increment_messages_received(ClientStatsEts), ListStatsEts = ets:tab2list(EtsStats) -- [{MyName , ClientStatsEts}], WorkersStatsEncStr = create_encoded_stats_str(ListStatsEts), DataToSend = ClientStatsEncStr ++ WorkersStatsEncStr, - %% io:format("DataToSend: ~p~n",[DataToSend]), StatsBody = {MyName , DataToSend}, {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), nerl_tools:http_router_request(RouterHost, RouterPort, [?MAIN_SERVER_ATOM], atom_to_list(statistics), StatsBody), stats:increment_messages_sent(ClientStatsEts), {next_state, idle, State}; +% Main Server triggers this state idle(cast, In = {training}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), MessageToCast = {training}, + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + MessageToCast = {training}, cast_message_to_workers(EtsRef, MessageToCast), + ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, false}), {next_state, waitforWorkers, State#client_statem_state{waitforWorkers = clientWorkersFunctions:get_workers_names(EtsRef), nextState = training}}; idle(cast, In = {predict}, State = #client_statem_state{etsRef = EtsRef}) -> @@ -222,23 +226,11 @@ training(cast, MessageIn = {update, {From, To, Data}}, State = #client_statem_st {keep_state, State}; -%% This is a generic way to move data from worker to worker -%% TODO fix variables names to make it more generic -%% federated server sends AvgWeights to workers -training(cast, InMessage = {custom_worker_message, WorkersList, WeightsTensor}, State = #client_statem_state{etsRef = EtsRef}) -> +training(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), - stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(InMessage)), - Func = fun(WorkerName) -> - DestClient = maps:get(WorkerName, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), - MessageBody = term_to_binary({DestClient, update, {_FedServer = "server", WorkerName, WeightsTensor}}), % TODO - fix client should not be aware of the data of custom worker message - - {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), - nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(custom_worker_message), MessageBody), - stats:increment_messages_sent(ClientStatsEts), - stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) - end, - lists:foreach(Func, WorkersList), % can be optimized with broadcast instead of unicast + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data), {keep_state, State}; % TODO Validate this state - sample and empty list @@ -265,20 +257,66 @@ training(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef} true -> ?LOG_ERROR("Given worker ~p isn't found in client ~p",[WorkerName, ClientName]) end, {next_state, training, State#client_statem_state{etsRef = EtsRef}}; +% This action is used for start_stream triggered from a clients' worker and not source +training(cast, {start_stream , {worker, WorkerName, TargetName}}, State = #client_statem_state{etsRef = EtsRef}) -> + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, TargetName}]}), + {keep_state, State}; + +% This action is used for start_stream triggered from a source per worker +training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, SourceName}]}), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), + gen_statem:cast(WorkerPid, {start_stream, SourceName}), + {keep_state, State}; + +training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), + {keep_state, State}; + +training(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {WorkerName, StreamName} = Data, + ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, StreamName}], + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), + case length(UpdatedListOfActiveWorkerSources) of + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + _ -> ok + end, + {next_state, training, State#client_statem_state{etsRef = EtsRef}}; + +% From MainServer training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), MessageToCast = {idle}, - cast_message_to_workers(EtsRef, MessageToCast), - Workers = clientWorkersFunctions:get_workers_names(EtsRef), - ?LOG_INFO("~p setting workers at idle: ~p~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), - {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; + WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX), + % io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]), + case WorkersDone of + true -> cast_message_to_workers(EtsRef, MessageToCast), + Workers = clientWorkersFunctions:get_workers_names(EtsRef), + ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), + {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; + false -> gen_statem:cast(get(my_pid) , {idle}), % Trigger this action until all workers are done + {keep_state, State} + end; training(cast, _In = {predict}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> ?LOG_ERROR("Wrong request , client ~p can't go from training to predict directly", [MyName]), {next_state, training, State#client_statem_state{etsRef = EtsRef}}; + training(cast, In = {loss, WorkerName ,SourceName ,LossTensor ,TimeNIF ,BatchID ,BatchTS}, State = #client_statem_state{myName = MyName,etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -315,6 +353,61 @@ predict(cast, In = {sample,Body}, State = #client_statem_state{etsRef = EtsRef}) end, {next_state, predict, State#client_statem_state{etsRef = EtsRef}}; +% This action is used for start_stream triggered from a clients' worker and not source +predict(cast, {start_stream , {worker, WorkerName, TargetName}}, State = #client_statem_state{etsRef = EtsRef}) -> + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, TargetName}]}), + {keep_state, State}; + +% This action is used for start_stream triggered from a source per worker +predict(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, ListOfActiveWorkersSources ++ [{WorkerName, SourceName}]}), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), + gen_statem:cast(WorkerPid, {start_stream, SourceName}), + {keep_state, State}; + +predict(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {SourceName, _ClientName, WorkerName} = binary_to_term(Data), + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef , WorkerName), + gen_statem:cast(WorkerPid, {end_stream, SourceName}), + {keep_state, State}; + +predict(cast, _In = {worker_done, Data}, State = #client_statem_state{etsRef = EtsRef}) -> + {WorkerName, StreamName} = Data, + ListOfActiveWorkerSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX), + UpdatedListOfActiveWorkerSources = ListOfActiveWorkerSources -- [{WorkerName, StreamName}], + ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkerSources}), + case length(UpdatedListOfActiveWorkerSources) of + 0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true}); + _ -> ok + end, + {keep_state, State}; + +% From MainServer +predict(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef = EtsRef}) -> + ClientStatsEts = get(client_stats_ets), + stats:increment_messages_received(ClientStatsEts), + stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), + MessageToCast = {idle}, + WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX), + % io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]), + case WorkersDone of + true -> cast_message_to_workers(EtsRef, MessageToCast), + Workers = clientWorkersFunctions:get_workers_names(EtsRef), + ?LOG_INFO("~p sent idle to workers: ~p , waiting for confirmation...~n",[MyName, ets:lookup_element(EtsRef, workersNames, ?DATA_IDX)]), + {next_state, waitforWorkers, State#client_statem_state{etsRef = EtsRef, waitforWorkers = Workers , nextState = idle}}; + false -> gen_statem:cast(get(my_pid) , {idle}), % Trigger this action until all workers are done + {keep_state, State} + end; + predict(cast, In = {predictRes,WorkerName, SourceName ,{PredictNerlTensor, NetlTensorType} , TimeTook , BatchID , BatchTS}, State = #client_statem_state{myName = _MyName, etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), @@ -335,17 +428,15 @@ predict(cast,_In = {training}, State = #client_statem_state{myName = MyName}) -> ?LOG_ERROR("client ~p got training request in predict state",[MyName]), {next_state, predict, State#client_statem_state{nextState = predict}}; -%% The source sends message to main server that it has finished -%% The main server updates its' clients to move to state 'idle' -predict(cast, In = {idle}, State = #client_statem_state{etsRef = EtsRef , myName = _MyName}) -> - - MsgToCast = {idle}, +predict(cast, In = {worker_to_worker_msg, FromWorker, ToWorker, Data}, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), stats:increment_messages_received(ClientStatsEts), stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)), - cast_message_to_workers(EtsRef, MsgToCast), - Workers = clientWorkersFunctions:get_workers_names(EtsRef), - {next_state, waitforWorkers, State#client_statem_state{nextState = idle, waitforWorkers = Workers, etsRef = EtsRef}}; + handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data), + {keep_state, State}; + +%% The source sends message to main server that it has finished +%% The main server updates its' clients to move to state 'idle' predict(cast, EventContent, State = #client_statem_state{etsRef = EtsRef}) -> ClientStatsEts = get(client_stats_ets), @@ -390,8 +481,8 @@ send_client_is_ready(MyName) -> cast_message_to_workers(EtsRef, Msg) -> ClientStatsEts = get(client_stats_ets), Workers = ets:lookup_element(EtsRef, workersNames, ?ETS_KV_VAL_IDX), - Func = fun(WorkerKey) -> - WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef, WorkerKey), % WorkerKey is the worker name + Func = fun(WorkerName) -> + WorkerPid = clientWorkersFunctions:get_worker_pid(EtsRef, WorkerName), gen_statem:cast(WorkerPid, Msg), stats:increment_messages_sent(ClientStatsEts) end, @@ -403,4 +494,27 @@ create_encoded_stats_str(ListStatsEts) -> %% |w1&bytes_sent:6.0:float#bad_messages:0:int....| ?API_SERVER_ENTITY_SEPERATOR ++ atom_to_list(WorkerName) ++ ?WORKER_SEPERATOR ++ WorkerEncStatsStr end, - lists:flatten(lists:map(Func , ListStatsEts)). \ No newline at end of file + lists:flatten(lists:map(Func , ListStatsEts)). + +handle_w2w_msg(EtsRef, FromWorker, ToWorker, Data) -> + ClientStatsEts = get(client_stats_ets), + WorkersOfThisClient = ets:lookup_element(EtsRef, workersNames, ?DATA_IDX), + WorkerOfThisClient = lists:member(ToWorker, WorkersOfThisClient), + case WorkerOfThisClient of + true -> + % Extract W2WPID from Ets + W2WPidsMap = ets:lookup_element(EtsRef, w2wcom_pids, ?DATA_IDX), + TargetWorkerW2WPID = maps:get(ToWorker, W2WPidsMap), + {ok, _Reply} = gen_server:call(TargetWorkerW2WPID, {worker_to_worker_msg, FromWorker, ToWorker, Data}), + stats:increment_messages_sent(ClientStatsEts); + _ -> + %% Send to the correct client + DestClient = maps:get(ToWorker, ets:lookup_element(EtsRef, workerToClient, ?ETS_KV_VAL_IDX)), + % ClientName = ets:lookup_element(EtsRef, myName , ?DATA_IDX), + % io:format("Client ~p passing w2w_msg {~p --> ~p} to ~p: Data ~p~n",[ClientName, FromWorker, ToWorker, DestClient,Data]), + MessageBody = {worker_to_worker_msg, FromWorker, ToWorker, Data}, + {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), + nerl_tools:http_router_request(RouterHost, RouterPort, [DestClient], atom_to_list(worker_to_worker_msg), MessageBody), + stats:increment_messages_sent(ClientStatsEts), + stats:increment_bytes_sent(ClientStatsEts , nerl_tools:calculate_size(MessageBody)) + end. \ No newline at end of file diff --git a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl index ffcbf9cf..1333c781 100644 --- a/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl +++ b/src_erl/NerlnetApp/src/Client/clientWorkersFunctions.erl @@ -8,7 +8,7 @@ -export([create_workers/4]). -export([get_worker_pid/2 , get_worker_stats_ets/2 , get_workers_names/1]). -get_distributed_worker_behavior(DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken) -> +get_distributed_worker_behavior(ClientEtsRef, DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken) -> case DistributedSystemType of ?DC_DISTRIBUTED_SYSTEM_TYPE_NONE_IDX_STR -> DistributedBehaviorFunc = fun workerNN:controller/2, @@ -18,8 +18,13 @@ case DistributedSystemType of DistributedWorkerData = {_WorkerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken}; %% Parse args eg. batch_sync_count ?DC_DISTRIBUTED_SYSTEM_TYPE_FEDSERVERAVG_IDX_STR -> + WorkersMap = ets:lookup_element(ClientEtsRef, workerToClient, ?DATA_IDX), + WorkersList = [Worker || {Worker, _Val} <- maps:to_list(WorkersMap)], DistributedBehaviorFunc = fun workerFederatedServer:controller/2, - DistributedWorkerData = {_ServerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken, _WorkersNamesList = []} + NumOfFedServers = ets:lookup_element(ClientEtsRef, num_of_fed_servers, ?DATA_IDX), + UpdatedNumOfFedServers = NumOfFedServers + 1, + ets:update_element(ClientEtsRef, num_of_fed_servers, {?DATA_IDX, UpdatedNumOfFedServers}), + DistributedWorkerData = {_ServerName = WorkerName , _Args = DistributedSystemArgs, _Token = DistributedSystemToken , _WorkersList = WorkersList} end, {DistributedBehaviorFunc , DistributedWorkerData}. @@ -43,13 +48,18 @@ create_workers(ClientName, ClientEtsRef , ShaToModelArgsMap , EtsStats) -> MyClientPid = self(), % TODO add documentation about this case of % move this case to module called client_controller - {DistributedBehaviorFunc , DistributedWorkerData} = get_distributed_worker_behavior(DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken), + {DistributedBehaviorFunc , DistributedWorkerData} = get_distributed_worker_behavior(ClientEtsRef, DistributedSystemType , WorkerName , DistributedSystemArgs , DistributedSystemToken), + W2wComPid = w2wCom:start_link({WorkerName, MyClientPid}), % TODO Switch to monitor instead of link WorkerArgs = {ModelID , ModelType , ModelArgs , LayersSizes, LayersTypes, LayersFunctions, LearningRate , Epochs, Optimizer, OptimizerArgs , LossMethod , DistributedSystemType , DistributedSystemArgs}, - WorkerPid = workerGeneric:start_link({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , MyClientPid , WorkerStatsETS}), + WorkerPid = workerGeneric:start_link({WorkerName , WorkerArgs , DistributedBehaviorFunc , DistributedWorkerData , MyClientPid , WorkerStatsETS , W2wComPid}), + gen_server:cast(W2wComPid, {update_gen_worker_pid, WorkerPid}), ets:insert(WorkersETS, {WorkerName, {WorkerPid, WorkerArgs}}), ets:insert(EtsStats, {WorkerName, WorkerStatsETS}), + W2WPidMap = ets:lookup_element(ClientEtsRef, w2wcom_pids, ?DATA_IDX), + W2WPidMapNew = maps:put(WorkerName, W2wComPid, W2WPidMap), + ets:update_element(ClientEtsRef, w2wcom_pids, {?DATA_IDX, W2WPidMapNew}), WorkerName end, diff --git a/src_erl/NerlnetApp/src/Source/sourceStatem.erl b/src_erl/NerlnetApp/src/Source/sourceStatem.erl index ecc44850..99b492a0 100644 --- a/src_erl/NerlnetApp/src/Source/sourceStatem.erl +++ b/src_erl/NerlnetApp/src/Source/sourceStatem.erl @@ -124,7 +124,7 @@ idle(cast, {batchList, WorkersList, NumOfBatches, NerlTensorType, Data}, State) ?LOG_NOTICE("Source ~p, workers are: ~p", [MyName, WorkersList]), ?LOG_NOTICE("Source ~p, sample size: ~p", [MyName, SampleSize]), ets:update_element(EtsRef, sample_size, [{?DATA_IDX, SampleSize}]), - ?LOG_INFO("Source ~p updated transmission list, total avilable batches to send: ~p~n",[MyName, NumOfBatches]), + ?LOG_INFO("Source ~p updated transmission list, total available batches to send: ~p~n",[MyName, NumOfBatches]), %% send an ACK to mainserver that the CSV file is ready {RouterHost,RouterPort} = ets:lookup_element(EtsRef, my_router, ?DATA_IDX), nerl_tools:http_router_request(RouterHost, RouterPort, [?MAIN_SERVER_ATOM], atom_to_list(dataReady), MyName), @@ -365,12 +365,25 @@ transmitter(TimeInterval_ms, SourceEtsRef, SourcePid ,ClientWorkerPairs, Batches ets:insert(TransmitterEts, {batches_skipped, 0}), ets:insert(TransmitterEts, {current_batch_id, 0}), TransmissionStart = erlang:timestamp(), + % Message to all workrers : "start_stream" , TRANSFER TO FUNCTIONS + {RouterHost, RouterPort} = ets:lookup_element(TransmitterEts, my_router, ?DATA_IDX), + FuncStart = fun({ClientName, WorkerNameStr}) -> + ToSend = {MyName, ClientName, list_to_atom(WorkerNameStr)}, + nerl_tools:http_router_request(RouterHost, RouterPort, [ClientName], atom_to_list(start_stream), ToSend) + end, + lists:foreach(FuncStart, ClientWorkerPairs), case Method of ?SOURCE_POLICY_CASTING_ATOM -> send_method_casting(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend); ?SOURCE_POLICY_ROUNDROBIN_ATOM -> send_method_round_robin(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend); ?SOURCE_POLICY_RANDOM_ATOM -> send_method_random(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend); _Default -> send_method_casting(TransmitterEts, TimeInterval_ms, ClientWorkerPairs, BatchesListToSend) end, + % Message to workers : "end_stream" + FuncEnd = fun({ClientName, WorkerNameStr}) -> + ToSend = {MyName, ClientName, list_to_atom(WorkerNameStr)}, + nerl_tools:http_router_request(RouterHost, RouterPort, [ClientName], atom_to_list(end_stream), ToSend) + end, + lists:foreach(FuncEnd, ClientWorkerPairs), TransmissionTimeTook_sec = timer:now_diff(erlang:timestamp(), TransmissionStart) / 1000000, ErrorBatches = ets:lookup_element(TransmitterEts, batches_issue, ?DATA_IDX), SkippedBatches = ets:lookup_element(TransmitterEts, batches_skipped, ?DATA_IDX), diff --git a/src_erl/NerlnetApp/src/nerlnetApp_app.erl b/src_erl/NerlnetApp/src/nerlnetApp_app.erl index 4c9836ff..2f228d78 100644 --- a/src_erl/NerlnetApp/src/nerlnetApp_app.erl +++ b/src_erl/NerlnetApp/src/nerlnetApp_app.erl @@ -20,7 +20,7 @@ -behaviour(application). -include("nerl_tools.hrl"). --define(NERLNET_APP_VERSION, "1.4.3"). +-define(NERLNET_APP_VERSION, "1.5.0"). -define(NERLPLANNER_TESTED_VERSION,"1.0.2"). -export([start/2, stop/1]). @@ -245,7 +245,10 @@ createClientsAndWorkers() -> {"/clientTraining",clientStateHandler, [training,ClientStatemPid]}, {"/clientIdle",clientStateHandler, [idle,ClientStatemPid]}, {"/clientPredict",clientStateHandler, [predict,ClientStatemPid]}, - {"/batch",clientStateHandler, [batch,ClientStatemPid]} + {"/batch",clientStateHandler, [batch,ClientStatemPid]}, + {"/worker_to_worker_msg",clientStateHandler, [worker_to_worker_msg,ClientStatemPid]}, + {"/start_stream", clientStateHandler, [start_stream, ClientStatemPid]}, + {"/end_stream", clientStateHandler, [end_stream, ClientStatemPid]} ]} ]), init_cowboy_start_clear(Client, {DeviceName, Port},NerlClientDispatch) diff --git a/src_py/apiServer/apiServer.py b/src_py/apiServer/apiServer.py index 89267411..ca83815f 100644 --- a/src_py/apiServer/apiServer.py +++ b/src_py/apiServer/apiServer.py @@ -210,11 +210,11 @@ def list_datasets(self): repo_csv_files = [file for file in files if file.endswith('.csv')] datasets[repo["id"]] = repo_csv_files for i , (repo_name , files) in enumerate(datasets.items()): - LOG_INFO(f'{i}. {repo_name}: {files}') + print(f'{i}. {repo_name}: {files}') except utils._errors.RepositoryNotFoundError: LOG_INFO(f"Failed to find the repository '{repo}'. Check your '{HF_DATA_REPO_PATHS_JSON}' file or network access.") - def download_dataset(self, repo_idx : int | list[int], download_dir_path : str = DEFAULT_NERLNET_TMP_DATA_DIR): + def download_dataset(self, repo_idx : int, download_dir_path : str = DEFAULT_NERLNET_TMP_DATA_DIR): with open(HF_DATA_REPO_PATHS_JSON) as file: repo_ids = json.load(file) try: diff --git a/src_py/nerlPlanner/Definitions.py b/src_py/nerlPlanner/Definitions.py index 11cc21e0..d70512d3 100644 --- a/src_py/nerlPlanner/Definitions.py +++ b/src_py/nerlPlanner/Definitions.py @@ -2,7 +2,7 @@ from logger import * VERSION = "1.0.2" -NERLNET_VERSION_TESTED_WITH = "1.4.2" +NERLNET_VERSION_TESTED_WITH = "1.5.0" NERLNET_TMP_PATH = "/tmp/nerlnet" NERLNET_GRAPHVIZ_OUTPUT_DIR = f"{NERLNET_TMP_PATH}/nerlplanner" NERLNET_GLOBAL_PATH = "/usr/local/lib/nerlnet-lib/NErlNet"