Skip to content

Commit

Permalink
CTOParallelFor with BoxND / add AnyCTO (#4109)
Browse files Browse the repository at this point in the history
## Summary

This PR adds support for BoxND to CTOParallelFor by adding the AnyCTO
function which can be used to implement compile time options with any
kernel launching function such as ParallelFor, ParallelForRNG, launch,
etc.

I'm not sure if AnyCTO is a good name, are there other suggestions?

## Additional background

AnyCTO Examples:
``` C++
    int A_runtime_option = ...;
    int B_runtime_option = ...;
    enum A_options : int { A0, A1, A2, A3 };
    enum B_options : int { B0, B1 };
    AnyCTO(TypeList<CompileTimeOptions<A0,A1,A2,A3>,
                    CompileTimeOptions<B0,B1>>{},
        {A_runtime_option, B_runtime_option},
        [&](auto cto_func){
            ParallelForRNG(N, cto_func);
        },
        [=] AMREX_GPU_DEVICE (int i, const RandomEngine& engine,
                              auto A_control, auto B_control)
        {
            ...
            if constexpr (A_control.value == A0) {
                ...
            } else if constexpr (A_control.value == A1) {
                ...
            } else if constexpr (A_control.value == A2) {
                ...
            else {
                ...
            }
            if constexpr (A_control.value != A3 && B_control.value == B1) {
                ...
            }
            ...
        }
    );


    constexpr int nthreads_per_block = ...;
    int nblocks = ...;
    AnyCTO(TypeList<CompileTimeOptions<A0,A1,A2,A3>,
                    CompileTimeOptions<B0,B1>>{},
        {A_runtime_option, B_runtime_option},
        [&](auto cto_func){
            launch<nthreads_per_block>(nblocks, Gpu::gpuStream(), cto_func);
        },
        [=] AMREX_GPU_DEVICE (auto A_control, auto B_control){
            ...
        }
    );
```
Additionally, .GetOptions() can be used to use the compile time options
in the function that launches the kernel:
```C++
    int nthreads_per_block = ...;
    AnyCTO(TypeList<CompileTimeOptions<128,256,512,1024>>{},
        {nthreads_per_block},
        [&](auto cto_func){
            constexpr std::array<int, 1> ctos = cto_func.GetOptions();
            constexpr int c_nthreads_per_block = ctos[0];
            ParallelFor<c_nthreads_per_block>(N, cto_func);
        },
        [=] AMREX_GPU_DEVICE (int i, auto){
            ...
        }
    );


    BoxND<6> box6D = ...;
    int dims_needed = ...;
    AnyCTO(TypeList<CompileTimeOptions<1,2,3,4,5,6>>{},
        {dims_needed},
        [&](auto cto_func){
            constexpr std::array<int, 1> ctos = cto_func.GetOptions();
            constexpr int c_dims_needed = ctos[0];
            const auto box = BoxShrink<c_dims_needed>(box6D);
            ParallelFor(box, cto_func);
        },
        [=] AMREX_GPU_DEVICE (auto intvect, auto) -> decltype(void(intvect.size())) {
            ...
        }
    );
```
  • Loading branch information
AlexanderSinn committed Sep 2, 2024
1 parent a31abb5 commit de4dc97
Showing 1 changed file with 174 additions and 87 deletions.
261 changes: 174 additions & 87 deletions Src/Base/AMReX_CTOParallelForImpl.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <AMReX_BLassert.H>
#include <AMReX_Box.H>
#include <AMReX_Tuple.H>
#include <AMReX_TypeList.H>

#include <array>
#include <type_traits>
Expand All @@ -18,125 +18,212 @@ namespace amrex {

template <int... ctr>
struct CompileTimeOptions {
// TypeList is defined in AMReX_Tuple.H
// TypeList is defined in AMReX_TypeList.H
using list_type = TypeList<std::integral_constant<int, ctr>...>;
};

#if (__cplusplus >= 201703L)

namespace detail
{
template <int MT, typename T, class F, typename... As>
std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T,Box>, bool>
ParallelFor_helper2 (T const& N, F const& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
{
if (runtime_options == std::array<int,sizeof...(As)>{As::value...}) {
if constexpr (std::is_integral_v<T>) {
ParallelFor<MT>(N, [f] AMREX_GPU_DEVICE (T i) noexcept
{
f(i, As{}...);
});
} else {
ParallelFor<MT>(N, [f] AMREX_GPU_DEVICE (int i, int j, int k) noexcept
{
f(i, j, k, As{}...);
});
}
return true;
} else {
return false;
template<class F, int... ctr>
struct CTOWrapper {
F f;

template<class... Args>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto operator() (Args... args) const noexcept
-> decltype(f(args..., std::integral_constant<int, ctr>{}...)) {
return f(args..., std::integral_constant<int, ctr>{}...);
}
}

template <int MT, typename T, class F, typename... As>
std::enable_if_t<std::is_integral_v<T>, bool>
ParallelFor_helper2 (Box const& box, T ncomp, F const& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
static constexpr
std::array<int, sizeof...(ctr)> GetOptions () noexcept {
return {ctr...};
}
};

template <class L, class F, typename... As>
bool
AnyCTO_helper2 (const L& l, const F& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
{
if (runtime_options == std::array<int,sizeof...(As)>{As::value...}) {
ParallelFor<MT>(box, ncomp, [f] AMREX_GPU_DEVICE (int i, int j, int k, T n) noexcept
{
f(i, j, k, n, As{}...);
});
l(CTOWrapper<F, As::value...>{f});
return true;
} else {
return false;
}
}

template <int MT, typename T, class F, typename... PPs, typename RO>
std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T,Box>>
ParallelFor_helper1 (T const& N, F const& f, TypeList<PPs...>,
RO const& runtime_options)
{
bool found_option = (false || ... ||
ParallelFor_helper2<MT>(N, f,
PPs{}, runtime_options));
amrex::ignore_unused(found_option);
AMREX_ASSERT(found_option);
}

template <int MT, typename T, class F, typename... PPs, typename RO>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor_helper1 (Box const& box, T ncomp, F const& f, TypeList<PPs...>,
RO const& runtime_options)
template <class L, class F, typename... PPs, typename RO>
void
AnyCTO_helper1 (const L& l, const F& f, TypeList<PPs...>, RO const& runtime_options)
{
bool found_option = (false || ... ||
ParallelFor_helper2<MT>(box, ncomp, f,
PPs{}, runtime_options));
bool found_option = (false || ... || AnyCTO_helper2(l, f, PPs{}, runtime_options));
amrex::ignore_unused(found_option);
AMREX_ASSERT(found_option);
}
}

#endif

template <int MT, typename T, class F, typename... CTOs>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor (TypeList<CTOs...> /*list_of_compile_time_options*/,
/**
* \brief Compile time optimization of kernels with run time options.
*
* This is a generalized version of ParallelFor with CTOs that can support any function that
* takes in one lambda to launch a GPU kernel such as ParallelFor, ParallelForRNG, launch, etc.
* It uses fold expression to generate kernel launches for all combinations
* of the run time options. The kernel function can use constexpr if to
* discard unused code blocks for better run time performance. In the
* example below, the code will be expanded into 4*2=8 normal ParallelForRNGs
* for all combinations of the run time parameters.
\verbatim
int A_runtime_option = ...;
int B_runtime_option = ...;
enum A_options : int { A0, A1, A2, A3 };
enum B_options : int { B0, B1 };
AnyCTO(TypeList<CompileTimeOptions<A0,A1,A2,A3>,
CompileTimeOptions<B0,B1>>{},
{A_runtime_option, B_runtime_option},
[&](auto cto_func){
ParallelForRNG(N, cto_func);
},
[=] AMREX_GPU_DEVICE (int i, const RandomEngine& engine,
auto A_control, auto B_control)
{
...
if constexpr (A_control.value == A0) {
...
} else if constexpr (A_control.value == A1) {
...
} else if constexpr (A_control.value == A2) {
...
} else {
...
}
if constexpr (A_control.value != A3 && B_control.value == B1) {
...
}
...
}
);
constexpr int nthreads_per_block = ...;
int nblocks = ...;
AnyCTO(TypeList<CompileTimeOptions<A0,A1,A2,A3>,
CompileTimeOptions<B0,B1>>{},
{A_runtime_option, B_runtime_option},
[&](auto cto_func){
launch<nthreads_per_block>(nblocks, Gpu::gpuStream(), cto_func);
},
[=] AMREX_GPU_DEVICE (auto A_control, auto B_control){
...
}
);
\endverbatim
* The static member function cto_func.GetOptions() can be used to obtain the runtime_options
* passed into AnyCTO, but at compile time. This enables some advanced use cases,
* such as changing the number of threads per block or the dimensionality of ParallelFor at runtime.
* For the second example -> decltype(void(intvect.size())) is necessary to
* disambiguate IntVectND<1> and int for the first argument of the kernel function.
\verbatim
int nthreads_per_block = ...;
AnyCTO(TypeList<CompileTimeOptions<128,256,512,1024>>{},
{nthreads_per_block},
[&](auto cto_func){
constexpr std::array<int, 1> ctos = cto_func.GetOptions();
constexpr int c_nthreads_per_block = ctos[0];
ParallelFor<c_nthreads_per_block>(N, cto_func);
},
[=] AMREX_GPU_DEVICE (int i, auto){
...
}
);
BoxND<6> box6D = ...;
int dims_needed = ...;
AnyCTO(TypeList<CompileTimeOptions<1,2,3,4,5,6>>{},
{dims_needed},
[&](auto cto_func){
constexpr std::array<int, 1> ctos = cto_func.GetOptions();
constexpr int c_dims_needed = ctos[0];
const auto box = BoxShrink<c_dims_needed>(box6D);
ParallelFor(box, cto_func);
},
[=] AMREX_GPU_DEVICE (auto intvect, auto) -> decltype(void(intvect.size())) {
...
}
);
\endverbatim
* Note that due to a limitation of CUDA's extended device lambda, the
* constexpr if block cannot be the one that captures a variable first.
* If nvcc complains about it, you will have to manually capture it outside
* constexpr if. Alternatively, the constexpr if can be replaced with a regular if.
* Compilers can still perform the same optimizations since the condition is known at compile time.
* The data type for the parameters is int.
*
* \param list_of_compile_time_options list of all possible values of the parameters.
* \param runtime_options the run time parameters.
* \param l a callable object containing a CPU function that launches the provided GPU kernel.
* \param f a callable object containing the GPU kernel with optimizations.
*/
template <class L, class F, typename... CTOs>
void AnyCTO ([[maybe_unused]] TypeList<CTOs...> list_of_compile_time_options,
std::array<int,sizeof...(CTOs)> const& runtime_options,
T N, F&& f)
L&& l, F&& f)
{
#if (__cplusplus >= 201703L)
detail::ParallelFor_helper1<MT>(N, std::forward<F>(f),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
detail::AnyCTO_helper1(std::forward<L>(l), std::forward<F>(f),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
#else
amrex::ignore_unused(N, f, runtime_options);
amrex::ignore_unused(runtime_options, l, f);
static_assert(std::is_integral<F>::value, "This requires C++17");
#endif
}

template <int MT, class F, typename... CTOs>
void ParallelFor (TypeList<CTOs...> /*list_of_compile_time_options*/,
template <int MT, typename T, class F, typename... CTOs>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& runtime_options,
T N, F&& f)
{
AnyCTO(ctos, runtime_options,
[&](auto cto_func){
ParallelFor<MT>(N, cto_func);
},
std::forward<F>(f)
);
}

template <int MT, class F, int dim, typename... CTOs>
void ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& runtime_options,
Box const& box, F&& f)
BoxND<dim> const& box, F&& f)
{
#if (__cplusplus >= 201703L)
detail::ParallelFor_helper1<MT>(box, std::forward<F>(f),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
#else
amrex::ignore_unused(box, f, runtime_options);
static_assert(std::is_integral<F>::value, "This requires C++17");
#endif
AnyCTO(ctos, runtime_options,
[&](auto cto_func){
ParallelFor<MT>(box, cto_func);
},
std::forward<F>(f)
);
}

template <int MT, typename T, class F, typename... CTOs>
template <int MT, typename T, class F, int dim, typename... CTOs>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor (TypeList<CTOs...> /*list_of_compile_time_options*/,
ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& runtime_options,
Box const& box, T ncomp, F&& f)
BoxND<dim> const& box, T ncomp, F&& f)
{
#if (__cplusplus >= 201703L)
detail::ParallelFor_helper1<MT>(box, ncomp, std::forward<F>(f),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
#else
amrex::ignore_unused(box, ncomp, f, runtime_options);
static_assert(std::is_integral<F>::value, "This requires C++17");
#endif
AnyCTO(ctos, runtime_options,
[&](auto cto_func){
ParallelFor<MT>(box, ncomp, cto_func);
},
std::forward<F>(f)
);
}

/**
Expand Down Expand Up @@ -164,7 +251,7 @@ ParallelFor (TypeList<CTOs...> /*list_of_compile_time_options*/,
...
} else if constexpr (A_control.value == A2) {
...
else {
} else {
...
}
if constexpr (A_control.value != A3 && B_control.value == B1) {
Expand Down Expand Up @@ -218,7 +305,7 @@ ParallelFor (TypeList<CTOs...> ctos,
...
} else if constexpr (A_control.value == A2) {
...
else {
} else {
...
}
if constexpr (A_control.value != A3 && B_control.value == B1) {
Expand All @@ -237,10 +324,10 @@ ParallelFor (TypeList<CTOs...> ctos,
* \param box a Box specifying the 3D for loop's range.
* \param f a callable object taking three integers and working on the given cell.
*/
template <class F, typename... CTOs>
template <class F, int dim, typename... CTOs>
void ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& option,
Box const& box, F&& f)
BoxND<dim> const& box, F&& f)
{
ParallelFor<AMREX_GPU_MAX_THREADS>(ctos, option, box, std::forward<F>(f));
}
Expand Down Expand Up @@ -271,7 +358,7 @@ void ParallelFor (TypeList<CTOs...> ctos,
...
} else if constexpr (A_control.value == A2) {
...
else {
} else {
...
}
if constexpr (A_control.value != A3 && B_control.value == B1) {
Expand All @@ -291,11 +378,11 @@ void ParallelFor (TypeList<CTOs...> ctos,
* \param ncomp an integer specifying the range for iteration over components.
* \param f a callable object taking three integers and working on the given cell.
*/
template <typename T, class F, typename... CTOs>
template <typename T, class F, int dim, typename... CTOs>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& option,
Box const& box, T ncomp, F&& f)
BoxND<dim> const& box, T ncomp, F&& f)
{
ParallelFor<AMREX_GPU_MAX_THREADS>(ctos, option, box, ncomp, std::forward<F>(f));
}
Expand Down

0 comments on commit de4dc97

Please sign in to comment.