Skip to content

Commit

Permalink
[SpecDecode] Fix sampler selection. (#1971)
Browse files Browse the repository at this point in the history
This PR temporarily fixes sampler selection logic for speculative
decoding. As GPU sampler support for speculative decoding is
not ready, speculative decoding will use cpu sampler.
  • Loading branch information
KnowingNothing authored Mar 20, 2024
1 parent 06d6115 commit 39d0865
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class EngineImpl : public Engine {
}
LogitProcessor logit_processor =
this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder);
Sampler sampler = this->models_[0]->CreateSampler(max_num_tokens, trace_recorder);
Sampler sampler = this->models_[0]->CreateSampler(
max_num_tokens, static_cast<int>(this->models_.size()), trace_recorder);
// Step 3. Initialize engine actions that represent state transitions.
if (this->engine_mode_->enable_speculative) {
// Speculative decoding is only possible for more than one model.
Expand Down
7 changes: 5 additions & 2 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,11 @@ class ModelImpl : public ModelObj {
std::move(trace_recorder));
}

Sampler CreateSampler(int max_num_sample, Optional<EventTraceRecorder> trace_recorder) {
if (Sampler::SupportGPUSampler(device_)) {
Sampler CreateSampler(int max_num_sample, int num_models,
Optional<EventTraceRecorder> trace_recorder) {
if (num_models > 1) { // speculative decoding uses cpu sampler
return Sampler::CreateCPUSampler(std::move(trace_recorder));
} else if (Sampler::SupportGPUSampler(device_)) {
return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_,
std::move(trace_recorder));
} else {
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class ModelObj : public Object {
Optional<EventTraceRecorder> trace_recorder) = 0;

/*! \brief Create a sampler from this model. */
virtual Sampler CreateSampler(int max_num_sample,
virtual Sampler CreateSampler(int max_num_sample, int num_models,
Optional<EventTraceRecorder> trace_recorder) = 0;

/*!
Expand Down

0 comments on commit 39d0865

Please sign in to comment.