Skip to content

Commit

Permalink
Build light weight PyRuntime without llvm or onnx-mlir (onnx#3044)
Browse files Browse the repository at this point in the history
* pass test

Signed-off-by: Chen Tong <[email protected]>

* package

Signed-off-by: Chen Tong <[email protected]>

* clean makefile

Signed-off-by: Chen Tong <[email protected]>

* document

Signed-off-by: Chen Tong <[email protected]>

* fix MLIR.cmake

Signed-off-by: Chen Tong <[email protected]>

* fix script

Signed-off-by: Chen Tong <[email protected]>

* fix

Signed-off-by: Chen Tong <[email protected]>

* add comments

Signed-off-by: Chen Tong <[email protected]>

* LIGHT

Signed-off-by: Chen Tong <[email protected]>

---------

Signed-off-by: Chen Tong <[email protected]>
  • Loading branch information
chentong319 authored and christopherlmunoz committed Jan 30, 2025
1 parent 9d8898b commit e24fd5c
Show file tree
Hide file tree
Showing 23 changed files with 786 additions and 61 deletions.
57 changes: 35 additions & 22 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ option(ONNX_MLIR_ENABLE_STABLEHLO "Enable StableHLO support." ON)
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)
option(ONNX_MLIR_ENABLE_JAVA "Set to ON for building the Java runtime, tools, and tests" ON)
option(ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT "Set to ON for building Python driver of running the compiled model without llvm-project." OFF)

set(CMAKE_CXX_STANDARD 17)

Expand Down Expand Up @@ -73,8 +74,10 @@ set(ONNX_MLIR_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(ONNX_MLIR_VENDOR ${PACKAGE_VENDOR} CACHE STRING
"Vendor-specific text for showing with version information.")

include(CTest)
include(ExternalProject)
if(NOT ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
include(CTest)
include(ExternalProject)
endif()
include(MLIR.cmake)

# MLIR.cmake calls find_package(MLIR) which sets LLVM_MINIMUM_PYTHON_VERSION
Expand Down Expand Up @@ -159,23 +162,29 @@ endif()
set(CMAKE_MESSAGE_LOG_LEVEL NOTICE)

# Add third party subdirectories and define options appropriate to run their cmakes.
set(pybind11_FIND_QUIETLY ON)
add_subdirectory(third_party/onnx)
add_subdirectory(third_party/pybind11)
add_subdirectory(third_party/rapidcheck)
if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_subdirectory(third_party/onnx)
add_subdirectory(third_party/pybind11)
else()
set(pybind11_FIND_QUIETLY ON)
add_subdirectory(third_party/onnx)
add_subdirectory(third_party/pybind11)

if (ONNX_MLIR_ENABLE_STABLEHLO)
add_subdirectory(third_party/stablehlo EXCLUDE_FROM_ALL)
endif()
add_subdirectory(third_party/rapidcheck)

if (NOT TARGET benchmark)
set(BENCHMARK_USE_BUNDLED_GTEST OFF)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
set(BENCHMARK_ENABLE_TESTING OFF)
set(BENCHMARK_ENABLE_WERROR OFF)
# Since LLVM requires C++11 (or higher) it is safe to assume that std::regex is available.
set(HAVE_STD_REGEX ON CACHE BOOL "OK" FORCE)
add_subdirectory(third_party/benchmark)
if (ONNX_MLIR_ENABLE_STABLEHLO)
add_subdirectory(third_party/stablehlo EXCLUDE_FROM_ALL)
endif()

if (NOT TARGET benchmark)
set(BENCHMARK_USE_BUNDLED_GTEST OFF)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
set(BENCHMARK_ENABLE_TESTING OFF)
set(BENCHMARK_ENABLE_WERROR OFF)
# Since LLVM requires C++11 (or higher) it is safe to assume that std::regex is available.
set(HAVE_STD_REGEX ON CACHE BOOL "OK" FORCE)
add_subdirectory(third_party/benchmark)
endif()
endif()

# All libraries and executables coming from llvm or ONNX-MLIR have had their
Expand Down Expand Up @@ -207,8 +216,12 @@ if (ONNX_MLIR_ENABLE_STABLEHLO)
add_compile_definitions(ONNX_MLIR_ENABLE_STABLEHLO)
endif()

add_subdirectory(utils)
add_subdirectory(include)
add_subdirectory(src)
add_subdirectory(docs)
add_subdirectory(test)
if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_subdirectory(src)
else()
add_subdirectory(utils)
add_subdirectory(include)
add_subdirectory(src)
add_subdirectory(docs)
add_subdirectory(test)
endif()
54 changes: 32 additions & 22 deletions MLIR.cmake
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
# SPDX-License-Identifier: Apache-2.0

# Must unset LLVM_DIR in cache. Otherwise, when MLIR_DIR changes LLVM_DIR
# won't change accordingly.
unset(LLVM_DIR CACHE)
if (NOT DEFINED MLIR_DIR)
message(FATAL_ERROR "MLIR_DIR is not configured but it is required. "
"Set the cmake option MLIR_DIR, e.g.,\n"
" cmake -DMLIR_DIR=/path/to/llvm-project/build/lib/cmake/mlir ..\n"
)
endif()
if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
# This function is defined in llvm_project.
# Define a dummy function for PYRUNTIME_LIGHT.
# If needed, the definition from llvm_project can be copied.
function(llvm_update_compile_flags name)
endfunction()
else()
# Must unset LLVM_DIR in cache. Otherwise, when MLIR_DIR changes LLVM_DIR
# won't change accordingly.
unset(LLVM_DIR CACHE)
if (NOT DEFINED MLIR_DIR)
message(FATAL_ERROR "MLIR_DIR is not configured but it is required. "
"Set the cmake option MLIR_DIR, e.g.,\n"
" cmake -DMLIR_DIR=/path/to/llvm-project/build/lib/cmake/mlir ..\n"
)
endif()

find_package(MLIR REQUIRED CONFIG)
find_package(MLIR REQUIRED CONFIG)

message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")

list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")

include(TableGen)
include(AddLLVM)
include(AddMLIR)
include(TableGen)
include(AddLLVM)
include(AddMLIR)

include(HandleLLVMOptions)
include(HandleLLVMOptions)

include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})

add_definitions(${LLVM_DEFINITIONS})
add_definitions(${LLVM_DEFINITIONS})
endif()

set(BUILD_SHARED_LIBS ${LLVM_ENABLE_SHARED_LIBS} CACHE BOOL "" FORCE)
message(STATUS "BUILD_SHARED_LIBS : " ${BUILD_SHARED_LIBS})
Expand Down Expand Up @@ -158,7 +166,9 @@ function(add_onnx_mlir_library name)
)

if (NOT ARG_EXCLUDE_FROM_OM_LIBS)
set_property(GLOBAL APPEND PROPERTY ONNX_MLIR_LIBS ${name})
if (NOT ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
set_property(GLOBAL APPEND PROPERTY ONNX_MLIR_LIBS ${name})
endif()
endif()

add_library(${name} ${ARG_UNPARSED_ARGUMENTS})
Expand Down
39 changes: 39 additions & 0 deletions docs/BuildPyRuntimeLit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# How to build and use PyRuntime lit

## Purpsoe

PyRuntime lit is a different way to build the original PyRuntime (src/Runtime/python).
All necessary dependence, such as llvm_project and onnx-mlir compiler is removed. The purpose is to easily build the python driver for the model execution on
different systems. Currently, only the OMTenserUtils (src/Runtime), Python driver (src/Runtime/python), third_party/onnx and third_party/pybind11 are built.

The build of PyRuntime lit is controlled by a CMake option: ONNX_MLIR_ENABLE_PYRUNTIME_LIT. Without this option to cmake, the whole system remains the same.

## Functionalities
1. Build the python driver without llvm_project and onnx-mlir compiler built.
2. The python driver can be used with utils/RunONNXModel.py, or onnxmlir python package.
3. With PyRuntime lit, the compiler has not been built locally and docker image of onnx-mlir has to be usd to compile the model. The onnxmlir package contains
the python code to use python docker package to perform the compilation. Alternatively, the old script, onnx-mlir/docker/onnx-mlir.py, can do the fulfill the same task with subprocess and docker CLI.

## How to use
You can find the script for build and run at "onnx-mlir/utils/build-pyruntime-lit.sh.
```
#!/bin/bash
# Assume you are in an empty directory for build in cloned onnx-mlir.
# Usually it is "your_path/onnx-mlir/build"
# then you can run this script as "../util/build-pyruntime-lit.sh"
cmake .. -DONNX_MLIR_ENABLE_PYRUNTIME_LIT=ON
make
make OMCreatePyRuntimePackage
# Install the package
pip3 install -e src/Runtime/python/onnxmlir
# -e is necessary for current package. Need to add resource description
# to install the pre-compiled binary
# Run test case
cd src/Runtime/python/onnxmlir/tests
python3 test_1.py
# Current limitation on where the model is
```
10 changes: 10 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_subdirectory(Runtime)
# Accelerators introduces a target AcceleratorsInc. Define a dummy one here
add_custom_target(AcceleratorsInc
COMMAND echo "This is the dummy definition for AcceleratorsInc"
)
add_compile_definitions(ENABLE_PYRUNTIME_LIGHT)
return()
endif()

add_subdirectory(Accelerators)
add_subdirectory(Interface)
add_subdirectory(Dialect)
Expand Down
14 changes: 14 additions & 0 deletions src/Runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

# SPDX-License-Identifier: Apache-2.0

if (NOT ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_subdirectory(jni)
add_subdirectory(omp)
endif()
add_subdirectory(python)

# TODO: should add for each accelerator its subdirectory that implements InitAccel##name
Expand Down Expand Up @@ -65,6 +67,17 @@ set_target_properties(OMTensorUtils
POSITION_INDEPENDENT_CODE TRUE
)

if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_compile_definitions(ENABLE_PYRUNTIME_LIGHT)
add_onnx_mlir_library(OMExecutionSession
ExecutionSession.cpp

EXCLUDE_FROM_OM_LIBS

LINK_LIBS PUBLIC
OMTensorUtils
)
else()
add_onnx_mlir_library(OMExecutionSession
ExecutionSession.cpp

Expand All @@ -74,6 +87,7 @@ add_onnx_mlir_library(OMExecutionSession
OMTensorUtils
LLVMSupport
)
endif()
set_target_properties(OMExecutionSession
PROPERTIES
POSITION_INDEPENDENT_CODE TRUE
Expand Down
46 changes: 46 additions & 0 deletions src/Runtime/ExecutionSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
#include <sstream>
#include <vector>

#ifndef ENABLE_PYRUNTIME_LIGHT
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Path.h"
#else
#include <dlfcn.h>
#endif

#include "ExecutionSession.hpp"
#include "OMTensorListHelper.hpp"
Expand All @@ -44,16 +48,24 @@ void ExecutionSession::Init(

// If there is no tag, use the model filename without extension as a tag.
if (tag == "") {
// ToFix: equivalent implementation of llvm utilities.
// The would not be an urgent issue, because tag is usually "NONE"
#ifndef ENABLE_PYRUNTIME_LIGHT
std::string fname = llvm::sys::path::filename(sharedLibPath).str();
llvm::SmallString<256> fnameWithoutExt(fname);
llvm::sys::path::replace_extension(fnameWithoutExt, "");
tag = fnameWithoutExt.str().lower();
#endif
}

// tag = "NONE" to use functions without tag.
std::string lowDashTag;
// ToFix: equivalent implementation of llv::StringRef
#ifndef ENABLE_PYRUNTIME_LIGHT
// Assume tag is always NONE
if (!llvm::StringRef(tag).equals_insensitive("NONE"))
lowDashTag = "_" + tag;
#endif

#if defined(_WIN32)
// Use functions without tags on Windows since we cannot define at compile
Expand All @@ -63,31 +75,55 @@ void ExecutionSession::Init(
#endif

// Init symbols used by execution session.
#ifndef ENABLE_PYRUNTIME_LIGHT
_sharedLibraryHandle =
llvm::sys::DynamicLibrary::getLibrary(sharedLibPath.c_str());
if (!_sharedLibraryHandle.isValid())
throw std::runtime_error(reportLibraryOpeningError(sharedLibPath));
#else
// Copy code from llvm/lib/Support/DynamicLibrary.cpp, especially the flags
// ToFix: copy the lock related code too.
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY | RTLD_GLOBAL);
if (!_sharedLibraryHandle)
throw std::runtime_error(reportLibraryOpeningError(sharedLibPath));
#endif

std::string queryEntryPointsNameWithTag = _queryEntryPointsName + lowDashTag;
#ifndef ENABLE_PYRUNTIME_LIGHT
_queryEntryPointsFunc = reinterpret_cast<queryEntryPointsFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(
queryEntryPointsNameWithTag.c_str()));
#else
_queryEntryPointsFunc = reinterpret_cast<queryEntryPointsFuncType>(
dlsym(_sharedLibraryHandle, queryEntryPointsNameWithTag.c_str()));
#endif

if (!_queryEntryPointsFunc)
throw std::runtime_error(
reportSymbolLoadingError(queryEntryPointsNameWithTag));

std::string inputSignatureNameWithTag = _inputSignatureName + lowDashTag;
#ifndef ENABLE_PYRUNTIME_LIGHT
_inputSignatureFunc = reinterpret_cast<signatureFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(
inputSignatureNameWithTag.c_str()));
#else
_inputSignatureFunc = reinterpret_cast<signatureFuncType>(
dlsym(_sharedLibraryHandle, inputSignatureNameWithTag.c_str()));
#endif
if (!_inputSignatureFunc)
throw std::runtime_error(
reportSymbolLoadingError(inputSignatureNameWithTag));

std::string outputSignatureNameWithTag = _outputSignatureName + lowDashTag;
#ifndef ENABLE_PYRUNTIME_LIGHT
_outputSignatureFunc = reinterpret_cast<signatureFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(
outputSignatureNameWithTag.c_str()));
#else
_outputSignatureFunc = reinterpret_cast<signatureFuncType>(
dlsym(_sharedLibraryHandle, outputSignatureNameWithTag.c_str()));
#endif
if (!_outputSignatureFunc)
throw std::runtime_error(
reportSymbolLoadingError(outputSignatureNameWithTag));
Expand All @@ -114,8 +150,13 @@ void ExecutionSession::Init(
}

ExecutionSession::~ExecutionSession() {
#ifndef ENABLE_PYRUNTIME_LIGHT
if (_sharedLibraryHandle.isValid())
llvm::sys::DynamicLibrary::closeLibrary(_sharedLibraryHandle);
#else
if (!_sharedLibraryHandle)
dlclose(_sharedLibraryHandle);
#endif
}

// =============================================================================
Expand All @@ -132,8 +173,13 @@ const std::string *ExecutionSession::queryEntryPoints(
void ExecutionSession::setEntryPoint(const std::string &entryPointName) {
if (!isInitialized)
throw std::runtime_error(reportInitError());
#ifndef ENABLE_PYRUNTIME_LIGHT
_entryPointFunc = reinterpret_cast<entryPointFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(entryPointName.c_str()));
#else
_entryPointFunc = reinterpret_cast<entryPointFuncType>(
dlsym(_sharedLibraryHandle, entryPointName.c_str()));
#endif
if (!_entryPointFunc)
throw std::runtime_error(reportSymbolLoadingError(entryPointName));
_entryPointName = entryPointName;
Expand Down
Loading

0 comments on commit e24fd5c

Please sign in to comment.