From 8942dfa8afee678bff83d4841cd0d56140efb035 Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Sun, 5 Sep 2021 16:41:18 +0200 Subject: [PATCH] create LLAMA iterators on the fly and revert hacks --- examples/thrust/thrust.cu | 66 +++++++++++++++++++------------ include/llama/ArrayIndexRange.hpp | 10 ++--- include/llama/View.hpp | 12 +++--- include/llama/VirtualRecord.hpp | 3 +- 4 files changed, 52 insertions(+), 39 deletions(-) diff --git a/examples/thrust/thrust.cu b/examples/thrust/thrust.cu index 82034895a0..5b2e1eac4f 100644 --- a/examples/thrust/thrust.cu +++ b/examples/thrust/thrust.cu @@ -375,8 +375,22 @@ void run(std::ostream& plotFile) auto view = llama::allocView(mapping, thrustDeviceAlloc); + auto makeViewIteratorFromIndexCreator = [](decltype(view) view) + { return [view] __host__ __device__(std::size_t i) mutable { return *(view.begin() + i); }; }; + auto b = thrust::make_transform_iterator( + thrust::counting_iterator{0}, + makeViewIteratorFromIndexCreator(view)); + auto e = thrust::make_transform_iterator( + thrust::counting_iterator{N}, + makeViewIteratorFromIndexCreator(view)); + // auto b = view.begin(); + // auto e = view.end(); + + auto r = (*b); + r(tag::eventId{}) = 0; + // touch memory once before running benchmarks - thrust::fill(thrust::device, view.begin(), view.end(), 0); + thrust::fill(thrust::device, b, e, 0); syncWithCuda(); //#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA @@ -427,7 +441,7 @@ void run(std::ostream& plotFile) } else { - thrust::tabulate(thrust::device, view.begin(), view.end(), InitOne{}); + thrust::tabulate(thrust::device, b, e, InitOne{}); syncWithCuda(); } tabulateTotal += stopwatch.printAndReset("tabulate", '\t'); @@ -453,7 +467,7 @@ void run(std::ostream& plotFile) { Stopwatch stopwatch; if constexpr(usePSTL) - std::for_each(exec, view.begin(), view.end(), NormalizeVel{}); + std::for_each(exec, b, e, NormalizeVel{}); else { thrust::for_each( @@ -471,10 +485,10 @@ void run(std::ostream& plotFile) thrust::device_vector dst(N); Stopwatch stopwatch; if constexpr(usePSTL) - std::transform(exec, view.begin(), view.end(), dst.begin(), GetMass{}); + std::transform(exec, b, e, dst.begin(), GetMass{}); else { - thrust::transform(thrust::device, view.begin(), view.end(), dst.begin(), GetMass{}); + thrust::transform(thrust::device, b, e, dst.begin(), GetMass{}); syncWithCuda(); } transformTotal += stopwatch.printAndReset("transform", '\t'); @@ -489,8 +503,8 @@ void run(std::ostream& plotFile) if constexpr(usePSTL) std::transform_exclusive_scan( exec, - view.begin(), - view.end(), + b, + e, scan_result.begin(), std::uint32_t{0}, std::plus<>{}, @@ -499,8 +513,8 @@ void run(std::ostream& plotFile) { thrust::transform_exclusive_scan( thrust::device, - view.begin(), - view.end(), + b, + e, scan_result.begin(), Predicate{}, std::uint32_t{0}, @@ -516,16 +530,10 @@ void run(std::ostream& plotFile) { Stopwatch stopwatch; if constexpr(usePSTL) - sink = std::transform_reduce(exec, view.begin(), view.end(), MassType{0}, std::plus<>{}, GetMass{}); + sink = std::transform_reduce(exec, b, e, MassType{0}, std::plus<>{}, GetMass{}); else { - sink = thrust::transform_reduce( - thrust::device, - view.begin(), - view.end(), - GetMass{}, - MassType{0}, - thrust::plus<>{}); + sink = thrust::transform_reduce(thrust::device, b, e, GetMass{}, MassType{0}, thrust::plus<>{}); syncWithCuda(); } transformReduceTotal += stopwatch.printAndReset("transform_reduce", '\t'); @@ -533,12 +541,15 @@ void run(std::ostream& plotFile) { auto dstView = llama::allocView(mapping, thrustDeviceAlloc); + auto db = thrust::make_transform_iterator( + thrust::counting_iterator{0}, + makeViewIteratorFromIndexCreator(dstView)); Stopwatch stopwatch; if constexpr(usePSTL) - std::copy(exec, view.begin(), view.end(), dstView.begin()); + std::copy(exec, b, e, db); else { - thrust::copy(thrust::device, view.begin(), view.end(), dstView.begin()); + thrust::copy(thrust::device, b, e, db); syncWithCuda(); } copyTotal += stopwatch.printAndReset("copy", '\t'); @@ -548,12 +559,15 @@ void run(std::ostream& plotFile) { auto dstView = llama::allocView(mapping, thrustDeviceAlloc); + auto db = thrust::make_transform_iterator( + thrust::counting_iterator{0}, + makeViewIteratorFromIndexCreator(dstView)); Stopwatch stopwatch; if constexpr(usePSTL) - std::copy_if(exec, view.begin(), view.end(), dstView.begin(), Predicate{}); + std::copy_if(exec, b, e, db, Predicate{}); else { - thrust::copy_if(thrust::device, view.begin(), view.end(), dstView.begin(), Predicate{}); + thrust::copy_if(thrust::device, b, e, db, Predicate{}); syncWithCuda(); } copyIfTotal += stopwatch.printAndReset("copy_if", '\t'); @@ -564,10 +578,10 @@ void run(std::ostream& plotFile) { Stopwatch stopwatch; if constexpr(usePSTL) - std::remove_if(exec, view.begin(), view.end(), Predicate{}); + std::remove_if(exec, b, e, Predicate{}); else { - thrust::remove_if(thrust::device, view.begin(), view.end(), Predicate{}); + thrust::remove_if(thrust::device, b, e, Predicate{}); syncWithCuda(); } removeIfTotal += stopwatch.printAndReset("remove_if", '\t'); @@ -576,14 +590,14 @@ void run(std::ostream& plotFile) //{ // Stopwatch stopwatch; // if constexpr(usePSTL) - // std::sort(std::execution::par, view.begin(), view.end(), Less{}); + // std::sort(std::execution::par, b, e, Less{}); // else // { - // thrust::sort(thrust::device, view.begin(), view.end(), Less{}); + // thrust::sort(thrust::device, b, e, Less{}); // syncWithCuda(); // } // sortTotal += stopwatch.printAndReset("sort", '\t'); - // if(!thrust::is_sorted(thrust::device, view.begin(), view.end(), Less{})) + // if(!thrust::is_sorted(thrust::device, b, e, Less{})) // std::cerr << "VALIDATION FAILED\n"; //} diff --git a/include/llama/ArrayIndexRange.hpp b/include/llama/ArrayIndexRange.hpp index 6c7ffa91ff..9539bbd892 100644 --- a/include/llama/ArrayIndexRange.hpp +++ b/include/llama/ArrayIndexRange.hpp @@ -118,11 +118,11 @@ namespace llama current[0] = static_cast(current[0]) + n; // current is either within bounds or at the end ([last + 1, 0, 0, ..., 0]) - //assert( - // (current[0] < extents[0] - // || (current[0] == extents[0] - // && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; }))) - // && "Iterator was moved past the end"); + assert( + (current[0] < extents[0] + || (current[0] == extents[0] + && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; }))) + && "Iterator was moved past the end"); return *this; } diff --git a/include/llama/View.hpp b/include/llama/View.hpp index 22b49227f2..d101c426be 100644 --- a/include/llama/View.hpp +++ b/include/llama/View.hpp @@ -187,7 +187,7 @@ namespace llama LLAMA_FN_HOST_ACC_INLINE constexpr auto operator*() const -> reference { - return const_cast(view)(*arrayIndex); + return (*view)(*arrayIndex); } LLAMA_FN_HOST_ACC_INLINE @@ -282,7 +282,7 @@ namespace llama } ArrayIndexIterator arrayIndex; - View view; + View* view; }; /// Central LLAMA class holding memory for storage and giving access to values stored there defined by a mapping. A @@ -423,25 +423,25 @@ namespace llama LLAMA_FN_HOST_ACC_INLINE auto begin() -> iterator { - return {ArrayIndexRange{mapping().extents()}.begin(), *this}; + return {ArrayIndexRange{mapping().extents()}.begin(), this}; } LLAMA_FN_HOST_ACC_INLINE auto begin() const -> const_iterator { - return {ArrayIndexRange{mapping().extents()}.begin(), *this}; + return {ArrayIndexRange{mapping().extents()}.begin(), this}; } LLAMA_FN_HOST_ACC_INLINE auto end() -> iterator { - return {ArrayIndexRange{mapping().extents()}.end(), *this}; + return {ArrayIndexRange{mapping().extents()}.end(), this}; } LLAMA_FN_HOST_ACC_INLINE auto end() const -> const_iterator { - return {ArrayIndexRange{mapping().extents()}.end(), *this}; + return {ArrayIndexRange{mapping().extents()}.end(), this}; } Array storageBlobs; diff --git a/include/llama/VirtualRecord.hpp b/include/llama/VirtualRecord.hpp index af237d419d..cbd86ec2ef 100644 --- a/include/llama/VirtualRecord.hpp +++ b/include/llama/VirtualRecord.hpp @@ -352,8 +352,7 @@ namespace llama using ArrayIndex = typename View::Mapping::ArrayIndex; using RecordDim = typename View::Mapping::RecordDim; - // std::conditional_t view; - View view; + std::conditional_t view; public: /// Subtree of the record dimension of View starting at BoundRecordCoord. If BoundRecordCoord is