Skip to content

Commit

Permalink
create LLAMA iterators on the fly and revert hacks
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Nov 29, 2021
1 parent 55a91cf commit 8942dfa
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 39 deletions.
66 changes: 40 additions & 26 deletions examples/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,22 @@ void run(std::ostream& plotFile)

auto view = llama::allocView(mapping, thrustDeviceAlloc);

auto makeViewIteratorFromIndexCreator = [](decltype(view) view)
{ return [view] __host__ __device__(std::size_t i) mutable { return *(view.begin() + i); }; };
auto b = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{0},
makeViewIteratorFromIndexCreator(view));
auto e = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{N},
makeViewIteratorFromIndexCreator(view));
// auto b = view.begin();
// auto e = view.end();

auto r = (*b);
r(tag::eventId{}) = 0;

// touch memory once before running benchmarks
thrust::fill(thrust::device, view.begin(), view.end(), 0);
thrust::fill(thrust::device, b, e, 0);
syncWithCuda();

//#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
Expand Down Expand Up @@ -427,7 +441,7 @@ void run(std::ostream& plotFile)
}
else
{
thrust::tabulate(thrust::device, view.begin(), view.end(), InitOne{});
thrust::tabulate(thrust::device, b, e, InitOne{});
syncWithCuda();
}
tabulateTotal += stopwatch.printAndReset("tabulate", '\t');
Expand All @@ -453,7 +467,7 @@ void run(std::ostream& plotFile)
{
Stopwatch stopwatch;
if constexpr(usePSTL)
std::for_each(exec, view.begin(), view.end(), NormalizeVel{});
std::for_each(exec, b, e, NormalizeVel{});
else
{
thrust::for_each(
Expand All @@ -471,10 +485,10 @@ void run(std::ostream& plotFile)
thrust::device_vector<MassType> dst(N);
Stopwatch stopwatch;
if constexpr(usePSTL)
std::transform(exec, view.begin(), view.end(), dst.begin(), GetMass{});
std::transform(exec, b, e, dst.begin(), GetMass{});
else
{
thrust::transform(thrust::device, view.begin(), view.end(), dst.begin(), GetMass{});
thrust::transform(thrust::device, b, e, dst.begin(), GetMass{});
syncWithCuda();
}
transformTotal += stopwatch.printAndReset("transform", '\t');
Expand All @@ -489,8 +503,8 @@ void run(std::ostream& plotFile)
if constexpr(usePSTL)
std::transform_exclusive_scan(
exec,
view.begin(),
view.end(),
b,
e,
scan_result.begin(),
std::uint32_t{0},
std::plus<>{},
Expand All @@ -499,8 +513,8 @@ void run(std::ostream& plotFile)
{
thrust::transform_exclusive_scan(
thrust::device,
view.begin(),
view.end(),
b,
e,
scan_result.begin(),
Predicate{},
std::uint32_t{0},
Expand All @@ -516,29 +530,26 @@ void run(std::ostream& plotFile)
{
Stopwatch stopwatch;
if constexpr(usePSTL)
sink = std::transform_reduce(exec, view.begin(), view.end(), MassType{0}, std::plus<>{}, GetMass{});
sink = std::transform_reduce(exec, b, e, MassType{0}, std::plus<>{}, GetMass{});
else
{
sink = thrust::transform_reduce(
thrust::device,
view.begin(),
view.end(),
GetMass{},
MassType{0},
thrust::plus<>{});
sink = thrust::transform_reduce(thrust::device, b, e, GetMass{}, MassType{0}, thrust::plus<>{});
syncWithCuda();
}
transformReduceTotal += stopwatch.printAndReset("transform_reduce", '\t');
}

{
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
auto db = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{0},
makeViewIteratorFromIndexCreator(dstView));
Stopwatch stopwatch;
if constexpr(usePSTL)
std::copy(exec, view.begin(), view.end(), dstView.begin());
std::copy(exec, b, e, db);
else
{
thrust::copy(thrust::device, view.begin(), view.end(), dstView.begin());
thrust::copy(thrust::device, b, e, db);
syncWithCuda();
}
copyTotal += stopwatch.printAndReset("copy", '\t');
Expand All @@ -548,12 +559,15 @@ void run(std::ostream& plotFile)

{
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
auto db = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>{0},
makeViewIteratorFromIndexCreator(dstView));
Stopwatch stopwatch;
if constexpr(usePSTL)
std::copy_if(exec, view.begin(), view.end(), dstView.begin(), Predicate{});
std::copy_if(exec, b, e, db, Predicate{});
else
{
thrust::copy_if(thrust::device, view.begin(), view.end(), dstView.begin(), Predicate{});
thrust::copy_if(thrust::device, b, e, db, Predicate{});
syncWithCuda();
}
copyIfTotal += stopwatch.printAndReset("copy_if", '\t');
Expand All @@ -564,10 +578,10 @@ void run(std::ostream& plotFile)
{
Stopwatch stopwatch;
if constexpr(usePSTL)
std::remove_if(exec, view.begin(), view.end(), Predicate{});
std::remove_if(exec, b, e, Predicate{});
else
{
thrust::remove_if(thrust::device, view.begin(), view.end(), Predicate{});
thrust::remove_if(thrust::device, b, e, Predicate{});
syncWithCuda();
}
removeIfTotal += stopwatch.printAndReset("remove_if", '\t');
Expand All @@ -576,14 +590,14 @@ void run(std::ostream& plotFile)
//{
// Stopwatch stopwatch;
// if constexpr(usePSTL)
// std::sort(std::execution::par, view.begin(), view.end(), Less{});
// std::sort(std::execution::par, b, e, Less{});
// else
// {
// thrust::sort(thrust::device, view.begin(), view.end(), Less{});
// thrust::sort(thrust::device, b, e, Less{});
// syncWithCuda();
// }
// sortTotal += stopwatch.printAndReset("sort", '\t');
// if(!thrust::is_sorted(thrust::device, view.begin(), view.end(), Less{}))
// if(!thrust::is_sorted(thrust::device, b, e, Less{}))
// std::cerr << "VALIDATION FAILED\n";
//}

Expand Down
10 changes: 5 additions & 5 deletions include/llama/ArrayIndexRange.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ namespace llama

current[0] = static_cast<difference_type>(current[0]) + n;
// current is either within bounds or at the end ([last + 1, 0, 0, ..., 0])
//assert(
// (current[0] < extents[0]
// || (current[0] == extents[0]
// && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
// && "Iterator was moved past the end");
assert(
(current[0] < extents[0]
|| (current[0] == extents[0]
&& std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
&& "Iterator was moved past the end");

return *this;
}
Expand Down
12 changes: 6 additions & 6 deletions include/llama/View.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ namespace llama
LLAMA_FN_HOST_ACC_INLINE
constexpr auto operator*() const -> reference
{
return const_cast<View&>(view)(*arrayIndex);
return (*view)(*arrayIndex);
}

LLAMA_FN_HOST_ACC_INLINE
Expand Down Expand Up @@ -282,7 +282,7 @@ namespace llama
}

ArrayIndexIterator arrayIndex;
View view;
View* view;
};

/// Central LLAMA class holding memory for storage and giving access to values stored there defined by a mapping. A
Expand Down Expand Up @@ -423,25 +423,25 @@ namespace llama
LLAMA_FN_HOST_ACC_INLINE
auto begin() -> iterator
{
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), *this};
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), this};
}

LLAMA_FN_HOST_ACC_INLINE
auto begin() const -> const_iterator
{
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), *this};
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), this};
}

LLAMA_FN_HOST_ACC_INLINE
auto end() -> iterator
{
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), *this};
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), this};
}

LLAMA_FN_HOST_ACC_INLINE
auto end() const -> const_iterator
{
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), *this};
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), this};
}

Array<BlobType, Mapping::blobCount> storageBlobs;
Expand Down
3 changes: 1 addition & 2 deletions include/llama/VirtualRecord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ namespace llama
using ArrayIndex = typename View::Mapping::ArrayIndex;
using RecordDim = typename View::Mapping::RecordDim;

// std::conditional_t<OwnView, View, View&> view;
View view;
std::conditional_t<OwnView, View, View&> view;

public:
/// Subtree of the record dimension of View starting at BoundRecordCoord. If BoundRecordCoord is
Expand Down

0 comments on commit 8942dfa

Please sign in to comment.