From 27a445648fc020cbf2e80560532b96d3978ac37f Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 10:42:48 -0700 Subject: [PATCH 01/10] Resource check test --- src/generators.cpp | 19 ++++++++++++++++++- src/generators.h | 14 ++++++++++++-- src/models/model.h | 2 +- test/c_api_tests.cpp | 2 ++ 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index fe005111a..d3be0c6f0 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -15,7 +15,24 @@ static bool _ = (Ort::InitApi(), false); OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {} -std::unique_ptr& GetOrtGlobals() { +std::atomic TrackedResource::count_{}; + +// Validate process exit conditions, as this is done atexit, we print errors to stderr and throw an exception to stop the process +void ValidateShutdown() { + if (GetOrtGlobals()) { + std::cerr << "Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; + std::abort(); + } + if (TrackedResource::Count()) { + std::cerr << "Resources leaked: " + std::to_string(TrackedResource::Count()) + " All Oga resources must be cleaned up before shutdown." << std::endl; + std::abort(); + } +} + +static bool _1 = (std::atexit(ValidateShutdown), false); // Call ValidateShutdown at exit + +std::unique_ptr& +GetOrtGlobals() { static auto globals = std::make_unique(); return globals; } diff --git a/src/generators.h b/src/generators.h index 7a8f08951..a6134515b 100644 --- a/src/generators.h +++ b/src/generators.h @@ -55,7 +55,17 @@ enum struct DeviceType { std::string to_string(DeviceType device_type); -struct GeneratorParams : std::enable_shared_from_this { +struct TrackedResource { + TrackedResource() { count_++; } + ~TrackedResource() { count_--; } + + static int Count() { return count_; } + + private: + static std::atomic count_; +}; + +struct GeneratorParams : std::enable_shared_from_this, TrackedResource { GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in GeneratorParams(const Model& model); @@ -125,7 +135,7 @@ struct GeneratorParams : std::enable_shared_from_this { // The model outlives the GeneratorParams }; -struct Generator { +struct Generator : TrackedResource { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; diff --git a/src/models/model.h b/src/models/model.h index f52f56499..6760fa300 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -105,7 +105,7 @@ struct SessionInfo { std::unordered_map inputs_, outputs_; }; -struct Model : std::enable_shared_from_this { +struct Model : std::enable_shared_from_this, TrackedResource { Model(std::unique_ptr config); virtual ~Model(); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index eba5aff15..d49f96e5a 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -170,6 +170,8 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { const auto* expected_output_start = &expected_output[i * max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); } + + model.release(); } #endif From a0688dca5a5aa5e08df8cdb05d471512816bed5b Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 13:11:05 -0700 Subject: [PATCH 02/10] Linux fixes --- src/generators.cpp | 4 ++-- src/generators.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index d3be0c6f0..64fcef701 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -20,11 +20,11 @@ std::atomic TrackedResource::count_{}; // Validate process exit conditions, as this is done atexit, we print errors to stderr and throw an exception to stop the process void ValidateShutdown() { if (GetOrtGlobals()) { - std::cerr << "Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; + std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; std::abort(); } if (TrackedResource::Count()) { - std::cerr << "Resources leaked: " + std::to_string(TrackedResource::Count()) + " All Oga resources must be cleaned up before shutdown." << std::endl; + std::cerr << "OGA Error: " + std::to_string(TrackedResource::Count()) + " resources leaked. All Oga resources must be cleaned up before shutdown." << std::endl; std::abort(); } } diff --git a/src/generators.h b/src/generators.h index a6134515b..913d0b4f3 100644 --- a/src/generators.h +++ b/src/generators.h @@ -56,10 +56,10 @@ enum struct DeviceType { std::string to_string(DeviceType device_type); struct TrackedResource { - TrackedResource() { count_++; } - ~TrackedResource() { count_--; } + TrackedResource() { ++count_; } + ~TrackedResource() { --count_; } - static int Count() { return count_; } + static int Count() { return count_.load(); } private: static std::atomic count_; From b8f21b56bd73b42f448c15ede3cfe8de3004dbb7 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 13:35:15 -0700 Subject: [PATCH 03/10] Fix build error on Linux --- src/generators.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/generators.h b/src/generators.h index 913d0b4f3..ca88f7df0 100644 --- a/src/generators.h +++ b/src/generators.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include "filesystem.h" @@ -59,7 +60,7 @@ struct TrackedResource { TrackedResource() { ++count_; } ~TrackedResource() { --count_; } - static int Count() { return count_.load(); } + static int Count() { return count_; } private: static std::atomic count_; From e46d3c7fc682778240aaa3cfd7ad16b6c7a20273 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 20:07:40 -0700 Subject: [PATCH 04/10] Say which type is leaked --- src/generators.cpp | 10 ++++++---- src/generators.h | 27 +++++++++++++++++++++------ src/models/model.h | 6 +++--- src/search.h | 2 +- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 64fcef701..4b53a93cd 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -15,7 +15,11 @@ static bool _ = (Ort::InitApi(), false); OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {} -std::atomic TrackedResource::count_{}; +template +bool LeakTypeList::Dump() { + ((LeakChecked::Count() != 0 ? std::cerr << "OGA Error: " << LeakChecked::Count() << " instances of " << typeid(Types).name() << " were leaked." << std::endl : void()), ...); + return ((LeakChecked::Count() != 0) || ...); +} // Validate process exit conditions, as this is done atexit, we print errors to stderr and throw an exception to stop the process void ValidateShutdown() { @@ -23,10 +27,8 @@ void ValidateShutdown() { std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; std::abort(); } - if (TrackedResource::Count()) { - std::cerr << "OGA Error: " + std::to_string(TrackedResource::Count()) + " resources leaked. All Oga resources must be cleaned up before shutdown." << std::endl; + if (LeakTypes::Dump()) std::abort(); - } } static bool _1 = (std::atexit(ValidateShutdown), false); // Call ValidateShutdown at exit diff --git a/src/generators.h b/src/generators.h index ca88f7df0..564b2ab23 100644 --- a/src/generators.h +++ b/src/generators.h @@ -40,10 +40,13 @@ using cudaStream_t = void*; #include "tensor.h" namespace Generators { +struct GeneratorParams; +struct Generator; struct Model; struct State; struct Search; struct Tokenizer; +struct TokenizerStream; // OgaSequences are a vector of int32 vectors using TokenSequences = std::vector>; @@ -56,17 +59,29 @@ enum struct DeviceType { std::string to_string(DeviceType device_type); -struct TrackedResource { - TrackedResource() { ++count_; } - ~TrackedResource() { --count_; } +template +struct LeakTypeList { + template + static constexpr bool is_tracked = (std::is_same_v || ...); + static bool Dump(); +}; + +using LeakTypes = LeakTypeList; + +template +struct LeakChecked { + static_assert(LeakTypes::is_tracked, "Please add this type to 'TrackedTypes' above"); + + LeakChecked() { ++count_; } + ~LeakChecked() { --count_; } static int Count() { return count_; } private: - static std::atomic count_; + static inline std::atomic count_; }; -struct GeneratorParams : std::enable_shared_from_this, TrackedResource { +struct GeneratorParams : std::enable_shared_from_this, LeakChecked { GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in GeneratorParams(const Model& model); @@ -136,7 +151,7 @@ struct GeneratorParams : std::enable_shared_from_this, TrackedR // The model outlives the GeneratorParams }; -struct Generator : TrackedResource { +struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; diff --git a/src/models/model.h b/src/models/model.h index 6760fa300..552fa4b94 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -48,7 +48,7 @@ struct State { int current_batch_size_{0}; }; -struct TokenizerStream { +struct TokenizerStream : LeakChecked { TokenizerStream(const Tokenizer& tokenizer); const std::string& Decode(int32_t token); @@ -63,7 +63,7 @@ struct TokenizerStream { // Sequence length is vector.size()/count std::vector PadInputs(std::span> sequences, int32_t pad_token_id); -struct Tokenizer : std::enable_shared_from_this { +struct Tokenizer : std::enable_shared_from_this, LeakChecked { Tokenizer(Config& config); std::unique_ptr CreateStream() const; @@ -105,7 +105,7 @@ struct SessionInfo { std::unordered_map inputs_, outputs_; }; -struct Model : std::enable_shared_from_this, TrackedResource { +struct Model : std::enable_shared_from_this, LeakChecked { Model(std::unique_ptr config); virtual ~Model(); diff --git a/src/search.h b/src/search.h index 901cb437f..8eb194131 100644 --- a/src/search.h +++ b/src/search.h @@ -5,7 +5,7 @@ namespace Generators { -struct Search { +struct Search : LeakChecked { Search(const GeneratorParams& params) : params_{params.shared_from_this()} {} virtual ~Search() = default; From 71d4a3cf520380c1b50b5efd0dec35770f185ac9 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 20:13:03 -0700 Subject: [PATCH 05/10] Add help message after leak check --- src/generators.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/generators.cpp b/src/generators.cpp index 4b53a93cd..0f0f27788 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -27,8 +27,10 @@ void ValidateShutdown() { std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; std::abort(); } - if (LeakTypes::Dump()) + if (LeakTypes::Dump()) { + std::cerr << " Please see the documentation for the API being used to ensure proper cleanup." << std::endl; std::abort(); + } } static bool _1 = (std::atexit(ValidateShutdown), false); // Call ValidateShutdown at exit From 7ad9fcdbed27007d84d3027dc748a0b140bd9085 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 20:34:15 -0700 Subject: [PATCH 06/10] Cleanup and rearrange files. --- src/generators.h | 29 ++--------------------------- src/leakcheck.h | 35 +++++++++++++++++++++++++++++++++++ src/tensor.h | 2 +- 3 files changed, 38 insertions(+), 28 deletions(-) create mode 100644 src/leakcheck.h diff --git a/src/generators.h b/src/generators.h index 564b2ab23..634e518b0 100644 --- a/src/generators.h +++ b/src/generators.h @@ -1,8 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. - +// Licensed under the MIT License. #pragma once -// Licensed under the MIT License. #include #include #include @@ -32,6 +31,7 @@ using cudaStream_t = void*; #endif +#include "leakcheck.h" #include "smartptrs.h" #include "models/onnxruntime_api.h" #include "models/debugging.h" @@ -40,13 +40,10 @@ using cudaStream_t = void*; #include "tensor.h" namespace Generators { -struct GeneratorParams; -struct Generator; struct Model; struct State; struct Search; struct Tokenizer; -struct TokenizerStream; // OgaSequences are a vector of int32 vectors using TokenSequences = std::vector>; @@ -59,28 +56,6 @@ enum struct DeviceType { std::string to_string(DeviceType device_type); -template -struct LeakTypeList { - template - static constexpr bool is_tracked = (std::is_same_v || ...); - static bool Dump(); -}; - -using LeakTypes = LeakTypeList; - -template -struct LeakChecked { - static_assert(LeakTypes::is_tracked, "Please add this type to 'TrackedTypes' above"); - - LeakChecked() { ++count_; } - ~LeakChecked() { --count_; } - - static int Count() { return count_; } - - private: - static inline std::atomic count_; -}; - struct GeneratorParams : std::enable_shared_from_this, LeakChecked { GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in GeneratorParams(const Model& model); diff --git a/src/leakcheck.h b/src/leakcheck.h new file mode 100644 index 000000000..27974843e --- /dev/null +++ b/src/leakcheck.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Generators { +struct GeneratorParams; +struct Generator; +struct Model; +struct Search; +struct Tensor; +struct Tokenizer; +struct TokenizerStream; + +template +struct LeakTypeList { + template + static constexpr bool is_tracked = (std::is_same_v || ...); + static bool Dump(); +}; + +using LeakTypes = LeakTypeList; + +template +struct LeakChecked { + static_assert(LeakTypes::is_tracked, "Please add this type to 'TrackedTypes' above"); + + LeakChecked() { ++count_; } + ~LeakChecked() { --count_; } + + static int Count() { return count_; } + + private: + static inline std::atomic count_; +}; + +} // namespace Generators \ No newline at end of file diff --git a/src/tensor.h b/src/tensor.h index 6fcde20c9..25b6dd706 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. namespace Generators { -struct Tensor : std::enable_shared_from_this { +struct Tensor : std::enable_shared_from_this, LeakChecked { Tensor() = default; Tensor(std::unique_ptr ort_tensor) : ort_tensor_{std::move(ort_tensor)} {} From 3256dcc88c24300452678c885637652266c36435 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 20:35:40 -0700 Subject: [PATCH 07/10] Fix linux build error --- src/generators.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generators.cpp b/src/generators.cpp index 0f0f27788..6137591e4 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -17,7 +17,7 @@ OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVE template bool LeakTypeList::Dump() { - ((LeakChecked::Count() != 0 ? std::cerr << "OGA Error: " << LeakChecked::Count() << " instances of " << typeid(Types).name() << " were leaked." << std::endl : void()), ...); + ((LeakChecked::Count() != 0 ? std::cerr << "OGA Error: " << LeakChecked::Count() << " instances of " << typeid(Types).name() << " were leaked." << std::endl : std::cerr), ...); return ((LeakChecked::Count() != 0) || ...); } From cba41f8f0b0458d42768a84b737d0d1fe4d6b2de Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 21:49:11 -0700 Subject: [PATCH 08/10] Polish --- src/leakcheck.h | 8 +++++++- test/c_api_tests.cpp | 2 -- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/leakcheck.h b/src/leakcheck.h index 27974843e..b71161d9f 100644 --- a/src/leakcheck.h +++ b/src/leakcheck.h @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// This file will track the number of instances of each type that are created and destroyed. This is useful for +// debugging memory leaks. To use this, just add the type to the LeakTypeList in this file. Then have that type +// inherit from LeakChecked<(itself)>. +// +// On process exit, ValidateShutdown() will call LeakTypeList::Dump() and print out any types that have leaked. + namespace Generators { struct GeneratorParams; struct Generator; @@ -32,4 +38,4 @@ struct LeakChecked { static inline std::atomic count_; }; -} // namespace Generators \ No newline at end of file +} // namespace Generators \ No newline at end of file diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index d49f96e5a..eba5aff15 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -170,8 +170,6 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { const auto* expected_output_start = &expected_output[i * max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); } - - model.release(); } #endif From 3d24c3f214db5895cceb5d93ada8413629c02ac4 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 15 Aug 2024 22:40:57 -0700 Subject: [PATCH 09/10] Move leak check to Shutdown() --- src/generators.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 6137591e4..c6367435b 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -27,10 +27,6 @@ void ValidateShutdown() { std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; std::abort(); } - if (LeakTypes::Dump()) { - std::cerr << " Please see the documentation for the API being used to ensure proper cleanup." << std::endl; - std::abort(); - } } static bool _1 = (std::atexit(ValidateShutdown), false); // Call ValidateShutdown at exit @@ -42,7 +38,16 @@ GetOrtGlobals() { } void Shutdown() { - GetOrtGlobals().reset(); + auto& globals = GetOrtGlobals(); + if (!globals) + return; + + if (LeakTypes::Dump()) { + std::cerr << " Please see the documentation for the API being used to ensure proper cleanup." << std::endl; + std::abort(); + } + + globals.reset(); // Delete now because on process exit is too late } OrtEnv& GetOrtEnv() { From fa1f947d9e79f85b87d22e4172bc18668d711740 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Fri, 16 Aug 2024 12:39:19 -0700 Subject: [PATCH 10/10] Small cleanup to be more consistent. --- src/generators.cpp | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index c6367435b..237bc4c0e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -15,39 +15,37 @@ static bool _ = (Ort::InitApi(), false); OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {} -template -bool LeakTypeList::Dump() { - ((LeakChecked::Count() != 0 ? std::cerr << "OGA Error: " << LeakChecked::Count() << " instances of " << typeid(Types).name() << " were leaked." << std::endl : std::cerr), ...); - return ((LeakChecked::Count() != 0) || ...); -} - -// Validate process exit conditions, as this is done atexit, we print errors to stderr and throw an exception to stop the process -void ValidateShutdown() { - if (GetOrtGlobals()) { - std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; - std::abort(); +// Ensure Shutdown() has been called before process exit +struct ValidateShutdown { + ~ValidateShutdown() { + if (GetOrtGlobals()) { + std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; + std::abort(); + } } -} - -static bool _1 = (std::atexit(ValidateShutdown), false); // Call ValidateShutdown at exit +}; std::unique_ptr& GetOrtGlobals() { static auto globals = std::make_unique(); + static auto validate = std::make_unique(); // Must be after the above line so the destructor runs before the above destructor return globals; } -void Shutdown() { - auto& globals = GetOrtGlobals(); - if (!globals) - return; +// Used by Shutdown() to display the counts and types of any leaked objects +template +bool LeakTypeList::Dump() { + ((LeakChecked::Count() != 0 ? std::cerr << "OGA Error: " << LeakChecked::Count() << " instances of " << typeid(Types).name() << " were leaked." << std::endl : std::cerr), ...); + return ((LeakChecked::Count() != 0) || ...); +} +void Shutdown() { if (LeakTypes::Dump()) { std::cerr << " Please see the documentation for the API being used to ensure proper cleanup." << std::endl; std::abort(); } - globals.reset(); // Delete now because on process exit is too late + GetOrtGlobals().reset(); // Delete now because on process exit is too late } OrtEnv& GetOrtEnv() {