Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cache/output length calculation #738

Merged
merged 1 commit into from
Nov 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
auto& seq = *state.sequences[idx];

if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
/// TODO: revise step setting
if (step <= seq.tokens.size()) {
seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step);
Expand Down Expand Up @@ -1258,7 +1257,17 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect

check_cuda_error(cudaStreamSynchronize(stream_));

// invariant: context_length = sequence_length + 1
// `SequenceManager` needs real-time value of cache length
// ! Must be done before incrementing `h_context_length` because the generated token is NOT kv-cached yet
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}

// invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just
// generated) tokens
for (int i = 0; i < batch_size; ++i) {
++state_->h_context_length[i];
}
Expand All @@ -1267,7 +1276,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
int* output_ptr = h_output_ids_;
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
const int count = state_->h_context_length[i] - 1 + int(g.step != g.max_init_ctx_len);
const int count = state_->h_context_length[i];
// TODO: sync history output tokens at when receiving the request and copy only the last token here
std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]);
*h_request_seqlen_ptrs_[i] = count;
Expand All @@ -1284,14 +1293,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
}

// `SequenceManager` needs real-time value of cache length
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}

std::vector<Signal> signals;
{
NvtxScope _("stream_and_completion_signal");
Expand Down Expand Up @@ -1343,8 +1344,7 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig
FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id));
}
else {
// Account for the last generated token if not a stop request (which doesn't generate)
const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(force_stop);
const int output_len = state_->h_context_length[index];
auto& seq = *state_->sequences[index];

// Update token IDs
Expand Down
Loading