Skip to content

Commit

Permalink
chat history update
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Nov 18, 2024
1 parent 3f9d590 commit dc8fa68
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 48 deletions.
90 changes: 46 additions & 44 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "sampler.hpp"
#include "lm_encoding.hpp"

#include "debug_utils.hpp"

namespace ov {
namespace genai {

Expand All @@ -30,9 +32,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

bool is_chat_conversation = false;
bool m_is_cache_empty = true;
std::optional<int32_t> m_selected_beam = std::nullopt;
ChatHistory m_history;
std::string m_templated_chat_history = {};
ov::Tensor m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0});
std::vector<int64_t> m_tokenized_chat_history;

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -211,51 +214,47 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

if (is_chat_conversation) {
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));
}
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
bool kv_history_available = m_selected_beam.has_value();

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
if (is_chat_conversation && !m_is_cache_empty) {
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1];

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>();
std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
new_atten_mask.data<int64_t>() + kv_cache_len);
concatenated_attention_mask = new_atten_mask;
if (kv_history_available) {
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
// Between subsequent runs attention_mask should not be modified.
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
auto prompt_len = attention_mask.get_shape()[1];
kv_cache_len = atten_mask_history.get_shape()[1];

ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
auto start_atten_hst = atten_mask_history.data<int64_t>();
std::copy(start_atten_hst, start_atten_hst + kv_cache_len,
new_atten_mask.data<int64_t>());
std::copy(attention_mask.data<int64_t>(), attention_mask.data<int64_t>() + prompt_len,
new_atten_mask.data<int64_t>() + kv_cache_len);
concatenated_attention_mask = new_atten_mask;
} else {
attention_mask = ov::genai::utils::init_attention_mask(tokenized_chat_history);
concatenated_attention_mask = attention_mask;
}
} else {
concatenated_attention_mask = attention_mask;
}

if (is_chat_conversation) {
ov::Tensor new_tokenized_chat_history = ov::Tensor{ov::element::i64, {batch_size, m_tokenized_chat_history.get_shape().at(1) + input_ids.get_shape().at(1)}};
auto start_chat_hst = m_tokenized_chat_history.data<int64_t>();
std::copy(start_chat_hst, start_chat_hst + m_tokenized_chat_history.get_size(), new_tokenized_chat_history.data<int64_t>());
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(),
new_tokenized_chat_history.data<int64_t>() + m_tokenized_chat_history.get_size());

m_tokenized_chat_history = new_tokenized_chat_history;
}

bool position_ids_available = (num_inputs == 4);
std::optional<ov::Tensor> position_ids = std::nullopt;
if (position_ids_available) {
if (is_chat_conversation && config.is_beam_search()) {
position_ids = ov::Tensor{ov::element::i64, m_tokenized_chat_history.get_shape()};
size_t start_pos = kv_cache_len - (m_tokenized_chat_history.get_shape().at(1) - input_ids.get_shape().at(1));
size_t seq_length = m_tokenized_chat_history.get_shape().at(1);

utils::initialize_position_ids(*position_ids, concatenated_attention_mask, seq_length, start_pos);
if (is_chat_conversation && !kv_history_available) {
position_ids = ov::Tensor{ov::element::i64, tokenized_chat_history.get_shape()};
} else {
position_ids = ov::Tensor{ov::element::i64, input_ids.get_shape()};
utils::initialize_position_ids(*position_ids, attention_mask, attention_mask.get_shape()[1], kv_cache_len);
}
utils::initialize_position_ids(*position_ids, attention_mask, kv_cache_len);
}

if(m_adapter_controller) {
Expand All @@ -266,11 +265,10 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
std::vector<SequenceGroup::Ptr> requests;
size_t block_size = 1;
bool enable_prefix_caching = false;

for (size_t request_id = 0; request_id < batch_size; request_id++) {
SequenceGroup::Ptr sequence_group;
if (is_chat_conversation && !m_is_cache_empty) {
sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching);
if (is_chat_conversation) {
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
} else {
size_t seq_len = input_ids.get_shape().at(1);
size_t batch_offset = request_id * seq_len;
Expand All @@ -286,24 +284,26 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

Sampler sampler = Sampler(m_tokenizer);
// we can't properly refer to history in case of chat scenario with beam search, so reset_kv_state and use the whole history for each new propmt
auto input_tokens = (is_chat_conversation && config.is_beam_search()) ? m_tokenized_chat_history : input_ids;
auto input_tokens = input_ids;
if (is_chat_conversation && !kv_history_available) {
input_tokens = tokenized_chat_history;
}
result = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr, sampler, requests, position_ids, std::nullopt);

m_selected_beam = 0;
if (!is_chat_conversation || config.is_beam_search()) {
reset_kv_state();
m_selected_beam = std::nullopt;
}

if (is_chat_conversation) {
m_is_cache_empty = false;
}

if (is_chat_conversation) {
// remove eos token, if it is at the end
auto last_token = result.tokens[0].back() == config.eos_token_id ? result.tokens[0].size() - 1 : result.tokens[0].size();
ov::Tensor new_tokenized_chat_history = ov::Tensor{ov::element::i64, {batch_size, m_tokenized_chat_history.get_shape().at(1) + last_token}};
auto start_chat_hst = m_tokenized_chat_history.data<int64_t>();
std::copy(start_chat_hst, start_chat_hst + m_tokenized_chat_history.get_size(), new_tokenized_chat_history.data<int64_t>());
std::copy(result.tokens[0].begin(), result.tokens[0].begin() + last_token, new_tokenized_chat_history.data<int64_t>() + m_tokenized_chat_history.get_size());

m_tokenized_chat_history = new_tokenized_chat_history;
std::copy(result.tokens[0].begin(), result.tokens[0].begin() + last_token, std::back_inserter(m_tokenized_chat_history));
}

auto stop_time = std::chrono::steady_clock::now();
Expand All @@ -319,12 +319,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void start_chat(const std::string& system_message) override {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
reset_kv_state();
m_is_cache_empty = true;
m_history = {};
m_templated_chat_history = "";
m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0});
m_tokenized_chat_history = {};
}
if (system_message.empty())
return;
Expand All @@ -337,12 +338,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

void finish_chat() override {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
reset_kv_state();
m_is_cache_empty = true;
m_history.clear();
m_templated_chat_history.clear();
m_tokenized_chat_history = ov::Tensor(ov::element::i64, {0, 0});
m_tokenized_chat_history = {};
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ EncodedResults get_lm_encoded_results(
// apply strides to shift to a next sequence
input_ids_data += num_scheduled_tokens;

// for different sequences iteration of beams started from 0, but we collect it to one input_ids#
// for different sequences iteration of beams started from 0, but we collect it to one input_ids
next_beams.push_back(beam_idxs[sequence->get_id()] + beam_offets.at(sequence_group->get_request_id()));
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) {
/**
* Initializes position ids based on attention mask and starting position
*/
void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, size_t seq_length, int64_t start_pos) {
void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos) {
OPENVINO_ASSERT(position_ids.get_element_type() == ov::element::i64,
"position_ids tensor element type should be an i64");
OPENVINO_ASSERT(position_ids.get_shape().size() == 2,
Expand All @@ -72,6 +72,7 @@ void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attenti
"attention_mask tensor should of rank 2 with shape [batch_size, seq_len]");

const size_t batch_size = attention_mask.get_shape()[0];
const size_t seq_length = attention_mask.get_shape()[1];

const int64_t* attention_mask_data = attention_mask.data<int64_t>();
int64_t* position_ids_data = position_ids.data<int64_t>();
Expand All @@ -96,7 +97,7 @@ void initialize_beam_inputs(const ov::Tensor& input_ids, const ov::Tensor& atten

ov::Tensor position_ids = request.get_tensor("position_ids");
position_ids.set_shape(input_shape);
initialize_position_ids(position_ids, attention_mask, attention_mask.get_shape()[1]);
initialize_position_ids(position_ids, attention_mask);

ov::Tensor beam_idx = request.get_tensor("beam_idx");
beam_idx.set_shape({input_shape.at(0)});
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void print_tensor(const ov::Tensor& tensor);

int64_t argmax(const ov::Tensor& logits, const size_t batch_idx);

void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, size_t seq_length, int64_t start_pos = 0);
void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos = 0);

ov::Tensor extend_attention(ov::Tensor attention_mask);

Expand Down

0 comments on commit dc8fa68

Please sign in to comment.