Skip to content

Commit

Permalink
Move beam search in case of chat scenario to sampler.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Dec 19, 2024
1 parent 17f4eb3 commit e2752c1
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 576 deletions.
455 changes: 0 additions & 455 deletions src/cpp/src/group_beam_searcher.cpp

This file was deleted.

122 changes: 72 additions & 50 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,22 @@
namespace ov {
namespace genai {

std::pair<EncodedResults, int32_t> beam_search(
ov::InferRequest& lm,
ov::Tensor prompts,
ov::Tensor attention_mask,
GenerationConfig config,
std::optional<ov::Tensor> position_ids,
std::optional<int32_t> selected_beam_idx
);

class StatefulLLMPipeline final : public LLMPipelineImplBase {
public:
ov::InferRequest m_model_runner;
bool is_chat_conversation = false;
bool m_trust_encoded_history = true;
std::optional<int32_t> m_selected_beam = std::nullopt;
ChatHistory m_history;
std::string m_templated_chat_history = {};
std::vector<int64_t> m_tokenized_chat_history;
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
size_t m_to_remove_from_hist = 0;
size_t m_kv_cache_seq_length_axis = 2;
// Tail of previous output in chat mode is missing in KV cache, let's keep it
std::optional<int64_t> m_last_disappeared_token = std::nullopt;
// If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_to_remove_from_hist = {0, 0};

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -160,8 +155,17 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (last_same_hist_token != SIZE_MAX) {
m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token;
} else if (last_same_hist_token != SIZE_MAX || m_to_remove_from_hist.is_kv_cache_need_to_update()) {
// is_kv_cache_need_to_update will be true here if beam search is activated
// in beam search mode we want to remove all history about last model answer from kv cache and add best answer directly
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_to_remove_from_hist
if (m_to_remove_from_hist.is_kv_cache_need_to_update()) {
last_same_hist_token = m_to_remove_from_hist.last_hist_token_to_unchange;
} else {
m_to_remove_from_hist.num_token_to_remove_from_kv_cache = m_tokenized_chat_history.size() - last_same_hist_token;
// if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it
m_to_remove_from_hist.num_token_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;
}

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token},
Expand All @@ -174,12 +178,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token});
new_tensor.copy_to(encoded_input.input_ids);
encoded_input.attention_mask = new_attention_mask;

m_selected_beam = std::nullopt;
m_last_disappeared_token = std::nullopt;
} else {
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;

m_tokenized_chat_history.clear();
m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size());
std::copy_n(new_chat_tokens.input_ids.data<int64_t>(), new_chat_tokens.input_ids.get_size(),
Expand Down Expand Up @@ -261,6 +265,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));

// Tail of previous output in chat mode is missing in KV cache.
if (is_chat_conversation && m_last_disappeared_token.has_value()) {
attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1);
input_ids = ov::genai::utils::push_front_inputs(input_ids, *m_last_disappeared_token);
}

GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;

// If eos_token_id was not provided, take value from default m_generation_config
Expand Down Expand Up @@ -292,7 +302,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller);
ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist.num_token_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
Expand All @@ -302,10 +312,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist;

kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist.num_token_to_remove_from_kv_cache;

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
auto start_atten_hst = atten_mask_history.data<int64_t>();

std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
Expand All @@ -315,6 +327,8 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
concatenated_attention_mask = attention_mask;
}

size_t prev_attn_mask_size = concatenated_attention_mask.get_shape()[1];

bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
Expand All @@ -328,48 +342,56 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

if (is_chat_conversation && !m_trust_encoded_history) {
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_to_remove_from_hist.reset();
}

ov::genai::EncodedResults result;
if (config.is_beam_search() && is_chat_conversation) {
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask,
config, position_ids, m_selected_beam);
} else {
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
}
for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation) {
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
const int64_t* prompt_start = input_ids.data<const int64_t>() + batch_offset;
std::vector<int64_t> tokenized_prompt(prompt_start, prompt_start + seq_len);

sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_prompt, config, block_size, enable_prefix_caching);
}

Sampler sampler = Sampler(m_tokenizer);
std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr,
sampler, requests, position_ids, std::nullopt, m_selected_beam);
sequence_group->set_sequence_group_ptr(sequence_group);
requests.push_back(sequence_group);
}

Sampler sampler = Sampler(m_tokenizer);
ov::genai::EncodedResults result = ov::genai::get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, sampler, requests, position_ids, std::nullopt);

if (is_chat_conversation) {
// force remove from kv_cache last answer
if (config.is_beam_search() && m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_to_remove_from_hist.last_hist_token_to_unchange = m_tokenized_chat_history.size();
m_to_remove_from_hist.num_token_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

// in chat mode it will be only one requests
if (requests[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH)
m_last_disappeared_token = result.tokens[0].back();
else
m_last_disappeared_token = std::nullopt;

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
} else {
reset_kv_state();
m_selected_beam = std::nullopt;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -383,10 +405,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void start_chat(const std::string& system_message) override {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_to_remove_from_hist.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_history = {};
Expand All @@ -404,10 +426,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void finish_chat() override {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
m_trust_encoded_history = true;
m_to_remove_from_hist = 0;
m_to_remove_from_hist.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
if (!m_tokenized_chat_history.empty()) {
reset_kv_state();
m_history.clear();
Expand Down
60 changes: 22 additions & 38 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
#include <regex>
#include <vector>

#include "utils.hpp"
#include "debug_utils.hpp"
#include "lm_encoding.hpp"
#include "openvino/genai/perf_metrics.hpp"

#include "debug_utils.hpp"

#include "utils.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -51,16 +50,15 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
}


std::pair<EncodedResults, int32_t> get_lm_encoded_results(
EncodedResults get_lm_encoded_results(
ov::InferRequest& m_llm,
const ov::Tensor& input_ids,
const ov::Tensor& attention_mask,
const std::shared_ptr<StreamerBase>& streamer_ptr,
Sampler& sampler,
std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids,
std::optional<EmbeddingsModel> m_embedding,
std::optional<int32_t> selected_beam_idx
std::optional<EmbeddingsModel> m_embedding
) {
std::vector<GenerationHandle> generations;
for (SequenceGroup::Ptr sequence_group : sequence_groups) {
Expand Down Expand Up @@ -88,10 +86,7 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(

ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {batch_size});
auto beam_data = beam_idx.data<int32_t>();
if (selected_beam_idx.has_value())
beam_data[0] = *selected_beam_idx;
else
std::fill_n(beam_data, batch_size, 0);
std::fill_n(beam_data, batch_size, 0);
m_llm.set_tensor("beam_idx", beam_idx);

const auto infer_start = std::chrono::steady_clock::now();
Expand All @@ -109,7 +104,6 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
for (auto& sequence_group : sequence_groups) {
sequence_group->update_processed_tokens_num(sequence_group->get_prompt_len() - sequence_len);
sequence_group->schedule_tokens(sequence_len);

}

std::map<size_t, size_t> beam_offets;
Expand Down Expand Up @@ -172,38 +166,34 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
// apply strides to shift to a next sequence
input_ids_data += num_scheduled_tokens;

// for different sequences iteration of beams started from 0, but we collect it to one input_ids#
// for different sequences iteration of beams started from 0, but we collect it to one input_ids
next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id()));
}
}

for (size_t i = 0; i < sequence_groups.size(); i++) {
for (size_t i = 0; i < active_sequence_groups.size(); i++) {
if (i == 0)
beam_offets[sequence_groups.at(i)->get_request_id()] = 0;
beam_offets[active_sequence_groups.at(i)->get_request_id()] = 0;
else {
beam_offets[sequence_groups.at(i)->get_request_id()] = sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i -1];
beam_offets[active_sequence_groups.at(i)->get_request_id()] = active_sequence_groups.at(i - 1)->num_running_seqs() + beam_offets[i -1];
}
}

if (m_embedding.has_value()) {
const ov::Tensor& embed_prompt_tensor = (*m_embedding).infer(new_input_ids);

m_llm.get_tensor("inputs_embeds").set_shape(embed_prompt_tensor.get_shape());
m_llm.set_tensor("inputs_embeds", embed_prompt_tensor);
} else {
m_llm.get_tensor("input_ids").set_shape(new_input_ids.get_shape());
m_llm.set_tensor("input_ids", new_input_ids);
}

m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()});

update_attention_mask_with_beams(m_llm.get_tensor("attention_mask"), next_beams);

if (position_ids.has_value()) {
update_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask"));
}

m_llm.get_tensor("beam_idx").set_shape({ total_num_tokens });
m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()});

const auto infer_start = std::chrono::steady_clock::now();
m_llm.infer();
const auto infer_end = std::chrono::steady_clock::now();
Expand All @@ -228,26 +218,20 @@ std::pair<EncodedResults, int32_t> get_lm_encoded_results(
if (streamer_ptr) {
streamer_ptr->end();
}

size_t next_selected_beam = 0;
for (size_t i = 0; i < sequence_groups.size(); i++) {
auto request = sequence_groups[i];
auto generation_outputs = generations[i]->read_all();

std::sort(generation_outputs.begin(), generation_outputs.end(), [] (const GenerationOutput& r1, const GenerationOutput& r2) {
return r1.score > r2.score;
});

auto num_outputs = std::min(request->get_sampling_parameters().num_return_sequences, generation_outputs.size());
for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) {
const auto& generation_output = generation_outputs[generation_output_idx];
results.tokens.push_back(std::move(generation_output.generated_ids));
results.scores.push_back(generation_output.score);

for (auto& sequence_group : sequence_groups) {
// sequences is sorted by cumulative_log_prob with length_penalty
auto outputs = sequence_group->get_finished_sequences();

auto num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, outputs.size());
for (size_t output_idx = 0; output_idx < num_outputs; ++output_idx) {
const auto& output = outputs[output_idx];
results.tokens.push_back(output->get_generated_ids());
results.scores.push_back(output->get_cumulative_score_with_length_penalty(sequence_group->get_sampling_parameters()));
}
// next_selected_beam = sampler.last_selected_beam(request);
}

return {results, next_selected_beam};
return results;
}

} // namespace genai
Expand Down
Loading

0 comments on commit e2752c1

Please sign in to comment.