Skip to content

Commit

Permalink
[TORCH] implement make_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
leondavi committed Jul 5, 2024
1 parent 7fbf81d commit dc27e6c
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src_cpp/torchBridge/nifppNerlTensorTorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace nifpp
// Declarations
template<typename BasicType> int get_nerltensor_dims(ErlNifEnv *env , ERL_NIF_TERM bin_term, nerltensor_dims &dims_info);
template<typename BasicType> int get_nerltensor(ErlNifEnv *env , ERL_NIF_TERM bin_term, TorchTensor &tensor, torch::ScalarType torch_dtype);
template<typename BasicType> void make_tensor(ErlNifEnv *env , nifpp::TERM &ret_bin_term, TorchTensor &tensor);


// Definitions
template<typename BasicType> int get_nerltensor_dims(ErlNifEnv *env , ERL_NIF_TERM bin_term, nerltensor_dims &dims_info)
Expand Down Expand Up @@ -86,4 +88,32 @@ namespace nifpp
std::memcpy(tensor.data_ptr(),bin.data + skip_dims_bytes, sizeof(BasicType)*tensor.numel());
}

template<typename BasicType> void make_tensor(ErlNifEnv *env , nifpp::TERM &ret_bin_term, TorchTensor &tensor)
{
std::vector<BasicType> dims;
dims.resize(DIMS_TOTAL);
for (int dim=0; dim < DIMS_TOTAL; dim++)
{
if (dim < tensor.sizes().Length())
{
dims[dim] = static_cast<BasicType>(tensor.sizes()[dim]);
}
else
{
dims[dim] = 1;
}
}
size_t dims_size = DIMS_TOTAL * sizeof(BasicType);
size_t data_size = tensor.numel() * sizeof(BasicType);

nifpp::binary nifpp_bin(dims_size + data_size);

assert((sizeof(BasicType) == tensor.element_size(), "Size of BasicType and torch tensor element size mismatch"));

std::memcpy(nifpp_bin.data, dims.data(), dims_size);
std::memcpy(nifpp_bin.data + dims_size, tensor.data_ptr(), data_size);

ret_bin_term = nifpp:make(env, nifpp_bin);
}

}

0 comments on commit dc27e6c

Please sign in to comment.