forked from jaredhoberock/bulk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
async_reduce.cu
63 lines (46 loc) · 1.37 KB
/
async_reduce.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <thrust/device_vector.h>
#include <bulk/bulk.hpp>
#include <cassert>
struct reduce_kernel
{
template<typename Iterator, typename Pointer>
__device__ void operator()(volatile bool *wait_for_me, Iterator first, Iterator last, Pointer result)
{
while(!*wait_for_me)
{
printf("waiting...\n");
}
*result = thrust::reduce(thrust::device, first, last);
}
};
struct greenlight
{
__device__ void operator()(bool *set_me)
{
*set_me = true;
}
};
int main()
{
cudaStream_t s1,s2;
cudaStreamCreate(&s1);
cudaStreamCreate(&s2);
using bulk::par;
using bulk::async;
thrust::device_vector<int> vec(1 << 20);
thrust::sequence(vec.begin(), vec.end());
thrust::device_vector<int> result(1);
thrust::device_vector<bool> flag(1);
// note we launch the reduction before the greenlight
async(par(s1,1), reduce_kernel(), thrust::raw_pointer_cast(flag.data()), vec.begin(), vec.end(), result.begin());
async(par(s2,1), greenlight(), thrust::raw_pointer_cast(flag.data()));
cudaStreamDestroy(s1);
cudaStreamDestroy(s2);
std::cout << "result: " << thrust::reduce(vec.begin(), vec.end()) << std::endl;
std::cout << "asynchronous result: " << result[0] << std::endl;
assert(thrust::reduce(vec.begin(), vec.end()) == result[0]);
return 0;
}