Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
olpipi committed Nov 27, 2024
1 parent 678580d commit e2bc9c5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 0 additions & 2 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ class ModelRunner {
gather_indice_values.push_back(gathering_current_index);
tokens_num_to_sample++;
}
} else {
tokens_num_to_sample++;
}
position_ids_data[token_id] = position_id;
}
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,8 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_seq_len_to_sample(); // points to a token which needs to be sampled
size_t actual_seq_len = sequence_group->is_matmul_sliced() ?
sequence_group->get_seq_len_to_sample() : sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(sequence_group->get_num_scheduled_tokens(), batch_seq_len);
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

Expand Down
10 changes: 9 additions & 1 deletion src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ class SequenceGroup {
bool m_is_gen_paused = false;
// seq len to sample at current iteration
size_t m_seq_len_to_sample = 0;

// flag shows wheather last matmul was sliced
bool m_sliced_matmul = false;

SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
: m_request_id(request_id),
Expand Down Expand Up @@ -398,6 +399,11 @@ class SequenceGroup {

void set_seq_len_to_sample(size_t len) {
m_seq_len_to_sample = len;
m_sliced_matmul = true;
}

bool is_matmul_sliced() const {
return m_sliced_matmul;
}

/**
Expand Down Expand Up @@ -449,6 +455,8 @@ class SequenceGroup {
void clear_scheduled_tokens() {
m_num_scheduled_tokens = 0;
m_num_validation_tokens = 0;
m_seq_len_to_sample = 0;
m_sliced_matmul = false;
}

bool is_scheduled() const {
Expand Down

0 comments on commit e2bc9c5

Please sign in to comment.