Skip to content

Commit

Permalink
#2092: message: move message around instead of copying
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Feb 16, 2023
1 parent 8aea696 commit 1319baf
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 72 deletions.
4 changes: 2 additions & 2 deletions src/vt/collective/reduce/reduce.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void Reduce::reduceRootRecv(MsgT* msg) {

auto const from_node = theContext()->getFromNodeCurrentTask();
auto m = promoteMsg(msg);
runnable::makeRunnable(m, false, handler, from_node)
runnable::makeRunnable(std::move(m), false, handler, from_node)
.withTDEpochFromMsg()
.run();
}
Expand Down Expand Up @@ -256,7 +256,7 @@ void Reduce::startReduce(detail::ReduceStamp id, bool use_num_contrib) {

// this needs to run inline.. threaded not allowed for reduction
// combination
runnable::makeRunnable(state.msgs[0], false, handler, from_node)
runnable::makeRunnable(std::move(state.msgs[0]), false, handler, from_node)
.withTDEpochFromMsg()
.run();
}
Expand Down
5 changes: 3 additions & 2 deletions src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ EventType ActiveMessenger::doMessageSend(
if (is_obj) {
objgroup::dispatchObjGroup(base, han, dest, nullptr);
} else {
runnable::makeRunnable(base, true, envelopeGetHandler(msg->env), dest)
runnable::makeRunnable(std::move(base), true, envelopeGetHandler(msg->env), dest)
.withTDEpochFromMsg(is_term)
.withLBData(&bare_handler_lb_data_, bare_handler_dummy_elm_id_for_lb_data_)
.enqueue();
Expand Down Expand Up @@ -963,7 +963,8 @@ void ActiveMessenger::prepareActiveMsgToRun(
if (is_obj) {
objgroup::dispatchObjGroup(base, handler, from_node, cont);
} else {
runnable::makeRunnable(base, not is_term, handler, from_node)
auto m = base;
runnable::makeRunnable(std::move(m), not is_term, handler, from_node)
.withContinuation(cont)
.withTDEpochFromMsg(is_term)
.withLBData(&bare_handler_lb_data_, bare_handler_dummy_elm_id_for_lb_data_)
Expand Down
5 changes: 5 additions & 0 deletions src/vt/messaging/message/smart_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ struct MsgSharedPtr final {
/*N.B. retain ORIGINAL-type implementation*/ impl_);
}

template <typename U>
MsgSharedPtr<U>* reinterpretAs() {
return reinterpret_cast<MsgSharedPtr<U>*>(this);
}

/// [obsolete] Use to() as MsgVirtualPtr <-> MsgSharedPtr.
/// Both methods are equivalent in funciton.
template <typename U>
Expand Down
12 changes: 9 additions & 3 deletions src/vt/messaging/pending_send.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,23 @@ PendingSend::PendingSend(PendingSend&& in) noexcept
std::swap(msg_, in.msg_);
std::swap(epoch_action_, in.epoch_action_);
std::swap(send_action_, in.send_action_);
std::swap(send_move_action_, in.send_move_action_);
}

void PendingSend::sendMsg() {
if (send_action_ == nullptr) {
if (send_action_ == nullptr and send_move_action_ == nullptr) {
theMsg()->doMessageSend(msg_);
} else {
send_action_(msg_);
if (send_action_) {
send_action_(msg_);
} else {
send_move_action_(std::move(msg_));
}
}
consumeMsg();
msg_ = nullptr;
send_action_ = nullptr;
send_move_action_ = nullptr;
}

EpochType PendingSend::getProduceEpochFromMsg() const {
Expand All @@ -95,7 +101,7 @@ void PendingSend::consumeMsg() {
}

void PendingSend::release() {
bool send_msg = msg_ != nullptr || send_action_ != nullptr;
bool send_msg = msg_ != nullptr || send_action_ != nullptr || send_move_action_ != nullptr;
vtAssert(!send_msg || !epoch_action_, "cannot have both a message and epoch action");
if (send_msg) {
sendMsg();
Expand Down
17 changes: 17 additions & 0 deletions src/vt/messaging/pending_send.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace vt { namespace messaging {
struct PendingSend final {
/// Function for complex action on send---takes a message to operate on
using SendActionType = std::function<void(MsgSharedPtr<BaseMsgType>&)>;
using SendActionMoveType = std::function<void(MsgSharedPtr<BaseMsgType>&&)>;
using EpochActionType = std::function<void()>;

/**
Expand All @@ -91,6 +92,21 @@ struct PendingSend final {
produceMsg();
}

/**
* \brief Construct a pending send.
*
* \param[in] in_msg the message to send
* \param[in] in_move_action the action to run, where the msg is moved.
*/
PendingSend(
MsgSharedPtr<BaseMsgType>&& in_msg,
SendActionMoveType in_move_action = nullptr
) : msg_(std::move(in_msg))
, send_move_action_(in_move_action)
{
produceMsg();
}

/**
* \brief Construct a pending send that invokes a callback.
*
Expand Down Expand Up @@ -182,6 +198,7 @@ struct PendingSend final {
private:
MsgPtr<BaseMsgType> msg_ = nullptr;
SendActionType send_action_ = {};
SendActionMoveType send_move_action_ = {};
EpochActionType epoch_action_ = {};
EpochType epoch_produced_ = no_epoch;
};
Expand Down
6 changes: 3 additions & 3 deletions src/vt/objgroup/manager.static.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void invoke(messaging::MsgPtrThief<MsgT> msg, HandlerType han, NodeType dest_nod
auto const& elm_id = holder->getElmID();
auto elm = holder->getPtr();
auto lb_data = &holder->getLBData();
runnable::makeRunnable(msg.msg_, false, han, this_node)
runnable::makeRunnable(std::move(msg.msg_), false, han, this_node)
.withObjGroup(elm)
.withTDEpochFromMsg()
.withLBData(lb_data, elm_id)
Expand All @@ -107,13 +107,13 @@ namespace detail {

template <typename MsgT, typename ObjT>
void dispatchImpl(
MsgSharedPtr<MsgT> const& msg, HandlerType han, NodeType from_node,
MsgSharedPtr<MsgT> msg, HandlerType han, NodeType from_node,
ActionType cont, ObjT* obj
) {
auto holder = detail::getHolderBase(han);
auto const& elm_id = holder->getElmID();
auto lb_data = &holder->getLBData();
runnable::makeRunnable(msg, true, han, from_node)
runnable::makeRunnable(std::move(msg), true, han, from_node)
.withContinuation(cont)
.withObjGroup(obj)
.withLBData(lb_data, elm_id)
Expand Down
2 changes: 1 addition & 1 deletion src/vt/pipe/callback/handler_send/callback_send_tl.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void CallbackSendTypeless::trigger(MsgT* msg, PipeType const& pipe) {
);
auto pmsg = promoteMsg(msg);
if (this_node == send_node_) {
runnable::makeRunnable(pmsg, true, handler_, this_node)
runnable::makeRunnable(std::move(pmsg), true, handler_, this_node)
.withTDEpochFromMsg()
.enqueue();
} else {
Expand Down
28 changes: 13 additions & 15 deletions src/vt/runnable/make_runnable.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,18 @@ struct RunnableMaker {
* \internal \brief Construct the builder. Shall not be called directly.
*
* \param[in] in_impl the runnable
* \param[in] in_msg the associated message
* \param[in] in_has_msg whether we have a message
* \param[in] in_handler the handler
* \param[in] in_han_type the type of handler
* \param[in] in_from_node the from node for the runnable
*/
RunnableMaker(
RunnableNew* in_impl, MsgSharedPtr<MsgT> const& in_msg,
RunnableNew* in_impl, bool in_has_msg,
HandlerType in_handler, NodeType in_from_node
) : impl_(in_impl),
msg_(in_msg),
has_msg_(in_has_msg),
handler_(in_handler),
from_node_(in_from_node),
has_msg_(in_msg != nullptr)
from_node_(in_from_node)
{ }
RunnableMaker(RunnableMaker const&) = delete;
RunnableMaker(RunnableMaker&&) = default;
Expand Down Expand Up @@ -119,7 +118,7 @@ struct RunnableMaker {
RunnableMaker&& withTDEpochFromMsg(bool is_term = false) {
is_term_ = is_term;
if (not is_term) {
impl_->addContextTD(msg_);
impl_->addContextTD(impl_->getMsg());
}
return std::move(*this);
}
Expand Down Expand Up @@ -219,7 +218,7 @@ struct RunnableMaker {
template <typename ElmT>
RunnableMaker&& withLBData(ElmT* elm) {
#if vt_check_enabled(lblite)
impl_->addContextLB(elm, msg_.get());
impl_->addContextLB(elm, impl_->getMsg().get());
#endif
return std::move(*this);
}
Expand All @@ -239,7 +238,7 @@ struct RunnableMaker {
uint64_t idx1, uint64_t idx2, uint64_t idx3, uint64_t idx4
) {
impl_->addContextTrace(
msg_, trace_event, handler_, from_node_, idx1, idx2, idx3, idx4
impl_->getMsg(), trace_event, handler_, from_node_, idx1, idx2, idx3, idx4
);
return std::move(*this);
}
Expand Down Expand Up @@ -309,13 +308,12 @@ struct RunnableMaker {

private:
RunnableNew* impl_ = nullptr;
MsgSharedPtr<MsgT> const& msg_;
bool has_msg_ = false;
HandlerType handler_ = uninitialized_handler;
bool set_handler_ = false;
NodeType from_node_ = uninitialized_destination;
bool is_done_ = false;
bool is_term_ = false;
bool has_msg_ = true;
};

/**
Expand All @@ -331,19 +329,19 @@ struct RunnableMaker {
*/
template <typename U>
RunnableMaker<U> makeRunnable(
MsgSharedPtr<U> const& msg, bool is_threaded, HandlerType handler, NodeType from
MsgSharedPtr<U>&& msg, bool is_threaded, HandlerType handler, NodeType from
) {
auto r = new RunnableNew(msg, is_threaded);
auto r = new RunnableNew(std::move(msg), is_threaded);
#if vt_check_enabled(trace_enabled)
auto const han_type = HandlerManager::getHandlerRegistryType(handler);
if (han_type == auto_registry::RegistryTypeEnum::RegVrt or
han_type == auto_registry::RegistryTypeEnum::RegGeneral or
han_type == auto_registry::RegistryTypeEnum::RegObjGroup) {
r->addContextTrace(msg, handler, from);
r->addContextTrace(r->getMsg(), handler, from);
}
#endif
r->addContextSetContext(r, from);
return RunnableMaker<U>{r, msg, handler, from};
return RunnableMaker<U>{r, true, handler, from};
}

/**
Expand All @@ -362,7 +360,7 @@ inline RunnableMaker<BaseMsgType> makeRunnableVoid(
auto r = new RunnableNew(is_threaded);
// @todo: figure out how to trace this?
r->addContextSetContext(r, from);
return RunnableMaker<BaseMsgType>{r, nullptr, handler, from};
return RunnableMaker<BaseMsgType>{r, false, handler, from};
}

}} /* end namespace vt::runnable */
Expand Down
6 changes: 4 additions & 2 deletions src/vt/runnable/runnable.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ struct RunnableNew {
* \param[in] in_is_threaded whether the handler can be run with a thread
*/
template <typename U>
RunnableNew(MsgSharedPtr<U> const& in_msg, bool in_is_threaded)
: msg_(in_msg.template to<BaseMsgType>())
RunnableNew(MsgSharedPtr<U>&& in_msg, bool in_is_threaded)
: msg_(std::move(*in_msg.template reinterpretAs<BaseMsgType>()))
#if vt_check_enabled(fcontext)
, is_threaded_(in_is_threaded)
#endif
Expand All @@ -130,6 +130,8 @@ struct RunnableNew {
RunnableNew& operator=(RunnableNew&&) = default;

public:
MsgSharedPtr<BaseMsgType> const& getMsg() { return msg_; }

/**
* \brief Add a new \c SetContext for this handler
*
Expand Down
10 changes: 5 additions & 5 deletions src/vt/serialization/messaging/serialized_messenger.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ template <typename UserMsgT>
user_msg.template to<BaseMsgType>(), handler, sys_msg->from_node, nullptr
);
} else {
runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node)
runnable::makeRunnable(std::move(user_msg), true, handler, sys_msg->from_node)
.withTDEpochFromMsg()
.enqueue();
}
Expand Down Expand Up @@ -146,7 +146,7 @@ template <typename UserMsgT>
msg.template to<BaseMsgType>(), handler, node, action
);
} else {
runnable::makeRunnable(msg, true, handler, node)
runnable::makeRunnable(std::move(msg), true, handler, node)
.withTDEpoch(epoch, not is_valid_epoch)
.withContinuation(action)
.enqueue();
Expand Down Expand Up @@ -195,7 +195,7 @@ template <typename UserMsgT, typename BaseEagerMsgT>
user_msg.template to<BaseMsgType>(), handler, sys_msg->from_node, nullptr
);
} else {
runnable::makeRunnable(user_msg, true, handler, sys_msg->from_node)
runnable::makeRunnable(std::move(user_msg), true, handler, sys_msg->from_node)
.withTDEpochFromMsg()
.enqueue();
}
Expand Down Expand Up @@ -428,14 +428,14 @@ template <typename MsgT, typename BaseT>
);

auto base_msg = user_msg.template to<BaseMsgType>();
return messaging::PendingSend(base_msg, [=](MsgPtr<BaseMsgType> in) {
return messaging::PendingSend(std::move(base_msg), [=](MsgPtr<BaseMsgType>&& in) mutable {
bool const is_obj = HandlerManager::isHandlerObjGroup(typed_handler);
if (is_obj) {
objgroup::dispatchObjGroup(
user_msg.template to<BaseMsgType>(), typed_handler, node, nullptr
);
} else {
runnable::makeRunnable(user_msg, true, typed_handler, node)
runnable::makeRunnable(std::move(user_msg), true, typed_handler, node)
.withTDEpochFromMsg()
.enqueue();
}
Expand Down
14 changes: 7 additions & 7 deletions src/vt/topos/location/location.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ struct EntityLocationCoord : LocationCoord {
template <typename MessageT, ActiveTypedFnType<MessageT> *f>
void routeMsgHandler(
EntityID const& id, NodeType const& home_node,
MsgSharedPtr<MessageT> const& msg
MsgSharedPtr<MessageT>&& msg
);

/**
Expand All @@ -253,15 +253,15 @@ struct EntityLocationCoord : LocationCoord {
* \param[in] m message shared pointer
*/
template <typename MessageT, ActiveTypedFnType<MessageT> *f>
void routePreparedMsgHandler(MsgSharedPtr<MessageT> const& msg);
void routePreparedMsgHandler(MsgSharedPtr<MessageT>&& msg);

/**
* \brief Route a message with a custom handler
*
* \param[in] m message shared pointer
*/
template <typename MessageT>
void routePreparedMsg(MsgSharedPtr<MessageT> const& msg);
void routePreparedMsg(MsgSharedPtr<MessageT>&& msg);

/**
* \brief Route a message with a custom handler where the element is local
Expand All @@ -270,7 +270,7 @@ struct EntityLocationCoord : LocationCoord {
* \param[in] obj the object pointer
*/
template <typename MessageT, ActiveTypedFnType<MessageT> *f>
void routeMsgHandlerLocal(MsgSharedPtr<MessageT> const& msg, void* obj);
void routeMsgHandlerLocal(MsgSharedPtr<MessageT>&& msg, void* obj);

/**
* \brief Route a message to the default handler
Expand All @@ -283,7 +283,7 @@ struct EntityLocationCoord : LocationCoord {
template <typename MessageT>
void routeMsg(
EntityID const& id, NodeType const& home_node,
MsgSharedPtr<MessageT> const& msg,
MsgSharedPtr<MessageT>&& msg,
NodeType from_node = uninitialized_destination
);

Expand Down Expand Up @@ -410,7 +410,7 @@ struct EntityLocationCoord : LocationCoord {
template <typename MessageT>
void routeMsgEager(
EntityID const& id, NodeType const& home_node,
MsgSharedPtr<MessageT> const& msg
MsgSharedPtr<MessageT>&& msg
);

/**
Expand All @@ -424,7 +424,7 @@ struct EntityLocationCoord : LocationCoord {
template <typename MessageT>
void routeMsgNode(
EntityID const& id, NodeType const& home_node, NodeType const& to_node,
MsgSharedPtr<MessageT> const& msg
MsgSharedPtr<MessageT>&& msg
);

/**
Expand Down
Loading

0 comments on commit 1319baf

Please sign in to comment.