Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Nov 28, 2024
1 parent e2bc9c5 commit 76b65ab
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 15 deletions.
8 changes: 4 additions & 4 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ class ModelRunner {
size_t num_running_sequences = running_sequences.size();
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
size_t group_position_id = sequence_group->get_num_processed_tokens();
auto prompt_len = sequence_group->get_prompt_len();
size_t tokens_num_to_sample = 0;
size_t prompt_len = sequence_group->get_prompt_len();
size_t seq_len_after_gather = 0;

// spec: In case of multiple input tokens for current sequence (prompt_len > 1),
// context_len corresponds to first token within subgroup of scheduled tokens
Expand All @@ -148,7 +148,7 @@ class ModelRunner {
if (matmul_gathering_is_required) {
if (group_position_id + token_id >= prompt_len - 1) {
gather_indice_values.push_back(gathering_current_index);
tokens_num_to_sample++;
seq_len_after_gather++;
}
}
position_ids_data[token_id] = position_id;
Expand All @@ -169,7 +169,7 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_seq_len_to_sample(tokens_num_to_sample);
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(seq_len_after_gather, num_scheduled_tokens) : num_scheduled_tokens);
}

// typical LLM parameters
Expand Down
5 changes: 2 additions & 3 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,8 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->is_matmul_sliced() ?
sequence_group->get_seq_len_to_sample() : sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(sequence_group->get_num_scheduled_tokens(), batch_seq_len);
size_t actual_seq_len = sequence_group->get_seq_len_to_sample();
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

const auto request_id = sequence_group->get_request_id();
Expand Down
8 changes: 0 additions & 8 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ class SequenceGroup {
bool m_is_gen_paused = false;
// seq len to sample at current iteration
size_t m_seq_len_to_sample = 0;
// flag shows wheather last matmul was sliced
bool m_sliced_matmul = false;

SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
: m_request_id(request_id),
Expand Down Expand Up @@ -399,11 +397,6 @@ class SequenceGroup {

void set_seq_len_to_sample(size_t len) {
m_seq_len_to_sample = len;
m_sliced_matmul = true;
}

bool is_matmul_sliced() const {
return m_sliced_matmul;
}

/**
Expand Down Expand Up @@ -456,7 +449,6 @@ class SequenceGroup {
m_num_scheduled_tokens = 0;
m_num_validation_tokens = 0;
m_seq_len_to_sample = 0;
m_sliced_matmul = false;
}

bool is_scheduled() const {
Expand Down

0 comments on commit 76b65ab

Please sign in to comment.