Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Sampler a member of the class for llm/vlm pipelines #1347

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
std::string m_templated_chat_history = {};
std::vector<int64_t> m_tokenized_chat_history;
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
Sampler m_sampler;

StatefulLLMPipeline(
const ov::InferRequest& request,
Expand Down Expand Up @@ -89,6 +90,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1)
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());

m_sampler = Sampler(m_tokenizer);
m_sampler.set_seed(m_generation_config.rng_seed);
}

StatefulLLMPipeline(
Expand All @@ -111,6 +115,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {

auto start_time = std::chrono::steady_clock::now();
GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
// If eos_token_id was not provided, take value from default m_generation_config
if (config.eos_token_id == -1)
config.set_eos_token_id(m_generation_config.eos_token_id);
config.validate();

TokenizedInputs encoded_input;

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
Expand Down Expand Up @@ -141,9 +150,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// and find the difference as a prompt, so let's check it out and use the whole history in this case
if (!m_tokenized_chat_history.empty()) {
auto stop_tokens = config.stop_token_ids;
// config could be reset by user and stop_tokens could be empty
// but model/tokenizer still will rely to eos token, so let's add it
stop_tokens.insert(m_tokenizer.get_eos_token_id());
size_t last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = last_same_hist_token == SIZE_MAX;
}
Expand Down Expand Up @@ -343,9 +349,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
requests.push_back(sequence_group);
}

Sampler sampler = Sampler(m_tokenizer);
if (m_sampler.get_seed() != config.rng_seed) {
m_sampler.set_seed(config.rng_seed);
}

std::tie(result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_model_runner, input_tokens, concatenated_attention_mask, streamer_ptr,
sampler, requests, position_ids, std::nullopt, m_selected_beam);
m_sampler, requests, position_ids, std::nullopt, m_selected_beam);

for (auto& request : requests)
m_sampler.clear_request_info(request->get_request_id());
}

if (is_chat_conversation) {
Expand Down
7 changes: 6 additions & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Sampler {
std::map<uint64_t, GroupBeamSearcher> m_beam_search_info;

std::mt19937 rng_engine;
size_t seed = rng_engine.default_seed;
// { request_id, logit_processor }
std::map<uint64_t, LogitProcessor> m_logit_processors;

Expand All @@ -65,7 +66,11 @@ class Sampler {
Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {};

SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
void set_seed(size_t seed) { rng_engine.seed(seed); }
void set_seed(size_t new_seed) {
rng_engine.seed(new_seed);
seed = new_seed;
}
size_t get_seed() { return seed; }

void clear_request_info(uint64_t request_id);

Expand Down
17 changes: 14 additions & 3 deletions src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
bool m_is_chat_conversation;
// InputsEmbedder
std::shared_ptr<InputsEmbedder> m_inputs_embedder;
// Component for applying sampling to lm outputs
Sampler m_sampler;

VLMPipelineImpl(
const std::filesystem::path& models_dir,
Expand Down Expand Up @@ -94,6 +96,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler = Sampler(m_tokenizer);
m_sampler.set_seed(m_generation_config.rng_seed);
}

VLMPipelineImpl(
Expand Down Expand Up @@ -129,6 +134,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler = Sampler(m_tokenizer);
m_sampler.set_seed(m_generation_config.rng_seed);
}

DecodedResults generate(
Expand All @@ -149,8 +157,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist);
}

Sampler sampler = Sampler(m_tokenizer);

std::vector<SequenceGroup::Ptr> requests;
size_t request_id = 0;
size_t block_size = 1; // not used
Expand Down Expand Up @@ -190,10 +196,15 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }};
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), history_size);

if (m_sampler.get_seed() != generation_config.rng_seed) {
m_sampler.set_seed(generation_config.rng_seed);
}

ov::genai::EncodedResults encoded_result;
int32_t m_selected_beam = 0;
std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, sampler, requests,
std::tie(encoded_result, m_selected_beam) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests,
position_ids, m_embedding, std::nullopt);
m_sampler.clear_request_info(0);

DecodedResults decoded;
for (size_t idx = 0; idx < encoded_result.tokens.size(); ++idx) {
Expand Down
7 changes: 5 additions & 2 deletions tests/python_tests/test_chat_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,13 @@ def test_set_chat_template():
model_descr = get_chat_models_list()[0]
model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'))
pipe.get_tokenizer().set_chat_template("{% for message in messages %}{{ message['content'] }}{% endfor %}")
config = ov_genai.GenerationConfig()
config.max_new_tokens = 1
config.do_sample = False
pipe.start_chat()
generated = pipe.generate("a", max_new_tokens=1)
generated = pipe.generate("a", config)
pipe.finish_chat()
reference = pipe.generate("a", max_new_tokens=1)
reference = pipe.generate("a", config)
assert generated == reference

prompts = [
Expand Down
Loading