diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 995f15b710..5d8d7d0411 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -30,6 +30,9 @@ void LlamaBatch::verifyRequests(std::vector>& stop_r auto invalidate = [](const char* type, std::shared_ptr& req, int ec) { TM_LOG_WARNING("[verifyRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec); + // We don't need a barrier there because + // this lambda is called only for new requests + // which are visible only for rank = 0 thread. req->signal.set_value(ec); req.reset(); }; @@ -139,6 +142,12 @@ void LlamaBatch::handleStopRequests(const std::vector(), 0, sizeof(int), stream_)); check_cuda_error(cudaStreamSynchronize(stream_)); } + + // When the signal is set threads from LlamaV2::forward can exit + // and free inputs/outputs tensors. + // Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry + // are accessing the tensors. + llama_->shared_state_->barrier->wait(); if (rank_ == 0) { r->signal.set_value(ec); } @@ -1112,6 +1121,11 @@ void LlamaBatch::finishRequest(int index, bool force_end) llama_->kv_cache_mgr_->update(cached_seq_[index], stream_); } + // When the signal is set threads from LlamaV2::forward can exit + // and free inputs/outputs tensors. + // Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry + // are accessing the tensors. + llama_->shared_state_->barrier->wait(); if (rank_ == 0) { requests_[index]->signal.set_value(0); }