Skip to content

Commit 04d9728

Browse files
[LLM/VLM] Stop generation when streaming callback returns true (#1410)
Affects only stateful VLM and LLM pipelines and CB, SD implementation should be fixed separately as 2 pipelines should be aborted in case of exception / cancel via streaming callback
1 parent 4d18f8b commit 04d9728

File tree

7 files changed

+113
-114
lines changed

7 files changed

+113
-114
lines changed

src/cpp/src/continuous_batching_impl.cpp

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
2222
m_tokenizer = tokenizer;
2323
m_generation_config = generation_config;
2424
m_is_validation_mode_enabled = is_validation_mode_enabled;
25-
25+
2626
ov::Core core;
2727

2828
auto [core_properties, compile_properties] = utils::split_core_compile_config(properties);
@@ -255,18 +255,6 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
255255
}
256256
}, streamer);
257257

258-
OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()),
259-
"Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding");
260-
261-
std::vector<GenerationHandle> generations;
262-
for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) {
263-
OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch.");
264-
generations.push_back(add_request(request_id, input_ids[request_id], sampling_params[request_id]));
265-
}
266-
267-
std::vector<EncodedGenerationResult> results;
268-
results.reserve(m_awaiting_requests.size());
269-
270258
auto drop_requests = [&] () {
271259
for (const std::shared_ptr<ov::genai::SequenceGroup> request : m_requests) {
272260
for (const auto& sequence: request->get_sequences()) {
@@ -279,25 +267,40 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
279267
m_requests.clear();
280268
};
281269

270+
OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 &&
271+
(sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()),
272+
"Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding");
273+
274+
std::vector<GenerationHandle> generations;
275+
for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) {
276+
OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch.");
277+
generations.push_back(add_request(request_id, input_ids[request_id], sampling_params[request_id]));
278+
}
279+
auto all_requests = m_awaiting_requests; // we need to store all requests to get results from them once generation has finished
280+
282281
bool continue_generation = true;
283282
while (has_non_finished_requests() && continue_generation) {
284283
try {
285284
step();
286285
} catch (...) {
287-
drop_requests();
286+
drop_requests(); // remove all requests from pipeline state in case of exception
288287
throw;
289288
}
290-
if (streamer_ptr && generations.at(0)->can_read()) {
291-
std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back();
289+
290+
auto & generation = generations.at(0);
291+
if (streamer_ptr && generation->can_read()) {
292+
std::unordered_map<uint64_t, GenerationOutput> token = generation->back();
292293
for (const auto& gen_token : token.begin()->second.generated_ids) {
293-
if (!streamer_ptr->put(gen_token)) {
294+
continue_generation = !streamer_ptr->put(gen_token);
295+
if (!continue_generation) {
296+
generation->drop();
294297
break;
295298
}
296299
}
297300
}
298301
}
299302

300-
if (streamer_ptr) {
303+
if (streamer_ptr) { // push streamer's cache
301304
streamer_ptr->end();
302305
}
303306

@@ -307,16 +310,32 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
307310
OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed");
308311
}
309312

310-
for (size_t generation_idx = 0; generation_idx < generations.size(); ++generation_idx) {
311-
const auto& generation = generations[generation_idx];
313+
std::vector<EncodedGenerationResult> results;
314+
results.reserve(all_requests.size());
315+
316+
for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) {
317+
const auto& request = all_requests[request_id];
318+
auto sampling_params = request->get_sampling_parameters();
319+
const auto& sequences = request->get_finished_sequences();
320+
size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size());
321+
312322
EncodedGenerationResult result;
313-
result.m_request_id = 1;
314-
std::vector<GenerationOutput> generation_outputs = generation->read_all();
315-
for (const auto& generation_output : generation_outputs) {
316-
result.m_generation_ids.push_back(std::move(generation_output.generated_ids));
317-
result.m_scores.push_back(generation_output.score);
323+
result.m_request_id = request_id;
324+
result.m_generation_ids.resize(num_outputs);
325+
result.m_scores.resize(num_outputs);
326+
327+
for (size_t i = 0; i < num_outputs; ++i) {
328+
const auto & sequence = sequences[i];
329+
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs();
330+
const auto & generated_ids = sequence->get_generated_ids();
331+
332+
if (sampling_params.echo)
333+
result.m_generation_ids[i] = request->get_prompt_ids();
334+
std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i]));
335+
result.m_scores[i] = score;
318336
}
319-
result.m_status = generation->get_status();
337+
338+
result.m_status = generations[request_id]->get_status();
320339
results.push_back(std::move(result));
321340
}
322341

@@ -408,7 +427,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(
408427
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
409428
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
410429
// requests not scheduled, in decoding phase or not echoing are not processed
411-
if (!sequence_group->is_scheduled() || sequence_group->get_context_len() > sequence_group->get_prompt_len() ||
430+
if (!sequence_group->is_scheduled() || sequence_group->get_context_len() > sequence_group->get_prompt_len() ||
412431
!sequence_group->get_sampling_parameters().echo)
413432
continue;
414433

@@ -421,10 +440,10 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(
421440

422441
size_t num_prompt_tokens_processed = sequence_group->get_num_processed_tokens();
423442
OPENVINO_ASSERT(num_prompt_tokens_processed + actual_seq_len <= sequence_group->get_prompt_len());
424-
443+
425444
// if we processed the whole prompt we don't include last logprob as it will be processed by the sampler (it's already completion)
426-
// otherwise we include it as it will be used in the next part of the prompt
427-
int exclude_last_logprob = 1;
445+
// otherwise we include it as it will be used in the next part of the prompt
446+
int exclude_last_logprob = 1;
428447
if (num_prompt_tokens_processed + actual_seq_len < sequence_group->get_prompt_len())
429448
exclude_last_logprob = 0;
430449

@@ -435,7 +454,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(
435454
for (int token_logits_offset = 0, token_id_offset = num_prompt_tokens_processed + 1;
436455
token_logits_offset < actual_seq_len - exclude_last_logprob;
437456
token_logits_offset++, token_id_offset++) {
438-
457+
439458
const float* token_logits = (sequence_group_logits_data + token_logits_offset * vocab_size);
440459
int64_t token_id = sequence_group->get_prompt_ids()[token_id_offset];
441460
float token_logit = token_logits[token_id];

src/cpp/src/generation_handle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ GenerationStatus GenerationHandleImpl::get_status() {
1717
}
1818

1919
bool GenerationHandleImpl::can_read() {
20-
return !is_dropped() && m_generation_stream->can_read();
20+
return !is_dropped() && m_generation_stream->can_read();
2121
}
2222

2323
bool GenerationHandleImpl::is_dropped() {

src/cpp/src/generation_stream.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ class GenerationStream {
1414
GenerationStatus m_status = GenerationStatus::RUNNING;
1515
SynchronizedQueue<GenerationOutputs> m_output_queue;
1616

17-
std::vector<uint64_t> last_sequence_ids;
18-
1917
public:
2018
using Ptr = std::shared_ptr<GenerationStream>;
2119

@@ -30,10 +28,11 @@ class GenerationStream {
3028
m_output_queue.push(std::move(outputs));
3129
}
3230

33-
// Retrieving vector of pairs <sequence_id, token_id> as we can generate multiple outputs for a single prompt
31+
// Retrieving vector of pairs <sequence_id, token_ids> as we can generate multiple outputs for a single prompt
3432
GenerationOutputs back() {
3533
return m_output_queue.back();
3634
}
35+
3736
GenerationOutputs read() {
3837
return m_output_queue.pull();
3938
}

src/cpp/src/llm_pipeline.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
284284
}
285285

286286
auto batch_size = input_ids.get_shape().at(0);
287-
if ((batch_size != 1 || !(config.is_greedy_decoding() || config.is_multinomial())) && streamer_ptr) {
288-
OPENVINO_THROW("Currently streaming is possible only with batch size=1 and "
289-
"only for greedy or multinomial decoding");
290-
}
287+
OPENVINO_ASSERT(streamer_ptr == nullptr || batch_size == 1 && config.num_return_sequences == 1 &&
288+
(config.is_greedy_decoding() || config.is_multinomial()),
289+
"Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding");
291290

292291
auto num_inputs = m_model_runner.get_compiled_model().inputs().size();
293292
OPENVINO_ASSERT(num_inputs == 4 || num_inputs == 3, "Model should have 3 or 4 inputs: "
@@ -587,9 +586,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
587586
std::vector<std::string> plain_replies;
588587
std::vector<float> plain_scores;
589588
for (GenerationResult& res : generated) {
590-
if (GenerationStatus::FINISHED != res.m_status) {
591-
OPENVINO_THROW("Got unfinished GenerationStatus");
592-
}
589+
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
593590
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies));
594591
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
595592
}
@@ -645,9 +642,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
645642
std::vector<std::vector<int64_t>> plain_tokens;
646643
std::vector<float> plain_scores;
647644
for (EncodedGenerationResult& res : generated) {
648-
if (GenerationStatus::FINISHED != res.m_status) {
649-
OPENVINO_THROW("Got unfinished GenerationStatus");
650-
}
645+
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
651646
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens));
652647
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
653648
}

src/cpp/src/lm_encoding.cpp

Lines changed: 45 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -67,33 +67,49 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
6767
generations.push_back(std::make_shared<GenerationHandleImpl>(sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters()));
6868
}
6969

70+
auto active_sequence_groups{sequence_groups};
71+
72+
auto stream_generated_tokens = [&streamer_ptr, &generations, &active_sequence_groups]() {
73+
GenerationHandle& handle = generations.at(0);
74+
if (streamer_ptr && handle->can_read()) {
75+
std::unordered_map<uint64_t, GenerationOutput> token = handle->back();
76+
for (const auto& gen_token : token.begin()->second.generated_ids) {
77+
if (streamer_ptr->put(gen_token)) {
78+
handle->drop();
79+
break;
80+
}
81+
}
82+
}
83+
84+
// free non running requests
85+
auto removed_it = std::remove_if(active_sequence_groups.begin(), active_sequence_groups.end(),
86+
[](SequenceGroup::Ptr sg) -> bool {
87+
return sg->has_finished() || sg->out_of_memory() || sg->handle_dropped();
88+
});
89+
active_sequence_groups.erase(removed_it, active_sequence_groups.end());
90+
};
91+
7092
ov::Shape prompts_shape = input_ids.get_shape();
7193
const size_t batch_size = prompts_shape[0];
7294

7395
// Initialize results and performance metrics.
96+
7497
EncodedResults results;
7598
auto& raw_perf_counters = results.perf_metrics.raw_metrics;
7699
raw_perf_counters.m_inference_durations = {{ MicroSeconds(0.0f) }};
77100

78101
// Initialize inputs
79-
if (m_embedding.has_value())
80-
m_llm.set_tensor("inputs_embeds", input_ids);
81-
else
82-
m_llm.set_tensor("input_ids", input_ids);
83-
102+
m_llm.set_tensor(m_embedding.has_value() ? "inputs_embeds" : "input_ids", input_ids);
84103
m_llm.set_tensor("attention_mask", attention_mask);
85-
86104
if (position_ids.has_value())
87105
m_llm.set_tensor("position_ids", *position_ids);
88106

89107
ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size});
90-
auto beam_data = beam_idx.data<int32_t>();
91-
if (selected_beam_idx.has_value())
92-
beam_data[0] = *selected_beam_idx;
93-
else
94-
std::fill_n(beam_data, batch_size, 0);
108+
std::fill_n(beam_idx.data<int32_t>(), batch_size, selected_beam_idx.has_value() ? *selected_beam_idx : 0);
95109
m_llm.set_tensor("beam_idx", beam_idx);
96110

111+
// "Prompt" phase
112+
97113
const auto infer_start = std::chrono::steady_clock::now();
98114
m_llm.infer();
99115
const auto infer_end = std::chrono::steady_clock::now();
@@ -109,35 +125,18 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
109125
for (auto& sequence_group : sequence_groups) {
110126
sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - sequence_len);
111127
sequence_group->schedule_tokens(sequence_len);
112-
113128
}
114129

115130
std::map<size_t, size_t> beam_offets;
116131
for (size_t i = 0; i < sequence_groups.size(); i++)
117132
beam_offets.insert({sequence_groups.at(i)->get_request_id(), i});
118133

119134
SamplerOutput sampler_output = sampler.sample(sequence_groups, logits);
135+
stream_generated_tokens();
120136

121-
auto active_sequence_groups{sequence_groups};
122-
auto get_active_sequence_groups = [](SequenceGroup::Ptr sg) { return sg->has_finished(); };
123-
124-
active_sequence_groups.erase(std::remove_if(active_sequence_groups.begin(),
125-
active_sequence_groups.end(),
126-
get_active_sequence_groups),
127-
active_sequence_groups.end());
128-
129-
auto stream_generated_tokens = [&streamer_ptr, &generations]() {
130-
if (streamer_ptr && generations.at(0).get()->can_read()) {
131-
std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back();
132-
for (const auto& gen_token : token.begin()->second.generated_ids) {
133-
if (!streamer_ptr->put(gen_token)) {
134-
break;
135-
}
136-
}
137-
}
138-
};
137+
// "Generation" phase
139138

140-
while (active_sequence_groups.size() > 0) {
139+
while (!active_sequence_groups.empty()) {
141140
size_t total_num_tokens = 0;
142141

143142
for (auto& sequence_group : active_sequence_groups) {
@@ -178,20 +177,13 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
178177
}
179178

180179
for (size_t i = 0; i < sequence_groups.size(); i++) {
181-
if (i == 0)
182-
beam_offets[sequence_groups.at(i)->get_request_id()] = 0;
183-
else {
184-
beam_offets[sequence_groups.at(i)->get_request_id()] = sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i -1];
185-
}
180+
beam_offets[sequence_groups.at(i)->get_request_id()] = i == 0 ? 0 : (sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i - 1]);
186181
}
187182

188183
if (m_embedding.has_value()) {
189184
const ov::Tensor& embed_prompt_tensor = (*m_embedding).infer(new_input_ids);
190-
191-
m_llm.get_tensor("inputs_embeds").set_shape(embed_prompt_tensor.get_shape());
192185
m_llm.set_tensor("inputs_embeds", embed_prompt_tensor);
193186
} else {
194-
m_llm.get_tensor("input_ids").set_shape(new_input_ids.get_shape());
195187
m_llm.set_tensor("input_ids", new_input_ids);
196188
}
197189

@@ -201,7 +193,6 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
201193
update_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask"));
202194
}
203195

204-
m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens });
205196
m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()});
206197

207198
const auto infer_start = std::chrono::steady_clock::now();
@@ -213,36 +204,30 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
213204
raw_perf_counters.m_new_token_times.emplace_back(infer_end);
214205
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);
215206

216-
stream_generated_tokens();
217-
218207
sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits"));
219-
220-
active_sequence_groups.erase(std::remove_if(active_sequence_groups.begin(),
221-
active_sequence_groups.end(),
222-
get_active_sequence_groups),
223-
active_sequence_groups.end());
208+
stream_generated_tokens();
224209
}
225210

226-
// to stream last token
227-
stream_generated_tokens();
228-
if (streamer_ptr) {
211+
if (streamer_ptr) { // push streamer's cache
229212
streamer_ptr->end();
230213
}
231-
214+
215+
// Collect results
216+
232217
size_t next_selected_beam = 0;
233218
for (size_t i = 0; i < sequence_groups.size(); i++) {
234219
auto request = sequence_groups[i];
235-
auto generation_outputs = generations[i]->read_all();
220+
std::vector<GenerationOutput> generation_outputs;
221+
auto sampling_params = request->get_sampling_parameters();
222+
const auto& sequences = request->get_finished_sequences();
223+
size_t num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, sequences.size());
236224

237-
std::sort(generation_outputs.begin(), generation_outputs.end(), [] (const GenerationOutput& r1, const GenerationOutput& r2) {
238-
return r1.score > r2.score;
239-
});
225+
for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) {
226+
const auto & sequence = sequences[seq_id];
227+
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_probs();
240228

241-
auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size());
242-
for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) {
243-
const auto& generation_output = generation_outputs[generation_output_idx];
244-
results.tokens.push_back(std::move(generation_output.generated_ids));
245-
results.scores.push_back(generation_output.score);
229+
results.tokens.push_back(sequence->get_generated_ids());
230+
results.scores.push_back(score);
246231
}
247232
// next_selected_beam = sampler.last_selected_beam(request);
248233
}

0 commit comments

Comments
 (0)