diff --git a/src/generators.cpp b/src/generators.cpp index fe005111a..237bc4c0e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -15,13 +15,37 @@ static bool _ = (Ort::InitApi(), false); OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {} -std::unique_ptr& GetOrtGlobals() { +// Ensure Shutdown() has been called before process exit +struct ValidateShutdown { + ~ValidateShutdown() { + if (GetOrtGlobals()) { + std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; + std::abort(); + } + } +}; + +std::unique_ptr& +GetOrtGlobals() { static auto globals = std::make_unique(); + static auto validate = std::make_unique(); // Must be after the above line so the destructor runs before the above destructor return globals; } +// Used by Shutdown() to display the counts and types of any leaked objects +template +bool LeakTypeList::Dump() { + ((LeakChecked::Count() != 0 ? std::cerr << "OGA Error: " << LeakChecked::Count() << " instances of " << typeid(Types).name() << " were leaked." << std::endl : std::cerr), ...); + return ((LeakChecked::Count() != 0) || ...); +} + void Shutdown() { - GetOrtGlobals().reset(); + if (LeakTypes::Dump()) { + std::cerr << " Please see the documentation for the API being used to ensure proper cleanup." << std::endl; + std::abort(); + } + + GetOrtGlobals().reset(); // Delete now because on process exit is too late } OrtEnv& GetOrtEnv() { diff --git a/src/generators.h b/src/generators.h index 7a8f08951..634e518b0 100644 --- a/src/generators.h +++ b/src/generators.h @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. - +// Licensed under the MIT License. #pragma once -// Licensed under the MIT License. #include #include #include +#include #include #include #include "filesystem.h" @@ -31,6 +31,7 @@ using cudaStream_t = void*; #endif +#include "leakcheck.h" #include "smartptrs.h" #include "models/onnxruntime_api.h" #include "models/debugging.h" @@ -55,7 +56,7 @@ enum struct DeviceType { std::string to_string(DeviceType device_type); -struct GeneratorParams : std::enable_shared_from_this { +struct GeneratorParams : std::enable_shared_from_this, LeakChecked { GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in GeneratorParams(const Model& model); @@ -125,7 +126,7 @@ struct GeneratorParams : std::enable_shared_from_this { // The model outlives the GeneratorParams }; -struct Generator { +struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; diff --git a/src/leakcheck.h b/src/leakcheck.h new file mode 100644 index 000000000..b71161d9f --- /dev/null +++ b/src/leakcheck.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file will track the number of instances of each type that are created and destroyed. This is useful for +// debugging memory leaks. To use this, just add the type to the LeakTypeList in this file. Then have that type +// inherit from LeakChecked<(itself)>. +// +// On process exit, ValidateShutdown() will call LeakTypeList::Dump() and print out any types that have leaked. + +namespace Generators { +struct GeneratorParams; +struct Generator; +struct Model; +struct Search; +struct Tensor; +struct Tokenizer; +struct TokenizerStream; + +template +struct LeakTypeList { + template + static constexpr bool is_tracked = (std::is_same_v || ...); + static bool Dump(); +}; + +using LeakTypes = LeakTypeList; + +template +struct LeakChecked { + static_assert(LeakTypes::is_tracked, "Please add this type to 'TrackedTypes' above"); + + LeakChecked() { ++count_; } + ~LeakChecked() { --count_; } + + static int Count() { return count_; } + + private: + static inline std::atomic count_; +}; + +} // namespace Generators \ No newline at end of file diff --git a/src/models/model.h b/src/models/model.h index f52f56499..552fa4b94 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -48,7 +48,7 @@ struct State { int current_batch_size_{0}; }; -struct TokenizerStream { +struct TokenizerStream : LeakChecked { TokenizerStream(const Tokenizer& tokenizer); const std::string& Decode(int32_t token); @@ -63,7 +63,7 @@ struct TokenizerStream { // Sequence length is vector.size()/count std::vector PadInputs(std::span> sequences, int32_t pad_token_id); -struct Tokenizer : std::enable_shared_from_this { +struct Tokenizer : std::enable_shared_from_this, LeakChecked { Tokenizer(Config& config); std::unique_ptr CreateStream() const; @@ -105,7 +105,7 @@ struct SessionInfo { std::unordered_map inputs_, outputs_; }; -struct Model : std::enable_shared_from_this { +struct Model : std::enable_shared_from_this, LeakChecked { Model(std::unique_ptr config); virtual ~Model(); diff --git a/src/search.h b/src/search.h index 901cb437f..8eb194131 100644 --- a/src/search.h +++ b/src/search.h @@ -5,7 +5,7 @@ namespace Generators { -struct Search { +struct Search : LeakChecked { Search(const GeneratorParams& params) : params_{params.shared_from_this()} {} virtual ~Search() = default; diff --git a/src/tensor.h b/src/tensor.h index 6fcde20c9..25b6dd706 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. namespace Generators { -struct Tensor : std::enable_shared_from_this { +struct Tensor : std::enable_shared_from_this, LeakChecked { Tensor() = default; Tensor(std::unique_ptr ort_tensor) : ort_tensor_{std::move(ort_tensor)} {}