Skip to content

Commit

Permalink
Add the collective operation broadcast (bcast).
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukas Hübner committed Mar 25, 2022
1 parent 4f8b20d commit 69f7a9c
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 2 deletions.
137 changes: 137 additions & 0 deletions include/kamping/collectives/bcast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// This file is part of KaMPI.ng.
//
// Copyright 2022 The KaMPI.ng Authors
//
// KaMPI.ng is free software : you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
// version. KaMPI.ng is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
// implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
// for more details.
//
// You should have received a copy of the GNU Lesser General Public License along with KaMPI.ng. If not, see
// <https://www.gnu.org/licenses/>.

#pragma once

#include "kamping/checking_casts.hpp"
#include "kamping/mpi_datatype.hpp"
#include "kamping/mpi_function_wrapper_helpers.hpp"
#include "kamping/named_parameter_selection.hpp"
#include "kamping/parameter_factories.hpp"
#include "kamping/parameter_objects.hpp"
#include "kamping/parameter_type_definitions.hpp"
#include <mpi.h>

namespace kamping::internal {

/// @brief CRTP mixin class for \c MPI_Bcast.
///
/// This class is only to be used as a super class of kamping::Communicator
template <typename Communicator>
class Bcast : public CRTPHelper<Communicator, Bcast> {
public:
/// @brief Wrapper for \c MPI_Bcast
///
/// This wrapper for \c MPI_Bcast sends data from the root to all other ranks.
/// The following buffers are required:
/// - \ref kamping::send_buf() containing the data that is sent to the other ranks.
/// The following parameters are optional:
/// - \ref kamping::root() specifying an alternative root. If not present, the default root of the \c Communicator
/// is used, see root().
/// - \ref kamping::recv_buf() containing a buffer for the output. Afterwards, at all other ranks, this buffer will
/// contain the data from the root send buffer.
/// @todo Describe what happens at the root
/// @tparam Args Automatically deducted template parameters.
/// @param args All required and any number of the optional buffers described above.
/// @return Result type wrapping the output buffer if not specified as input parameter.
template <typename... Args>
auto bcast(Args&&... args) {
static_assert(
all_parameters_are_rvalues<Args...>,
"All parameters have to be passed in as rvalue references, meaning that you must not hold a variable "
"returned by the named parameter helper functions like recv_buf().");

// Check and get all parameters.

// The parameter send_recv_buf() is required on all processes.
static_assert(
internal::has_parameter_type<internal::ParameterType::send_recv_buf, Args...>(),
"Missing required parameter send_recv_buf.");

const auto& send_recv_buf =
internal::select_parameter_type<internal::ParameterType::send_recv_buf>(args...).get();
using value_type = typename std::remove_reference_t<decltype(send_recv_buf)>::value_type;

auto&& root = internal::select_parameter_type_or_default<internal::ParameterType::root, internal::Root>(
std::tuple(this->underlying().root()), args...);

auto mpi_value_type = mpi_datatype<value_type>();

// Conduct some validity check on the parmeters.
KASSERT(this->underlying().is_valid_rank(root.rank()), "Invalid rank as root.", assert::light);

KASSERT(
this->underlying().rank() != root.rank() || send_recv_buf.size() > 0,
"The send_recv_buf() on root is empty.", assert::light);

KASSERT(
recv_buf_large_enough_on_all_processes(send_recv_buf, root.rank()),
"The receive buffer is too small on at least one rank.", assert::light_communication);

// Perform the broadcast.
// int size = 0;
// void* buffer = nullptr;
// if constexpr (internal:: ))

// The error code is unused if KTHROW is removed at compile time.
// KASSERT(size != 0, assert::light);
// KASSERT(buffer != nullptr, assert::light);
[[maybe_unused]] int err = MPI_Bcast(
send_recv_buf.data(), // buffer*
asserting_cast<int>(send_recv_buf.size()), // count
mpi_value_type, // datatype
root.rank(), // root
this->underlying().mpi_communicator() // MPI_Comm comm
);
THROW_IF_MPI_ERROR(err, MPI_Bcast);

return MPIResult(
std::move(send_recv_buf), internal::BufferCategoryNotUsed{}, internal::BufferCategoryNotUsed{},
internal::BufferCategoryNotUsed{});
}

protected:
Bcast() {}

private:
/// @brief Checks if the receive buffer is large enough to receive all elements on all ranks.
///
/// Broadcasts the size of the send buffer (which is equal to the recv_buf) from the root rank,
/// performs local comparison and collects the result using an allreduce.
/// @param send_recv_buf The send buffer on root, the receive buffer on all other ranks.
/// @param root The rank of the root process.
template <typename RecvBuf>
bool recv_buf_large_enough_on_all_processes(RecvBuf const& send_recv_buf, int const root) const {
uint64_t size = send_recv_buf.size();
MPI_Bcast(
&size, // src/dest buffer
1, // size
mpi_datatype<decltype(size)>(), // datatype
root, // root
this->underlying().mpi_communicator() // communicator
);
bool const local_buffer_large_enough = size <= send_recv_buf.size();
bool every_buffer_large_enough;
MPI_Allreduce(
&local_buffer_large_enough, // src buffer
&every_buffer_large_enough, // dest buffer
1, // count
mpi_datatype<bool>(), // datatype
MPI_LAND, // operation
this->underlying().mpi_communicator() // communicator
);
return every_buffer_large_enough;
}
}; // class Bcast

} // namespace kamping::internal
5 changes: 4 additions & 1 deletion include/kamping/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@

#include "error_handling.hpp"
#include "kamping/collectives/alltoall.hpp"
#include "kamping/collectives/bcast.hpp"
#include "kamping/collectives/reduce.hpp"
#include "kamping/kassert.hpp"

namespace kamping {

/// @brief Wrapper for MPI communicator providing access to \ref rank() and \ref size() of the communicator. The \ref
/// Communicator is also access point to all MPI communications provided by KaMPI.ng.
class Communicator : public internal::Alltoall<Communicator>, public internal::Reduce<Communicator> {
class Communicator : public internal::Alltoall<Communicator>,
public internal::Reduce<Communicator>,
public internal::Bcast<Communicator> {
public:
/// @brief Default constructor not specifying any MPI communicator and using \c MPI_COMM_WORLD by default.
Communicator() : Communicator(MPI_COMM_WORLD) {}
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ kamping_register_mpi_test(test_mpi_datatype FILES mpi_datatype_test.cpp CORES 1)
kamping_register_mpi_test(test_mpi_communicator FILES mpi_communicator_test.cpp CORES 1 2 4)
kamping_register_mpi_test(test_mpi_alltoall FILES collectives/mpi_alltoall_test.cpp CORES 1 2 4)
kamping_register_mpi_test(test_mpi_reduce FILES collectives/mpi_reduce_test.cpp CORES 1 2 4)

kamping_register_mpi_test(test_mpi_bcast FILES collectives/mpi_bcast_test.cpp CORES 1 2 4)

kamping_register_compilation_failure_test(
test_mpi_datatype_unsupported_types
Expand Down
61 changes: 61 additions & 0 deletions tests/collectives/mpi_bcast_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

// This file is part of KaMPI.ng.
//
// Copyright 2022 The KaMPI.ng Authors
//
// KaMPI.ng is free software : you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
// version. KaMPI.ng is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
// implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
// for more details.
//
// You should have received a copy of the GNU Lesser General Public License along with KaMPI.ng. If not, see
// <https://www.gnu.org/licenses/>.

#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include <numeric>

#include "../helpers_for_testing.hpp"
#include "kamping/communicator.hpp"


using namespace ::kamping;
using namespace ::testing;

TEST(BcastTest, SingleElement) {
Communicator comm;

// Basic use case, broadcast a single POD.
int value = comm.rank();
comm.bcast(send_recv_buf(value));
EXPECT_EQ(value, comm.root());

// The following test is only valid if the communicator has more than one rank.
if (comm.size() >= 2) {
// Broadcast a single POD to all processes, manually specify the root process.
value = comm.rank();
comm.bcast(send_recv_buf(value), root(1));
EXPECT_EQ(value, 1);

// Broadcast a single POD to all processes, use a non-default communicator's root.
value = comm.rank();
const int new_root = 1;
comm.root(new_root);
ASSERT_EQ(new_root, comm.root());
comm.bcast(send_recv_buf(value));
EXPECT_EQ(value, new_root);
}
}

TEST(BcastTest, Vector) {
Communicator comm;

std::vector<int> values(4);
if (comm.is_root()) {
std::fill(values.begin(), values.end(), comm.rank());
}

comm.bcast(send_recv_buf(values));
EXPECT_THAT(values, Each(Eq(comm.root())));
}

0 comments on commit 69f7a9c

Please sign in to comment.