Skip to content

Commit

Permalink
Merge pull request #2213 from DARMA-tasking/2212-add-message-properti…
Browse files Browse the repository at this point in the history
…es-for-generated-messages

2212 add message properties for generated messages
  • Loading branch information
lifflander authored Nov 30, 2023
2 parents 65e7ebc + e38bed2 commit 2e81943
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 42 deletions.
10 changes: 5 additions & 5 deletions src/vt/collective/reduce/reduce_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@

namespace vt { namespace collective { namespace reduce {

static std::unique_ptr<Reduce> makeReduceScope(detail::ReduceScope const& scope) {
return std::make_unique<Reduce>(scope);
}

ReduceManager::ReduceManager()
: reducers_( // default cons reducer for non-group
[](detail::ReduceScope const& scope) {
return std::make_unique<Reduce>(scope);
}
)
: reducers_(makeReduceScope)
{
// insert the default reducer scope
reducers_.make(
Expand Down
6 changes: 4 additions & 2 deletions src/vt/messaging/active.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,8 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
PendingSendType send(Node dest, Params&&... params) {
using Tuple = typename FuncTraits<decltype(f)>::TupleType;
using MsgT = ParamMsg<Tuple>;
auto msg = vt::makeMessage<MsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::forward<Params>(params)...);
auto han = auto_registry::makeAutoHandlerParam<decltype(f), f, MsgT>();
return sendMsg<MsgT>(dest.get(), han, msg, no_tag);
}
Expand All @@ -782,7 +783,8 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
PendingSendType broadcast(Params&&... params) {
using Tuple = typename FuncTraits<decltype(f)>::TupleType;
using MsgT = ParamMsg<Tuple>;
auto msg = vt::makeMessage<MsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::forward<Params>(params)...);
auto han = auto_registry::makeAutoHandlerParam<decltype(f), f, MsgT>();
constexpr bool deliver_to_sender = true;
return broadcastMsg<MsgT>(han, msg, deliver_to_sender, no_tag);
Expand Down
2 changes: 2 additions & 0 deletions src/vt/messaging/active.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,4 +512,6 @@ inline EpochType ActiveMessenger::setupEpochMsg(MsgSharedPtr<MsgT> const& msg) {

}} //end namespace vt::messaging

#include "vt/messaging/param_msg.impl.h"

#endif /*INCLUDED_VT_MESSAGING_ACTIVE_IMPL_H*/
146 changes: 132 additions & 14 deletions src/vt/messaging/param_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,103 @@

#include "vt/messaging/message/message_serialize.h"

namespace vt { namespace messaging {
namespace vt {

struct MsgProps {

MsgProps() = default;

MsgProps&& asLocationMsg(bool set = true) {
as_location_msg_ = set;
return std::move(*this);
}

MsgProps&& asTerminationMsg(bool set = true) {
as_termination_msg_ = set;
return std::move(*this);
}

MsgProps&& asCollectionMsg(bool set = true) {
as_collection_msg_ = set;
return std::move(*this);
}

MsgProps&& asSerializationMsg(bool set = true) {
as_serial_msg_ = set;
return std::move(*this);
}

MsgProps&& withEpoch(EpochType in_ep) {
ep_ = in_ep;
return std::move(*this);
}

MsgProps&& withPriority(PriorityType in_priority) {
#if vt_check_enabled(priorities)
priority_ = in_priority;
#endif
return std::move(*this);
}

MsgProps&& withPriorityLevel(PriorityLevelType in_priority_level) {
#if vt_check_enabled(priorities)
priority_level_ = in_priority_level;
#endif
return std::move(*this);
}

template <typename MsgPtrT>
void apply(MsgPtrT msg);

private:
bool as_location_msg_ = false;
bool as_termination_msg_ = false;
bool as_serial_msg_ = false;
bool as_collection_msg_ = false;
EpochType ep_ = no_epoch;
#if vt_check_enabled(priorities)
PriorityType priority_ = no_priority;
PriorityLevelType priority_level_ = no_priority_level;
#endif
};

} /* end namespace vt */

namespace vt::messaging::detail {

template <typename enabled_, typename... Params>
struct GetTraits;

template <>
struct GetTraits<std::enable_if_t<std::is_same_v<void, void>>> {
using TupleType = std::tuple<>;
};

template <typename Param, typename... Params>
struct GetTraits<
std::enable_if_t<std::is_same_v<MsgProps, Param>>, Param, Params...
> {
using TupleType = std::tuple<Params...>;
};

template <typename Param, typename... Params>
struct GetTraits<
std::enable_if_t<not std::is_same_v<MsgProps, Param>>, Param, Params...
> {
using TupleType = std::tuple<Param, Params...>;
};

template <typename Tuple>
struct GetTraitsTuple;

template <typename... Params>
struct GetTraitsTuple<std::tuple<Params...>> {
using TupleType = typename GetTraits<void, Params...>::TupleType;
};

} /* end namespace vt::messaging::detail */

namespace vt::messaging {

template <typename Tuple, typename enabled = void>
struct ParamMsg;
Expand All @@ -56,15 +152,24 @@ struct ParamMsg<
Tuple, std::enable_if_t<is_byte_copyable_t<Tuple>::value>
> : vt::Message
{
using TupleType = typename detail::GetTraitsTuple<Tuple>::TupleType;

ParamMsg() = default;

template <typename... Params>
explicit ParamMsg(Params&&... in_params)
: params(std::forward<Params>(in_params)...)
{ }
void setParams() { }

template <typename Param, typename... Params>
void setParams(Param&& p, Params&&... in_params) {
if constexpr (std::is_same_v<std::decay_t<Param>, MsgProps>) {
params = TupleType{std::forward<Params>(in_params)...};
p.apply(this);
} else {
params = TupleType{std::forward<Param>(p), std::forward<Params>(in_params)...};
}
}

Tuple params;
Tuple& getTuple() { return params; }
TupleType params;
TupleType& getTuple() { return params; }
};

template <typename Tuple>
Expand All @@ -75,16 +180,29 @@ struct ParamMsg<
using MessageParentType = vt::Message;
vt_msg_serialize_if_needed_by_parent_or_type1(Tuple); // by tup

using TupleType = typename detail::GetTraitsTuple<Tuple>::TupleType;

ParamMsg() = default;

template <typename... Params>
explicit ParamMsg(Params&&... in_params)
: params(std::make_unique<Tuple>(std::forward<Params>(in_params)...))
{ }
void setParams() {
params = std::make_unique<TupleType>();
}

template <typename Param, typename... Params>
void setParams(Param&& p, Params&&... in_params) {
if constexpr (std::is_same_v<std::decay_t<Param>, MsgProps>) {
params = std::make_unique<TupleType>(std::forward<Params>(in_params)...);
p.apply(this);
} else {
params = std::make_unique<TupleType>(
std::forward<Param>(p), std::forward<Params>(in_params)...
);
}
}

std::unique_ptr<Tuple> params;
std::unique_ptr<TupleType> params;

Tuple& getTuple() { return *params.get(); }
TupleType& getTuple() { return *params.get(); }

template <typename SerializerT>
void serialize(SerializerT& s) {
Expand All @@ -93,6 +211,6 @@ struct ParamMsg<
}
};

}} /* end namespace vt::messaging */
} /* end namespace vt::messaging */

#endif /*INCLUDED_VT_MESSAGING_PARAM_MSG_H*/
80 changes: 80 additions & 0 deletions src/vt/messaging/param_msg.impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
//@HEADER
// *****************************************************************************
//
// param_msg.impl.h
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact [email protected]
//
// *****************************************************************************
//@HEADER
*/

#if !defined INCLUDED_VT_MESSAGING_PARAM_MSG_IMPL_H
#define INCLUDED_VT_MESSAGING_PARAM_MSG_IMPL_H

#include "vt/messaging/param_msg.h"

namespace vt {

template <typename MsgPtrT>
void MsgProps::apply(MsgPtrT msg) {
if (as_location_msg_) {
theMsg()->markAsLocationMessage(msg);
}
if (as_termination_msg_) {
theMsg()->markAsTermMessage(msg);
}
if (as_serial_msg_) {
theMsg()->markAsSerialMsgMessage(msg);
}
if (as_collection_msg_) {
theMsg()->markAsCollectionMessage(msg);
}
if (ep_ != no_epoch) {
envelopeSetEpoch(msg->env, ep_);
}
#if vt_check_enabled(priorities)
if (priority_ != no_priority) {
envelopeSetPriority(msg->env, priority_);
}
if (priority_level_ != no_priority_level) {
envelopeSetPriorityLevel(msg->env, priority_level_);
}
#endif
}

} /* end namespace vt */

#endif /*INCLUDED_VT_MESSAGING_PARAM_MSG_IMPL_H*/
9 changes: 9 additions & 0 deletions src/vt/messaging/pending_send.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ struct PendingSend final {
*/
void release();

/**
* \internal \brief Get the message stored in the pending send
*
* \note Used for testing purposes
*
* \return a reference to the message
*/
MsgPtr<BaseMsgType>& getMsg() { return msg_; }

private:

/**
Expand Down
6 changes: 4 additions & 2 deletions src/vt/objgroup/proxy/proxy_objgroup.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ Proxy<ObjT>::broadcast(Params&&... params) const {
if constexpr (std::is_same_v<MsgT, NoMsg>) {
using Tuple = typename ObjFuncTraits<decltype(f)>::TupleType;
using SendMsgT = messaging::ParamMsg<Tuple>;
auto msg = vt::makeMessage<SendMsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<SendMsgT>();
msg->setParams(std::forward<Params>(params)...);
auto const ctrl = proxy::ObjGroupProxy::getID(proxy_);
auto const han = auto_registry::makeAutoHandlerObjGroupParam<
ObjT, decltype(f), f, SendMsgT
Expand All @@ -117,7 +118,8 @@ Proxy<ObjT>::multicast(GroupType type, Params&&... params) const{
if constexpr (std::is_same_v<MsgT, NoMsg>) {
using Tuple = typename ObjFuncTraits<decltype(f)>::TupleType;
using SendMsgT = messaging::ParamMsg<Tuple>;
auto msg = vt::makeMessage<SendMsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<SendMsgT>();
msg->setParams(std::forward<Params>(params)...);
vt::envelopeSetGroup(msg->env, type);
auto const ctrl = proxy::ObjGroupProxy::getID(proxy_);
auto const han = auto_registry::makeAutoHandlerObjGroupParam<
Expand Down
3 changes: 2 additions & 1 deletion src/vt/objgroup/proxy/proxy_objgroup_elm.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ ProxyElm<ObjT>::send(Params&&... params) const {
if constexpr (std::is_same_v<MsgT, NoMsg>) {
using Tuple = typename ObjFuncTraits<decltype(f)>::TupleType;
using SendMsgT = messaging::ParamMsg<Tuple>;
auto msg = vt::makeMessage<SendMsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<SendMsgT>();
msg->setParams(std::forward<Params>(params)...);
auto const ctrl = proxy::ObjGroupProxy::getID(proxy_);
auto const han = auto_registry::makeAutoHandlerObjGroupParam<
ObjT, decltype(f), f, SendMsgT
Expand Down
30 changes: 26 additions & 4 deletions src/vt/pipe/callback/cb_union/cb_raw_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,39 @@ struct CallbackTyped : CallbackRawBaseSingle {
void sendTuple(std::tuple<Params...> tup) {
using Trait = CBTraits<Args...>;
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
auto msg = vt::makeMessage<MsgT>(std::move(tup));
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::move(tup));
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
}

template <typename... Params>
void send(Params&&... params) {
using Trait = CBTraits<Args...>;
if constexpr (std::is_same_v<typename Trait::MsgT, NoMsg>) {
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
auto msg = vt::makeMessage<MsgT>(std::forward<Params>(params)...);
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
// We have to go through some tricky code to make the MsgProps case work
// If we use the type for Params to send, it's possible that we have a
// type mismatch in the actual handler type. A possible edge case is when
// a char const* is sent, but the handler is a std::string. In this case,
// the ParamMsg will be cast incorrectly during the virual dispatch to a
// collection because callbacks don't have the collection type. Thus, the
// wrong ParamMsg will be cast to which requires serialization, leading to
// a failure.
if constexpr (sizeof...(Params) == sizeof...(Args) + 1) {
using MsgT = messaging::ParamMsg<
std::tuple<
std::decay_t<std::tuple_element_t<0, std::tuple<Params...>>>,
std::decay_t<Args>...
>
>;
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::forward<Params>(params)...);
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
} else {
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::forward<Params>(params)...);
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
}
} else {
using MsgT = typename Trait::MsgT;
auto msg = makeMessage<MsgT>(std::forward<Params>(params)...);
Expand Down
Loading

0 comments on commit 2e81943

Please sign in to comment.