Skip to content

Commit

Permalink
[PyTorch] update CMake to build libtorch lite (pytorch#51419)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#51419

## Summary

1. Add an option `BUILD_LITE_INTERPRETER` in `caffe2/CMakeLists.txt` and set `OFF` as default.
2. Update 'build_android.sh' with an argument to swtich `BUILD_LITE_INTERPRETER`, 'OFF' as default.
3. Add a mini demo app `lite_interpreter_demo` linked with `libtorch` library, which can be used for quick test.

## Test Plan
Built lite interpreter version of libtorch and test with Image Segmentation demo app ([android version](https://github.com/pytorch/android-demo-app/tree/master/ImageSegmentation)/[ios version](https://github.com/pytorch/ios-demo-app/tree/master/ImageSegmentation))

### Android
1. **Prepare model**: Prepare the lite interpreter version of model by run the script below to generate the scripted model `deeplabv3_scripted.pt` and `deeplabv3_scripted.ptl`
```
import torch

model = torch.hub.load('pytorch/vision:v0.7.0', 'deeplabv3_resnet50', pretrained=True)
model.eval()

scripted_module = torch.jit.script(model)
# Export full jit version model (not compatible lite interpreter), leave it here for comparison
scripted_module.save("deeplabv3_scripted.pt")
# Export lite interpreter version model (compatible with lite interpreter)
scripted_module._save_for_lite_interpreter("deeplabv3_scripted.ptl")

```
2. **Build libtorch lite for android**: Build libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64) `BUILD_LITE_INTERPRETER=1 ./scripts/build_pytorch_android.sh`. This pr is tested on Pixel 4 emulator with x86, so use cmd `BUILD_LITE_INTERPRETER=1 ./scripts/build_pytorch_android.sh x86` to specify abi to save built time. After the build finish, it will show the library path:
```
...
BUILD SUCCESSFUL in 55s
134 actionable tasks: 22 executed, 112 up-to-date
+ find /Users/chenlai/pytorch/android -type f -name '*aar'
+ xargs ls -lah
-rw-r--r--  1 chenlai  staff    13M Feb 11 11:48 /Users/chenlai/pytorch/android/pytorch_android/build/outputs/aar/pytorch_android-release.aar
-rw-r--r--  1 chenlai  staff    36K Feb  9 16:45 /Users/chenlai/pytorch/android/pytorch_android_torchvision/build/outputs/aar/pytorch_android_torchvision-release.aar
```
3. **Use the PyTorch Android libraries built from source in the ImageSegmentation app**: Create a folder 'libs' in the path, the path from repository root will be `ImageSegmentation/app/libs`. Copy `pytorch_android-release` to the path `ImageSegmentation/app/libs/pytorch_android-release.aar`. Copy 'pytorch_android_torchvision` (downloaded from [here](https://oss.sonatype.org/#nexus-search;quick~torchvision_android)) to the path `ImageSegmentation/app/libs/pytorch_android_torchvision.aar` Update the `dependencies` part of `ImageSegmentation/app/build.gradle` to
```
dependencies {
    implementation 'androidx.appcompat:appcompat:1.2.0'
    implementation 'androidx.constraintlayout:constraintlayout:2.0.2'
    testImplementation 'junit:junit:4.12'
    androidTestImplementation 'androidx.test.ext:junit:1.1.2'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'

    implementation(name:'pytorch_android-release', ext:'aar')
    implementation(name:'pytorch_android_torchvision', ext:'aar')

    implementation 'com.android.support:appcompat-v7:28.0.0'
    implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3'
}
```
Update `allprojects` part in `ImageSegmentation/build.gradle` to
```

allprojects {
    repositories {
        google()
        jcenter()
        flatDir {
            dirs 'libs'
        }
    }
}
```
4. **Update model loader api**: Update `ImageSegmentation/app/src/main/java/org/pytorch/imagesegmentation/MainActivity.java` by
4.1 Add new import: `import org.pytorch.LiteModuleLoader;`
4.2 Replace the way to load pytorch lite model
```
//            mModule = Module.load(MainActivity.assetFilePath(getApplicationContext(), "deeplabv3_scripted.pt"));
            mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "deeplabv3_scripted.ptl"));
```
5. **Test app**: Build and run the ImageSegmentation app in Android Studio,
![image](https://user-images.githubusercontent.com/16430979/107696279-9cea5900-6c66-11eb-8286-4d1d68abff61.png)

### iOS
1. **Prepare model**: Same as Android.
2. **Build libtorch lite for ios** `BUILD_PYTORCH_MOBILE=1 IOS_PLATFORM=SIMULATOR BUILD_LITE_INTERPRETER=1   ./scripts/build_ios.sh`
3. **Remove Cocoapods from the project**: run `pod deintegrate`
4. **Link ImageSegmentation demo app with the custom built library**:
Open your project in XCode, go to your project Target’s **Build Phases - Link Binaries With Libraries**, click the **+** sign and add all the library files located in `build_ios/install/lib`. Navigate to the project **Build Settings**, set the value **Header Search Paths** to `build_ios/install/include` and **Library Search Paths** to `build_ios/install/lib`.
In the build settings, search for **other linker flags**. Add a custom linker flag below
```
-all_load
```
Finally, disable bitcode for your target by selecting the Build Settings, searching for Enable Bitcode, and set the value to No.
**

5. Update library and api**
5.1 Update `TorchModule.mm``
To use the custom built libraries the project, replace `#import <LibTorch/LibTorch.h>` (in `TorchModule.mm`) which is needed when using LibTorch via Cocoapods with the code below:

```
//#import <LibTorch/LibTorch.h>
#include "ATen/ATen.h"
#include "caffe2/core/timer.h"
#include "caffe2/utils/string_utils.h"
#include "torch/csrc/autograd/grad_mode.h"
#include "torch/script.h"
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/observer.h>
```
5.2 Update `ViewController.swift`
```
//        if let filePath = Bundle.main.path(forResource:
//            "deeplabv3_scripted", ofType: "pt"),
//            let module = TorchModule(fileAtPath: filePath) {
//            return module
//        } else {
//            fatalError("Can't find the model file!")
//        }
        if let filePath = Bundle.main.path(forResource:
            "deeplabv3_scripted", ofType: "ptl"),
            let module = TorchModule(fileAtPath: filePath) {
            return module
        } else {
            fatalError("Can't find the model file!")
        }
```

### Unit test
Add `test/cpp/lite_interpreter`, with one unit test `test_cores.cpp` and a light model `sequence.ptl` to test `_load_for_mobile()`, `bc.find_method()` and `bc.forward()` functions.

### Size:
**With the change:**
Android:
x86: `pytorch_android-release.aar` (**13.8 MB**)

IOS:
`pytorch/build_ios/install/lib` (lib: **66 MB**):
```
(base) chenlai@chenlai-mp lib % ls -lh
total 135016
-rw-r--r--  1 chenlai  staff   3.3M Feb 15 20:45 libXNNPACK.a
-rw-r--r--  1 chenlai  staff   965K Feb 15 20:45 libc10.a
-rw-r--r--  1 chenlai  staff   4.6K Feb 15 20:45 libclog.a
-rw-r--r--  1 chenlai  staff    42K Feb 15 20:45 libcpuinfo.a
-rw-r--r--  1 chenlai  staff    39K Feb 15 20:45 libcpuinfo_internals.a
-rw-r--r--  1 chenlai  staff   1.5M Feb 15 20:45 libeigen_blas.a
-rw-r--r--  1 chenlai  staff   148K Feb 15 20:45 libfmt.a
-rw-r--r--  1 chenlai  staff    44K Feb 15 20:45 libpthreadpool.a
-rw-r--r--  1 chenlai  staff   166K Feb 15 20:45 libpytorch_qnnpack.a
-rw-r--r--  1 chenlai  staff   384B Feb 15 21:19 libtorch.a
-rw-r--r--  1 chenlai  staff    **60M** Feb 15 20:47 libtorch_cpu.a
```
`pytorch/build_ios/install`:
```
(base) chenlai@chenlai-mp install % du -sh *
 14M	include
 66M	lib
2.8M	share
```

**Master (baseline):**
Android:
x86: `pytorch_android-release.aar` (**16.2 MB**)

IOS:
`pytorch/build_ios/install/lib` (lib: **84 MB**):
```
(base) chenlai@chenlai-mp lib % ls -lh
total 172032
-rw-r--r--  1 chenlai  staff   3.3M Feb 17 22:18 libXNNPACK.a
-rw-r--r--  1 chenlai  staff   969K Feb 17 22:18 libc10.a
-rw-r--r--  1 chenlai  staff   4.6K Feb 17 22:18 libclog.a
-rw-r--r--  1 chenlai  staff    42K Feb 17 22:18 libcpuinfo.a
-rw-r--r--  1 chenlai  staff   1.5M Feb 17 22:18 libeigen_blas.a
-rw-r--r--  1 chenlai  staff    44K Feb 17 22:18 libpthreadpool.a
-rw-r--r--  1 chenlai  staff   166K Feb 17 22:18 libpytorch_qnnpack.a
-rw-r--r--  1 chenlai  staff   384B Feb 17 22:19 libtorch.a
-rw-r--r--  1 chenlai  staff    78M Feb 17 22:19 libtorch_cpu.a
```
`pytorch/build_ios/install`:
```
(base) chenlai@chenlai-mp install % du -sh *
 14M	include
 84M	lib
2.8M	share
```

Test Plan: Imported from OSS

Reviewed By: iseeyuan

Differential Revision: D26518778

Pulled By: cccclai

fbshipit-source-id: 4503ffa1f150ecc309ed39fb0549e8bd046a3f9c
  • Loading branch information
cccclai authored and facebook-github-bot committed Feb 21, 2021
1 parent a935118 commit 14f7bf0
Show file tree
Hide file tree
Showing 16 changed files with 220 additions and 28 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ option(BUILD_DOCS "Build Caffe2 documentation" OFF)
option(BUILD_CUSTOM_PROTOBUF "Build and use Caffe2's own protobuf under third_party" ON)
option(BUILD_PYTHON "Build Python binaries" ON)
option(BUILD_CAFFE2 "Master flag to build Caffe2" ON)
option(BUILD_LITE_INTERPRETER "Master flag to build Lite Interpreter" OFF)
cmake_dependent_option(
BUILD_CAFFE2_OPS "Build Caffe2 operators" ON
"BUILD_CAFFE2" OFF)
Expand Down
49 changes: 32 additions & 17 deletions android/pytorch_android/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
cmake_minimum_required(VERSION 3.4.1)
project(pytorch_jni CXX)
option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" OFF)
message(
STATUS
"BUILD_LITE_INTERPRETER (pytorch_jni_lite): ${BUILD_LITE_INTERPRETER}")

if(BUILD_LITE_INTERPRETER)
project(pytorch_jni_lite CXX)
set(PYTORCH_JNI_TARGET pytorch_jni_lite)
else()
project(pytorch_jni CXX)
set(PYTORCH_JNI_TARGET pytorch_jni)
endif()

include(GNUInstallDirs)

Expand Down Expand Up @@ -45,15 +56,21 @@ configure_file(
${pytorch_android_DIR}/cmake_macros.h.in
${pytorch_android_DIR}/cmake_macros.h)

file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_jit.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)

add_library(pytorch_jni SHARED
${pytorch_android_SOURCES}
)
if(BUILD_LITE_INTERPRETER)
file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_lite.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)
else()
file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_jit.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)
endif()
add_library(${PYTORCH_JNI_TARGET} SHARED ${pytorch_android_SOURCES})

if(APPLE)
# Need to add rpath so dlopen can find dependencies.
Expand All @@ -63,13 +80,11 @@ if(APPLE)
$<TARGET_FILE:pytorch_jni>)
endif()

target_compile_options(pytorch_jni PRIVATE
target_compile_options(${PYTORCH_JNI_TARGET} PRIVATE
-fexceptions
)

target_include_directories(pytorch_jni BEFORE
PUBLIC $<BUILD_INTERFACE:${libtorch_include_DIR}>
)
target_include_directories(${PYTORCH_JNI_TARGET} BEFORE
PUBLIC $<BUILD_INTERFACE:${libtorch_include_DIR}>)

set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/)
set(fbjni_BUILD_DIR ${CMAKE_BINARY_DIR}/fbjni/${BUILD_SUBDIR})
Expand Down Expand Up @@ -153,14 +168,14 @@ if(USE_VULKAN)
list(APPEND pytorch_jni_LIBS ${Vulkan_LIBS})
endif()

target_link_libraries(pytorch_jni ${pytorch_jni_LIBS})
target_link_libraries(${PYTORCH_JNI_TARGET} ${pytorch_jni_LIBS})

install(TARGETS pytorch_jni
install(TARGETS ${PYTORCH_JNI_TARGET}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows

if(MSVC)
install(FILES $<TARGET_PDB_FILE:pytorch_jni> DESTINATION ${CMAKE_INSTALL_LIBDIR} OPTIONAL)
install(TARGETS pytorch_jni DESTINATION ${CMAKE_INSTALL_LIBDIR})
install(TARGETS ${PYTORCH_JNI_TARGET} DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif()
9 changes: 9 additions & 0 deletions android/pytorch_android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ android {
}
externalNativeBuild {
cmake {
if(System.env.BUILD_LITE_INTERPRETER == '1') {
arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=ON"
} else {
arguments "-DANDROID_STL=c++_shared"
}
}
}
}
Expand All @@ -33,8 +37,13 @@ android {
sourceSets {
main {
java {
if(System.env.BUILD_LITE_INTERPRETER == '1') {
println 'Build pytorch_jni_lite'
} else {
println 'Build pytorch_jni'
exclude 'org/pytorch/LiteModuleLoader.java'
exclude 'org/pytorch/LiteNativePeer.java'
}
}
jniLibs.srcDirs = ['src/main/jniLibs']
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import com.facebook.jni.HybridData;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;

class LiteNativePeer implements INativePeer {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_jni_lite");
PyTorchCodegenLoader.loadNativeLibs();
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ install(FILES ${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml

if(ATEN_NO_TEST)
message("disable test because ATEN_NO_TEST is set")
elseif(BUILD_LITE_INTERPRETER)
message("disable aten test when BUILD_LITE_INTERPRETER is enabled")
else()
add_subdirectory(test)
endif()
Expand Down
49 changes: 39 additions & 10 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp"
)

if(NOT INTERN_DISABLE_AUTOGRAD)
if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
list(APPEND GENERATED_CXX_TORCH
"${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_0.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_1.cpp"
Expand Down Expand Up @@ -501,7 +501,22 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)

set(TORCH_SRCS ${GENERATED_CXX_TORCH})
list(APPEND TORCH_SRCS ${GENERATED_H_TORCH})
append_filelist("libtorch_cmake_sources" TORCH_SRCS)
list(APPEND LIBTORCH_CMAKE_SRCS "")

# Switch between the full jit interpreter and lite interpreter
if(BUILD_LITE_INTERPRETER)
append_filelist("libtorch_lite_cmake_sources" LIBTORCH_CMAKE_SRCS)
else()
append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS)
endif()
list(APPEND TORCH_SRCS ${LIBTORCH_CMAKE_SRCS})

if(PRINT_CMAKE_DEBUG_INFO)
message(STATUS "Interpreter sources: ")
foreach(tmp ${LIBTORCH_CMAKE_SRCS})
message(STATUS " " ${tmp})
endforeach()
endif()

# Required workaround for LLVM 9 includes.
if(NOT MSVC)
Expand Down Expand Up @@ -535,14 +550,14 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/utils/out_types.cpp
)

if(NOT INTERN_DISABLE_AUTOGRAD)
if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/autograd/TraceTypeManual.cpp
${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp
)
endif()

if(NOT INTERN_BUILD_MOBILE)
if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/onnx.cpp
Expand Down Expand Up @@ -619,7 +634,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()

if(NOT NO_API)
if(NOT NO_API AND NOT BUILD_LITE_INTERPRETER)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp
${TORCH_SRC_DIR}/csrc/api/src/data/datasets/mnist.cpp
Expand Down Expand Up @@ -946,14 +961,24 @@ endif()


if(BUILD_TEST)
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
add_subdirectory(${TORCH_ROOT}/test/cpp/tensorexpr ${CMAKE_BINARY_DIR}/test_tensorexpr)
if(USE_DISTRIBUTED AND NOT WIN32)
add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc)
if(BUILD_LITE_INTERPRETER)
add_subdirectory(
${TORCH_ROOT}/test/cpp/lite_interpreter_runtime
${CMAKE_BINARY_DIR}/test_lite_interpreter_runtime
)
else()
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
add_subdirectory(
${TORCH_ROOT}/test/cpp/tensorexpr
${CMAKE_BINARY_DIR}/test_tensorexpr
)
if(USE_DISTRIBUTED AND NOT WIN32)
add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc)
endif()
endif()
endif()

if(BUILD_TEST AND NOT NO_API)
if(BUILD_TEST AND NOT NO_API AND NOT BUILD_LITE_INTERPRETER)
add_subdirectory(${TORCH_ROOT}/test/cpp/api ${CMAKE_BINARY_DIR}/test_api)
add_subdirectory(${TORCH_ROOT}/test/cpp/dist_autograd ${CMAKE_BINARY_DIR}/dist_autograd)
endif()
Expand Down Expand Up @@ -1069,6 +1094,10 @@ if(USE_ROCM)
)
endif()

if(BUILD_LITE_INTERPRETER)
target_compile_definitions(torch_cpu PRIVATE DBUILD_LITE_INTERPRETER)
endif()

# Pass USE_DISTRIBUTED to torch_cpu, as some codes in jit/pickler.cpp and
# jit/unpickler.cpp need to be compiled only when USE_DISTRIBUTED is set
if(USE_DISTRIBUTED)
Expand Down
1 change: 1 addition & 0 deletions cmake/Summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " BUILD_TEST : ${BUILD_TEST}")
message(STATUS " BUILD_JNI : ${BUILD_JNI}")
message(STATUS " BUILD_MOBILE_AUTOGRAD : ${BUILD_MOBILE_AUTOGRAD}")
message(STATUS " BUILD_LITE_INTERPRETER: ${BUILD_LITE_INTERPRETER}")
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
message(STATUS " CROSS_COMPILING_MACOSX : ${CROSS_COMPILING_MACOSX}")
endif()
Expand Down
8 changes: 8 additions & 0 deletions scripts/build_android.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ fi
# Don't build artifacts we don't need
CMAKE_ARGS+=("-DBUILD_TEST=OFF")
CMAKE_ARGS+=("-DBUILD_BINARY=OFF")

# If there exists env variable and it equals to 1, build lite interpreter.
# cmd: BUILD_LITE_INTERPRETER=1 ./scripts/build_android.sh
if [ "${BUILD_LITE_INTERPRETER}" == 1 ]; then
CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON")
else
CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF")
fi
CMAKE_ARGS+=("-DBUILD_MOBILE_BENCHMARK=$BUILD_MOBILE_BENCHMARK")
CMAKE_ARGS+=("-DBUILD_MOBILE_TEST=$BUILD_MOBILE_TEST")
CMAKE_ARGS+=("-DBUILD_PYTHON=OFF")
Expand Down
6 changes: 6 additions & 0 deletions scripts/build_ios.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ if [ -n "${IOS_ARCH:-}" ]; then
CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}")
fi

if [ "${BUILD_LITE_INTERPRETER}" == 1 ]; then
CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON")
else
CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF")
fi

# Don't build binaries or tests (only the library)
CMAKE_ARGS+=("-DBUILD_TEST=OFF")
CMAKE_ARGS+=("-DBUILD_BINARY=OFF")
Expand Down
25 changes: 25 additions & 0 deletions test/cpp/lite_interpreter_runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
set(
LITE_INTERPRETER_RUNTIME_TEST_DIR
"${TORCH_ROOT}/test/cpp/lite_interpreter_runtime")
set(LITE_INTERPRETER_RUNTIME_TEST_DIR
${TORCH_ROOT}/test/cpp/lite_interpreter_runtime/main.cpp
${TORCH_ROOT}/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp
)

add_executable(
test_lite_interpreter_runtime
${LITE_INTERPRETER_RUNTIME_TEST_DIR})
target_include_directories(
test_lite_interpreter_runtime PRIVATE
${ATen_CPU_INCLUDE})
target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest)

if(INSTALL_TEST)
install(TARGETS test_lite_interpreter_runtime DESTINATION bin)
# Install PDB files for MSVC builds
if(MSVC AND BUILD_SHARED_LIBS)
install(
FILES $<TARGET_PDB_FILE:test_lite_interpreter_runtime>
DESTINATION bin OPTIONAL)
endif()
endif()
Binary file added test/cpp/lite_interpreter_runtime/light_model.ptl
Binary file not shown.
24 changes: 24 additions & 0 deletions test/cpp/lite_interpreter_runtime/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

#include <gtest/gtest.h>
#include <iostream>
#include <string>
#include <torch/csrc/jit/mobile/import.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/autograd/generated/variable_factories.h>

std::string add_negative_flag(const std::string& flag) {
std::string filter = ::testing::GTEST_FLAG(filter);
if (filter.find('-') == std::string::npos) {
filter.push_back('-');
} else {
filter.push_back(':');
}
filter += flag;
return filter;
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA");

return RUN_ALL_TESTS();
}
Binary file added test/cpp/lite_interpreter_runtime/sequence.ptl
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include <gtest/gtest.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/torch.h>

#include <unordered_set>

namespace torch {
namespace jit {
namespace mobile {

TEST(RunTimeTest, LoadAndForward) {
// Load check in model: sequence.ptl
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
testModelFile.append("sequence.ptl");

// sequence.ptl source code:
// class A(torch.nn.Module):
// def __init__(self):
// super(A, self).__init__()
//
// def forward(self, x):
// return x + 1
//
// class B(torch.nn.Module):
// def __init__(self):
// super(B, self).__init__()
//
// def forward(self, x):
// return x + 2
//
// class C(torch.nn.Module):
// def __init__(self):
// super(C, self).__init__()
// self.A0 = A()
// self.B0 = B()
//
// def forward(self, x):
// return self.A0.forward(self.B0.forward(x))

Module bc = _load_for_mobile(testModelFile);
auto forward_method = bc.find_method("forward");
std::vector<c10::IValue> input{c10::IValue(at::tensor(1))};
const auto result = bc.forward(input);
const auto expected_result = c10::IValue(at::tensor(4));
ASSERT_EQ(result, expected_result);
}

} // namespace mobile
} // namespace jit
} // namespace torch
Loading

0 comments on commit 14f7bf0

Please sign in to comment.