diff --git a/include/range_mapper.h b/include/range_mapper.h index 44b93a865..879bdebaa 100644 --- a/include/range_mapper.h +++ b/include/range_mapper.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -159,7 +160,7 @@ namespace access { fixed(const subrange& sr) : m_sr(sr) {} template - subrange operator()(const chunk&) const { + subrange operator()(const chunk& /* chnk */) const { return m_sr; } @@ -216,6 +217,49 @@ namespace access { neighborhood(size_t, size_t)->neighborhood<2>; neighborhood(size_t, size_t, size_t)->neighborhood<3>; + struct kernel_dim { + explicit kernel_dim(const int d) : m_dim(d) {} + + template + subrange<1> operator()(const chunk& chnk) const { + return {chnk.offset[m_dim], chnk.range[m_dim]}; + }; + + private: + int m_dim; + }; + + template + struct components { + constexpr static int buffer_dims = sizeof...(ComponentMappers); + + explicit components(const ComponentMappers&... components) : m_mappers{components...} {} + + template + subrange operator()(const chunk& chnk, const range& buffer_size) const { + return apply(chnk, buffer_size, std::make_integer_sequence()); + } + + private: + std::tuple m_mappers; + + template + void apply_component(const chunk& chnk, const range& buffer_size, subrange& out_sr) const { + static_assert(Dim < buffer_dims); + const auto component_sr = detail::invoke_range_mapper_for_kernel(std::get(m_mappers), chnk, range<1>(buffer_size[Dim])); + out_sr.offset[Dim] = component_sr.offset[0]; + out_sr.range[Dim] = component_sr.range[0]; + } + + template + subrange apply( + const chunk& chnk, const range& buffer_size, std::integer_sequence /* seq */) const { + subrange sr; + (apply_component(chnk, buffer_size, sr), ...); + return sr; + } + }; + } // namespace access namespace experimental::access { diff --git a/test/runtime_tests.cc b/test/runtime_tests.cc index 3e3453ab0..802c3ac93 100644 --- a/test/runtime_tests.cc +++ b/test/runtime_tests.cc @@ -184,6 +184,51 @@ namespace detail { } } + TEST_CASE("kernel_dim built-in range mapper behaves as expected", "[range-mapper]") { + using celerity::access::kernel_dim; + { + range_mapper rm{kernel_dim(0), cl::sycl::access::mode::read, range<1>{128}}; + auto sr = rm.map_1(chunk<1>{{1}, {4}, {7}}); + CHECK(sr.offset == id<1>{1}); + CHECK(sr.range == range<1>{4}); + } + { + range_mapper rm{kernel_dim(1), cl::sycl::access::mode::read, range<1>{128}}; + auto sr = rm.map_1(chunk<3>{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + CHECK(sr.offset == id<1>{2}); + CHECK(sr.range == range<1>{5}); + } + } + + TEST_CASE("components built-in range mapper behaves as expected", "[range-mapper]") { + using celerity::access::components; + using celerity::access::kernel_dim; + { + range_mapper rm{components(all(), all(), all()), cl::sycl::access::mode::read, range<3>{128, 128, 128}}; + auto sr = rm.map_3(chunk<3>{{1, 2, 3}, {40, 50, 60}, {70, 80, 90}}); + CHECK(sr.offset == id<3>{0, 0, 0}); + CHECK(sr.range == range<3>{128, 128, 128}); + } + { + range_mapper rm{components(fixed(subrange<1>(19, 31)), fixed(subrange<1>(15, 44))), cl::sycl::access::mode::read, range<2>{128, 128}}; + auto sr = rm.map_2(chunk<3>{{1, 2, 3}, {40, 50, 60}, {70, 80, 90}}); + CHECK(sr.offset == id<2>{19, 15}); + CHECK(sr.range == range<2>{31, 44}); + } + { + range_mapper rm{components(kernel_dim(2), kernel_dim(0), kernel_dim(1)), cl::sycl::access::mode::read, range<3>{128, 128, 128}}; + auto sr = rm.map_3(chunk<3>{{1, 2, 3}, {40, 50, 60}, {70, 80, 90}}); + CHECK(sr.offset == id<3>{3, 1, 2}); + CHECK(sr.range == range<3>{60, 40, 50}); + } + { + range_mapper rm{components(kernel_dim(0), kernel_dim(0)), cl::sycl::access::mode::read, range<2>{128, 128}}; + auto sr = rm.map_2(chunk<1>{{1}, {40}, {70}}); + CHECK(sr.offset == id<2>{1, 1}); + CHECK(sr.range == range<2>{40, 40}); + } + } + TEST_CASE("even_split built-in range mapper behaves as expected", "[range-mapper]") { { range_mapper rm{even_split<3>(), cl::sycl::access::mode::read, range<3>{128, 345, 678}}; diff --git a/website/pages/en/index.js b/website/pages/en/index.js index daaa2501f..1c89f0ce1 100644 --- a/website/pages/en/index.js +++ b/website/pages/en/index.js @@ -120,9 +120,7 @@ int main() { distr_queue q; q.submit([&](handler &cgh) { // (2) specify data access patterns to enable distributed execution - accessor m(matrix, cgh, [size](chunk<1> chnk) { - return subrange<2>({chnk.offset[0], 0}, {chnk.range[0], size}); - }, read_only); + accessor m(matrix, cgh, access::components(access::kernel_dim(0), access::all()), read_only); accessor v(vector, cgh, access::one_to_one(), read_only); accessor r(result, cgh, access::one_to_one(), write_only, no_init);