@@ -67,33 +67,49 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
67
67
generations.push_back (std::make_shared<GenerationHandleImpl>(sequence_group->get_generation_stream (), sequence_group->get_sampling_parameters ()));
68
68
}
69
69
70
+ auto active_sequence_groups{sequence_groups};
71
+
72
+ auto stream_generated_tokens = [&streamer_ptr, &generations, &active_sequence_groups]() {
73
+ GenerationHandle& handle = generations.at (0 );
74
+ if (streamer_ptr && handle->can_read ()) {
75
+ std::unordered_map<uint64_t , GenerationOutput> token = handle->back ();
76
+ for (const auto & gen_token : token.begin ()->second .generated_ids ) {
77
+ if (streamer_ptr->put (gen_token)) {
78
+ handle->drop ();
79
+ break ;
80
+ }
81
+ }
82
+ }
83
+
84
+ // free non running requests
85
+ auto removed_it = std::remove_if (active_sequence_groups.begin (), active_sequence_groups.end (),
86
+ [](SequenceGroup::Ptr sg) -> bool {
87
+ return sg->has_finished () || sg->out_of_memory () || sg->handle_dropped ();
88
+ });
89
+ active_sequence_groups.erase (removed_it, active_sequence_groups.end ());
90
+ };
91
+
70
92
ov::Shape prompts_shape = input_ids.get_shape ();
71
93
const size_t batch_size = prompts_shape[0 ];
72
94
73
95
// Initialize results and performance metrics.
96
+
74
97
EncodedResults results;
75
98
auto & raw_perf_counters = results.perf_metrics .raw_metrics ;
76
99
raw_perf_counters.m_inference_durations = {{ MicroSeconds (0 .0f ) }};
77
100
78
101
// Initialize inputs
79
- if (m_embedding.has_value ())
80
- m_llm.set_tensor (" inputs_embeds" , input_ids);
81
- else
82
- m_llm.set_tensor (" input_ids" , input_ids);
83
-
102
+ m_llm.set_tensor (m_embedding.has_value () ? " inputs_embeds" : " input_ids" , input_ids);
84
103
m_llm.set_tensor (" attention_mask" , attention_mask);
85
-
86
104
if (position_ids.has_value ())
87
105
m_llm.set_tensor (" position_ids" , *position_ids);
88
106
89
107
ov::Tensor beam_idx = ov::Tensor (ov::element::i32 , {batch_size});
90
- auto beam_data = beam_idx.data <int32_t >();
91
- if (selected_beam_idx.has_value ())
92
- beam_data[0 ] = *selected_beam_idx;
93
- else
94
- std::fill_n (beam_data, batch_size, 0 );
108
+ std::fill_n (beam_idx.data <int32_t >(), batch_size, selected_beam_idx.has_value () ? *selected_beam_idx : 0 );
95
109
m_llm.set_tensor (" beam_idx" , beam_idx);
96
110
111
+ // "Prompt" phase
112
+
97
113
const auto infer_start = std::chrono::steady_clock::now ();
98
114
m_llm.infer ();
99
115
const auto infer_end = std::chrono::steady_clock::now ();
@@ -109,35 +125,18 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
109
125
for (auto & sequence_group : sequence_groups) {
110
126
sequence_group->update_processed_tokens_num (sequence_group->get_prompt_len () - sequence_len);
111
127
sequence_group->schedule_tokens (sequence_len);
112
-
113
128
}
114
129
115
130
std::map<size_t , size_t > beam_offets;
116
131
for (size_t i = 0 ; i < sequence_groups.size (); i++)
117
132
beam_offets.insert ({sequence_groups.at (i)->get_request_id (), i});
118
133
119
134
SamplerOutput sampler_output = sampler.sample (sequence_groups, logits);
135
+ stream_generated_tokens ();
120
136
121
- auto active_sequence_groups{sequence_groups};
122
- auto get_active_sequence_groups = [](SequenceGroup::Ptr sg) { return sg->has_finished (); };
123
-
124
- active_sequence_groups.erase (std::remove_if (active_sequence_groups.begin (),
125
- active_sequence_groups.end (),
126
- get_active_sequence_groups),
127
- active_sequence_groups.end ());
128
-
129
- auto stream_generated_tokens = [&streamer_ptr, &generations]() {
130
- if (streamer_ptr && generations.at (0 ).get ()->can_read ()) {
131
- std::unordered_map<uint64_t , GenerationOutput> token = generations.at (0 ).get ()->back ();
132
- for (const auto & gen_token : token.begin ()->second .generated_ids ) {
133
- if (!streamer_ptr->put (gen_token)) {
134
- break ;
135
- }
136
- }
137
- }
138
- };
137
+ // "Generation" phase
139
138
140
- while (active_sequence_groups.size () > 0 ) {
139
+ while (! active_sequence_groups.empty () ) {
141
140
size_t total_num_tokens = 0 ;
142
141
143
142
for (auto & sequence_group : active_sequence_groups) {
@@ -178,20 +177,13 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
178
177
}
179
178
180
179
for (size_t i = 0 ; i < sequence_groups.size (); i++) {
181
- if (i == 0 )
182
- beam_offets[sequence_groups.at (i)->get_request_id ()] = 0 ;
183
- else {
184
- beam_offets[sequence_groups.at (i)->get_request_id ()] = sequence_groups.at (i - 1 )->num_running_seqs () + beam_offets[i -1 ];
185
- }
180
+ beam_offets[sequence_groups.at (i)->get_request_id ()] = i == 0 ? 0 : (sequence_groups.at (i - 1 )->num_running_seqs () + beam_offets[i - 1 ]);
186
181
}
187
182
188
183
if (m_embedding.has_value ()) {
189
184
const ov::Tensor& embed_prompt_tensor = (*m_embedding).infer (new_input_ids);
190
-
191
- m_llm.get_tensor (" inputs_embeds" ).set_shape (embed_prompt_tensor.get_shape ());
192
185
m_llm.set_tensor (" inputs_embeds" , embed_prompt_tensor);
193
186
} else {
194
- m_llm.get_tensor (" input_ids" ).set_shape (new_input_ids.get_shape ());
195
187
m_llm.set_tensor (" input_ids" , new_input_ids);
196
188
}
197
189
@@ -201,7 +193,6 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
201
193
update_position_ids (m_llm.get_tensor (" position_ids" ), m_llm.get_tensor (" attention_mask" ));
202
194
}
203
195
204
- m_llm.get_tensor (" beam_idx" ).set_shape ({ total_num_tokens });
205
196
m_llm.set_tensor (" beam_idx" , ov::Tensor{ov::element::i32 , {total_num_tokens}, next_beams.data ()});
206
197
207
198
const auto infer_start = std::chrono::steady_clock::now ();
@@ -213,36 +204,30 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
213
204
raw_perf_counters.m_new_token_times .emplace_back (infer_end);
214
205
raw_perf_counters.m_batch_sizes .emplace_back (batch_size);
215
206
216
- stream_generated_tokens ();
217
-
218
207
sampler_output = sampler.sample (active_sequence_groups, m_llm.get_tensor (" logits" ));
219
-
220
- active_sequence_groups.erase (std::remove_if (active_sequence_groups.begin (),
221
- active_sequence_groups.end (),
222
- get_active_sequence_groups),
223
- active_sequence_groups.end ());
208
+ stream_generated_tokens ();
224
209
}
225
210
226
- // to stream last token
227
- stream_generated_tokens ();
228
- if (streamer_ptr) {
211
+ if (streamer_ptr) { // push streamer's cache
229
212
streamer_ptr->end ();
230
213
}
231
-
214
+
215
+ // Collect results
216
+
232
217
size_t next_selected_beam = 0 ;
233
218
for (size_t i = 0 ; i < sequence_groups.size (); i++) {
234
219
auto request = sequence_groups[i];
235
- auto generation_outputs = generations[i]->read_all ();
220
+ std::vector<GenerationOutput> generation_outputs;
221
+ auto sampling_params = request->get_sampling_parameters ();
222
+ const auto & sequences = request->get_finished_sequences ();
223
+ size_t num_outputs = std::min (request->get_sampling_parameters ().num_return_sequences , sequences.size ());
236
224
237
- std::sort (generation_outputs. begin (), generation_outputs. end (), [] ( const GenerationOutput& r1, const GenerationOutput& r2 ) {
238
- return r1. score > r2. score ;
239
- } );
225
+ for ( size_t seq_id = 0 ; seq_id < num_outputs; ++seq_id ) {
226
+ const auto & sequence = sequences[seq_id] ;
227
+ const float score = sampling_params. is_beam_search () ? sequence-> get_beam_search_score (sampling_params) : sequence-> get_cumulative_log_probs ( );
240
228
241
- auto num_outputs = std::min (request->get_sampling_parameters ().num_return_sequences , generation_outputs.size ());
242
- for (size_t generation_output_idx = 0 ; generation_output_idx < num_outputs; ++generation_output_idx) {
243
- const auto & generation_output = generation_outputs[generation_output_idx];
244
- results.tokens .push_back (std::move (generation_output.generated_ids ));
245
- results.scores .push_back (generation_output.score );
229
+ results.tokens .push_back (sequence->get_generated_ids ());
230
+ results.scores .push_back (score);
246
231
}
247
232
// next_selected_beam = sampler.last_selected_beam(request);
248
233
}
0 commit comments