diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c72a82ca..8ef8eff20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Versioning](http://semver.org/spec/v2.0.0.html). - Introduce new experimental `for_each_item` utility to iterate over a celerity range (#199) - Add new environment variables `CELERITY_HORIZON_STEP` and `CELERITY_HORIZON_MAX_PARALLELISM` to control Horizon generation (#199) +- Add new `experimental::constrain_split` API to limit how a kernel can be split (#?) ## Changed diff --git a/include/handler.h b/include/handler.h index 440119c36..4b6bb5820 100644 --- a/include/handler.h +++ b/include/handler.h @@ -31,6 +31,11 @@ namespace celerity { class handler; +namespace experimental { + template + void constrain_split(handler& cgh, const range& constraint); +} + namespace detail { class device_queue; class task_manager; @@ -353,7 +358,8 @@ class handler { */ template void host_task(range global_range, id global_offset, Functor&& kernel) { - const detail::task_geometry geometry{Dims, detail::range_cast<3>(global_range), detail::id_cast<3>(global_offset), {1, 1, 1}}; + const detail::task_geometry geometry{ + Dims, detail::range_cast<3>(global_range), detail::id_cast<3>(global_offset), get_constrained_granularity(range(detail::unit_range))}; auto launcher = make_host_task_launcher(detail::range_cast<3>(global_range), 0, std::forward(kernel)); create_host_compute_task(geometry, std::move(launcher)); } @@ -372,6 +378,8 @@ class handler { friend detail::hydration_id detail::add_requirement(handler& cgh, const detail::buffer_id bid, std::unique_ptr rm); friend void detail::add_requirement(handler& cgh, const detail::host_object_id hoid, const experimental::side_effect_order order, const bool is_void); friend void detail::add_reduction(handler& cgh, const detail::reduction_info& rinfo); + template + friend void experimental::constrain_split(handler& cgh, const range& constraint); friend void detail::extend_lifetime(handler& cgh, std::shared_ptr state); detail::task_id m_tid; @@ -383,6 +391,7 @@ class handler { size_t m_num_collective_nodes; detail::hydration_id m_next_accessor_hydration_id = 1; std::vector> m_attached_state; + range<3> m_split_constraint = detail::unit_range; handler(detail::task_id tid, size_t num_collective_nodes) : m_tid(tid), m_num_collective_nodes(num_collective_nodes) {} @@ -409,7 +418,8 @@ class handler { granularity[d] = local_range[d]; } } - const detail::task_geometry geometry{Dims, detail::range_cast<3>(global_range), detail::id_cast<3>(global_offset), granularity}; + const detail::task_geometry geometry{ + Dims, detail::range_cast<3>(global_range), detail::id_cast<3>(global_offset), get_constrained_granularity(detail::range_cast(granularity))}; auto launcher = make_device_kernel_launcher( global_range, global_offset, local_range, std::forward(kernel), std::index_sequence_for(), reductions...); create_device_compute_task(geometry, detail::kernel_debug_name(), std::move(launcher)); @@ -434,6 +444,23 @@ class handler { void extend_lifetime(std::shared_ptr state) { m_attached_state.emplace_back(std::move(state)); } + template + void experimental_constrain_split(const range& constraint) { + assert(m_task == nullptr); + m_split_constraint = detail::range_cast<3>(constraint); + } + + template + range<3> get_constrained_granularity(const range& granularity) const { + range<3> result = detail::range_cast<3>(granularity); + for(size_t i = 0; i < Dims; ++i) { + const auto lcm = std::lcm(granularity[i], m_split_constraint[i]); + if(lcm == 0) { throw std::runtime_error("Split constraint cannot be 0"); } + result[i] = lcm; + } + return result; + } + void create_host_compute_task(detail::task_geometry geometry, std::unique_ptr launcher) { assert(m_task == nullptr); if(geometry.global_size.size() == 0) { @@ -622,4 +649,11 @@ template return detail::reduction_impl(vars, cgh, identity, combiner, prop_list); } +namespace experimental { + template + void constrain_split(handler& cgh, const range& constraint) { + cgh.experimental_constrain_split(constraint); + } +} // namespace experimental + } // namespace celerity diff --git a/test/distributed_graph_generator_test_utils.h b/test/distributed_graph_generator_test_utils.h index c77fd14e8..36aac31d7 100644 --- a/test/distributed_graph_generator_test_utils.h +++ b/test/distributed_graph_generator_test_utils.h @@ -76,6 +76,11 @@ class task_builder { return chain([&host_obj, order](handler& cgh) { host_obj.add_side_effect(cgh, order); }); } + template + step constrain_split(const range& constraint) { + return chain([constraint](handler& cgh) { experimental::constrain_split(cgh, constraint); }); + } + private: dist_cdag_test_context& m_dctx; std::deque m_actions; diff --git a/test/graph_gen_granularity_tests.cc b/test/graph_gen_granularity_tests.cc index 9036f0677..1eaf04ad5 100644 --- a/test/graph_gen_granularity_tests.cc +++ b/test/graph_gen_granularity_tests.cc @@ -57,6 +57,23 @@ TEST_CASE("distributed_graph_generator respects task granularity when splitting" } } +TEST_CASE("distributed_graph_generator respects split constraints", "[distributed_graph_generator]") { + const size_t num_nodes = 4; + dist_cdag_test_context dctx(num_nodes); + + // Split constraints use the same underlying mechanisms as task granularity (tested above), so we'll keep this brief + const auto tid_a = dctx.device_compute(range<1>{128}).constrain_split(range<1>{64}).submit(); + REQUIRE(dctx.query(tid_a).count() == 2); + CHECK(dynamic_cast(dctx.query(tid_a).get_raw(0)[0])->get_execution_range().range == range<3>{64, 1, 1}); + CHECK(dynamic_cast(dctx.query(tid_a).get_raw(1)[0])->get_execution_range().range == range<3>{64, 1, 1}); + + // The more interesting aspect is that a constrained nd-range kernel uses the least common multiple of the two constraints + const auto tid_b = dctx.device_compute(nd_range<1>{{192}, {32}}).constrain_split(range<1>{3}).submit(); + REQUIRE(dctx.query(tid_b).count() == 2); + CHECK(dynamic_cast(dctx.query(tid_b).get_raw(0)[0])->get_execution_range().range == range<3>{96, 1, 1}); + CHECK(dynamic_cast(dctx.query(tid_b).get_raw(1)[0])->get_execution_range().range == range<3>{96, 1, 1}); +} + template class simple_task;