Skip to content

Commit

Permalink
Fix memory leak (#488)
Browse files Browse the repository at this point in the history
* Fix memory leak

* modern c++
  • Loading branch information
lvhan028 authored Sep 26, 2023
1 parent 97dcdff commit 5d87c20
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ LlamaWeight<T>::~LlamaWeight()

pre_decoder_embedding_table = nullptr;
post_decoder_embedding_kernel = nullptr;

for (auto& p : decoder_layer_weights) {
delete p;
}
}

template<typename T>
Expand Down
6 changes: 3 additions & 3 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
cuda_device_prop_ptr.get());

return std::make_unique<LlamaTritonSharedModelInstance<T>>(
LlamaTritonSharedModelInstance<T>{std::move(llama),
shared_weights_[device_id],
std::move(allocator),
LlamaTritonSharedModelInstance<T>{std::move(allocator),
std::move(cublas_algo_map),
std::move(cublas_wrapper_mutex),
std::move(cublas_wrapper),
std::move(cuda_device_prop_ptr),
shared_weights_[device_id],
std::move(llama),
session_len_});
}

Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ namespace ft = turbomind;

template<typename T>
struct LlamaTritonSharedModelInstance {
std::unique_ptr<ft::LlamaV2<T>> llm;
std::shared_ptr<ft::LlamaWeight<T>> llm_weight;
std::unique_ptr<ft::Allocator<ft::AllocatorType::CUDA>> allocator;
std::unique_ptr<ft::cublasAlgoMap> cublas_algo_map;
std::unique_ptr<std::mutex> cublas_wrapper_mutex;
std::unique_ptr<ft::cublasMMWrapper> cublas_wrapper;
std::unique_ptr<cudaDeviceProp> cuda_device_prop_ptr;
std::shared_ptr<ft::LlamaWeight<T>> llm_weight;
std::unique_ptr<ft::LlamaV2<T>> llm;
const int session_len;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ struct AbstractTransformerModel;
struct AbstractTransformerModelInstance;

struct AbstractTransformerModelInstance {
virtual ~AbstractTransformerModelInstance() = default;
virtual std::shared_ptr<std::vector<triton::Tensor>>
forward(std::shared_ptr<std::vector<triton::Tensor>> input_tensors) = 0;

Expand Down

0 comments on commit 5d87c20

Please sign in to comment.