Skip to content

Commit

Permalink
Move lm model to async infer
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Dec 23, 2024
1 parent 3496d45 commit a4b9743
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
}
}
}
};

// free non running requests
auto free_non_running_requests = [&streamer_ptr, &generations, &active_sequence_groups]() {
auto removed_it = std::remove_if(active_sequence_groups.begin(), active_sequence_groups.end(),
[](SequenceGroup::Ptr sg) -> bool {
return sg->has_finished() || sg->out_of_memory() || sg->handle_dropped();
Expand Down Expand Up @@ -130,7 +131,7 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
beam_offets.insert({sequence_groups.at(i)->get_request_id(), i});

SamplerOutput sampler_output = sampler.sample(sequence_groups, logits);
stream_generated_tokens();
free_non_running_requests();

// "Generation" phase

Expand Down Expand Up @@ -194,7 +195,12 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()});

const auto infer_start = std::chrono::steady_clock::now();
m_llm.infer();
m_llm.start_async();

stream_generated_tokens();

m_llm.wait();

const auto infer_end = std::chrono::steady_clock::now();
const auto infer_ms = PerfMetrics::get_microsec(infer_end - infer_start);
raw_perf_counters.m_inference_durations[0] += MicroSeconds(infer_ms);
Expand All @@ -203,9 +209,10 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);

sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits"));
stream_generated_tokens();
free_non_running_requests();
}

stream_generated_tokens();
if (streamer_ptr) { // push streamer's cache
streamer_ptr->end();
}
Expand Down

0 comments on commit a4b9743

Please sign in to comment.