diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 4ccb6d71591ef..50ec6bf904a21 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -11,7 +11,10 @@ target_link_libraries(torchbind_test torch) add_library(jitbackend_test SHARED ${JIT_TEST_ROOT}/test_backend_lib.cpp) target_link_libraries(jitbackend_test torch) -add_library(backend_with_compiler SHARED ${JIT_TEST_ROOT}/test_backend_compiler_lib.cpp) +add_library(backend_with_compiler SHARED + ${JIT_TEST_ROOT}/test_backend_compiler_lib.cpp + ${JIT_TEST_ROOT}/test_backend_compiler_preprocess.cpp + ) target_link_libraries(backend_with_compiler torch) if(INSTALL_TEST) diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index caf4e3f0494d6..a6007790cdc55 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -106,56 +106,8 @@ class BackendWithCompiler : public PyTorchBackendInterface { }; namespace { -// For this backend, the actual compilation happens in preprocess function AOT. -// Put here for demonstration of backend -// as a whole piece. It's used when compilation is required. A dummy function -// can be passed when there's no usage of compilation in runtime backend lib. -c10::IValue preprocess( - const Module& mod, - const c10::Dict& method_compile_spec) { - // The output of this process would produce a dictionary - // Key: method name. - // Val: compiled blob (represented by a string). - c10::Dict compiled(StringType::get(), StringType::get()); - for (const auto& method : mod.get_methods()) { - const auto graph = method.function().graph()->copy(); - auto key = method.name(); - std::stringstream ss; - for (const auto& node : graph->nodes()) { - switch (node->kind()) { - case prim::Constant: - ss << node->kind().toDisplayString() << "#" - << toIValue(node->output()).value(); - break; - case aten::add: - ss << node->kind().toQualString(); - break; - case aten::sub: - ss << node->kind().toQualString(); - break; - default: - TORCH_CHECK( - false, - "The node of ", - node->kind().toQualString(), - " is not supported in this compiler. Source code: ", - node->sourceRange().str()); - break; - } - ss << ","; - } - std::string blob = ss.str(); - if (!blob.empty()) { - blob.pop_back(); - } - compiled.insert(method.name(), blob); - } - return compiled; -} - -static auto cls = torch::jit::backend( - "backend_with_compiler_demo", - preprocess); +constexpr auto backend_name = "backend_with_compiler_demo"; +static auto cls = torch::jit::backend(backend_name); } // namespace } // namespace jit diff --git a/test/cpp/jit/test_backend_compiler_preprocess.cpp b/test/cpp/jit/test_backend_compiler_preprocess.cpp new file mode 100644 index 0000000000000..9f545d9ce3357 --- /dev/null +++ b/test/cpp/jit/test_backend_compiler_preprocess.cpp @@ -0,0 +1,59 @@ +#include +#include + +namespace torch { +namespace jit { +namespace { +// For this backend, the actual compilation happens in preprocess function AOT. +// Put here for demonstration of backend +// as a whole piece. It's used when compilation is required. A dummy function +// can be passed when there's no usage of compilation in runtime backend lib. +c10::IValue preprocess( + const Module& mod, + const c10::Dict& method_compile_spec) { + // The output of this process would produce a dictionary + // Key: method name. + // Val: compiled blob (represented by a string). + c10::Dict compiled(StringType::get(), StringType::get()); + for (const auto& method : mod.get_methods()) { + const auto graph = method.function().graph()->copy(); + auto key = method.name(); + std::stringstream ss; + for (const auto& node : graph->nodes()) { + switch (node->kind()) { + case prim::Constant: + ss << node->kind().toDisplayString() << "#" + << toIValue(node->output()).value(); + break; + case aten::add: + ss << node->kind().toQualString(); + break; + case aten::sub: + ss << node->kind().toQualString(); + break; + default: + TORCH_CHECK( + false, + "The node of ", + node->kind().toQualString(), + " is not supported in this compiler. Source code: ", + node->sourceRange().str()); + break; + } + ss << ","; + } + std::string blob = ss.str(); + if (!blob.empty()) { + blob.pop_back(); + } + compiled.insert(method.name(), blob); + } + return compiled; +} + +constexpr auto backend_name = "backend_with_compiler_demo"; +static auto pre_reg = backend_preprocess_register(backend_name, preprocess); +} // namespace + +} // namespace jit +} // namespace torch diff --git a/test/cpp/jit/test_backend_lib.cpp b/test/cpp/jit/test_backend_lib.cpp index 5c10308f40cdf..bb234ead4b9d1 100644 --- a/test/cpp/jit/test_backend_lib.cpp +++ b/test/cpp/jit/test_backend_lib.cpp @@ -1,4 +1,5 @@ #include +#include namespace torch { namespace jit { @@ -73,12 +74,17 @@ c10::IValue preprocess( return mod._ivalue(); } +constexpr auto backend_name = "test_backend"; static auto cls_available = - torch::jit::backend>("test_backend", preprocess); -static auto cls_unavailable = torch::jit::backend>( - "test_backend_unavailable", - preprocess); -} // namespace + torch::jit::backend>(backend_name); +static auto pre_reg = backend_preprocess_register(backend_name, preprocess); + +constexpr auto backend_unavailable_name = "test_backend_unavailable"; +static auto cls_unavailable = + torch::jit::backend>(backend_unavailable_name); +static auto pre_reg_unavailable = + backend_preprocess_register(backend_unavailable_name, preprocess); +} // namespace } // namespace jit } // namespace torch diff --git a/test/cpp/lite_interpreter_runtime/CMakeLists.txt b/test/cpp/lite_interpreter_runtime/CMakeLists.txt index cf43c9083f35d..c68ea8869b4b1 100644 --- a/test/cpp/lite_interpreter_runtime/CMakeLists.txt +++ b/test/cpp/lite_interpreter_runtime/CMakeLists.txt @@ -6,13 +6,21 @@ set(LITE_INTERPRETER_RUNTIME_TEST_DIR ${TORCH_ROOT}/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp ) +add_library(backend_with_compiler_runtime SHARED + ${TORCH_ROOT}/test/cpp/jit/test_backend_compiler_lib.cpp + ${TORCH_ROOT}/torch/csrc/jit/backends/backend_interface.cpp + ) +target_link_libraries(backend_with_compiler_runtime PRIVATE torch) + add_executable( test_lite_interpreter_runtime ${LITE_INTERPRETER_RUNTIME_TEST_DIR}) target_include_directories( test_lite_interpreter_runtime PRIVATE - ${ATen_CPU_INCLUDE}) -target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest) + ${ATen_CPU_INCLUDE} +) + +target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest backend_with_compiler_runtime) if(INSTALL_TEST) install(TARGETS test_lite_interpreter_runtime DESTINATION bin) diff --git a/test/cpp/lite_interpreter_runtime/delegate_test.ptl b/test/cpp/lite_interpreter_runtime/delegate_test.ptl new file mode 100644 index 0000000000000..1e65c14e2ab7a Binary files /dev/null and b/test/cpp/lite_interpreter_runtime/delegate_test.ptl differ diff --git a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp index 852a28f7a4361..ae58b99676a6f 100644 --- a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp +++ b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp @@ -49,6 +49,28 @@ TEST(RunTimeTest, LoadAndForward) { ASSERT_EQ(result, expected_result); } +TEST(RunTimeTest, Delegate) { + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + // "delegate_test.ptl" is generated from test/cpp/jit/test_backend.cpp, + // BackendTest.TestCompiler. This test is on target runtime. It has + // model running capability, but no compilation and serialization. + // The mobile model delegated to the "backend_with_compiler_demo" backend + // The model is from the jit code: + // Module m("m"); + // m.define(R"( + // def forward(self, x, h): + // return x + h + // )"); + testModelFile.append("delegate_test.ptl"); + auto mlm = _load_for_mobile(testModelFile); + std::vector inputs; + inputs.emplace_back(2.0 * at::ones({})); + inputs.emplace_back(1.0 * at::ones({})); + + auto mres = mlm.forward(inputs); + AT_ASSERT(mres.toTensor().equal(3 * at::ones({}))); +} } // namespace mobile } // namespace jit } // namespace torch diff --git a/test/custom_backend/custom_backend.cpp b/test/custom_backend/custom_backend.cpp index dd67f4db2170d..ddcdc7bca0268 100644 --- a/test/custom_backend/custom_backend.cpp +++ b/test/custom_backend/custom_backend.cpp @@ -1,10 +1,12 @@ #include "custom_backend.h" +#include namespace torch { namespace custom_backend { namespace { constexpr auto kBackendName = "custom_backend"; -static auto cls = torch::jit::backend(kBackendName, preprocess); +static auto cls = torch::jit::backend(kBackendName); +static auto pre_reg = torch::jit::backend_preprocess_register(kBackendName, preprocess); } std::string getBackendName() { diff --git a/test/custom_backend/custom_backend.h b/test/custom_backend/custom_backend.h index b1f8ca13609dc..4829d2169d521 100644 --- a/test/custom_backend/custom_backend.h +++ b/test/custom_backend/custom_backend.h @@ -1,4 +1,5 @@ #include +#include namespace torch { namespace custom_backend { diff --git a/torch/csrc/jit/backends/backend.h b/torch/csrc/jit/backends/backend.h index 2d648d81bf298..995cf373a9f6e 100644 --- a/torch/csrc/jit/backends/backend.h +++ b/torch/csrc/jit/backends/backend.h @@ -2,12 +2,86 @@ #include #include -#include #include #include namespace torch { namespace jit { +namespace { +c10::FunctionSchema getIsAvailableSchema() { + c10::Argument self("self", c10::AnyType::get()); + c10::Argument available("available", c10::BoolType::get()); + c10::FunctionSchema preprocessor_schema( + "is_available", + /*overload_name=*/"", + /*arguments=*/{self}, + /*returns=*/{available}); + return preprocessor_schema; +} + +constexpr static auto kBackendsNamespace = "__backends__"; + +c10::FunctionSchema getCompileSchema() { + c10::Argument self("self", c10::AnyType::get()); + c10::Argument mod("processed", c10::AnyType::get()); + auto any_dict_ty = + c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); + c10::Argument method_compile_spec("method_compile_spec", any_dict_ty); + c10::Argument handles("handles", any_dict_ty); + + c10::FunctionSchema compile_schema( + "compile", + /*overload_name=*/"", + /*arguments=*/{self, mod, method_compile_spec}, + /*returns=*/{handles}); + return compile_schema; +} + +c10::FunctionSchema getExecuteSchema() { + auto any_list_ty = c10::ListType::create(c10::AnyType::get()); + c10::Argument self("self", c10::AnyType::get()); + c10::Argument handle("handle", c10::AnyType::get()); + c10::Argument input("input", any_list_ty); + c10::Argument output("output", any_list_ty); + return c10::FunctionSchema( + "execute", + /*overload_name=*/"", + /*arguments=*/{self, handle, input}, + /*returns=*/{output}); +} + +template +std::function getIsAvailableFunc() { + return [](Stack& stack) { + auto self = pop(stack).toCustomClass(); + auto ret = self->is_available(); + push(stack, ret); + }; +} + +template +std::function getCompileFunc() { + return [](Stack& stack) { + auto method_compile_spec = pop(stack).toGenericDict(); + auto processed = pop(stack); + auto self = pop(stack).toCustomClass(); + auto ret = self->compile(processed, method_compile_spec); + push(stack, ret); + }; +} + +template +std::function getExecuteFunc() { + return [](Stack& stack) { + auto args = pop(stack); + auto handle = pop(stack); + auto self = pop(stack); + auto backend = self.toCustomClass(); + auto res = backend->execute(handle, args.toList()); + push(stack, res); + }; +} +} // namespace // Static registration API for backends. template @@ -20,26 +94,22 @@ class backend { public: // Registers a new backend with /p name, and the given /p preprocess // function. - backend( - const std::string& name, - const detail::BackendPreprocessFunction& preprocess) - : backend_name_(name) { - detail::registerBackendPreprocessFunction(name, preprocess); + backend(const std::string& name) : backend_name_(name) { static auto cls = - torch::class_(detail::kBackendsNamespace, name) + torch::class_(kBackendsNamespace, name) .def(torch::init<>()) ._def_unboxed( "is_available", - detail::getIsAvailableFunc(), - detail::getIsAvailableSchema()) + getIsAvailableFunc(), + getIsAvailableSchema()) ._def_unboxed( "compile", - detail::getCompileFunc(), - detail::getCompileSchema()) + getCompileFunc(), + getCompileSchema()) ._def_unboxed( "execute", - detail::getExecuteFunc(), - detail::getExecuteSchema()); + getExecuteFunc(), + getExecuteSchema()); } }; diff --git a/torch/csrc/jit/backends/backend_detail.cpp b/torch/csrc/jit/backends/backend_detail.cpp index 5dbd3dcdec053..9d0e44f451794 100644 --- a/torch/csrc/jit/backends/backend_detail.cpp +++ b/torch/csrc/jit/backends/backend_detail.cpp @@ -1,7 +1,7 @@ #include -#include #include +#include #include #include @@ -10,46 +10,6 @@ namespace torch { namespace jit { namespace detail { -c10::FunctionSchema getIsAvailableSchema() { - c10::Argument self("self", c10::AnyType::get()); - c10::Argument available("available", c10::BoolType::get()); - c10::FunctionSchema preprocessor_schema( - "is_available", - /*overload_name=*/"", - /*arguments=*/{self}, - /*returns=*/{available}); - return preprocessor_schema; -} - -c10::FunctionSchema getCompileSchema() { - c10::Argument self("self", c10::AnyType::get()); - c10::Argument mod("processed", c10::AnyType::get()); - auto any_dict_ty = - c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); - c10::Argument method_compile_spec("method_compile_spec", any_dict_ty); - c10::Argument handles("handles", any_dict_ty); - - c10::FunctionSchema compile_schema( - "compile", - /*overload_name=*/"", - /*arguments=*/{self, mod, method_compile_spec}, - /*returns=*/{handles}); - return compile_schema; -} - -c10::FunctionSchema getExecuteSchema() { - auto any_list_ty = c10::ListType::create(c10::AnyType::get()); - c10::Argument self("self", c10::AnyType::get()); - c10::Argument handle("handle", c10::AnyType::get()); - c10::Argument input("input", any_list_ty); - c10::Argument output("output", any_list_ty); - return c10::FunctionSchema( - "execute", - /*overload_name=*/"", - /*arguments=*/{self, handle, input}, - /*returns=*/{output}); -} - namespace { std::unordered_map& backendPreprocessFunctions() { @@ -93,7 +53,7 @@ Module codegen_backend_module( {"__torch__", "torch", "classes", - detail::kBackendsNamespace, + kBackendsNamespace, backend_name}); // TODO: Validate method_compile_spec. diff --git a/torch/csrc/jit/backends/backend_detail.h b/torch/csrc/jit/backends/backend_detail.h index 1d75378017566..ffc9ce07b34ef 100644 --- a/torch/csrc/jit/backends/backend_detail.h +++ b/torch/csrc/jit/backends/backend_detail.h @@ -3,7 +3,6 @@ #include #include -#include #include @@ -11,44 +10,6 @@ namespace torch { namespace jit { namespace detail { -constexpr static auto kBackendsNamespace = "__backends__"; - -c10::FunctionSchema TORCH_API getIsAvailableSchema(); -c10::FunctionSchema TORCH_API getCompileSchema(); -c10::FunctionSchema TORCH_API getExecuteSchema(); - -template -std::function getIsAvailableFunc() { - return [](Stack& stack) { - auto self = pop(stack).toCustomClass(); - auto ret = self->is_available(); - push(stack, ret); - }; -} - -template -std::function getCompileFunc() { - return [](Stack& stack) { - auto method_compile_spec = pop(stack).toGenericDict(); - auto processed = pop(stack); - auto self = pop(stack).toCustomClass(); - auto ret = self->compile(processed, method_compile_spec); - push(stack, ret); - }; -} - -template -std::function getExecuteFunc() { - return [](Stack& stack) { - auto args = pop(stack); - auto handle = pop(stack); - auto self = pop(stack); - auto backend = self.toCustomClass(); - auto res = backend->execute(handle, args.toList()); - push(stack, res); - }; -} - using BackendPreprocessFunction = std::function&)>; diff --git a/torch/csrc/jit/backends/backend_preprocess.h b/torch/csrc/jit/backends/backend_preprocess.h new file mode 100644 index 0000000000000..0a256134aa96e --- /dev/null +++ b/torch/csrc/jit/backends/backend_preprocess.h @@ -0,0 +1,18 @@ +#pragma once + +#include +namespace torch { +namespace jit { +class backend_preprocess_register { + std::string backend_name_; + + public: + backend_preprocess_register( + const std::string& name, + const detail::BackendPreprocessFunction& preprocess) + : backend_name_(name) { + detail::registerBackendPreprocessFunction(name, preprocess); + } +}; +} // namespace jit +} // namespace torch