Skip to content

Commit

Permalink
fix race condition (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
akhoroshev authored Sep 26, 2023
1 parent 327deae commit a54e3e0
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r

auto invalidate = [](const char* type, std::shared_ptr<Request>& req, int ec) {
TM_LOG_WARNING("[verifyRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
// We don't need a barrier there because
// this lambda is called only for new requests
// which are visible only for rank = 0 thread.
req->signal.set_value(ec);
req.reset();
};
Expand Down Expand Up @@ -139,6 +142,12 @@ void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request
check_cuda_error(cudaMemsetAsync(sequence_length.getPtr<int>(), 0, sizeof(int), stream_));
check_cuda_error(cudaStreamSynchronize(stream_));
}

// When the signal is set threads from LlamaV2::forward can exit
// and free inputs/outputs tensors.
// Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
// are accessing the tensors.
llama_->shared_state_->barrier->wait();
if (rank_ == 0) {
r->signal.set_value(ec);
}
Expand Down Expand Up @@ -1112,6 +1121,11 @@ void LlamaBatch<T>::finishRequest(int index, bool force_end)
llama_->kv_cache_mgr_->update(cached_seq_[index], stream_);
}

// When the signal is set threads from LlamaV2::forward can exit
// and free inputs/outputs tensors.
// Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
// are accessing the tensors.
llama_->shared_state_->barrier->wait();
if (rank_ == 0) {
requests_[index]->signal.set_value(0);
}
Expand Down

0 comments on commit a54e3e0

Please sign in to comment.