diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc index b575afbe1b..e55b08794c 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc @@ -474,6 +474,13 @@ void TemperedLB::runLB(TimeType total_load) { } } +void TemperedLB::clearDataStructures() { + selected_.clear(); + underloaded_.clear(); + load_info_.clear(); + is_overloaded_ = is_underloaded_ = false; +} + void TemperedLB::doLBStages(TimeType start_imb) { decltype(this->cur_objs_) best_objs; LoadType best_load = 0; @@ -483,11 +490,7 @@ void TemperedLB::doLBStages(TimeType start_imb) { auto this_node = theContext()->getNode(); for (trial_ = 0; trial_ < num_trials_; ++trial_) { - // Clear out data structures - selected_.clear(); - underloaded_.clear(); - load_info_.clear(); - is_overloaded_ = is_underloaded_ = false; + clearDataStructures(); TimeType best_imb_this_trial = start_imb + 10; @@ -504,11 +507,7 @@ void TemperedLB::doLBStages(TimeType start_imb) { } this_new_load_ = this_load; } else { - // Clear out data structures from previous iteration - selected_.clear(); - underloaded_.clear(); - load_info_.clear(); - is_overloaded_ = is_underloaded_ = false; + clearDataStructures(); } vt_debug_print( @@ -682,7 +681,7 @@ void TemperedLB::informAsync() { auto propagate_epoch = theTerm()->makeEpochCollective("TemperedLB: informAsync"); // Underloaded start the round - if (is_underloaded_) { + if (canPropagate()) { uint8_t k_cur_async = 0; propagateRound(k_cur_async, false, propagate_epoch); } @@ -722,7 +721,7 @@ void TemperedLB::informSync() { underloaded_.insert(this_node); } - auto propagate_this_round = is_underloaded_; + auto propagate_this_round = canPropagate(); propagate_next_round_ = false; new_underloaded_ = underloaded_; new_load_info_ = load_info_; @@ -793,6 +792,9 @@ void TemperedLB::propagateRound(uint8_t k_cur, bool sync, EpochType epoch) { gen_propagate_.seed(seed_()); } + // remove 1 line + // selected = getSelectedNodessss + // selected = all? auto& selected = selected_; selected = underloaded_; if (selected.find(this_node) == selected.end()) { @@ -814,6 +816,8 @@ void TemperedLB::propagateRound(uint8_t k_cur, bool sync, EpochType epoch) { return; } + // extract generateRandomNode + // auto random_node = generateRandomNode(dist); // First, randomly select a node NodeType random_node = uninitialized_destination; @@ -1203,7 +1207,7 @@ void TemperedLB::decide() { int n_transfers = 0, n_rejected = 0; - if (is_overloaded_) { + if (canMigrate()) { std::vector under = makeUnderloaded(); std::unordered_map migrate_objs; diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.h b/src/vt/vrt/collection/balance/temperedlb/temperedlb.h index 8e6f68b5b6..8086e08f13 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.h +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.h @@ -94,6 +94,10 @@ struct TemperedLB : BaseLB { void informSync(); void decide(); void migrate(); + void clearDataStructures(); + + virtual bool canMigrate() const { return is_overloaded_; } + virtual bool canPropagate() const { return is_underloaded_; } void propagateRound(uint8_t k_cur_async, bool sync, EpochType epoch = no_epoch); void propagateIncomingAsync(LoadMsgAsync* msg); diff --git a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h index c008741b4a..df6cbe39af 100644 --- a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h +++ b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h @@ -61,6 +61,8 @@ struct TemperedWMin : TemperedLB { void inputParams(balance::SpecEntry* spec) override; protected: + bool canPropagate() const override { return true; } + TimeType getModeledWork(const elm::ElementIDStruct& obj) override; private: