diff --git a/NerlnetBuild.sh b/NerlnetBuild.sh index d732e0a4..047ea709 100755 --- a/NerlnetBuild.sh +++ b/NerlnetBuild.sh @@ -155,10 +155,12 @@ if command -v python3 >/dev/null 2>&1; then set -e AUTOGENERATED_WORKER_DEFINITIONS_PATH="`pwd`/src_cpp/opennnBridge/worker_definitions_ag.h" AUTOGENERATED_WORKER_DEFINITIONS_PATH_HRL="`pwd`/src_erl/NerlnetApp/src/worker_definitions_ag.hrl" + AUTOGENERATED_DC_DEFINITIONS_PATH_HRL="`pwd`/src_erl/NerlnetApp/src/dc_definitions_ag.hrl" echo "$NERLNET_BUILD_PREFIX Generate auto-generated files" python3 src_py/nerlPlanner/CppHeadersExporter.py --output $AUTOGENERATED_WORKER_DEFINITIONS_PATH #--debug - python3 src_py/nerlPlanner/ErlHeadersExporter.py --output $AUTOGENERATED_WORKER_DEFINITIONS_PATH_HRL #--debug + python3 src_py/nerlPlanner/ErlHeadersExporter.py --gen_worker_fields_hrl --output $AUTOGENERATED_WORKER_DEFINITIONS_PATH_HRL #--debug + python3 src_py/nerlPlanner/ErlHeadersExporter.py --gen_dc_fields_hrl --output $AUTOGENERATED_DC_DEFINITIONS_PATH_HRL #--debug set +e else echo "$NERLNET_BUILD_PREFIX Python 3 is not installed" diff --git a/src_erl/NerlnetApp/src/dc_definitions_ag.hrl b/src_erl/NerlnetApp/src/dc_definitions_ag.hrl new file mode 100644 index 00000000..3970c047 --- /dev/null +++ b/src_erl/NerlnetApp/src/dc_definitions_ag.hrl @@ -0,0 +1,44 @@ +% This is an auto generated .hrl file +% Generated by Nerlplanner version: 1.0.0 + +-define(DC_KEY_NERLNET_SETTINGS_ATOM,nerlnetSettings). +-define(DC_KEY_FREQUENCY_ATOM,frequency). +-define(DC_KEY_BATCH_SIZE_ATOM,batchSize). +-define(DC_KEY_DEVICES_ATOM,devices). +-define(DC_KEY_CLIENTS_ATOM,clients). +-define(DC_KEY_WORKERS_ATOM,workers). +-define(DC_KEY_MODEL_SHA_ATOM,model_sha). +-define(DC_KEY_SOURCES_ATOM,sources). +-define(DC_KEY_ROUTERS_ATOM,routers). +-define(DC_NAME_FIELD_ATOM,name). +-define(DC_WORKER_MODEL_SHA_FIELD_ATOM,model_sha). +-define(DC_IPV4_FIELD_ATOM,ipv4). +-define(DC_PORT_FIELD_ATOM,port). +-define(DC_ARGS_FIELD_ATOM,args). +-define(DC_ENTITIES_FIELD_ATOM,entities). +-define(DC_POLICY_FIELD_ATOM,policy). +-define(DC_EPOCHS_FIELD_ATOM,epochs). +-define(DC_TYPE_FIELD_ATOM,type). +-define(DC_FREQUENCY_FIELD_ATOM,frequency). +-define(DC_WORKERS_FIELD_ATOM,workers). + +-define(DC_KEY_NERLNET_SETTINGS_STR,"nerlnetSettings"). +-define(DC_KEY_FREQUENCY_STR,"frequency"). +-define(DC_KEY_BATCH_SIZE_STR,"batchSize"). +-define(DC_KEY_DEVICES_STR,"devices"). +-define(DC_KEY_CLIENTS_STR,"clients"). +-define(DC_KEY_WORKERS_STR,"workers"). +-define(DC_KEY_MODEL_SHA_STR,"model_sha"). +-define(DC_KEY_SOURCES_STR,"sources"). +-define(DC_KEY_ROUTERS_STR,"routers"). +-define(DC_NAME_FIELD_STR,"name"). +-define(DC_WORKER_MODEL_SHA_FIELD_STR,"model_sha"). +-define(DC_IPV4_FIELD_STR,"ipv4"). +-define(DC_PORT_FIELD_STR,"port"). +-define(DC_ARGS_FIELD_STR,"args"). +-define(DC_ENTITIES_FIELD_STR,"entities"). +-define(DC_POLICY_FIELD_STR,"policy"). +-define(DC_EPOCHS_FIELD_STR,"epochs"). +-define(DC_TYPE_FIELD_STR,"type"). +-define(DC_FREQUENCY_FIELD_STR,"frequency"). +-define(DC_WORKERS_FIELD_STR,"workers"). diff --git a/src_erl/NerlnetApp/src/worker_definitions_ag.hrl b/src_erl/NerlnetApp/src/worker_definitions_ag.hrl index 6b7b4724..f6663497 100644 --- a/src_erl/NerlnetApp/src/worker_definitions_ag.hrl +++ b/src_erl/NerlnetApp/src/worker_definitions_ag.hrl @@ -1,11 +1,11 @@ % This is an auto generated .hrl file % Generated by Nerlplanner version: 1.0.0 --define(KEY_MODEL_TYPE,modelType). --define(KEY_LAYER_SIZES_LIST,layersSizes). --define(KEY_LAYER_TYPES_LIST,layerTypesList). --define(KEY_LAYERS_FUNCTIONS,layers_functions). --define(KEY_LOSS_METHOD,lossMethod). --define(KEY_LEARNING_RATE,lr). --define(KEY_EPOCHS,epochs). --define(KEY_OPTIMIZER_TYPE,optimizer). +-define(WORKER_KEY_MODEL_TYPE,modelType). +-define(WORKER_KEY_LAYER_SIZES_LIST,layersSizes). +-define(WORKER_KEY_LAYER_TYPES_LIST,layerTypesList). +-define(WORKER_KEY_LAYERS_FUNCTIONS,layers_functions). +-define(WORKER_KEY_LOSS_METHOD,lossMethod). +-define(WORKER_KEY_LEARNING_RATE,lr). +-define(WORKER_KEY_EPOCHS,epochs). +-define(WORKER_KEY_OPTIMIZER_TYPE,optimizer). diff --git a/src_py/nerlPlanner/ErlHeadersExporter.py b/src_py/nerlPlanner/ErlHeadersExporter.py index ba697cee..ea7087bc 100644 --- a/src_py/nerlPlanner/ErlHeadersExporter.py +++ b/src_py/nerlPlanner/ErlHeadersExporter.py @@ -1,6 +1,7 @@ import argparse import os from ErlHeadersExporterDefs import * +from JsonDistributedConfigDefs import * from JsonElementWorkerDefinitions import * from Definitions import VERSION as NERLPLANNER_VERSION @@ -11,6 +12,10 @@ def gen_erlang_exporter_logger(message : str): if DEBUG: print(f'[NERLPLANNER][AUTO_HEADER_GENERATOR][DEBUG] {message}') +def path_validator(path : str): + if os.path.dirname(path): + os.makedirs(os.path.dirname(path), exist_ok=True) + def gen_worker_fields_hrl(header_path : str, debug : bool = False): global DEBUG DEBUG = debug @@ -29,13 +34,13 @@ def gen_worker_fields_hrl(header_path : str, debug : bool = False): 'KEY_LAYER_TYPES_LIST', 'KEY_LAYERS_FUNCTIONS', 'KEY_LOSS_METHOD', 'KEY_LEARNING_RATE', 'KEY_EPOCHS', 'KEY_OPTIMIZER_TYPE'] + fields_list_strs = [f'WORKER_{x}' for x in fields_list_strs] fields_list_defs = [ Definition(fields_list_strs[idx], f'{Definition.assert_not_atom(fields_list_vals[idx])}') for idx in range(len(fields_list_vals))] [gen_erlang_exporter_logger(x.generate_code()) for x in fields_list_defs] - if os.path.dirname(header_path): - os.makedirs(os.path.dirname(header_path), exist_ok=True) + path_validator(header_path) with open(header_path, 'w') as f: f.write(auto_generated_header.generate_code()) @@ -53,16 +58,50 @@ def gen_dc_fields_hrl(header_path : str, debug : bool = False): nerlplanner_version = Comment(f'Generated by Nerlplanner version: {NERLPLANNER_VERSION}') gen_erlang_exporter_logger(nerlplanner_version.generate_code()) - #TODO + fields_list_vals_atoms = [KEY_NERLNET_SETTINGS, KEY_FREQUENCY, KEY_BATCH_SIZE, + KEY_DEVICES, KEY_CLIENTS, KEY_WORKERS, KEY_MODEL_SHA, + KEY_SOURCES, KEY_ROUTERS, NAME_FIELD, WORKER_MODEL_SHA_FIELD, + IPV4_FIELD, PORT_FIELD, ARGS_FIELD, ENTITIES_FIELD, + POLICY_FIELD, EPOCHS_FIELD, TYPE_FIELD, FREQUENCY_FIELD, + WORKERS_FIELD] + fields_list_vals_strs = [f'"{x}"' for x in fields_list_vals_atoms] + fields_list_strs = ['KEY_NERLNET_SETTINGS', 'KEY_FREQUENCY', 'KEY_BATCH_SIZE', + 'KEY_DEVICES', 'KEY_CLIENTS', 'KEY_WORKERS', 'KEY_MODEL_SHA', + 'KEY_SOURCES', 'KEY_ROUTERS', 'NAME_FIELD', 'WORKER_MODEL_SHA_FIELD', + 'IPV4_FIELD', 'PORT_FIELD', 'ARGS_FIELD', 'ENTITIES_FIELD', + 'POLICY_FIELD', 'EPOCHS_FIELD', 'TYPE_FIELD', 'FREQUENCY_FIELD', + 'WORKERS_FIELD'] + fields_list_strs_atom = [f'DC_{x}_ATOM' for x in fields_list_strs] + fields_list_strs_string = [f'DC_{x}_STR' for x in fields_list_strs] + + fields_list_defs_atoms = [ Definition(fields_list_strs_atom[idx], f'{fields_list_vals_atoms[idx]}') for idx in range(len(fields_list_strs))] + [gen_erlang_exporter_logger(x.generate_code()) for x in fields_list_defs_atoms] + + fields_list_defs_strings = [ Definition(fields_list_strs_string[idx], f'{fields_list_vals_strs[idx]}') for idx in range(len(fields_list_strs))] + [gen_erlang_exporter_logger(x.generate_code()) for x in fields_list_defs_strings] + + path_validator(header_path) + with open(header_path, 'w') as f: + f.write(auto_generated_header.generate_code()) + f.write(nerlplanner_version.generate_code()) + f.write(EMPTY_LINE) + [f.write(x.generate_code()) for x in fields_list_defs_atoms] + f.write(EMPTY_LINE) + [f.write(x.generate_code()) for x in fields_list_defs_strings] def main(): parser = argparse.ArgumentParser(description='Generate C++ header file for nerlPlanner') parser.add_argument('-o', '--output', help='output header file path', required=True) parser.add_argument('-d', '--debug', help='debug mode', action='store_true') + parser.add_argument('--gen_worker_fields_hrl', help='debug mode', action='store_true') + parser.add_argument('--gen_dc_fields_hrl', help='debug mode', action='store_true') + args = parser.parse_args() - gen_worker_fields_hrl(args.output, args.debug) - gen_dc_fields_hrl(args.output, args.debug) + if args.gen_worker_fields_hrl: + gen_worker_fields_hrl(args.output, args.debug) + if args.gen_dc_fields_hrl: + gen_dc_fields_hrl(args.output, args.debug) if __name__=="__main__": main() diff --git a/src_py/nerlPlanner/JsonDistributedConfigDefs.py b/src_py/nerlPlanner/JsonDistributedConfigDefs.py index f80e66d3..0a9fde15 100644 --- a/src_py/nerlPlanner/JsonDistributedConfigDefs.py +++ b/src_py/nerlPlanner/JsonDistributedConfigDefs.py @@ -1,13 +1,26 @@ +# Any change of this file influences autogenerated Erlang files +# Please change version of Nerlnet Planner if this file is changed -KEY_NERLNET_SETTINGS = "NerlNetSettings" +# The following definitions are also treated as atoms in Erlang +# Definition must start with lower case letter +KEY_NERLNET_SETTINGS = "nerlnetSettings" KEY_FREQUENCY = "frequency" KEY_BATCH_SIZE = "batchSize" KEY_DEVICES = "devices" KEY_CLIENTS = "clients" KEY_WORKERS = "workers" -KEY_MODEL_SHA = "model-sha" +KEY_MODEL_SHA = "model_sha" KEY_SOURCES = "sources" KEY_ROUTERS = "routers" NAME_FIELD = "name" -WORKER_MODEL_SHA_FIELD = "model-sha" \ No newline at end of file +WORKER_MODEL_SHA_FIELD = "model_sha" +IPV4_FIELD = "ipv4" +PORT_FIELD = "port" +ARGS_FIELD = "args" +ENTITIES_FIELD = "entities" +POLICY_FIELD = "policy" +EPOCHS_FIELD = "epochs" +TYPE_FIELD = "type" +FREQUENCY_FIELD = "frequency" +WORKERS_FIELD = "workers" diff --git a/tests/inputJsonsFiles/dc_test_synt_1d_2c_1s_4r_4w.json b/tests/inputJsonsFiles/dc_test_synt_1d_2c_1s_4r_4w.json index 1a75cb00..9ae6b45c 100644 --- a/tests/inputJsonsFiles/dc_test_synt_1d_2c_1s_4r_4w.json +++ b/tests/inputJsonsFiles/dc_test_synt_1d_2c_1s_4r_4w.json @@ -1,5 +1,5 @@ { - "NerlNetSettings": { + "nerlnetSettings": { "frequency": "60", "batchSize": "50" }, @@ -65,22 +65,22 @@ "workers": [ { "name": "w1", - "model-sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c" + "model_sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c" }, { "name": "w2", - "model-sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c" + "model_sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c" }, { "name": "w3", - "model-sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c" + "model_sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c" }, { "name": "w4", - "model-sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c" + "model_sha": "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c" } ], - "model-sha": { + "model_sha": { "5396cc8dbd1407c408021a16bb4e014e780de95cdf62680c19cac81139ad791c": { "modelType": "5", "_doc_modelType": " approximation:1 | classification:2 | forecasting:3 | encoder_decoder:4 | nn:5 | autoencoder:6 | ae-classifier:7 | fed-client:8 | fed-server:9 |",