-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add slice before matmut transformation for CB scenario #1261
base: master
Are you sure you want to change the base?
Conversation
76b65ab
to
1523fab
Compare
1523fab
to
e82d3ce
Compare
5fc1067
to
c82d6d1
Compare
c82d6d1
to
894439b
Compare
fa0a951
to
3a6066a
Compare
@@ -37,6 +37,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con | |||
|
|||
utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction); | |||
utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction); | |||
utils::apply_gather_before_matmul_transformation(main_model); | |||
utils::apply_gather_before_matmul_transformation(draft_model); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please, the same to prompt lookup decoding pipeline
// context_len corresponds to first token within subgroup of scheduled tokens | ||
size_t group_context_len = group_position_id; | ||
// Next variables are only for sliced matmul case | ||
size_t actual_seq_len = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should per sequence, not per sequence group, but you increment this value within a loop over num_running_sequences
, so finally you will have seq_len * num_running_sequences
instead of seq_len
@@ -756,8 +756,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups, | |||
continue; | |||
|
|||
size_t num_running_sequences = sequence_group->num_running_seqs(); | |||
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled | |||
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len); | |||
size_t actual_seq_len = sequence_group->get_seq_len_to_sample(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_t actual_seq_len = sequence_group->get_seq_len_to_sample(); | |
size_t output_seq_len = sequence_group->get_output_seq_len(); |
IMO, it's a better name as by this function we define a number of tokens in last matmul
@@ -153,6 +173,7 @@ class ModelRunner { | |||
subsequence_begins_data += 1; | |||
block_indices_begins_data += 1; | |||
} | |||
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(actual_seq_len, num_scheduled_tokens) : num_scheduled_tokens); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(actual_seq_len, num_scheduled_tokens) : num_scheduled_tokens); | |
sequence_group->set_seq_len_to_sample(seq_len); |
why not just as simple as suggested?
|
||
position_ids_data[token_id] = position_id; | ||
|
||
if (matmul_gathering_is_required && sampling_is_required) { | ||
if (echo_output || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like echo
is broken here as sampling_is_required
can be false
when we process prompt using several iterations (e.g. with dynamic split fuse and batch size 256, we require 4 iterations to process prompt 1024)
|
||
if (matmul_gathering_is_required && sampling_is_required) { | ||
if (echo_output || | ||
group_position_id + token_id >= prompt_len - 1 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this condition? it says that token is actually not prompt token, which is guaranteed by sampling_is_required
if (matmul_gathering_is_required && sampling_is_required) { | ||
if (echo_output || | ||
group_position_id + token_id >= prompt_len - 1 && | ||
group_position_id + token_id >= num_scheduled_tokens - tokens_to_sample_per_sequence) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this condition do?
let's consider example:
group_position_id (the same as content length / KV cache size) is 16
we have scheduled 1 token (because of KV cache limitations)
number of candidates for speculative decoding is 3, so tokens_to_sample_per_sequence is 4
so, the condition is:
16 + 0 >= 1 - 4
CVS-154930
CVS-155533