Skip to content

Commit 3e188fe

Browse files
add .ptd support to extension/module
Differential Revision: D69478424 Pull Request resolved: #8421
1 parent 72432ba commit 3e188fe

File tree

10 files changed

+149
-36
lines changed

10 files changed

+149
-36
lines changed

CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
258258
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
259259
endif()
260260

261+
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
262+
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
263+
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
264+
endif()
265+
261266
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
262267
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
263268
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)

extension/flat_tensor/targets.bzl

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ def define_common_targets():
99
exported_headers = ["flat_tensor_data_map.h"],
1010
deps = [
1111
"//executorch/extension/flat_tensor/serialize:generated_headers",
12-
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
1312
"//executorch/runtime/core:core",
1413
"//executorch/runtime/core:evalue",
1514
"//executorch/runtime/core:named_data_map",
1615
"//executorch/runtime/core/exec_aten:lib",
1716
"//executorch/runtime/core/exec_aten/util:tensor_util",
1817
],
18+
exported_deps = [
19+
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
20+
],
1921
visibility = [
2022
"//executorch/...",
2123
],

extension/module/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ if(CMAKE_TOOLCHAIN_IOS
2727
else()
2828
add_library(extension_module SHARED ${_extension_module__srcs})
2929
endif()
30-
target_link_libraries(extension_module PRIVATE executorch extension_data_loader)
30+
target_link_libraries(extension_module PRIVATE executorch extension_data_loader extension_flat_tensor)
3131
target_include_directories(extension_module PUBLIC ${EXECUTORCH_ROOT}/..)
3232
target_compile_options(
3333
extension_module PUBLIC -Wno-deprecated-declarations -fPIC
@@ -37,7 +37,7 @@ target_compile_options(
3737
# after cleaning up CMake targets.
3838
add_library(extension_module_static STATIC ${_extension_module__srcs})
3939
target_link_libraries(
40-
extension_module_static PRIVATE executorch extension_data_loader
40+
extension_module_static PRIVATE executorch extension_data_loader extension_flat_tensor
4141
)
4242
target_include_directories(extension_module_static PUBLIC ${EXECUTORCH_ROOT}/..)
4343
target_compile_options(

extension/module/module.cpp

+80-25
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/extension/data_loader/file_data_loader.h>
1212
#include <executorch/extension/data_loader/mmap_data_loader.h>
13+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1314
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
1415
#include <executorch/runtime/platform/runtime.h>
1516

@@ -36,73 +37,125 @@
3637
namespace executorch {
3738
namespace extension {
3839

40+
namespace {
41+
runtime::Result<std::unique_ptr<runtime::DataLoader>> load_file(
42+
const std::string& file_path,
43+
Module::LoadMode mode) {
44+
std::unique_ptr<runtime::DataLoader> res = nullptr;
45+
switch (mode) {
46+
case Module::LoadMode::File:
47+
res = ET_UNWRAP_UNIQUE(FileDataLoader::from(file_path.c_str()));
48+
break;
49+
case Module::LoadMode::Mmap:
50+
res = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
51+
file_path.c_str(), MmapDataLoader::MlockConfig::NoMlock));
52+
break;
53+
case Module::LoadMode::MmapUseMlock:
54+
res = ET_UNWRAP_UNIQUE(MmapDataLoader::from(file_path.c_str()));
55+
break;
56+
case Module::LoadMode::MmapUseMlockIgnoreErrors:
57+
res = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
58+
file_path.c_str(),
59+
MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
60+
break;
61+
}
62+
return res;
63+
}
64+
} // namespace
65+
66+
Module::Module(
67+
const std::string& file_path,
68+
const LoadMode load_mode,
69+
std::unique_ptr<runtime::EventTracer> event_tracer)
70+
: file_path_(file_path),
71+
load_mode_(load_mode),
72+
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
73+
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
74+
event_tracer_(std::move(event_tracer)),
75+
data_map_loader_(nullptr),
76+
data_map_(nullptr) {
77+
runtime::runtime_init();
78+
}
79+
3980
Module::Module(
4081
const std::string& file_path,
82+
const std::string& data_map_path,
4183
const LoadMode load_mode,
4284
std::unique_ptr<runtime::EventTracer> event_tracer)
4385
: file_path_(file_path),
86+
data_map_path_(data_map_path),
4487
load_mode_(load_mode),
4588
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
4689
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47-
event_tracer_(std::move(event_tracer)) {
90+
event_tracer_(std::move(event_tracer)),
91+
data_map_loader_(nullptr),
92+
data_map_(nullptr) {
4893
runtime::runtime_init();
4994
}
5095

5196
Module::Module(
5297
std::unique_ptr<runtime::DataLoader> data_loader,
5398
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
5499
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55-
std::unique_ptr<runtime::EventTracer> event_tracer)
100+
std::unique_ptr<runtime::EventTracer> event_tracer,
101+
std::unique_ptr<runtime::DataLoader> data_map_loader)
56102
: data_loader_(std::move(data_loader)),
57103
memory_allocator_(
58104
memory_allocator ? std::move(memory_allocator)
59105
: std::make_unique<MallocMemoryAllocator>()),
60106
temp_allocator_(
61107
temp_allocator ? std::move(temp_allocator)
62108
: std::make_unique<MallocMemoryAllocator>()),
63-
event_tracer_(std::move(event_tracer)) {
109+
event_tracer_(std::move(event_tracer)),
110+
data_map_loader_(std::move(data_map_loader)),
111+
data_map_(nullptr) {
64112
runtime::runtime_init();
65113
}
66114

67115
Module::Module(
68116
std::shared_ptr<runtime::Program> program,
69117
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70118
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71-
std::unique_ptr<runtime::EventTracer> event_tracer)
119+
std::unique_ptr<runtime::EventTracer> event_tracer,
120+
std::unique_ptr<runtime::DataLoader> data_map_loader)
72121
: program_(std::move(program)),
73122
memory_allocator_(
74123
memory_allocator ? std::move(memory_allocator)
75124
: std::make_unique<MallocMemoryAllocator>()),
76125
temp_allocator_(
77126
temp_allocator ? std::move(temp_allocator)
78127
: std::make_unique<MallocMemoryAllocator>()),
79-
event_tracer_(std::move(event_tracer)) {
128+
event_tracer_(std::move(event_tracer)),
129+
data_map_loader_(std::move(data_map_loader)),
130+
data_map_(nullptr) {
80131
runtime::runtime_init();
81132
}
82133

83134
runtime::Error Module::load(const runtime::Program::Verification verification) {
84135
if (!is_loaded()) {
136+
// Load the program
85137
if (!data_loader_) {
86-
switch (load_mode_) {
87-
case LoadMode::File:
88-
data_loader_ =
89-
ET_UNWRAP_UNIQUE(FileDataLoader::from(file_path_.c_str()));
90-
break;
91-
case LoadMode::Mmap:
92-
data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
93-
file_path_.c_str(), MmapDataLoader::MlockConfig::NoMlock));
94-
break;
95-
case LoadMode::MmapUseMlock:
96-
data_loader_ =
97-
ET_UNWRAP_UNIQUE(MmapDataLoader::from(file_path_.c_str()));
98-
break;
99-
case LoadMode::MmapUseMlockIgnoreErrors:
100-
data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
101-
file_path_.c_str(),
102-
MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103-
break;
138+
auto res = load_file(file_path_, load_mode_);
139+
if (!res.ok()) {
140+
return res.error();
104141
}
105-
};
142+
data_loader_ = std::move(res.get());
143+
}
144+
// If a .ptd path was given load it.
145+
if (data_map_path_ != "") {
146+
auto res = load_file(data_map_path_, load_mode_);
147+
if (!res.ok()) {
148+
return res.error();
149+
}
150+
data_map_loader_ = std::move(res.get());
151+
}
152+
// If we have a .ptd loader, then load the map.
153+
if (data_map_loader_) {
154+
data_map_ =
155+
ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loader_.get()));
156+
}
157+
// else: either the map itself was provided or we have no data map, either
158+
// way no work to do.
106159
auto program = ET_UNWRAP_UNIQUE(
107160
runtime::Program::load(data_loader_.get(), verification));
108161
program_ = std::shared_ptr<runtime::Program>(
@@ -130,6 +183,7 @@ runtime::Error Module::load_method(
130183
ET_CHECK_OK_OR_RETURN_ERROR(load());
131184

132185
MethodHolder method_holder;
186+
133187
const auto method_metadata =
134188
ET_UNWRAP(program_->method_meta(method_name.c_str()));
135189
const auto planned_buffersCount =
@@ -155,7 +209,8 @@ runtime::Error Module::load_method(
155209
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156210
method_name.c_str(),
157211
method_holder.memory_manager.get(),
158-
event_tracer ? event_tracer : this->event_tracer()));
212+
event_tracer ? event_tracer : this->event_tracer(),
213+
data_map_.get()));
159214
method_holder.inputs.resize(method_holder.method->inputs_size());
160215
methods_.emplace(method_name, std::move(method_holder));
161216
}

extension/module/module.h

+24-3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ class Module {
5151
const LoadMode load_mode = LoadMode::MmapUseMlock,
5252
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
5353

54+
/**
55+
* Constructs an instance by loading a program from a file with specified
56+
* memory locking behavior.
57+
*
58+
* @param[in] file_path The path to the ExecuTorch program file to load.
59+
* @param[in] data_map_path The path to a .ptd file
60+
* @param[in] load_mode The loading mode to use.
61+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
62+
*/
63+
explicit Module(
64+
const std::string& file_path,
65+
const std::string& data_map_path,
66+
const LoadMode load_mode = LoadMode::MmapUseMlock,
67+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
68+
5469
/**
5570
* Constructs an instance with the provided data loader and memory allocator.
5671
*
@@ -59,12 +74,14 @@ class Module {
5974
* @param[in] temp_allocator A MemoryAllocator to use when allocating
6075
* temporary data during kernel or delegate execution.
6176
* @param[in] event_tracer A EventTracer used for tracking and logging events.
77+
* @param[in] data_map_loader A DataLoader used for loading external weights.
6278
*/
6379
explicit Module(
6480
std::unique_ptr<runtime::DataLoader> data_loader,
6581
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
6682
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
67-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
83+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
84+
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
6885

6986
/**
7087
* Constructs an instance using an existing shared program.
@@ -75,12 +92,14 @@ class Module {
7592
* @param[in] temp_allocator A MemoryAllocator to use when allocating
7693
* temporary data.
7794
* @param[in] event_tracer A EventTracer used for tracking and logging events.
95+
* @param[in] data_map_loader A DataLoader used for loading external weights.
7896
*/
7997
explicit Module(
8098
std::shared_ptr<runtime::Program> program,
8199
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
82100
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
83-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
101+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
102+
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
84103

85104
Module(const Module&) = delete;
86105
Module& operator=(const Module&) = delete;
@@ -433,14 +452,16 @@ class Module {
433452
std::vector<runtime::EValue> inputs;
434453
};
435454

436-
private:
437455
std::string file_path_;
456+
std::string data_map_path_;
438457
LoadMode load_mode_{LoadMode::MmapUseMlock};
439458
std::shared_ptr<runtime::Program> program_;
440459
std::unique_ptr<runtime::DataLoader> data_loader_;
441460
std::unique_ptr<runtime::MemoryAllocator> memory_allocator_;
442461
std::unique_ptr<runtime::MemoryAllocator> temp_allocator_;
443462
std::unique_ptr<runtime::EventTracer> event_tracer_;
463+
std::unique_ptr<runtime::DataLoader> data_map_loader_;
464+
std::unique_ptr<runtime::NamedDataMap> data_map_;
444465

445466
protected:
446467
std::unordered_map<std::string, MethodHolder> methods_;

extension/module/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def define_common_targets():
2525
"//executorch/extension/memory_allocator:malloc_memory_allocator",
2626
"//executorch/extension/data_loader:file_data_loader",
2727
"//executorch/extension/data_loader:mmap_data_loader",
28+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
2829
],
2930
exported_deps = [
3031
"//executorch/runtime/executor:program" + aten_suffix,

extension/module/test/module_test.cpp

+21-4
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,20 @@ using namespace ::executorch::runtime;
2222
class ModuleTest : public ::testing::Test {
2323
protected:
2424
static void SetUpTestSuite() {
25-
model_path_ = std::getenv("RESOURCES_PATH") + std::string("/add.pte");
25+
std::string resources_path;
26+
if (const char* env = std::getenv("RESOURCES_PATH")) {
27+
resources_path = env;
28+
}
29+
model_path_ = resources_path + "/add.pte";
30+
linear_path_ = resources_path + "/linear.pte";
31+
linear_data_path_ = resources_path + "/linear.ptd";
2632
}
2733

28-
static std::string model_path_;
34+
static inline std::string model_path_;
35+
static inline std::string linear_path_;
36+
static inline std::string linear_data_path_;
2937
};
3038

31-
std::string ModuleTest::model_path_;
32-
3339
TEST_F(ModuleTest, TestLoad) {
3440
Module module(model_path_);
3541

@@ -435,3 +441,14 @@ TEST_F(ModuleTest, TestSetOutputInvalidType) {
435441

436442
EXPECT_NE(module.set_output(EValue()), Error::Ok);
437443
}
444+
445+
TEST_F(ModuleTest, TestPTD) {
446+
Module module(linear_path_, linear_data_path_);
447+
448+
ASSERT_EQ(module.load_method("forward"), Error::Ok);
449+
450+
auto tensor1 =
451+
make_tensor_ptr({3, 3}, {2.f, 3.f, 4.f, 2.f, 3.f, 4.f, 2.f, 3.f, 4.f});
452+
453+
ASSERT_EQ(module.forward(tensor1).error(), Error::Ok);
454+
}
+13-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
## Resources
22

3-
### model.pte
3+
### add.pte, linear.pte, linear.ptd
44
- Internally generated after D62209852, 2024-09-06 with:
55
```
66
buck2 run fbcode//executorch/examples/portable/scripts:export -- --model_name="add"
77
```
8+
9+
and
10+
11+
```
12+
buck2 run fbcode//executorch/examples/portable/scripts:export -- --model_name="linear" -examples
13+
```
814
- In OSS, the same file can be generated after [#5145](https://github.com/pytorch/executorch/pull/5145), 2024-09-06 with:
915
```
1016
python -m examples.portable.scripts.export --model_name="add"
1117
```
18+
19+
and
20+
21+
```
22+
python -m examples.portable.scripts.export --model_name="linear" -e
23+
```
336 Bytes
Binary file not shown.
1.18 KB
Binary file not shown.

0 commit comments

Comments
 (0)