Skip to content

Commit

Permalink
[TORCH] Nifpp impelmentation of get
Browse files Browse the repository at this point in the history
  • Loading branch information
leondavi committed Jul 4, 2024
1 parent 57f5753 commit 7fbf81d
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src_cpp/torchBridge/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@ add_definitions( -D LOGGER_ENABLE_COLORS=1 )
add_definitions( -D LOGGER_ENABLE_COLORS_ON_USER_HEADER=0 )

set(SRC_CODE
"nerltensorTorchDefs.h"
"nifppNerlTensorTorch.h"
"torchNIF.h"
"torchNIF.cpp"
"NerlWorkerTorch.h"
"NerlWorkerTorch.cpp"
"NerlWorkerTorchNIF.h"
)

add_library(${PROJECT_NAME} SHARED ${SRC_CODE})
Expand Down
1 change: 1 addition & 0 deletions src_cpp/torchBridge/NerlWorkerTorch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "NerlWorkerTorch.h"
19 changes: 19 additions & 0 deletions src_cpp/torchBridge/NerlWorkerTorch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <cassert>
#include <Logger.h>

#include "../opennn/opennn/opennn.h"
#include "../common/nerlWorker.h"
#include "worker_definitions_ag.h"


namespace nerlnet
{

class NerlWorkerTorch : public NerlWorker
{

};

} // namespace nerlnet
2 changes: 2 additions & 0 deletions src_cpp/torchBridge/NerlWorkerTorchNIF.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#pragma once

12 changes: 12 additions & 0 deletions src_cpp/torchBridge/nerltensorTorchDefs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <nerltensor.h>
#include <torch/torch.h>

namespace nerlnet
{

using TorchTensor = torch::Tensor;

enum {DIMS_CASE_1D,DIMS_CASE_2D,DIMS_CASE_3D};
enum {DIMS_X_IDX,DIMS_Y_IDX,DIMS_Z_IDX,DIMS_TOTAL};

} // namespace nerlnet
89 changes: 89 additions & 0 deletions src_cpp/torchBridge/nifppNerlTensorTorch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#pragma once

#include "nerltensorTorchDefs.h"
#include "nifpp.h"

namespace nifpp
{
using namespace nerlnet;

struct nerltensor_dims
{
int dimx;
int dimy;
int dimz;
int total_size;
int dims_case;
};

// Declarations
template<typename BasicType> int get_nerltensor_dims(ErlNifEnv *env , ERL_NIF_TERM bin_term, nerltensor_dims &dims_info);
template<typename BasicType> int get_nerltensor(ErlNifEnv *env , ERL_NIF_TERM bin_term, TorchTensor &tensor, torch::ScalarType torch_dtype);

// Definitions
template<typename BasicType> int get_nerltensor_dims(ErlNifEnv *env , ERL_NIF_TERM bin_term, nerltensor_dims &dims_info)
{
ErlNifBinary bin;
int ret = enif_inspect_binary(env, bin_term, &bin);
assert(ret != 0);

std::vector<BasicType> dims;
// extract dims and data size
dims.resize(DIMS_TOTAL);
memcpy(dims.data(), bin.data, DIMS_TOTAL * sizeof(BasicType));

dims_info.total_size = 1;
for (int i=0; i < DIMS_TOTAL; i++)
{
dims_info.total_size *= dims[i];
if (dims[i] > 1)
{
dims_info.dims_case = i;
}
}
assert(("Negative Or zero value of dimension", dims_info.total_size > 0));


dims_info.dimx = static_cast<int>(dims[DIMS_X_IDX]);
dims_info.dimy = static_cast<int>(dims[DIMS_Y_IDX]);
dims_info.dimz = static_cast<int>(dims[DIMS_Z_IDX]);
}


template<typename BasicType> int get_nerltensor(ErlNifEnv *env , ERL_NIF_TERM bin_term, TorchTensor &tensor, torch::ScalarType torch_dtype)
{
ErlNifBinary bin;
int ret = enif_inspect_binary(env, bin_term, &bin);
assert(ret != 0);

// extract dims and data size
nerltensor_dims dims_info;
get_nerltensor_dims<BasicType>(env, bin_term, dims_info);

switch (dims_info.dims_case)
{
case DIMS_CASE_1D:
{
tensor = torch::zeros(dims_info.dimx, torch_dtype);
break;
}
case DIMS_CASE_2D:
{
tensor = torch::zeros({dims_info.dimx, dims_info.dimy}, torch_dtype);
break;
}
case DIMS_CASE_3D:
{
tensor = torch::zeros({dims_info.dimx, dims_info.dimy, dims_info.dimz}, torch_dtype);
break;
}
}

assert((sizeof(BasicType) == tensor.element_size(), "Size of BasicType and torch tensor element size mismatch"));

// copy data from nerltensor to torch tensor
int skip_dims_bytes = (DIMS_TOTAL * sizeof(BasicType));
std::memcpy(tensor.data_ptr(),bin.data + skip_dims_bytes, sizeof(BasicType)*tensor.numel());
}

}
1 change: 1 addition & 0 deletions src_cpp/torchBridge/torchNIF.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "torchNIF.h"
6 changes: 5 additions & 1 deletion src_cpp/torchBridge/torchNIF.h
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
#include <torch/torch.h>
#pragma once

#include "nifppNerlTensorTorch.h"
#include "nifpp.h"

3 changes: 3 additions & 0 deletions src_erl/NerlnetApp/src/Bridge/torchWorkers/torchDefs.hrl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

-define(NERLNET_LIB_PATH,"/usr/local/lib/nerlnet-lib/NErlNet").
-include("/usr/local/lib/nerlnet-lib/NErlNet/src_erl/NerlnetApp/src/Bridge/nerlTensor.hrl").
15 changes: 15 additions & 0 deletions src_erl/NerlnetApp/src/Bridge/torchWorkers/torchNIF.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-module(torchNIF).

-include_lib("kernel/include/logger.hrl").
-include("torchDefs.hrl").

-author("David Leon").



% -export([init/0,nif_preload/0,get_active_models_ids_list/0, train_nif/3,update_nerlworker_train_params_nif/6,call_to_train/5,predict_nif/3,call_to_predict/5,get_weights_nif/1,printTensor/2]).
% -export([call_to_get_weights/1,call_to_set_weights/2]).
% -export([decode_nif/2, nerltensor_binary_decode/2]).
% -export([encode_nif/2, nerltensor_encode/5, nerltensor_conversion/2, get_all_binary_types/0, get_all_nerltensor_list_types/0]).
% -export([erl_type_conversion/1]).

26 changes: 26 additions & 0 deletions src_erl/NerlnetApp/src/Bridge/torchWorkers/torchTests.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-module(torchTests).
-author("David Leon").

-include_lib("kernel/include/logger.hrl").
-include("torchTestsDefs.hrl").

-define(NERLTEST_PRINT_STR, "[NERLTEST] ").

-export([run_tests/0]).

nerltest_print(String) ->
logger:notice(?NERLTEST_PRINT_STR++String).

test_envelope(Func, TestName, Rounds) ->
nerltest_print(nerl:string_format("~p test starts for ~p rounds",[TestName, Rounds])),
{TimeTookMicro, _RetVal} = timer:tc(Func, [Rounds]),
nerltest_print(nerl:string_format("Elapsed: ~p~p",[TimeTookMicro / 1000, ms])), ok.

test_envelope_nif_performance(Func, TestName, Rounds) ->
nerltest_print(nerl:string_format("~p test starts for ~p rounds",[TestName, Rounds])),
{TimeTookMicro, AccPerfromance} = timer:tc(Func, [Rounds]),
AveragedPerformance = AccPerfromance/Rounds,
nerltest_print(nerl:string_format("Elapsed: ~p~p Average nif performance: ~.3f~p",[TimeTookMicro/1000,ms, AveragedPerformance, ms])), ok.

run_tests()->
nerl:logger_settings(nerlTests).
2 changes: 2 additions & 0 deletions src_erl/NerlnetApp/src/Bridge/torchWorkers/torchTestsDefs.hrl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

-include("/usr/local/lib/nerlnet-lib/NErlNet/src_erl/NerlnetApp/src/Bridge/neural_networks_testing_models.hrl").

0 comments on commit 7fbf81d

Please sign in to comment.