diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 89a87a02276db..efe628b354811 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -250,6 +250,12 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; +// When loading model from memory buffer and the model has external initializers +// Use this config to set the external data file folder path +// All external data files should be in the same folder +static const char* const kOrtSessionOptionsModelExternalInitializersFileFolderPath = + "session.model_external_initializers_file_folder_path"; + // Use this config when saving pre-packed constant initializers to an external data file. // This allows you to memory map pre-packed initializers on model load and leave it to // to the OS the amount of memory consumed by the pre-packed initializers. Otherwise, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 26ffeb93ab3b6..afd1e24bd4742 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1090,7 +1090,14 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; - return onnxruntime::Model::Load(std::move(model_proto), PathString(), model, + + std::string external_data_folder_path = session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsModelExternalInitializersFileFolderPath, ""); + if (!external_data_folder_path.empty() && model_location_.empty()) { + model_location_ = ToPathString(external_data_folder_path + "/virtual_model.onnx"); + } + + return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, ModelOptions(true, strict_shape_type_inference)); }; @@ -1120,8 +1127,15 @@ common::Status InferenceSession::LoadOnnxModel(ModelProto model_proto) { #endif const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; + + std::string external_data_folder_path = session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsModelExternalInitializersFileFolderPath, ""); + if (!external_data_folder_path.empty() && model_location_.empty()) { + model_location_ = ToPathString(external_data_folder_path + "/virtual_model.onnx"); + } + // This call will move model_proto to the constructed model instance - return onnxruntime::Model::Load(std::move(model_proto), PathString(), model, + return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, ModelOptions(true, strict_shape_type_inference)); }; @@ -1157,7 +1171,14 @@ common::Status InferenceSession::Load(std::istream& model_istream, bool allow_re kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; ModelOptions model_opts(allow_released_opsets_only, strict_shape_type_inference); - return onnxruntime::Model::Load(std::move(model_proto), PathString(), model, + + std::string external_data_folder_path = session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsModelExternalInitializersFileFolderPath, ""); + if (!external_data_folder_path.empty() && model_location_.empty()) { + model_location_ = ToPathString(external_data_folder_path + "/virtual_model.onnx"); + } + + return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, model_opts); }; diff --git a/onnxruntime/test/shared_lib/test_model_loading.cc b/onnxruntime/test/shared_lib/test_model_loading.cc index 5694398b9cb10..89b12ec61649e 100644 --- a/onnxruntime/test/shared_lib/test_model_loading.cc +++ b/onnxruntime/test/shared_lib/test_model_loading.cc @@ -215,6 +215,52 @@ TEST(CApiTest, TestLoadModelFromArrayWithExternalInitializersFromFileArrayPathRo #endif } +// The model has external data, Test loading model from array +// Extra API required to set the external data path +TEST(CApiTest, TestLoadModelFromArrayWithExternalInitializersViaSetExternalDataPath) { + std::string model_file_name = "conv_qdq_external_ini.onnx"; + std::string external_bin_name = "conv_qdq_external_ini.bin"; + std::string test_folder = "testdata/"; + std::string model_path = test_folder + model_file_name; + std::vector buffer; + ReadFileToBuffer(model_path.c_str(), buffer); + + std::vector external_bin_buffer; + std::string external_bin_path = test_folder + external_bin_name; + ReadFileToBuffer(external_bin_path.c_str(), external_bin_buffer); + + Ort::SessionOptions so; + std::string optimized_model_file_name(model_file_name); + auto length = optimized_model_file_name.length(); + optimized_model_file_name.insert(length - 5, "_opt"); + std::string optimized_file_path(test_folder + optimized_model_file_name); + PathString optimized_file_path_t(optimized_file_path.begin(), optimized_file_path.end()); + + // Dump the optimized model with external data so that it will unpack the external data from the loaded model + so.SetOptimizedModelFilePath(optimized_file_path_t.c_str()); + + // set the model external file folder path + so.AddConfigEntry(kOrtSessionOptionsModelExternalInitializersFileFolderPath, test_folder.c_str()); + + std::string opt_bin_file_name(optimized_model_file_name); + opt_bin_file_name.replace(optimized_model_file_name.length() - 4, 4, "bin"); + so.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL); + so.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersFileName, opt_bin_file_name.c_str()); + so.AddConfigEntry(kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes, "10"); + + Ort::Session session(*ort_env.get(), buffer.data(), buffer.size(), so); + + std::string generated_bin_path = test_folder + opt_bin_file_name; + std::vector generated_bin_buffer; + ReadFileToBuffer(generated_bin_path.c_str(), generated_bin_buffer); + + ASSERT_EQ(external_bin_buffer, generated_bin_buffer); + + // Cleanup. + ASSERT_EQ(std::remove(optimized_file_path.c_str()), 0); + ASSERT_EQ(std::remove(generated_bin_path.c_str()), 0); +} + #ifndef _WIN32 struct FileDescriptorTraits { using Handle = int;