Skip to content

Commit

Permalink
#2240: Add helpers and use Kokkos::View for internals of Rabenseifner…
Browse files Browse the repository at this point in the history
… when user's payload is View
  • Loading branch information
JacobDomagala committed Jul 18, 2024
1 parent 28139a7 commit c5232dc
Show file tree
Hide file tree
Showing 6 changed files with 417 additions and 165 deletions.
30 changes: 2 additions & 28 deletions src/vt/collective/reduce/allreduce/data_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,11 @@ class DataHandler<ScalarType, typename std::enable_if<std::is_arithmetic<ScalarT

static std::vector<ScalarType> toVec(const ScalarType& data) { return std::vector<ScalarType>{data}; }
static ScalarType fromVec(const std::vector<ScalarType>& data) { return data[0]; }
static ScalarType fromMemory(ScalarType* data, size_t) {
static ScalarType fromMemory(const ScalarType* data, size_t) {
return *data;
}

// static const ScalarType* data(const ScalarType& data) { return &data; }
static size_t size(const ScalarType&) { return 1; }
// static ScalarType& at(ScalarType& data, size_t) { return data; }
// static void set(ScalarType& data, size_t, const ScalarType& value) { data = value; }
// static ScalarType split(ScalarType&, size_t, size_t) { return ScalarType{}; }
};

template <typename T>
Expand All @@ -84,20 +80,11 @@ class DataHandler<std::vector<T>> {

static const std::vector<T>& toVec(const std::vector<T>& data) { return data; }
static std::vector<T> fromVec(const std::vector<T>& data) { return data; }
static std::vector<T> fromMemory(T* data, size_t count) {
static std::vector<T> fromMemory(const T* data, size_t count) {
return std::vector<T>(data, data + count);
}

// static const T* data(const std::vector<T>& data) {return data.data(); }
static size_t size(const std::vector<T>& data) { return data.size(); }
// static T at(const std::vector<T>& data, size_t idx) { return data[idx]; }
// static T& at(std::vector<T>& data, size_t idx) { return data[idx]; }
// static void set(std::vector<T>& data, size_t idx, const T& value) {
// data[idx] = value;
// }
// static std::vector<T> split(std::vector<T>& data, size_t start, size_t end) {
// return std::vector<T>{data.begin() + start, data.begin() + end};
// }
};

#if MAGISTRATE_KOKKOS_ENABLED
Expand Down Expand Up @@ -129,20 +116,7 @@ class DataHandler<Kokkos::View<T*, Kokkos::HostSpace, Props...>> {
return view;
}

// static const T* data(const ViewType& data) {return data.data(); }
static size_t size(const ViewType& data) { return data.extent(0); }

// static T at(const ViewType& data, size_t idx) { return data(idx); }

// static T& at(ViewType& data, size_t idx) { return data(idx); }

// static void set(ViewType& data, size_t idx, const T& value) {
// data(idx) = value;
// }

// static ViewType split(ViewType& data, size_t start, size_t end) {
// return Kokkos::subview(data, std::make_pair(start, end));
// }
};

#endif // MAGISTRATE_KOKKOS_ENABLED
Expand Down
195 changes: 195 additions & 0 deletions src/vt/collective/reduce/allreduce/helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
//@HEADER
// *****************************************************************************
//
// helpers.h
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact [email protected]
//
// *****************************************************************************
//@HEADER
*/

#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H
#include "data_handler.h"
#include "rabenseifner_msg.h"
#include "vt/messaging/message/shared_message.h"
#include <vector>

namespace vt::collective::reduce::allreduce {

template <typename Scalar, typename DataT>
struct DataHelper {
using DataType = DataHandler<DataT>;

template <typename... Args>
static void assign(std::vector<Scalar>& dest, Args&&... data) {
dest = DataHandler<DataT>::toVec(std::forward<Args>(data)...);
}

static MsgPtr<RabenseifnerMsg<Scalar, DataT>> createMessage(
const std::vector<Scalar>& payload, size_t begin, size_t count, size_t id,
int32_t step = 0) {
return vt::makeMessage<RabenseifnerMsg<Scalar, DataT>>(
payload.data() + begin, count, id, step);
}

static void copy(
std::vector<Scalar>& dest, size_t start_idx, RabenseifnerMsg<Scalar, DataT>* msg) {
for (uint32_t i = 0; i < msg->size_; i++) {
dest[start_idx + i] = msg->val_[i];
}
}

template <template <typename Arg> class Op>
static void reduce(
std::vector<Scalar>& dest, size_t start_idx, RabenseifnerMsg<Scalar, DataT>* msg) {
for (uint32_t i = 0; i < msg->size_; i++) {
Op<Scalar>()(dest[start_idx + i], msg->val_[i]);
}
}

static void invoke() { }

static bool empty(const std::vector<Scalar>& payload) {
return payload.empty();
}
};

#if MAGISTRATE_KOKKOS_ENABLED

template <typename Scalar>
struct DataHelper<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> {
using DataT = Kokkos::View<Scalar*, Kokkos::HostSpace>;
using DataType = DataHandler<DataT>;

template <typename... Args>
static void assign(DataT& dest, Args&&... data) {
dest = {std::forward<Args>(data)...};
}

static MsgPtr<RabenseifnerMsg<Scalar, DataT>> createMessage(
const DataT& payload, size_t begin, size_t count, size_t id,
int32_t step = 0) {
return vt::makeMessage<RabenseifnerMsg<Scalar, DataT>>(
Kokkos::subview(payload, std::make_pair(begin, begin + count)), id, step
);
}

static void
copy(DataT& dest, size_t start_idx, RabenseifnerMsg<Scalar, DataT>* msg) {
Kokkos::parallel_for(
"Rabenseifner::copy", msg->val_.extent(0),
KOKKOS_LAMBDA(const int i) { dest(start_idx + i) = msg->val_(i); }
);
}

template <template <typename Arg> class Op>
static void reduce(
DataT& dest, size_t start_idx, RabenseifnerMsg<Scalar, DataT>* msg) {
Kokkos::parallel_for(
"Rabenseifner::reduce", msg->val_.extent(0), KOKKOS_LAMBDA(const int i) {
Op<Scalar>()(dest(start_idx + i), msg->val_(i));
}
);
}

static void invoke() { }

static bool empty(const DataT& payload) {
return payload.extent(0) == 0;
}
};

#endif // MAGISTRATE_KOKKOS_ENABLED

struct StateBase {
size_t size_ = {};

bool finished_adjustment_part_ = false;

int32_t mask_ = 1;
int32_t step_ = 0;
bool initialized_ = false;
bool completed_ = false;

// Scatter
int32_t scatter_mask_ = 1;
int32_t scatter_step_ = 0;
int32_t scatter_num_recv_ = 0;
std::vector<bool> scatter_steps_recv_ = {};
std::vector<bool> scatter_steps_reduced_ = {};

bool finished_scatter_part_ = false;

// Gather
int32_t gather_step_ = 0;
int32_t gather_mask_ = 1;
int32_t gather_num_recv_ = 0;
std::vector<bool> gather_steps_recv_ = {};
std::vector<bool> gather_steps_reduced_ = {};

std::vector<uint32_t> r_index_ = {};
std::vector<uint32_t> r_count_ = {};
std::vector<uint32_t> s_index_ = {};
std::vector<uint32_t> s_count_ = {};
};

template <typename Scalar, typename DataT>
struct State : StateBase {
std::vector<Scalar> val_ = {};

MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> left_adjust_message_ = nullptr;
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> right_adjust_message_ = nullptr;
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> scatter_messages_ = {};
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> gather_messages_ = {};
};

#if MAGISTRATE_KOKKOS_ENABLED
template <typename Scalar>
struct State<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> : StateBase {
using DataT = Kokkos::View<Scalar*, Kokkos::HostSpace>;

Kokkos::View<Scalar*, Kokkos::HostSpace> val_ = {};

MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> left_adjust_message_ = nullptr;
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> right_adjust_message_ = nullptr;
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> scatter_messages_ = {};
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> gather_messages_ = {};
};
#endif //MAGISTRATE_KOKKOS_ENABLED

} // namespace vt::collective::reduce::allreduce
#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H*/
Loading

0 comments on commit c5232dc

Please sign in to comment.