Skip to content

Commit

Permalink
StaticLLMPipeline: Support more generation options (#1431)
Browse files Browse the repository at this point in the history
Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
TolyaTalamanov and ilya-lavrenov authored Jan 4, 2025
1 parent 002f84f commit 31d632b
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 50 deletions.
96 changes: 67 additions & 29 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "llm_pipeline_static.hpp"

#include "sampler.hpp"

#include <fstream>
#include <regex>

Expand Down Expand Up @@ -235,12 +237,12 @@ enum class GenerateHint {

std::string to_string(GenerateHint h) {
switch(h) {
case GenerateHint::FAST_COMPILE :
case GenerateHint::FAST_COMPILE :
return "FAST_COMPILE";
case GenerateHint::BEST_PERF :
case GenerateHint::BEST_PERF :
return "BEST_PERF";
default:
OPENVINO_THROW("Unsupported value for type GenerateHint provided");
OPENVINO_THROW("Unsupported value for type GenerateHint provided");
}
}

Expand Down Expand Up @@ -632,6 +634,19 @@ void copy_columns_by_row_chunks(const ov::Tensor& src, ov::Tensor& dst) {
}
}

void stream_generated_tokens(std::shared_ptr<ov::genai::StreamerBase> streamer_ptr,
ov::genai::GenerationHandle& handle) {
if (streamer_ptr && handle->can_read()) {
std::unordered_map<uint64_t, ov::genai::GenerationOutput> token = handle->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
handle->drop();
break;
}
}
}
}

} // anonymous namespace

namespace ov {
Expand All @@ -643,7 +658,8 @@ StaticLLMPipeline::StaticLLMPipeline(
const std::string& device,
const ov::AnyMap& config
) : LLMPipelineImplBase(tokenizer,
utils::from_config_json_if_exists(models_path)) {
utils::from_config_json_if_exists(models_path)),
m_sampler(m_tokenizer) {
auto properties = config;
/* NB: Static LLM pipeline consists of two models,
first to process the input prompt (prefill),
Expand Down Expand Up @@ -672,6 +688,8 @@ StaticLLMPipeline::StaticLLMPipeline(
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler.set_seed(m_generation_config.rng_seed);
};

StaticLLMPipeline::StaticLLMPipeline(
Expand All @@ -688,8 +706,7 @@ StaticLLMPipeline::StaticLLMPipeline(
const std::string& device,
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config
) : LLMPipelineImplBase(tokenizer, generation_config) {

) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
bool use_blobs = false;
auto anyopt = get_option<bool>(properties, "USE_BLOBS");
if (anyopt.has_value()) {
Expand All @@ -708,6 +725,8 @@ StaticLLMPipeline::StaticLLMPipeline(
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler.set_seed(m_generation_config.rng_seed);
}

void StaticLLMPipeline::setupAndCompileModels(
Expand Down Expand Up @@ -955,7 +974,10 @@ EncodedResults StaticLLMPipeline::generate(
attention_mask = data->attention_mask;
}

if (input_ids.get_shape().at(0) > 1u) {
ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];

if (batch_size > 1u) {
OPENVINO_THROW("Currently only batch size=1 is supported");
}

Expand All @@ -974,12 +996,14 @@ EncodedResults StaticLLMPipeline::generate(
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

if (!config.is_greedy_decoding()) {
OPENVINO_THROW("Currently only greedy decoding is supported");
if (!config.is_greedy_decoding() && !config.is_multinomial()) {
OPENVINO_THROW("Currently only greedy and multinomial decoding are supported");
}

if (config.num_return_sequences != 1u) {
OPENVINO_THROW("Currently only \"num_return_sequences\" equal to 1 is supported!");
}

ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];
ov::genai::EncodedResults results;
auto& raw_perf_counters = results.perf_metrics.raw_metrics;
// NB: Only batch=1 is supported now
Expand Down Expand Up @@ -1016,11 +1040,21 @@ EncodedResults StaticLLMPipeline::generate(

// NB: Now there are prompt_len tokens in KV-cache
m_kvcache_desc.num_stored_tokens += static_cast<uint32_t>(prompt_len);
int64_t last_token = utils::argmax(m_prefill_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);
if (streamer_ptr && streamer_ptr->put(last_token)) {
return results;
}

auto logits = m_prefill_request.get_tensor("logits");
int64_t output_sequence_len = logits.get_shape().at(1);

auto sequence_group = std::make_shared<SequenceGroup>(
0 /* request_id */, padded_input_ids, config, 1 /* block_size */);
sequence_group->update_processed_tokens_num(m_kvcache_desc.max_prompt_size - output_sequence_len);
sequence_group->schedule_tokens(output_sequence_len);

// NB: Controls what tokens are ready to be pushed into the streamer
GenerationHandle handle = std::make_shared<GenerationHandleImpl>(
sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters());

SamplerOutput sampler_output = m_sampler.sample({sequence_group}, logits);
stream_generated_tokens(streamer_ptr, handle);

// Outputs: logits, ...
const auto kStartOutputKVCacheLayers = 1u;
Expand Down Expand Up @@ -1061,30 +1095,28 @@ EncodedResults StaticLLMPipeline::generate(
std::fill(attention_mask_data, attention_mask_data + m_kvcache_desc.num_stored_tokens - 1u, 1u);
attention_mask_data[m_kvcache_desc.total_size - 1] = 1u;

const size_t max_tokens = config.get_max_new_tokens(prompt_len);
for (int i = 0; i < max_tokens - 1; ++i) {
input_ids_data[0] = last_token;
while (sequence_group->is_running()) {
sequence_group->schedule_tokens(1);
const auto running_sequences = sequence_group->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() == 1u);

input_ids_data[0] = running_sequences.front()->get_generated_ids().back();
position_ids_data[0] = m_kvcache_desc.num_stored_tokens;
attention_mask_data[m_kvcache_desc.num_stored_tokens - 1] = 1u;

m_kvcache_request.infer();
m_kvcache_desc.num_stored_tokens += 1;

last_token = utils::argmax(m_kvcache_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);

raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now());
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);
if (streamer_ptr && streamer_ptr->put(last_token)) {
break;
}

if (last_token == config.eos_token_id && !config.ignore_eos) {
break;
}
SamplerOutput sampler_output = m_sampler.sample(
{sequence_group}, m_kvcache_request.get_tensor("logits"));
stream_generated_tokens(streamer_ptr, handle);

// NB: KV-cache is full, further generation is impossible
if (m_kvcache_desc.num_stored_tokens == m_kvcache_desc.total_size) {
sequence_group->set_out_of_memory();
break;
}

Expand All @@ -1108,6 +1140,12 @@ EncodedResults StaticLLMPipeline::generate(
streamer_ptr->end();
}

OPENVINO_ASSERT(sequence_group->get_finished_sequences().size() == 1u);
auto sequence = sequence_group->get_finished_sequences().front();
results.tokens[0] = sequence->get_generated_ids();
results.scores[0] = sequence->get_cumulative_log_prob();
m_sampler.clear_request_info(sequence_group->get_request_id());

auto stop_time = std::chrono::steady_clock::now();
// If is called without tokenization then that stat will not be reported.
auto& metrics = results.perf_metrics;
Expand Down
5 changes: 4 additions & 1 deletion src/cpp/src/llm_pipeline_static.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <filesystem>

#include "llm_pipeline_base.hpp"
#include "sampler.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -77,6 +78,8 @@ class StaticLLMPipeline final : public LLMPipelineImplBase {
bool v_tensors_transposed;
};

Sampler m_sampler;

KVCacheDesc m_kvcache_desc;
ov::InferRequest m_kvcache_request;
ov::InferRequest m_prefill_request;
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2023-2024 Intel Corporation
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "sampler.hpp"
Expand Down Expand Up @@ -743,7 +743,7 @@ process_stop_strings(const std::set<std::string>& stop_strings, Tokenizer& token
return result;
}

SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
SamplerOutput Sampler::sample(const std::vector<SequenceGroup::Ptr> & sequence_groups,
ov::Tensor logits,
bool is_validation_mode_enabled) {
const float * logits_data = logits.data<float>();
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

// Copyright (C) 2023-2024 Intel Corporation
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once
Expand Down Expand Up @@ -67,7 +67,7 @@ class Sampler {
Sampler() = default;
Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {};

SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
SamplerOutput sample(const std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
void set_seed(size_t new_seed) {
rng_engine.seed(new_seed);
seed = new_seed;
Expand Down
6 changes: 3 additions & 3 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2023-2024 Intel Corporation
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once
Expand Down Expand Up @@ -292,8 +292,8 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
}

size_t num_finished_seqs() const {
return std::count_if(m_sequences.begin(), m_sequences.end(), [] (Sequence::CPtr seq) {
return seq->has_finished();
return std::count_if(m_sequences.begin(), m_sequences.end(), [this] (Sequence::CPtr seq) {
return seq->has_finished() || seq->out_of_memory() || handle_dropped();
});
}

Expand Down
Loading

0 comments on commit 31d632b

Please sign in to comment.