From dc8fa68e6658a2b75fb52ea7d2ad2cd9cbdced7c Mon Sep 17 00:00:00 2001 From: sbalandi Date: Mon, 18 Nov 2024 17:26:39 +0000 Subject: [PATCH] chat history update --- src/cpp/src/llm_pipeline.cpp | 90 ++++++++++++++++++------------------ src/cpp/src/lm_encoding.cpp | 2 +- src/cpp/src/utils.cpp | 5 +- src/cpp/src/utils.hpp | 2 +- 4 files changed, 51 insertions(+), 48 deletions(-) diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index c2b2b6813f..29714f3dcc 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -21,6 +21,8 @@ #include "sampler.hpp" #include "lm_encoding.hpp" +#include "debug_utils.hpp" + namespace ov { namespace genai { @@ -30,9 +32,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { bool is_chat_conversation = false; bool m_is_cache_empty = true; + std::optional m_selected_beam = std::nullopt; ChatHistory m_history; std::string m_templated_chat_history = {}; - ov::Tensor m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0}); + std::vector m_tokenized_chat_history; StatefulLLMPipeline( const ov::InferRequest& request, @@ -211,51 +214,47 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { "(input_ids, attention_mask, position_ids, beam_idx) " "but you have '" + std::to_string(num_inputs) + "' inputs"); + if (is_chat_conversation) { + std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); + } + ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()); + bool kv_history_available = m_selected_beam.has_value(); size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; if (is_chat_conversation && !m_is_cache_empty) { - OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1"); - // If history is saved in KV cache, concatenate new attention_mask with the already existing. - // Between subsequent runs attention_mask should not be modified. - auto atten_mask_history = m_model_runner.get_tensor("attention_mask"); - auto prompt_len = attention_mask.get_shape()[1]; - kv_cache_len = atten_mask_history.get_shape()[1]; - - ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}}; - auto start_atten_hst = atten_mask_history.data(); - std::copy(start_atten_hst, start_atten_hst + kv_cache_len, - new_atten_mask.data()); - std::copy(attention_mask.data(), attention_mask.data() + prompt_len, - new_atten_mask.data() + kv_cache_len); - concatenated_attention_mask = new_atten_mask; + if (kv_history_available) { + OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1"); + // If history is saved in KV cache, concatenate new attention_mask with the already existing. + // Between subsequent runs attention_mask should not be modified. + auto atten_mask_history = m_model_runner.get_tensor("attention_mask"); + auto prompt_len = attention_mask.get_shape()[1]; + kv_cache_len = atten_mask_history.get_shape()[1]; + + ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}}; + auto start_atten_hst = atten_mask_history.data(); + std::copy(start_atten_hst, start_atten_hst + kv_cache_len, + new_atten_mask.data()); + std::copy(attention_mask.data(), attention_mask.data() + prompt_len, + new_atten_mask.data() + kv_cache_len); + concatenated_attention_mask = new_atten_mask; + } else { + attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history); + concatenated_attention_mask = attention_mask; + } } else { concatenated_attention_mask = attention_mask; } - if (is_chat_conversation) { - ov::Tensor new_tokenized_chat_history = ov::Tensor{ov::element::i64, {batch_size, m_tokenized_chat_history.get_shape().at(1) + input_ids.get_shape().at(1)}}; - auto start_chat_hst = m_tokenized_chat_history.data(); - std::copy(start_chat_hst, start_chat_hst + m_tokenized_chat_history.get_size(), new_tokenized_chat_history.data()); - std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), - new_tokenized_chat_history.data() + m_tokenized_chat_history.get_size()); - - m_tokenized_chat_history = new_tokenized_chat_history; - } - bool position_ids_available = (num_inputs == 4); std::optional position_ids = std::nullopt; if (position_ids_available) { - if (is_chat_conversation && config.is_beam_search()) { - position_ids = ov::Tensor{ov::element::i64, m_tokenized_chat_history.get_shape()}; - size_t start_pos = kv_cache_len - (m_tokenized_chat_history.get_shape().at(1) - input_ids.get_shape().at(1)); - size_t seq_length = m_tokenized_chat_history.get_shape().at(1); - - utils::initialize_position_ids(*position_ids, concatenated_attention_mask, seq_length, start_pos); + if (is_chat_conversation && !kv_history_available) { + position_ids = ov::Tensor{ov::element::i64, tokenized_chat_history.get_shape()}; } else { position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()}; - utils::initialize_position_ids(*position_ids, attention_mask, attention_mask.get_shape()[1], kv_cache_len); } + utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len); } if(m_adapter_controller) { @@ -266,11 +265,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { std::vector requests; size_t block_size = 1; bool enable_prefix_caching = false; - for (size_t request_id = 0; request_id < batch_size; request_id++) { SequenceGroup::Ptr sequence_group; - if (is_chat_conversation && !m_is_cache_empty) { - sequence_group = std::make_shared(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching); + if (is_chat_conversation) { + sequence_group = std::make_shared(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching); } else { size_t seq_len = input_ids.get_shape().at(1); size_t batch_offset = request_id * seq_len; @@ -286,24 +284,26 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { Sampler sampler = Sampler(m_tokenizer); // we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt - auto input_tokens = (is_chat_conversation && config.is_beam_search()) ? m_tokenized_chat_history : input_ids; + auto input_tokens = input_ids; + if (is_chat_conversation && !kv_history_available) { + input_tokens = tokenized_chat_history; + } result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt); + m_selected_beam = 0; if (!is_chat_conversation || config.is_beam_search()) { reset_kv_state(); + m_selected_beam = std::nullopt; } if (is_chat_conversation) { m_is_cache_empty = false; + } + if (is_chat_conversation) { // remove eos token, if it is at the end auto last_token = result.tokens[0].back() == config.eos_token_id ? result.tokens[0].size() - 1 : result.tokens[0].size(); - ov::Tensor new_tokenized_chat_history = ov::Tensor{ov::element::i64, {batch_size, m_tokenized_chat_history.get_shape().at(1) + last_token}}; - auto start_chat_hst = m_tokenized_chat_history.data(); - std::copy(start_chat_hst, start_chat_hst + m_tokenized_chat_history.get_size(), new_tokenized_chat_history.data()); - std::copy(result.tokens[0].begin(), result.tokens[0].begin() + last_token, new_tokenized_chat_history.data() + m_tokenized_chat_history.get_size()); - - m_tokenized_chat_history = new_tokenized_chat_history; + std::copy(result.tokens[0].begin(), result.tokens[0].begin() + last_token, std::back_inserter(m_tokenized_chat_history)); } auto stop_time = std::chrono::steady_clock::now(); @@ -319,12 +319,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void start_chat(const std::string& system_message) override { is_chat_conversation = true; + m_selected_beam = std::nullopt; if (!m_is_cache_empty) { reset_kv_state(); m_is_cache_empty = true; m_history = {}; m_templated_chat_history = ""; - m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0}); + m_tokenized_chat_history = {}; } if (system_message.empty()) return; @@ -337,12 +338,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { void finish_chat() override { is_chat_conversation = false; + m_selected_beam = std::nullopt; if (!m_is_cache_empty) { reset_kv_state(); m_is_cache_empty = true; m_history.clear(); m_templated_chat_history.clear(); - m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0}); + m_tokenized_chat_history = {}; } } }; diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 23c3ba8973..61f27a7a44 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -152,7 +152,7 @@ EncodedResults get_lm_encoded_results( // apply strides to shift to a next sequence input_ids_data += num_scheduled_tokens; - // for different sequences iteration of beams started from 0, but we collect it to one input_ids# + // for different sequences iteration of beams started from 0, but we collect it to one input_ids next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id())); } } diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 2334de4b3a..50c2e0c49e 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -61,7 +61,7 @@ int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) { /** * Initializes position ids based on attention mask and starting position */ -void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, size_t seq_length, int64_t start_pos) { +void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos) { OPENVINO_ASSERT(position_ids.get_element_type() == ov::element::i64, "position_ids tensor element type should be an i64"); OPENVINO_ASSERT(position_ids.get_shape().size() == 2, @@ -72,6 +72,7 @@ void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attenti "attention_mask tensor should of rank 2 with shape [batch_size, seq_len]"); const size_t batch_size = attention_mask.get_shape()[0]; + const size_t seq_length = attention_mask.get_shape()[1]; const int64_t* attention_mask_data = attention_mask.data(); int64_t* position_ids_data = position_ids.data(); @@ -96,7 +97,7 @@ void initialize_beam_inputs(const ov::Tensor& input_ids, const ov::Tensor& atten ov::Tensor position_ids = request.get_tensor("position_ids"); position_ids.set_shape(input_shape); - initialize_position_ids(position_ids, attention_mask, attention_mask.get_shape()[1]); + initialize_position_ids(position_ids, attention_mask); ov::Tensor beam_idx = request.get_tensor("beam_idx"); beam_idx.set_shape({input_shape.at(0)}); diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index fd4e251f59..9adc46c87a 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -18,7 +18,7 @@ void print_tensor(const ov::Tensor& tensor); int64_t argmax(const ov::Tensor& logits, const size_t batch_idx); -void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, size_t seq_length, int64_t start_pos = 0); +void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos = 0); ov::Tensor extend_attention(ov::Tensor attention_mask);