diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc index 32af5af035..04be181aec 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.cc @@ -42,6 +42,8 @@ */ #include "vt/config.h" +#include "vt/configs/types/types_sentinels.h" +#include "vt/configs/types/types_type.h" #include "vt/timing/timing.h" #include "vt/vrt/collection/balance/baselb/baselb.h" #include "vt/vrt/collection/balance/model/load_model.h" @@ -474,6 +476,12 @@ void TemperedLB::runLB(TimeType total_load) { } } +void TemperedLB::clearDataStructures() { + underloaded_.clear(); + load_info_.clear(); + is_overloaded_ = is_underloaded_ = false; +} + void TemperedLB::doLBStages(TimeType start_imb) { decltype(this->cur_objs_) best_objs; LoadType best_load = 0; @@ -483,11 +491,7 @@ void TemperedLB::doLBStages(TimeType start_imb) { auto this_node = theContext()->getNode(); for (trial_ = 0; trial_ < num_trials_; ++trial_) { - // Clear out data structures - selected_.clear(); - underloaded_.clear(); - load_info_.clear(); - is_overloaded_ = is_underloaded_ = false; + clearDataStructures(); TimeType best_imb_this_trial = start_imb + 10; @@ -504,11 +508,7 @@ void TemperedLB::doLBStages(TimeType start_imb) { } this_new_load_ = this_load; } else { - // Clear out data structures from previous iteration - selected_.clear(); - underloaded_.clear(); - load_info_.clear(); - is_overloaded_ = is_underloaded_ = false; + clearDataStructures(); } vt_debug_print( @@ -667,7 +667,7 @@ void TemperedLB::informAsync() { vtAssert(k_max_ > 0, "Number of rounds (k) must be greater than zero"); auto const this_node = theContext()->getNode(); - if (is_underloaded_) { + if (canPropagate()) { underloaded_.insert(this_node); } @@ -682,7 +682,7 @@ void TemperedLB::informAsync() { auto propagate_epoch = theTerm()->makeEpochCollective("TemperedLB: informAsync"); // Underloaded start the round - if (is_underloaded_) { + if (canPropagate()) { uint8_t k_cur_async = 0; propagateRound(k_cur_async, false, propagate_epoch); } @@ -718,11 +718,11 @@ void TemperedLB::informSync() { vtAssert(k_max_ > 0, "Number of rounds (k) must be greater than zero"); auto const this_node = theContext()->getNode(); - if (is_underloaded_) { + if (canPropagate()) { underloaded_.insert(this_node); } - auto propagate_this_round = is_underloaded_; + auto propagate_this_round = canPropagate(); propagate_next_round_ = false; new_underloaded_ = underloaded_; new_load_info_ = load_info_; @@ -793,8 +793,7 @@ void TemperedLB::propagateRound(uint8_t k_cur, bool sync, EpochType epoch) { gen_propagate_.seed(seed_()); } - auto& selected = selected_; - selected = underloaded_; + auto& selected = underloaded_; if (selected.find(this_node) == selected.end()) { selected.insert(this_node); } @@ -1203,7 +1202,7 @@ void TemperedLB::decide() { int n_transfers = 0, n_rejected = 0; - if (is_overloaded_) { + if (canMigrate()) { std::vector under = makeUnderloaded(); std::unordered_map migrate_objs; diff --git a/src/vt/vrt/collection/balance/temperedlb/temperedlb.h b/src/vt/vrt/collection/balance/temperedlb/temperedlb.h index 6839ae6eb7..df3352e26d 100644 --- a/src/vt/vrt/collection/balance/temperedlb/temperedlb.h +++ b/src/vt/vrt/collection/balance/temperedlb/temperedlb.h @@ -94,6 +94,15 @@ struct TemperedLB : BaseLB { void informSync(); void decide(); void migrate(); + void clearDataStructures(); + + virtual bool canMigrate() const { return is_overloaded_; } + /** + * \brief Decides whether the rank can initiate information propagation stage + * + * TemperedLB restricts this to underloaded ranks + */ + virtual bool canPropagate() const { return is_underloaded_; } void propagateRound(uint8_t k_cur_async, bool sync, EpochType epoch = no_epoch); void propagateIncomingAsync(LoadMsgAsync* msg); @@ -164,7 +173,6 @@ struct TemperedLB : BaseLB { objgroup::proxy::Proxy proxy_ = {}; bool is_overloaded_ = false; bool is_underloaded_ = false; - std::unordered_set selected_ = {}; std::unordered_set underloaded_ = {}; std::unordered_set new_underloaded_ = {}; std::unordered_map cur_objs_ = {}; diff --git a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h index d6f61bbaa4..a705489615 100644 --- a/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h +++ b/src/vt/vrt/collection/balance/temperedwmin/temperedwmin.h @@ -66,6 +66,11 @@ struct TemperedWMin : TemperedLB { protected: TimeType getModeledValue(const elm::ElementIDStruct& obj) override; + /** + * All ranks are allowed to initiate the information propagation stage + */ + bool canPropagate() const override { return true; } + private: std::unique_ptr total_work_model_ = nullptr; balance::LoadModel* load_model_ptr = nullptr;