Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for resource leaks on Shutdown() to help users #799

Merged
merged 10 commits into from
Aug 17, 2024
28 changes: 26 additions & 2 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,37 @@ static bool _ = (Ort::InitApi(), false);

OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {}

std::unique_ptr<OrtGlobals>& GetOrtGlobals() {
// Ensure Shutdown() has been called before process exit
struct ValidateShutdown {
~ValidateShutdown() {
if (GetOrtGlobals()) {
std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl;
std::abort();
}
}
};

std::unique_ptr<OrtGlobals>&
GetOrtGlobals() {
static auto globals = std::make_unique<OrtGlobals>();
static auto validate = std::make_unique<ValidateShutdown>(); // Must be after the above line so the destructor runs before the above destructor
return globals;
}

// Used by Shutdown() to display the counts and types of any leaked objects
template <typename... Types>
bool LeakTypeList<Types...>::Dump() {
((LeakChecked<Types>::Count() != 0 ? std::cerr << "OGA Error: " << LeakChecked<Types>::Count() << " instances of " << typeid(Types).name() << " were leaked." << std::endl : std::cerr), ...);
return ((LeakChecked<Types>::Count() != 0) || ...);
}

void Shutdown() {
GetOrtGlobals().reset();
if (LeakTypes::Dump()) {
std::cerr << " Please see the documentation for the API being used to ensure proper cleanup." << std::endl;
std::abort();
}

GetOrtGlobals().reset(); // Delete now because on process exit is too late
}

OrtEnv& GetOrtEnv() {
Expand Down
9 changes: 5 additions & 4 deletions src/generators.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

// Licensed under the MIT License.
#pragma once

// Licensed under the MIT License.
#include <algorithm>
#include <array>
#include <assert.h>
#include <atomic>
#include <cmath>
#include <cstring>
#include "filesystem.h"
Expand All @@ -31,6 +31,7 @@
using cudaStream_t = void*;
#endif

#include "leakcheck.h"
#include "smartptrs.h"
#include "models/onnxruntime_api.h"
#include "models/debugging.h"
Expand All @@ -55,7 +56,7 @@ enum struct DeviceType {

std::string to_string(DeviceType device_type);

struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {
struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChecked<GeneratorParams> {
GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in
GeneratorParams(const Model& model);

Expand Down Expand Up @@ -125,7 +126,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {
// The model outlives the GeneratorParams
};

struct Generator {
struct Generator : LeakChecked<Generator> {
Generator(const Model& model, const GeneratorParams& params);

bool IsDone() const;
Expand Down
41 changes: 41 additions & 0 deletions src/leakcheck.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// This file will track the number of instances of each type that are created and destroyed. This is useful for
// debugging memory leaks. To use this, just add the type to the LeakTypeList in this file. Then have that type
// inherit from LeakChecked<(itself)>.
//
// On process exit, ValidateShutdown() will call LeakTypeList::Dump() and print out any types that have leaked.

namespace Generators {
struct GeneratorParams;
struct Generator;
struct Model;
struct Search;
struct Tensor;
struct Tokenizer;
struct TokenizerStream;

template <typename... Types>
struct LeakTypeList {
template <typename T>
static constexpr bool is_tracked = (std::is_same_v<T, Types> || ...);
static bool Dump();
};

using LeakTypes = LeakTypeList<GeneratorParams, Generator, Model, Search, Tensor, Tokenizer, TokenizerStream>;

template <typename T>
struct LeakChecked {
static_assert(LeakTypes::is_tracked<T>, "Please add this type to 'TrackedTypes' above");

LeakChecked() { ++count_; }
~LeakChecked() { --count_; }

static int Count() { return count_; }

private:
static inline std::atomic<int> count_;
};

} // namespace Generators
6 changes: 3 additions & 3 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct State {
int current_batch_size_{0};
};

struct TokenizerStream {
struct TokenizerStream : LeakChecked<TokenizerStream> {
TokenizerStream(const Tokenizer& tokenizer);

const std::string& Decode(int32_t token);
Expand All @@ -63,7 +63,7 @@ struct TokenizerStream {
// Sequence length is vector.size()/count
std::vector<int32_t> PadInputs(std::span<std::span<const int32_t>> sequences, int32_t pad_token_id);

struct Tokenizer : std::enable_shared_from_this<Tokenizer> {
struct Tokenizer : std::enable_shared_from_this<Tokenizer>, LeakChecked<Tokenizer> {
Tokenizer(Config& config);

std::unique_ptr<TokenizerStream> CreateStream() const;
Expand Down Expand Up @@ -105,7 +105,7 @@ struct SessionInfo {
std::unordered_map<std::string, ONNXTensorElementDataType> inputs_, outputs_;
};

struct Model : std::enable_shared_from_this<Model> {
struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
Model(std::unique_ptr<Config> config);
virtual ~Model();

Expand Down
2 changes: 1 addition & 1 deletion src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace Generators {

struct Search {
struct Search : LeakChecked<Search> {
Search(const GeneratorParams& params) : params_{params.shared_from_this()} {}
virtual ~Search() = default;

Expand Down
2 changes: 1 addition & 1 deletion src/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.
namespace Generators {

struct Tensor : std::enable_shared_from_this<Tensor> {
struct Tensor : std::enable_shared_from_this<Tensor>, LeakChecked<Tensor> {
Tensor() = default;
Tensor(std::unique_ptr<OrtValue> ort_tensor) : ort_tensor_{std::move(ort_tensor)} {}

Expand Down
Loading