Skip to content

Commit

Permalink
[PyTorch] Lite interpreter with a backend delegate (pytorch#54462)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#54462

Unclean files during sync - Sat Mar 20 04:00:02 PDT 2021

Unclean files during sync - Sun Mar 21 04:00:01 PDT 2021
ghstack-source-id: 124585992

Test Plan:
```
buck run xplat/caffe2/fb/test/delegate:interpreter_test -- --model_file_path=/path/to/mobile_model.ptl
```

Reviewed By: raziel

Differential Revision: D27232309

fbshipit-source-id: 8504a3185339d73bfa6e924485c4745acf269cec
  • Loading branch information
iseeyuan authored and facebook-github-bot committed Apr 6, 2021
1 parent 7d9a619 commit 3551bd3
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 153 deletions.
5 changes: 4 additions & 1 deletion test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ target_link_libraries(torchbind_test torch)
add_library(jitbackend_test SHARED ${JIT_TEST_ROOT}/test_backend_lib.cpp)
target_link_libraries(jitbackend_test torch)

add_library(backend_with_compiler SHARED ${JIT_TEST_ROOT}/test_backend_compiler_lib.cpp)
add_library(backend_with_compiler SHARED
${JIT_TEST_ROOT}/test_backend_compiler_lib.cpp
${JIT_TEST_ROOT}/test_backend_compiler_preprocess.cpp
)
target_link_libraries(backend_with_compiler torch)

if(INSTALL_TEST)
Expand Down
52 changes: 2 additions & 50 deletions test/cpp/jit/test_backend_compiler_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,56 +106,8 @@ class BackendWithCompiler : public PyTorchBackendInterface {
};

namespace {
// For this backend, the actual compilation happens in preprocess function AOT.
// Put here for demonstration of backend
// as a whole piece. It's used when compilation is required. A dummy function
// can be passed when there's no usage of compilation in runtime backend lib.
c10::IValue preprocess(
const Module& mod,
const c10::Dict<IValue, IValue>& method_compile_spec) {
// The output of this process would produce a dictionary
// Key: method name.
// Val: compiled blob (represented by a string).
c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
for (const auto& method : mod.get_methods()) {
const auto graph = method.function().graph()->copy();
auto key = method.name();
std::stringstream ss;
for (const auto& node : graph->nodes()) {
switch (node->kind()) {
case prim::Constant:
ss << node->kind().toDisplayString() << "#"
<< toIValue(node->output()).value();
break;
case aten::add:
ss << node->kind().toQualString();
break;
case aten::sub:
ss << node->kind().toQualString();
break;
default:
TORCH_CHECK(
false,
"The node of ",
node->kind().toQualString(),
" is not supported in this compiler. Source code: ",
node->sourceRange().str());
break;
}
ss << ",";
}
std::string blob = ss.str();
if (!blob.empty()) {
blob.pop_back();
}
compiled.insert(method.name(), blob);
}
return compiled;
}

static auto cls = torch::jit::backend<BackendWithCompiler>(
"backend_with_compiler_demo",
preprocess);
constexpr auto backend_name = "backend_with_compiler_demo";
static auto cls = torch::jit::backend<BackendWithCompiler>(backend_name);
} // namespace

} // namespace jit
Expand Down
59 changes: 59 additions & 0 deletions test/cpp/jit/test_backend_compiler_preprocess.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_preprocess.h>

namespace torch {
namespace jit {
namespace {
// For this backend, the actual compilation happens in preprocess function AOT.
// Put here for demonstration of backend
// as a whole piece. It's used when compilation is required. A dummy function
// can be passed when there's no usage of compilation in runtime backend lib.
c10::IValue preprocess(
const Module& mod,
const c10::Dict<IValue, IValue>& method_compile_spec) {
// The output of this process would produce a dictionary
// Key: method name.
// Val: compiled blob (represented by a string).
c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
for (const auto& method : mod.get_methods()) {
const auto graph = method.function().graph()->copy();
auto key = method.name();
std::stringstream ss;
for (const auto& node : graph->nodes()) {
switch (node->kind()) {
case prim::Constant:
ss << node->kind().toDisplayString() << "#"
<< toIValue(node->output()).value();
break;
case aten::add:
ss << node->kind().toQualString();
break;
case aten::sub:
ss << node->kind().toQualString();
break;
default:
TORCH_CHECK(
false,
"The node of ",
node->kind().toQualString(),
" is not supported in this compiler. Source code: ",
node->sourceRange().str());
break;
}
ss << ",";
}
std::string blob = ss.str();
if (!blob.empty()) {
blob.pop_back();
}
compiled.insert(method.name(), blob);
}
return compiled;
}

constexpr auto backend_name = "backend_with_compiler_demo";
static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
} // namespace

} // namespace jit
} // namespace torch
16 changes: 11 additions & 5 deletions test/cpp/jit/test_backend_lib.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_preprocess.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -73,12 +74,17 @@ c10::IValue preprocess(
return mod._ivalue();
}

constexpr auto backend_name = "test_backend";
static auto cls_available =
torch::jit::backend<TestBackend<true>>("test_backend", preprocess);
static auto cls_unavailable = torch::jit::backend<TestBackend<false>>(
"test_backend_unavailable",
preprocess);
} // namespace
torch::jit::backend<TestBackend<true>>(backend_name);
static auto pre_reg = backend_preprocess_register(backend_name, preprocess);

constexpr auto backend_unavailable_name = "test_backend_unavailable";
static auto cls_unavailable =
torch::jit::backend<TestBackend<false>>(backend_unavailable_name);
static auto pre_reg_unavailable =
backend_preprocess_register(backend_unavailable_name, preprocess);

} // namespace
} // namespace jit
} // namespace torch
12 changes: 10 additions & 2 deletions test/cpp/lite_interpreter_runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@ set(LITE_INTERPRETER_RUNTIME_TEST_DIR
${TORCH_ROOT}/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp
)

add_library(backend_with_compiler_runtime SHARED
${TORCH_ROOT}/test/cpp/jit/test_backend_compiler_lib.cpp
${TORCH_ROOT}/torch/csrc/jit/backends/backend_interface.cpp
)
target_link_libraries(backend_with_compiler_runtime PRIVATE torch)

add_executable(
test_lite_interpreter_runtime
${LITE_INTERPRETER_RUNTIME_TEST_DIR})
target_include_directories(
test_lite_interpreter_runtime PRIVATE
${ATen_CPU_INCLUDE})
target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest)
${ATen_CPU_INCLUDE}
)

target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest backend_with_compiler_runtime)

if(INSTALL_TEST)
install(TARGETS test_lite_interpreter_runtime DESTINATION bin)
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,28 @@ TEST(RunTimeTest, LoadAndForward) {
ASSERT_EQ(result, expected_result);
}

TEST(RunTimeTest, Delegate) {
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
// "delegate_test.ptl" is generated from test/cpp/jit/test_backend.cpp,
// BackendTest.TestCompiler. This test is on target runtime. It has
// model running capability, but no compilation and serialization.
// The mobile model delegated to the "backend_with_compiler_demo" backend
// The model is from the jit code:
// Module m("m");
// m.define(R"(
// def forward(self, x, h):
// return x + h
// )");
testModelFile.append("delegate_test.ptl");
auto mlm = _load_for_mobile(testModelFile);
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * at::ones({}));
inputs.emplace_back(1.0 * at::ones({}));

auto mres = mlm.forward(inputs);
AT_ASSERT(mres.toTensor().equal(3 * at::ones({})));
}
} // namespace mobile
} // namespace jit
} // namespace torch
4 changes: 3 additions & 1 deletion test/custom_backend/custom_backend.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include "custom_backend.h"
#include <torch/csrc/jit/backends/backend_preprocess.h>

namespace torch {
namespace custom_backend {
namespace {
constexpr auto kBackendName = "custom_backend";
static auto cls = torch::jit::backend<CustomBackend>(kBackendName, preprocess);
static auto cls = torch::jit::backend<CustomBackend>(kBackendName);
static auto pre_reg = torch::jit::backend_preprocess_register(kBackendName, preprocess);
}

std::string getBackendName() {
Expand Down
1 change: 1 addition & 0 deletions test/custom_backend/custom_backend.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/api/module.h>

namespace torch {
namespace custom_backend {
Expand Down
96 changes: 83 additions & 13 deletions torch/csrc/jit/backends/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,86 @@

#include <ATen/core/builtin_function.h>
#include <ATen/core/stack.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_interface.h>
#include <torch/custom_class.h>

namespace torch {
namespace jit {
namespace {
c10::FunctionSchema getIsAvailableSchema() {
c10::Argument self("self", c10::AnyType::get());
c10::Argument available("available", c10::BoolType::get());
c10::FunctionSchema preprocessor_schema(
"is_available",
/*overload_name=*/"",
/*arguments=*/{self},
/*returns=*/{available});
return preprocessor_schema;
}

constexpr static auto kBackendsNamespace = "__backends__";

c10::FunctionSchema getCompileSchema() {
c10::Argument self("self", c10::AnyType::get());
c10::Argument mod("processed", c10::AnyType::get());
auto any_dict_ty =
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
c10::Argument method_compile_spec("method_compile_spec", any_dict_ty);
c10::Argument handles("handles", any_dict_ty);

c10::FunctionSchema compile_schema(
"compile",
/*overload_name=*/"",
/*arguments=*/{self, mod, method_compile_spec},
/*returns=*/{handles});
return compile_schema;
}

c10::FunctionSchema getExecuteSchema() {
auto any_list_ty = c10::ListType::create(c10::AnyType::get());
c10::Argument self("self", c10::AnyType::get());
c10::Argument handle("handle", c10::AnyType::get());
c10::Argument input("input", any_list_ty);
c10::Argument output("output", any_list_ty);
return c10::FunctionSchema(
"execute",
/*overload_name=*/"",
/*arguments=*/{self, handle, input},
/*returns=*/{output});
}

template <typename TBackendInterface>
std::function<void(Stack&)> getIsAvailableFunc() {
return [](Stack& stack) {
auto self = pop(stack).toCustomClass<TBackendInterface>();
auto ret = self->is_available();
push(stack, ret);
};
}

template <typename TBackendInterface>
std::function<void(Stack&)> getCompileFunc() {
return [](Stack& stack) {
auto method_compile_spec = pop(stack).toGenericDict();
auto processed = pop(stack);
auto self = pop(stack).toCustomClass<TBackendInterface>();
auto ret = self->compile(processed, method_compile_spec);
push(stack, ret);
};
}

template <typename TBackendInterface>
std::function<void(Stack&)> getExecuteFunc() {
return [](Stack& stack) {
auto args = pop(stack);
auto handle = pop(stack);
auto self = pop(stack);
auto backend = self.toCustomClass<TBackendInterface>();
auto res = backend->execute(handle, args.toList());
push(stack, res);
};
}
} // namespace

// Static registration API for backends.
template <class TBackendInterface>
Expand All @@ -20,26 +94,22 @@ class backend {
public:
// Registers a new backend with /p name, and the given /p preprocess
// function.
backend(
const std::string& name,
const detail::BackendPreprocessFunction& preprocess)
: backend_name_(name) {
detail::registerBackendPreprocessFunction(name, preprocess);
backend(const std::string& name) : backend_name_(name) {
static auto cls =
torch::class_<TBackendInterface>(detail::kBackendsNamespace, name)
torch::class_<TBackendInterface>(kBackendsNamespace, name)
.def(torch::init<>())
._def_unboxed(
"is_available",
detail::getIsAvailableFunc<TBackendInterface>(),
detail::getIsAvailableSchema())
getIsAvailableFunc<TBackendInterface>(),
getIsAvailableSchema())
._def_unboxed(
"compile",
detail::getCompileFunc<TBackendInterface>(),
detail::getCompileSchema())
getCompileFunc<TBackendInterface>(),
getCompileSchema())
._def_unboxed(
"execute",
detail::getExecuteFunc<TBackendInterface>(),
detail::getExecuteSchema());
getExecuteFunc<TBackendInterface>(),
getExecuteSchema());
}
};

Expand Down
Loading

0 comments on commit 3551bd3

Please sign in to comment.