10
10
11
11
#include < executorch/extension/data_loader/file_data_loader.h>
12
12
#include < executorch/extension/data_loader/mmap_data_loader.h>
13
+ #include < executorch/extension/flat_tensor/flat_tensor_data_map.h>
13
14
#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
14
15
#include < executorch/runtime/platform/runtime.h>
15
16
36
37
namespace executorch {
37
38
namespace extension {
38
39
40
+ namespace {
41
+ runtime::Result<std::unique_ptr<runtime::DataLoader>> load_file (
42
+ const std::string& file_path,
43
+ Module::LoadMode mode) {
44
+ std::unique_ptr<runtime::DataLoader> res = nullptr ;
45
+ switch (mode) {
46
+ case Module::LoadMode::File:
47
+ res = ET_UNWRAP_UNIQUE (FileDataLoader::from (file_path.c_str ()));
48
+ break ;
49
+ case Module::LoadMode::Mmap:
50
+ res = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
51
+ file_path.c_str (), MmapDataLoader::MlockConfig::NoMlock));
52
+ break ;
53
+ case Module::LoadMode::MmapUseMlock:
54
+ res = ET_UNWRAP_UNIQUE (MmapDataLoader::from (file_path.c_str ()));
55
+ break ;
56
+ case Module::LoadMode::MmapUseMlockIgnoreErrors:
57
+ res = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
58
+ file_path.c_str (),
59
+ MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
60
+ break ;
61
+ }
62
+ return res;
63
+ }
64
+ } // namespace
65
+
66
+ Module::Module (
67
+ const std::string& file_path,
68
+ const LoadMode load_mode,
69
+ std::unique_ptr<runtime::EventTracer> event_tracer)
70
+ : file_path_(file_path),
71
+ load_mode_ (load_mode),
72
+ memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
73
+ temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
74
+ event_tracer_(std::move(event_tracer)),
75
+ data_map_loader_(nullptr ),
76
+ data_map_(nullptr ) {
77
+ runtime::runtime_init ();
78
+ }
79
+
39
80
Module::Module (
40
81
const std::string& file_path,
82
+ const std::string& data_map_path,
41
83
const LoadMode load_mode,
42
84
std::unique_ptr<runtime::EventTracer> event_tracer)
43
85
: file_path_(file_path),
86
+ data_map_path_(data_map_path),
44
87
load_mode_(load_mode),
45
88
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
46
89
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47
- event_tracer_(std::move(event_tracer)) {
90
+ event_tracer_(std::move(event_tracer)),
91
+ data_map_loader_(nullptr ),
92
+ data_map_(nullptr ) {
48
93
runtime::runtime_init ();
49
94
}
50
95
51
96
Module::Module (
52
97
std::unique_ptr<runtime::DataLoader> data_loader,
53
98
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
54
99
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55
- std::unique_ptr<runtime::EventTracer> event_tracer)
100
+ std::unique_ptr<runtime::EventTracer> event_tracer,
101
+ std::unique_ptr<runtime::DataLoader> data_map_loader)
56
102
: data_loader_(std::move(data_loader)),
57
103
memory_allocator_(
58
104
memory_allocator ? std::move(memory_allocator)
59
105
: std::make_unique<MallocMemoryAllocator>()),
60
106
temp_allocator_(
61
107
temp_allocator ? std::move(temp_allocator)
62
108
: std::make_unique<MallocMemoryAllocator>()),
63
- event_tracer_(std::move(event_tracer)) {
109
+ event_tracer_(std::move(event_tracer)),
110
+ data_map_loader_(std::move(data_map_loader)),
111
+ data_map_(nullptr ) {
64
112
runtime::runtime_init ();
65
113
}
66
114
67
115
Module::Module (
68
116
std::shared_ptr<runtime::Program> program,
69
117
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70
118
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71
- std::unique_ptr<runtime::EventTracer> event_tracer)
119
+ std::unique_ptr<runtime::EventTracer> event_tracer,
120
+ std::unique_ptr<runtime::DataLoader> data_map_loader)
72
121
: program_(std::move(program)),
73
122
memory_allocator_(
74
123
memory_allocator ? std::move(memory_allocator)
75
124
: std::make_unique<MallocMemoryAllocator>()),
76
125
temp_allocator_(
77
126
temp_allocator ? std::move(temp_allocator)
78
127
: std::make_unique<MallocMemoryAllocator>()),
79
- event_tracer_(std::move(event_tracer)) {
128
+ event_tracer_(std::move(event_tracer)),
129
+ data_map_loader_(std::move(data_map_loader)),
130
+ data_map_(nullptr ) {
80
131
runtime::runtime_init ();
81
132
}
82
133
83
134
runtime::Error Module::load (const runtime::Program::Verification verification) {
84
135
if (!is_loaded ()) {
136
+ // Load the program
85
137
if (!data_loader_) {
86
- switch (load_mode_) {
87
- case LoadMode::File:
88
- data_loader_ =
89
- ET_UNWRAP_UNIQUE (FileDataLoader::from (file_path_.c_str ()));
90
- break ;
91
- case LoadMode::Mmap:
92
- data_loader_ = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
93
- file_path_.c_str (), MmapDataLoader::MlockConfig::NoMlock));
94
- break ;
95
- case LoadMode::MmapUseMlock:
96
- data_loader_ =
97
- ET_UNWRAP_UNIQUE (MmapDataLoader::from (file_path_.c_str ()));
98
- break ;
99
- case LoadMode::MmapUseMlockIgnoreErrors:
100
- data_loader_ = ET_UNWRAP_UNIQUE (MmapDataLoader::from (
101
- file_path_.c_str (),
102
- MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103
- break ;
138
+ auto res = load_file (file_path_, load_mode_);
139
+ if (!res.ok ()) {
140
+ return res.error ();
104
141
}
105
- };
142
+ data_loader_ = std::move (res.get ());
143
+ }
144
+ // If a .ptd path was given load it.
145
+ if (data_map_path_ != " " ) {
146
+ auto res = load_file (data_map_path_, load_mode_);
147
+ if (!res.ok ()) {
148
+ return res.error ();
149
+ }
150
+ data_map_loader_ = std::move (res.get ());
151
+ }
152
+ // If we have a .ptd loader, then load the map.
153
+ if (data_map_loader_) {
154
+ data_map_ =
155
+ ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loader_.get ()));
156
+ }
157
+ // else: either the map itself was provided or we have no data map, either
158
+ // way no work to do.
106
159
auto program = ET_UNWRAP_UNIQUE (
107
160
runtime::Program::load (data_loader_.get (), verification));
108
161
program_ = std::shared_ptr<runtime::Program>(
@@ -130,6 +183,7 @@ runtime::Error Module::load_method(
130
183
ET_CHECK_OK_OR_RETURN_ERROR (load ());
131
184
132
185
MethodHolder method_holder;
186
+
133
187
const auto method_metadata =
134
188
ET_UNWRAP (program_->method_meta (method_name.c_str ()));
135
189
const auto planned_buffersCount =
@@ -155,7 +209,8 @@ runtime::Error Module::load_method(
155
209
method_holder.method = ET_UNWRAP_UNIQUE (program_->load_method (
156
210
method_name.c_str (),
157
211
method_holder.memory_manager .get (),
158
- event_tracer ? event_tracer : this ->event_tracer ()));
212
+ event_tracer ? event_tracer : this ->event_tracer (),
213
+ data_map_.get ()));
159
214
method_holder.inputs .resize (method_holder.method ->inputs_size ());
160
215
methods_.emplace (method_name, std::move (method_holder));
161
216
}
0 commit comments