Skip to content

Commit

Permalink
Merge pull request #2196 from DARMA-tasking/2195-fix-addaction-semantics
Browse files Browse the repository at this point in the history
#2195: termination: make actions run consistently in a runnable
  • Loading branch information
lifflander authored Oct 4, 2023
2 parents 9b8df4f + 4345027 commit 2ff5cc2
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 75 deletions.
135 changes: 74 additions & 61 deletions src/vt/termination/term_action.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
#include "vt/termination/term_action.h"
#include "vt/termination/term_common.h"
#include "vt/termination/termination.h"
#include "vt/registry/auto/auto_registry_interface.h"
#include "vt/runnable/make_runnable.h"

namespace vt { namespace term {

Expand All @@ -65,92 +67,103 @@ void TermAction::afterAddEpochAction(EpochType const& epoch) {
* Produce a unit of any epoch type to inhibit global termination when
* local termination of a specific epoch is waiting for detection
*/
theTerm()->produce(term::any_epoch_sentinel);

auto const status = testEpochTerminated(epoch);
if (status == TermStatusEnum::Terminated) {
triggerAllEpochActions(epoch);
}
}

void TermAction::addActionEpoch(EpochType const& epoch, ActionType action) {
if (epoch == term::any_epoch_sentinel) {
return addAction(action);
} else {
auto epoch_iter = epoch_actions_.find(epoch);
if (epoch_iter == epoch_actions_.end()) {
epoch_actions_.emplace(
std::piecewise_construct,
std::forward_as_tuple(epoch),
std::forward_as_tuple(ActionContType{action})
);
} else {
epoch_iter->second.emplace_back(action);
}
queueActions(epoch);
}
afterAddEpochAction(epoch);
}

void TermAction::clearActions() {
global_term_actions_.clear();
/*static*/ void TermAction::runActions(ActionMsg* msg) {
theTerm()->triggerAllEpochActions(msg->ep, msg->encapsulated_epoch);
}

void TermAction::clearActionsEpoch(EpochType const& epoch) {
void TermAction::queueActions(EpochType epoch) {
if (epoch == term::any_epoch_sentinel) {
return clearActions();
// @todo: should this be delayed also?
for (auto&& action : global_term_actions_) {
action();
}

global_term_actions_.clear();
} else {
auto iter = epoch_actions_.find(epoch);
if (iter != epoch_actions_.end()) {
auto const& epoch_actions_count = iter->second.size();
epoch_actions_.erase(iter);
/*
* Consume units of epoch-specific actions are cleared to match the
* production in addActionEpoch
*/
theTerm()->consume(term::any_epoch_sentinel, epoch_actions_count);
auto make_runnable = [&](EpochType encap_epoch){
auto msg = makeMessage<ActionMsg>(epoch, encap_epoch);
auto const han = auto_registry::makeAutoHandler<ActionMsg, runActions>();
auto const this_node = theContext()->getNode();
runnable::makeRunnable(msg, true, han, this_node)
.withTDEpoch(encap_epoch)
.enqueue();
};

if (auto iter = epoch_actions_.find(epoch); iter != epoch_actions_.end()) {
for (auto const& [encapsulated_epoch, _] : iter->second) {
make_runnable(encapsulated_epoch);
}
}

if (auto iter = epoch_callable_actions_.find(epoch);
iter != epoch_callable_actions_.end()) {
for (auto const& [encapsulated_epoch, _] : iter->second) {
make_runnable(encapsulated_epoch);
}
}
}
}

void TermAction::triggerAllActions(EpochType const& epoch) {
void TermAction::addActionEpoch(EpochType const& epoch, ActionType action) {
if (epoch == term::any_epoch_sentinel) {
for (auto&& action : global_term_actions_) {
action();
}

global_term_actions_.clear();
return addAction(action);
} else {
return triggerAllEpochActions(epoch);
auto encapsulated_epoch = getCurrentEpoch();
theTerm()->produce(encapsulated_epoch);
epoch_actions_[epoch][encapsulated_epoch].push_back(action);
}
afterAddEpochAction(epoch);
}

void TermAction::triggerAllEpochActions(EpochType const& epoch) {
void TermAction::produceOn(EpochType epoch) const {
theTerm()->produce(epoch);
}

void TermAction::triggerAllEpochActions(
EpochType epoch, EpochType encapsulated_epoch
) {
// Run through the normal ActionType elements associated with this epoch
std::size_t epoch_actions_count = 0;
auto iter = epoch_actions_.find(epoch);
if (iter != epoch_actions_.end()) {
epoch_actions_count += iter->second.size();
for (auto&& action : iter->second) {
action();
if (auto iter = epoch_actions_.find(epoch);
iter != epoch_actions_.end()) {
if (auto iter2 = iter->second.find(encapsulated_epoch);
iter2 != iter->second.end()) {
for (auto&& action : iter2->second) {
theTerm()->consume(encapsulated_epoch);
action();
}
iter->second.erase(iter2);
}
if (iter->second.size() == 0) {
epoch_actions_.erase(iter);
}
epoch_actions_.erase(iter);
}

// Run through the callables associated with this epoch
auto iter2 = epoch_callable_actions_.find(epoch);
if (iter2 != epoch_callable_actions_.end()) {
epoch_actions_count += iter2->second.size();

for (auto&& action : iter2->second) {
action->invoke();
if (auto iter = epoch_callable_actions_.find(epoch);
iter != epoch_callable_actions_.end()) {
if (auto iter2 = iter->second.find(encapsulated_epoch);
iter2 != iter->second.end()) {
for (auto&& action : iter2->second) {
theTerm()->consume(encapsulated_epoch);
action->invoke();
}
iter->second.erase(iter2);
}
if (iter->second.size() == 0) {
epoch_callable_actions_.erase(iter);
}

epoch_callable_actions_.erase(iter2);
}
/*
* Consume number of action units of any epoch type to match the production
* in addActionEpoch() so global termination can now be detected
*/
theTerm()->consume(term::any_epoch_sentinel, epoch_actions_count);
}

EpochType TermAction::getCurrentEpoch() const {
return theTerm()->getEpoch();
}

}} /* end namespace vt::term */
30 changes: 23 additions & 7 deletions src/vt/termination/term_action.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "vt/termination/term_common.h"
#include "vt/termination/term_state.h"
#include "vt/termination/term_terminated.h"
#include "vt/messaging/message/message.h"

#include <vector>
#include <unordered_map>
Expand Down Expand Up @@ -92,9 +93,12 @@ struct CallableHolder : CallableBase {

struct TermAction : TermTerminated {
using TermStateType = TermState;
using ActionContType = std::vector<ActionType>;
using ActionVecType = std::vector<ActionType>;
using ActionContType = std::unordered_map<EpochType, ActionVecType>;
using CallableActionType = std::unique_ptr<CallableBase>;
using CallableVecType = std::vector<CallableActionType>;
using CallableVecType = std::unordered_map<
EpochType, std::vector<CallableActionType>
>;
using CallableContType = std::unordered_map<EpochType,CallableVecType>;
using EpochActionContType = std::unordered_map<EpochType,ActionContType>;
using EpochStateType = std::unordered_map<EpochType,TermStateType>;
Expand All @@ -106,20 +110,32 @@ struct TermAction : TermTerminated {
void addAction(ActionType action);
void addAction(EpochType const& epoch, ActionType action);
void addActionEpoch(EpochType const& epoch, ActionType action);
void clearActions();
void clearActionsEpoch(EpochType const& epoch);

template <typename Callable>
void addActionUnique(EpochType const& epoch, Callable&& c);

struct ActionMsg : vt::Message {
ActionMsg(EpochType in_ep, EpochType in_encapsulated_epoch)
: ep(in_ep),
encapsulated_epoch(in_encapsulated_epoch)
{ }
EpochType ep = no_epoch;
EpochType encapsulated_epoch = no_epoch;
};

static void runActions(ActionMsg* msg);

EpochType getCurrentEpoch() const;
void produceOn(EpochType epoch) const;

protected:
void triggerAllActions(EpochType const& epoch);
void triggerAllEpochActions(EpochType const& epoch);
void queueActions(EpochType epoch);
void triggerAllEpochActions(EpochType epoch, EpochType encapsulated_epoch);
void afterAddEpochAction(EpochType const& epoch);

protected:
// Container for hold global termination actions
ActionContType global_term_actions_ = {};
ActionVecType global_term_actions_ = {};
// Container to hold actions to perform when an epoch has terminated
EpochActionContType epoch_actions_ = {};
// Container for "callables"; restricted in semantic wrt std::function
Expand Down
6 changes: 5 additions & 1 deletion src/vt/termination/term_action.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ template <typename Callable>
void TermAction::addActionUnique(EpochType const& epoch, Callable&& c) {
std::unique_ptr<CallableBase> callable =
std::make_unique<CallableHolder<Callable>>(std::move(c));
epoch_callable_actions_[epoch].emplace_back(std::move(callable));
auto encapsulated_epoch = getCurrentEpoch();
produceOn(encapsulated_epoch);
epoch_callable_actions_[epoch][encapsulated_epoch].emplace_back(
std::move(callable)
);
afterAddEpochAction(epoch);
}

Expand Down
2 changes: 1 addition & 1 deletion src/vt/termination/termination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ void TerminationDetector::epochTerminated(EpochType const& epoch, CallFromEnum f
}

// Trigger actions associated with epoch
triggerAllActions(epoch);
queueActions(epoch);

// Update the window for the epoch archetype
updateResolvedEpochs(epoch);
Expand Down
30 changes: 25 additions & 5 deletions tests/unit/termination/test_term_chaining.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ struct TestTermChaining : TestParallelHarness {
}

static void start_chain() {
EpochType epoch1 = theTerm()->makeEpochRooted();
EpochType epoch1 = theTerm()->makeEpochRooted("rooted epoch 1");
vt::theMsg()->pushEpoch(epoch1);
auto msg = makeMessage<TestMsg>();
chain.add(
Expand All @@ -145,7 +145,7 @@ struct TestTermChaining : TestParallelHarness {
vt::theMsg()->popEpoch(epoch1);
vt::theTerm()->finishedEpoch(epoch1);

EpochType epoch2 = theTerm()->makeEpochRooted();
EpochType epoch2 = theTerm()->makeEpochRooted("rooted epoch 2");
vt::theMsg()->pushEpoch(epoch2);
auto msg2 = makeMessage<TestMsg>();
chain.add(
Expand Down Expand Up @@ -213,7 +213,7 @@ TEST_F(TestTermChaining, test_termination_chaining_1) {

auto const& this_node = theContext()->getNode();

epoch = theTerm()->makeEpochCollective();
epoch = theTerm()->makeEpochCollective("top chain");

handler_count = 0;

Expand All @@ -224,16 +224,18 @@ TEST_F(TestTermChaining, test_termination_chaining_1) {
start_chain();
theTerm()->finishedEpoch(epoch);
theMsg()->popEpoch(epoch);
fmt::print("before run 1\n");
vt_print(gen, "before run 1\n");
vt::runSchedulerThrough(epoch);
fmt::print("after run 1\n");
vt_print(gen, "after run 1\n");

EXPECT_EQ(handler_count, 4);
} else {
theMsg()->pushEpoch(epoch);
theTerm()->finishedEpoch(epoch);
theMsg()->popEpoch(epoch);
vt_print(gen, "before run 1 (other)\n");
vt::runSchedulerThrough(epoch);
vt_print(gen, "after run 1 (other)\n");
EXPECT_EQ(handler_count, 13);
}
}
Expand All @@ -253,4 +255,22 @@ TEST_F(TestTermChaining, test_termination_chaining_collective_1) {
}
}

TEST_F(TestTermChaining, test_termination_action_grouping) {
SET_NUM_NODES_CONSTRAINT(2);

auto ep1 = theTerm()->makeEpochCollective();
theMsg()->pushEpoch(ep1);

{ // scope for illustration
auto ep2 = theTerm()->makeEpochCollective();
theTerm()->finishedEpoch(ep2);

theTerm()->addAction(ep2, [ep1]{
EXPECT_EQ(theTerm()->getEpoch(), ep1);
});
}
theMsg()->popEpoch(ep1);
theTerm()->finishedEpoch(ep1);
}

}}} // end namespace vt::tests::unit

0 comments on commit 2ff5cc2

Please sign in to comment.