Skip to content

Commit

Permalink
Removing torch includes from pytorch runner/serdes headers
Browse files Browse the repository at this point in the history
  • Loading branch information
RajivChitale committed Feb 11, 2025
1 parent e6ae5a1 commit ba3563c
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 116 deletions.
10 changes: 3 additions & 7 deletions MLModelRunner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ if(NOT PROTOS_DIRECTORY STREQUAL "")
add_subdirectory(gRPCModelRunner)
endif()
add_subdirectory(ONNXModelRunner)
add_subdirectory(PTModelRunner)
add_subdirectory(C)

# # For up-to-date instructions for installing the TFLite dependency, refer to
Expand Down Expand Up @@ -44,13 +45,8 @@ else()
add_library(ModelRunnerLib OBJECT MLModelRunner.cpp PipeModelRunner.cpp)
endif(LLVM_MLBRIDGE)

target_link_libraries(ModelRunnerLib PUBLIC ModelRunnerUtils ONNXModelRunnerLib)
target_link_libraries(ModelRunnerLib PUBLIC ModelRunnerUtils ONNXModelRunnerLib PTModelRunnerLib)

if(NOT PROTOS_DIRECTORY STREQUAL "")
target_link_libraries(ModelRunnerLib PUBLIC gRPCModelRunnerLib)
endif()
set_property(TARGET ModelRunnerLib PROPERTY POSITION_INDEPENDENT_CODE 1)

find_package(Torch REQUIRED)
target_link_libraries(ModelRunnerLib PRIVATE ${TORCH_LIBRARIES})
target_compile_options (ModelRunnerLib PRIVATE -fexceptions)
endif()
13 changes: 13 additions & 0 deletions MLModelRunner/PTModelRunner/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
if(LLVM_MLBRIDGE)
add_llvm_library(PTModelRunnerLib PTModelRunner.cpp)
else()
add_library(PTModelRunnerLib OBJECT PTModelRunner.cpp)
endif(LLVM_MLBRIDGE)

find_package(Torch REQUIRED)
target_link_libraries(PTModelRunnerLib PRIVATE ${TORCH_LIBRARIES})
target_compile_options (PTModelRunnerLib PRIVATE -fexceptions)

# message("HERE FIND TORCH ${TORCH_INCLUDE_DIRS}")
# find_package(Torch REQUIRED)
target_include_directories(PTModelRunnerLib PRIVATE ${TORCH_INCLUDE_DIRS})
70 changes: 70 additions & 0 deletions MLModelRunner/PTModelRunner/PTModelRunner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//=== PTModelRunner.cpp - PTModelRunner Implementation ---*- C++ -*-===//
//
// Part of the MLCompilerBridge Project
//
//===------------------===//

#include "MLModelRunner/PTModelRunner.h"

#include "MLModelRunner/MLModelRunner.h"
#include "SerDes/TensorSpec.h"
// #include "SerDes/baseSerDes.h"
#include "SerDes/pytorchSerDes.h"
#include "llvm/Support/ErrorHandling.h"
#include <torch/torch.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h> // or model_container_runner_cuda.h for CUDA

#include <memory>
#include <vector>

using TensorVec = std::vector<torch::Tensor>;

namespace MLBridge
{

PTModelRunner::PTModelRunner(const std::string &modelPath, llvm::LLVMContext &Ctx)
: MLModelRunner(MLModelRunner::Kind::PTAOT, BaseSerDes::Kind::Pytorch, &Ctx)
{
this->SerDes = new PytorchSerDes();

c10::InferenceMode mode;
this->CompiledModel = new torch::inductor::AOTIModelContainerRunnerCpu(modelPath);
}



void *PTModelRunner::evaluateUntyped()
{

if ((*reinterpret_cast<TensorVec*>(this->SerDes->getRequest())).empty())
{
llvm::errs() << "Input vector is empty.\n";
return nullptr;
}

try
{

std::vector<torch::Tensor> *outputTensors = reinterpret_cast<std::vector<torch::Tensor>*>(this->SerDes->getResponse());
auto outputs = reinterpret_cast<torch::inductor::AOTIModelContainerRunnerCpu*>(this->CompiledModel)->run((*reinterpret_cast<TensorVec*>(this->SerDes->getRequest())));
for (auto i = outputs.begin(); i != outputs.end(); ++i)
(*(outputTensors)).push_back(*i);
void *rawData = this->SerDes->deserializeUntyped(outputTensors);
return rawData;
}
catch (const c10::Error &e)
{
llvm::errs() << "Error during model evaluation: " << e.what() << "\n";
return nullptr;
}
}

template <typename U, typename T, typename... Types>
void PTModelRunner::populateFeatures(const std::pair<U, T> &var1,
const std::pair<U, Types> &...var2)
{
SerDes->setFeature(var1.first, var1.second);
PTModelRunner::populateFeatures(var2...);
}

} // namespace MLBridge
11 changes: 5 additions & 6 deletions SerDes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
find_package(Torch REQUIRED)

add_subdirectory(pytorchSerDes)

set(protobuf_MODULE_COMPATIBLE TRUE)
find_package(Protobuf CONFIG REQUIRED)

Expand All @@ -9,18 +11,15 @@ TensorSpec.cpp
jsonSerDes.cpp
bitstreamSerDes.cpp
protobufSerDes.cpp
pytorchSerDes.cpp
tensorflowSerDes.cpp
JSON.cpp
)

else()
add_library(SerDesLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp protobufSerDes.cpp tensorflowSerDes.cpp JSON.cpp pytorchSerDes.cpp)

# add_library(SerDesCLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp JSON.cpp)
add_library(SerDesLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp protobufSerDes.cpp tensorflowSerDes.cpp JSON.cpp )
endif()

target_link_libraries(SerDesLib PRIVATE ${TORCH_LIBRARIES})
target_compile_options(SerDesLib PRIVATE -fexceptions)
target_link_libraries(SerDesLib PUBLIC PyTorchSerDesLib)

target_include_directories(SerDesLib PUBLIC ${TENSORFLOW_AOT_PATH}/include)
target_link_libraries(SerDesLib PRIVATE tf_xla_runtime)
19 changes: 19 additions & 0 deletions SerDes/pytorchSerDes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
find_package(Torch REQUIRED)

# set(protobuf_MODULE_COMPATIBLE TRUE)
# find_package(Protobuf CONFIG REQUIRED)

if(LLVM_MLBRIDGE)
add_llvm_library(PyTorchSerDesLib
pytorchSerDes.cpp
)

else()
add_library(PyTorchSerDesLib OBJECT pytorchSerDes.cpp)

# add_library(SerDesCLib OBJECT TensorSpec.cpp jsonSerDes.cpp bitstreamSerDes.cpp JSON.cpp)
endif()

target_link_libraries(PyTorchSerDesLib PRIVATE ${TORCH_LIBRARIES})
target_compile_options(PyTorchSerDesLib PRIVATE -fexceptions)
target_include_directories(PyTorchSerDesLib PUBLIC ${TORCH_INCLUDE_DIRS})
75 changes: 50 additions & 25 deletions SerDes/pytorchSerDes.cpp → SerDes/pytorchSerDes/pytorchSerDes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,72 @@

#include "SerDes/pytorchSerDes.h"
#include "SerDes/baseSerDes.h"
#include <torch/torch.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>

using TensorVec = std::vector<torch::Tensor>;

namespace MLBridge {

PytorchSerDes::PytorchSerDes() : BaseSerDes(BaseSerDes::Kind::Pytorch) {
// inputTensors = std::make_shared<std::vector<torch::Tensor>>();
// outputTensors = new std::vector<torch::Tensor>();

// RequestVoid = std::make_shared<std::vector<torch::Tensor>>();
RequestVoid = new TensorVec();
ResponseVoid = new TensorVec();
}

void PytorchSerDes::setFeature(const std::string &Name, const int Value) {
auto tensor = torch::tensor({Value}, torch::kInt32);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const long Value) {
auto tensor = torch::tensor({Value}, torch::kInt64);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const float Value) {
auto tensor = torch::tensor({Value}, torch::kFloat32);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const double Value) {
auto tensor = torch::tensor({Value}, torch::kFloat64);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const std::string Value) {
std::vector<int8_t> encoded_str(Value.begin(), Value.end());
auto tensor = torch::tensor(encoded_str, torch::kInt8);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const bool Value) {
auto tensor = torch::tensor({Value}, torch::kBool);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const std::vector<int> &Value) {
auto tensor = torch::tensor(Value, torch::kInt32);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const std::vector<long> &Value) {
auto tensor = torch::tensor(Value, torch::kInt64);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const std::vector<float> &Value) {
auto tensor = torch::tensor(Value, torch::kFloat32);
tensor = tensor.reshape({1, Value.size()});
inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const std::vector<double> &Value) {
auto tensor = torch::tensor(Value, torch::kFloat64);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::string> &Value) {
Expand All @@ -68,21 +81,21 @@ void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::s
flat_vec.push_back('\0'); // Null-terminate each string
}
auto tensor = torch::tensor(flat_vec, torch::kInt8);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

void PytorchSerDes::setFeature(const std::string &Name, const std::vector<bool> &Value) {
std::vector<uint8_t> bool_vec(Value.begin(), Value.end());
auto tensor = torch::tensor(bool_vec, torch::kUInt8);
this->inputTensors->push_back(tensor.clone());
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
}

// void PytorchSerDes::setRequest(void *Request) {
// CompiledModel = reinterpret_cast<torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
// }

void PytorchSerDes::cleanDataStructures() {
this->inputTensors->clear(); // Clear the input vector
reinterpret_cast<TensorVec*>(this->RequestVoid)->clear(); // Clear the input vector
}

void *PytorchSerDes::deserializeUntyped(void *Data) {
Expand All @@ -91,7 +104,7 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
}

// Assume Data is a pointer to a vector of tensors
std::vector<torch::Tensor> *serializedTensors = reinterpret_cast<std::vector<torch::Tensor> *>(Data);
std::vector<torch::Tensor> *serializedTensors = reinterpret_cast<TensorVec *>(Data);

if (serializedTensors->empty()) {
return nullptr;
Expand All @@ -100,22 +113,22 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
auto type_vect = serializedTensors->at(0).dtype();

if (type_vect == torch::kInt32) {
return copyTensorToVect<int32_t>(serializedTensors);
return copyTensorToVect<int32_t>(Data);
}
else if (type_vect == torch::kInt64) {
return copyTensorToVect<int64_t>(serializedTensors);
return copyTensorToVect<int64_t>(Data);
}
else if (type_vect == torch::kFloat32) {
return copyTensorToVect<float>(serializedTensors);
return copyTensorToVect<float>(Data);
}
else if (type_vect == torch::kFloat64) {
return copyTensorToVect<double>(serializedTensors);
return copyTensorToVect<double>(Data);
}
else if (type_vect == torch::kBool) {
return copyTensorToVect<bool>(serializedTensors);
return copyTensorToVect<bool>(Data);
}
else if (type_vect == torch::kInt8) {
return copyTensorToVect<int8_t>(serializedTensors);
return copyTensorToVect<int8_t>(Data);
}
else {
llvm::errs() << "Unsupported tensor dtype.\n";
Expand All @@ -124,13 +137,25 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
}

void *PytorchSerDes::getSerializedData() {
std::vector<torch::Tensor> serializedData = *(this->outputTensors);
return this->ResponseVoid; // TODO - check
// TensorVec serializedData = *reinterpret_cast<TensorVec*>(this->ReponseVoid);

// Allocate memory for the output and copy the serialized data
auto *output = new std::vector<torch::Tensor>(serializedData);
return static_cast<void *>(output);
// // Allocate memory for the output and copy the serialized data
// auto *output = new TensorVec(serializedData);
// return static_cast<void *>(output);
}

template <typename T>
std::vector<T> *PytorchSerDes::copyTensorToVect(void *serializedTensors) {
auto *ret = new std::vector<T>();
for (const auto &tensor : *reinterpret_cast<TensorVec*>(serializedTensors)) {
ret->insert(ret->end(), tensor.data_ptr<T>(), tensor.data_ptr<T>() + tensor.numel());
}
return ret;
}

void *PytorchSerDes::getRequest() { return this->RequestVoid; }
void *PytorchSerDes::getResponse() { return this->ResponseVoid; }


} // namespace MLBridge
} // namespace MLBridge
Loading

0 comments on commit ba3563c

Please sign in to comment.