Skip to content

Commit

Permalink
Clean code and add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Sep 20, 2024
1 parent 0291560 commit 0e4693f
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,6 @@

namespace Generators {

namespace {

void DumpStore(const std::unordered_map<std::string, std::unique_ptr<OrtValue>>& store) {
for (const auto& [name, ort_value] : store) {
std::cout << name << " ";
}
std::cout << std::endl;
}

} // namespace

DecoderOnlyPipelineModel::DecoderOnlyPipelineModel(std::unique_ptr<Config> config, OrtEnv& ort_env)
: Model{std::move(config)} {
for (const auto& model : config_->model.decoder.pipeline) {
Expand Down Expand Up @@ -130,7 +119,8 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
continue;
}

// Clear the intermediate pipeline state from previous runs.
// Clear the intermediate pipeline state outputs from the previous runs.
// These outputs will be replaced by the outputs from the current run.
for (const auto& output_name : pipeline_state->output_names_) {
if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) {
ortvalue_store_.erase(iter);
Expand Down Expand Up @@ -183,7 +173,9 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
}
}

// Output of pipeline models could also be managed inputs for the subsequent pipeline model.
// Output of pipeline models could also be managed inputs.
// For example, the output of a pipeline model could be the key-value cache.
// In such cases, use the managed output buffers and register them with the pipeline model as outputs.
for (const auto& input_name : input_names_) {
if (pipeline_state->HasOutput(input_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
Expand Down Expand Up @@ -212,7 +204,7 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
// Run the intermediate pipeline state
pipeline_state->Run(current_length, next_tokens, next_indices);

// Store the non-managed outputs from the current pipeline state in the ortvalue pool.
// Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store.
// All non managed outputs are assumed to be on CPU
for (size_t i = 0; i < pipeline_state->output_names_.size(); ++i) {
if (std::none_of(output_names_.begin(), output_names_.end(),
Expand All @@ -229,10 +221,10 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
}
}

// Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier.
if (!first_run_) {
for (auto& pipeline_state : pipeline_states_) {
if (!model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_token_gen) {
// Clear the ortvalue store.
for (const auto& output_name : pipeline_state->output_names_) {
if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) {
ortvalue_store_.erase(iter);
Expand All @@ -242,8 +234,6 @@ RoamingArray<float> DecoderOnlyPipelineState::Run(int current_length, RoamingArr
}
}

// DumpStore(ortvalue_store_);

first_run_ = false;

return logits_.Get();
Expand All @@ -258,7 +248,7 @@ void DecoderOnlyPipelineState::UpdateInputsOutputs(const RoamingArray<int32_t>&
}

OrtValue* DecoderOnlyPipelineState::GetOutput(const char* name) {
// Check the ortvalue pool to search if name is one of the non-managed output.
// Check the ortvalue store to search if name is one of the non-managed output.
auto it = ortvalue_store_.find(name);
if (it != ortvalue_store_.end()) {
return it->second.get();
Expand Down

0 comments on commit 0e4693f

Please sign in to comment.