diff --git a/examples/thrust/thrust.cu b/examples/thrust/thrust.cu index 5b2e1eac4f..5c8778a24f 100644 --- a/examples/thrust/thrust.cu +++ b/examples/thrust/thrust.cu @@ -321,6 +321,29 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size) return p; }; +template +struct ViewIteratorAt +{ + View view; + + LLAMA_FN_HOST_ACC_INLINE auto operator()(std::size_t i) + { + return *(view.begin() + i); + } +}; + +template +auto viewIteratorAt(View& view, std::size_t index) +{ + ViewIteratorAt t{view}; + using ViewTransformIterator = thrust::transform_iterator< + decltype(t), + thrust::counting_iterator, + typename View::iterator::reference, + typename View::iterator::value_type>; + return ViewTransformIterator{thrust::counting_iterator{index}, t}; +} + template void run(std::ostream& plotFile) { @@ -374,15 +397,8 @@ void run(std::ostream& plotFile) std::cout << mappingName << '\n'; auto view = llama::allocView(mapping, thrustDeviceAlloc); - - auto makeViewIteratorFromIndexCreator = [](decltype(view) view) - { return [view] __host__ __device__(std::size_t i) mutable { return *(view.begin() + i); }; }; - auto b = thrust::make_transform_iterator( - thrust::counting_iterator{0}, - makeViewIteratorFromIndexCreator(view)); - auto e = thrust::make_transform_iterator( - thrust::counting_iterator{N}, - makeViewIteratorFromIndexCreator(view)); + auto b = viewIteratorAt(view, 0); + auto e = viewIteratorAt(view, N); // auto b = view.begin(); // auto e = view.end(); @@ -541,9 +557,7 @@ void run(std::ostream& plotFile) { auto dstView = llama::allocView(mapping, thrustDeviceAlloc); - auto db = thrust::make_transform_iterator( - thrust::counting_iterator{0}, - makeViewIteratorFromIndexCreator(dstView)); + auto db = viewIteratorAt(dstView, 0); Stopwatch stopwatch; if constexpr(usePSTL) std::copy(exec, b, e, db); @@ -559,9 +573,7 @@ void run(std::ostream& plotFile) { auto dstView = llama::allocView(mapping, thrustDeviceAlloc); - auto db = thrust::make_transform_iterator( - thrust::counting_iterator{0}, - makeViewIteratorFromIndexCreator(dstView)); + auto db = viewIteratorAt(dstView, 0); Stopwatch stopwatch; if constexpr(usePSTL) std::copy_if(exec, b, e, db, Predicate{});