-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
180 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#include "NerlWorkerTorch.h" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#pragma once | ||
|
||
#include <cassert> | ||
#include <Logger.h> | ||
|
||
#include "../opennn/opennn/opennn.h" | ||
#include "../common/nerlWorker.h" | ||
#include "worker_definitions_ag.h" | ||
|
||
|
||
namespace nerlnet | ||
{ | ||
|
||
class NerlWorkerTorch : public NerlWorker | ||
{ | ||
|
||
}; | ||
|
||
} // namespace nerlnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#pragma once | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#include <nerltensor.h> | ||
#include <torch/torch.h> | ||
|
||
namespace nerlnet | ||
{ | ||
|
||
using TorchTensor = torch::Tensor; | ||
|
||
enum {DIMS_CASE_1D,DIMS_CASE_2D,DIMS_CASE_3D}; | ||
enum {DIMS_X_IDX,DIMS_Y_IDX,DIMS_Z_IDX,DIMS_TOTAL}; | ||
|
||
} // namespace nerlnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#pragma once | ||
|
||
#include "nerltensorTorchDefs.h" | ||
#include "nifpp.h" | ||
|
||
namespace nifpp | ||
{ | ||
using namespace nerlnet; | ||
|
||
struct nerltensor_dims | ||
{ | ||
int dimx; | ||
int dimy; | ||
int dimz; | ||
int total_size; | ||
int dims_case; | ||
}; | ||
|
||
// Declarations | ||
template<typename BasicType> int get_nerltensor_dims(ErlNifEnv *env , ERL_NIF_TERM bin_term, nerltensor_dims &dims_info); | ||
template<typename BasicType> int get_nerltensor(ErlNifEnv *env , ERL_NIF_TERM bin_term, TorchTensor &tensor, torch::ScalarType torch_dtype); | ||
|
||
// Definitions | ||
template<typename BasicType> int get_nerltensor_dims(ErlNifEnv *env , ERL_NIF_TERM bin_term, nerltensor_dims &dims_info) | ||
{ | ||
ErlNifBinary bin; | ||
int ret = enif_inspect_binary(env, bin_term, &bin); | ||
assert(ret != 0); | ||
|
||
std::vector<BasicType> dims; | ||
// extract dims and data size | ||
dims.resize(DIMS_TOTAL); | ||
memcpy(dims.data(), bin.data, DIMS_TOTAL * sizeof(BasicType)); | ||
|
||
dims_info.total_size = 1; | ||
for (int i=0; i < DIMS_TOTAL; i++) | ||
{ | ||
dims_info.total_size *= dims[i]; | ||
if (dims[i] > 1) | ||
{ | ||
dims_info.dims_case = i; | ||
} | ||
} | ||
assert(("Negative Or zero value of dimension", dims_info.total_size > 0)); | ||
|
||
|
||
dims_info.dimx = static_cast<int>(dims[DIMS_X_IDX]); | ||
dims_info.dimy = static_cast<int>(dims[DIMS_Y_IDX]); | ||
dims_info.dimz = static_cast<int>(dims[DIMS_Z_IDX]); | ||
} | ||
|
||
|
||
template<typename BasicType> int get_nerltensor(ErlNifEnv *env , ERL_NIF_TERM bin_term, TorchTensor &tensor, torch::ScalarType torch_dtype) | ||
{ | ||
ErlNifBinary bin; | ||
int ret = enif_inspect_binary(env, bin_term, &bin); | ||
assert(ret != 0); | ||
|
||
// extract dims and data size | ||
nerltensor_dims dims_info; | ||
get_nerltensor_dims<BasicType>(env, bin_term, dims_info); | ||
|
||
switch (dims_info.dims_case) | ||
{ | ||
case DIMS_CASE_1D: | ||
{ | ||
tensor = torch::zeros(dims_info.dimx, torch_dtype); | ||
break; | ||
} | ||
case DIMS_CASE_2D: | ||
{ | ||
tensor = torch::zeros({dims_info.dimx, dims_info.dimy}, torch_dtype); | ||
break; | ||
} | ||
case DIMS_CASE_3D: | ||
{ | ||
tensor = torch::zeros({dims_info.dimx, dims_info.dimy, dims_info.dimz}, torch_dtype); | ||
break; | ||
} | ||
} | ||
|
||
assert((sizeof(BasicType) == tensor.element_size(), "Size of BasicType and torch tensor element size mismatch")); | ||
|
||
// copy data from nerltensor to torch tensor | ||
int skip_dims_bytes = (DIMS_TOTAL * sizeof(BasicType)); | ||
std::memcpy(tensor.data_ptr(),bin.data + skip_dims_bytes, sizeof(BasicType)*tensor.numel()); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#include "torchNIF.h" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
#include <torch/torch.h> | ||
#pragma once | ||
|
||
#include "nifppNerlTensorTorch.h" | ||
#include "nifpp.h" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
|
||
-define(NERLNET_LIB_PATH,"/usr/local/lib/nerlnet-lib/NErlNet"). | ||
-include("/usr/local/lib/nerlnet-lib/NErlNet/src_erl/NerlnetApp/src/Bridge/nerlTensor.hrl"). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
-module(torchNIF). | ||
|
||
-include_lib("kernel/include/logger.hrl"). | ||
-include("torchDefs.hrl"). | ||
|
||
-author("David Leon"). | ||
|
||
|
||
|
||
% -export([init/0,nif_preload/0,get_active_models_ids_list/0, train_nif/3,update_nerlworker_train_params_nif/6,call_to_train/5,predict_nif/3,call_to_predict/5,get_weights_nif/1,printTensor/2]). | ||
% -export([call_to_get_weights/1,call_to_set_weights/2]). | ||
% -export([decode_nif/2, nerltensor_binary_decode/2]). | ||
% -export([encode_nif/2, nerltensor_encode/5, nerltensor_conversion/2, get_all_binary_types/0, get_all_nerltensor_list_types/0]). | ||
% -export([erl_type_conversion/1]). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
-module(torchTests). | ||
-author("David Leon"). | ||
|
||
-include_lib("kernel/include/logger.hrl"). | ||
-include("torchTestsDefs.hrl"). | ||
|
||
-define(NERLTEST_PRINT_STR, "[NERLTEST] "). | ||
|
||
-export([run_tests/0]). | ||
|
||
nerltest_print(String) -> | ||
logger:notice(?NERLTEST_PRINT_STR++String). | ||
|
||
test_envelope(Func, TestName, Rounds) -> | ||
nerltest_print(nerl:string_format("~p test starts for ~p rounds",[TestName, Rounds])), | ||
{TimeTookMicro, _RetVal} = timer:tc(Func, [Rounds]), | ||
nerltest_print(nerl:string_format("Elapsed: ~p~p",[TimeTookMicro / 1000, ms])), ok. | ||
|
||
test_envelope_nif_performance(Func, TestName, Rounds) -> | ||
nerltest_print(nerl:string_format("~p test starts for ~p rounds",[TestName, Rounds])), | ||
{TimeTookMicro, AccPerfromance} = timer:tc(Func, [Rounds]), | ||
AveragedPerformance = AccPerfromance/Rounds, | ||
nerltest_print(nerl:string_format("Elapsed: ~p~p Average nif performance: ~.3f~p",[TimeTookMicro/1000,ms, AveragedPerformance, ms])), ok. | ||
|
||
run_tests()-> | ||
nerl:logger_settings(nerlTests). |
2 changes: 2 additions & 0 deletions
2
src_erl/NerlnetApp/src/Bridge/torchWorkers/torchTestsDefs.hrl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
-include("/usr/local/lib/nerlnet-lib/NErlNet/src_erl/NerlnetApp/src/Bridge/neural_networks_testing_models.hrl"). |