Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add slice before matmut transformation for CB scenario #1261

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

olpipi
Copy link
Collaborator

@olpipi olpipi commented Nov 27, 2024

CVS-154930
CVS-155533

src/cpp/src/model_runner.hpp Outdated Show resolved Hide resolved
src/cpp/src/sampler.cpp Outdated Show resolved Hide resolved
src/cpp/src/model_runner.hpp Outdated Show resolved Hide resolved
src/cpp/src/model_runner.hpp Outdated Show resolved Hide resolved
@ilya-lavrenov ilya-lavrenov added this to the 2025.0 milestone Nov 27, 2024
@olpipi olpipi marked this pull request as ready for review November 28, 2024 18:27
src/cpp/src/utils/paged_attention_transformations.cpp Outdated Show resolved Hide resolved
src/cpp/src/utils.cpp Show resolved Hide resolved
src/cpp/src/sampler.cpp Outdated Show resolved Hide resolved
src/cpp/src/utils/paged_attention_transformations.cpp Outdated Show resolved Hide resolved
src/cpp/src/model_runner.hpp Outdated Show resolved Hide resolved
src/cpp/src/model_runner.hpp Outdated Show resolved Hide resolved
src/cpp/src/sampler.cpp Outdated Show resolved Hide resolved
src/cpp/src/sampler.cpp Outdated Show resolved Hide resolved
@github-actions github-actions bot added the category: LLM LLM pipeline (stateful, static) label Dec 6, 2024
@mlukasze mlukasze requested a review from iefode December 9, 2024 05:50
@olpipi olpipi force-pushed the slice_matmul_cb branch 2 times, most recently from 5fc1067 to c82d6d1 Compare December 12, 2024 11:58
@@ -37,6 +37,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con

utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction);
utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction);
utils::apply_gather_before_matmul_transformation(main_model);
utils::apply_gather_before_matmul_transformation(draft_model);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please, the same to prompt lookup decoding pipeline

// context_len corresponds to first token within subgroup of scheduled tokens
size_t group_context_len = group_position_id;
// Next variables are only for sliced matmul case
size_t actual_seq_len = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should per sequence, not per sequence group, but you increment this value within a loop over num_running_sequences, so finally you will have seq_len * num_running_sequences instead of seq_len

@@ -756,8 +756,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
size_t actual_seq_len = sequence_group->get_seq_len_to_sample();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
size_t actual_seq_len = sequence_group->get_seq_len_to_sample();
size_t output_seq_len = sequence_group->get_output_seq_len();

IMO, it's a better name as by this function we define a number of tokens in last matmul

@@ -153,6 +173,7 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(actual_seq_len, num_scheduled_tokens) : num_scheduled_tokens);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(actual_seq_len, num_scheduled_tokens) : num_scheduled_tokens);
sequence_group->set_seq_len_to_sample(seq_len);

why not just as simple as suggested?


position_ids_data[token_id] = position_id;

if (matmul_gathering_is_required && sampling_is_required) {
if (echo_output ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like echo is broken here as sampling_is_required can be false when we process prompt using several iterations (e.g. with dynamic split fuse and batch size 256, we require 4 iterations to process prompt 1024)


if (matmul_gathering_is_required && sampling_is_required) {
if (echo_output ||
group_position_id + token_id >= prompt_len - 1 &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this condition? it says that token is actually not prompt token, which is guaranteed by sampling_is_required

if (matmul_gathering_is_required && sampling_is_required) {
if (echo_output ||
group_position_id + token_id >= prompt_len - 1 &&
group_position_id + token_id >= num_scheduled_tokens - tokens_to_sample_per_sequence) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this condition do?

let's consider example:
group_position_id (the same as content length / KV cache size) is 16
we have scheduled 1 token (because of KV cache limitations)
number of candidates for speculative decoding is 3, so tokens_to_sample_per_sequence is 4
so, the condition is:

16 + 0 >= 1 - 4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: continuous batching Continuous batching category: LLM LLM pipeline (stateful, static) category: sampling Sampling / Decoding algorithms category: speculative decoding Speculative decoding no-match-files
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants