diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 475197a38..dbe90a2c0 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -6,17 +6,6 @@ namespace Generators { -namespace { - -void DumpStore(const std::unordered_map>& store) { - for (const auto& [name, ort_value] : store) { - std::cout << name << " "; - } - std::cout << std::endl; -} - -} // namespace - DecoderOnlyPipelineModel::DecoderOnlyPipelineModel(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { for (const auto& model : config_->model.decoder.pipeline) { @@ -130,7 +119,8 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr continue; } - // Clear the intermediate pipeline state from previous runs. + // Clear the intermediate pipeline state outputs from the previous runs. + // These outputs will be replaced by the outputs from the current run. for (const auto& output_name : pipeline_state->output_names_) { if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) { ortvalue_store_.erase(iter); @@ -183,7 +173,9 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr } } - // Output of pipeline models could also be managed inputs for the subsequent pipeline model. + // Output of pipeline models could also be managed inputs. + // For example, the output of a pipeline model could be the key-value cache. + // In such cases, use the managed output buffers and register them with the pipeline model as outputs. for (const auto& input_name : input_names_) { if (pipeline_state->HasOutput(input_name)) { if (!pipeline_state->SupportsPrimaryDevice()) { @@ -212,7 +204,7 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr // Run the intermediate pipeline state pipeline_state->Run(current_length, next_tokens, next_indices); - // Store the non-managed outputs from the current pipeline state in the ortvalue pool. + // Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store. // All non managed outputs are assumed to be on CPU for (size_t i = 0; i < pipeline_state->output_names_.size(); ++i) { if (std::none_of(output_names_.begin(), output_names_.end(), @@ -229,10 +221,10 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr } } + // Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier. if (!first_run_) { for (auto& pipeline_state : pipeline_states_) { if (!model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_token_gen) { - // Clear the ortvalue store. for (const auto& output_name : pipeline_state->output_names_) { if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) { ortvalue_store_.erase(iter); @@ -242,8 +234,6 @@ RoamingArray DecoderOnlyPipelineState::Run(int current_length, RoamingArr } } - // DumpStore(ortvalue_store_); - first_run_ = false; return logits_.Get(); @@ -258,7 +248,7 @@ void DecoderOnlyPipelineState::UpdateInputsOutputs(const RoamingArray& } OrtValue* DecoderOnlyPipelineState::GetOutput(const char* name) { - // Check the ortvalue pool to search if name is one of the non-managed output. + // Check the ortvalue store to search if name is one of the non-managed output. auto it = ortvalue_store_.find(name); if (it != ortvalue_store_.end()) { return it->second.get();