diff --git a/include/simsycl/detail/parallel_for.hh b/include/simsycl/detail/parallel_for.hh index 3fea9d8..9ada8e9 100644 --- a/include/simsycl/detail/parallel_for.hh +++ b/include/simsycl/detail/parallel_for.hh @@ -64,64 +64,88 @@ template void cooperative_for_nd_range(const sycl::device &device, const sycl::nd_range &range, const std::vector &local_memory, const nd_kernel &kernel); -template -void execute_parallel_for( - const sycl::range &range, const Offset &offset, KernelFunc &&func, Params &&...args) { +template +void execute_parallel_for(const sycl::range &range, const Offset &offset, sycl::kernel_handler kh, + const KernelFunc &func, + Reducers &...reducers) // +{ register_kernel_on_static_construction(); - const simple_kernel> kernel( - [&](const sycl::item &item) { func(item, std::forward(args)...); }); + + simple_kernel> kernel; + if constexpr(std::is_invocable_v>, Reducers &..., + sycl::kernel_handler>) { + kernel = [&](const sycl::item &item) { func(item, reducers..., kh); }; + } else { + static_assert( + std::is_invocable_v>, Reducers &...>); + kernel = [&](const sycl::item &item) { func(item, reducers...); }; + } sequential_for(range, offset, kernel); } -template +template void execute_parallel_for(const sycl::device &device, const sycl::nd_range &range, - const std::vector &local_memory, KernelFunc &&func, Params &&...args) { - const nd_kernel kernel( - [&](const sycl::nd_item &item) { func(item, std::forward(args)...); }); + const std::vector &local_memory, sycl::kernel_handler kh, const KernelFunc &func, + Reducers &...reducers) // +{ register_kernel_on_static_construction(); + + nd_kernel kernel; + if constexpr(std::is_invocable_v, Reducers &..., + sycl::kernel_handler>) { + kernel = [&](const sycl::nd_item &item) { func(item, reducers..., kh); }; + } else { + static_assert(std::is_invocable_v, Reducers &...>); + kernel = [&](const sycl::nd_item &item) { func(item, reducers...); }; + } cooperative_for_nd_range(device, range, local_memory, kernel); } template -void execute_single_task(KernelFunc &&func) { +void execute_single_task(sycl::kernel_handler kh, KernelFunc &&func) { register_kernel_on_static_construction(); - func(); + if constexpr(std::is_invocable_v) { + func(kh); + } else { + static_assert(std::is_invocable_v); + func(); + } } template -void dispatch_parallel_for(const sycl::range &range, ParamTuple &¶ms, +void dispatch_parallel_for(const sycl::range &range, sycl::kernel_handler kh, ParamTuple &¶ms, std::index_sequence /* reduction_indices */, std::index_sequence /* kernel_index */) { auto &kernel_func = std::get(params); - execute_parallel_for(range, no_offset, kernel_func, std::get(params)...); + execute_parallel_for(range, no_offset, kh, kernel_func, std::get(params)...); } -template +template void dispatch_parallel_for(const sycl::device &device, const sycl::nd_range &range, - const std::vector &local_memory, ParamTuple &¶ms, + const std::vector &local_memory, sycl::kernel_handler kh, RestTuple &&rest, std::index_sequence /* reduction_indices */, std::index_sequence /* kernel_index */) { - const auto &kernel_func = std::get(params); - execute_parallel_for(device, range, local_memory, kernel_func, std::get(params)...); + const auto &kernel_func = std::get(rest); + execute_parallel_for(device, range, local_memory, kh, kernel_func, std::get(rest)...); } template 0), int> = 0> -void parallel_for(sycl::range num_work_items, Rest &&...rest) { - dispatch_parallel_for(num_work_items, std::forward_as_tuple(std::forward(rest)...), +void parallel_for(sycl::range num_work_items, sycl::kernel_handler kh, Rest &&...rest) { + dispatch_parallel_for(num_work_items, kh, std::forward_as_tuple(std::forward(rest)...), std::make_index_sequence(), std::index_sequence()); } template -void parallel_for( - sycl::range num_work_items, sycl::id work_item_offset, KernelFunc &&kernel_func) { - execute_parallel_for(num_work_items, work_item_offset, kernel_func); +void parallel_for(sycl::range num_work_items, sycl::id work_item_offset, + sycl::kernel_handler kh, const KernelFunc &kernel_func) { + execute_parallel_for(num_work_items, work_item_offset, kh, kernel_func); } template 0), int> = 0> void parallel_for(const sycl::device &device, sycl::nd_range execution_range, - const std::vector &local_memory, Rest &&...rest) { - detail::dispatch_parallel_for(device, execution_range, local_memory, + const std::vector &local_memory, sycl::kernel_handler kh, Rest &&...rest) { + detail::dispatch_parallel_for(device, execution_range, local_memory, kh, std::forward_as_tuple(std::forward(rest)...), std::make_index_sequence(), std::index_sequence()); } @@ -133,10 +157,20 @@ template template void parallel_for_work_group(const sycl::device &device, sycl::range num_work_groups, std::optional> work_group_size, const std::vector &local_memory, - const WorkgroupFunctionType &kernel_func) { + sycl::kernel_handler kh, const WorkgroupFunctionType &kernel_func) // +{ register_kernel_on_static_construction(); + + hierarchical_kernel kernel; + if constexpr(std::is_invocable_v, sycl::kernel_handler>) { + kernel = [&](const sycl::group &group) { kernel_func(group, kh); }; + } else { + static_assert(std::is_invocable_v>); + kernel = kernel_func; + } + const auto local_allocations = prepare_hierarchical_parallel_for(device, work_group_size, local_memory); - sequential_for_work_group(num_work_groups, work_group_size, hierarchical_kernel(kernel_func)); + sequential_for_work_group(num_work_groups, work_group_size, kernel); } } // namespace simsycl::detail diff --git a/include/simsycl/sycl/forward.hh b/include/simsycl/sycl/forward.hh index 620b323..4014e54 100644 --- a/include/simsycl/sycl/forward.hh +++ b/include/simsycl/sycl/forward.hh @@ -129,6 +129,9 @@ class sampled_image; template class sampled_image_accessor; +template +class specialization_id; + class stream; template diff --git a/include/simsycl/sycl/handler.hh b/include/simsycl/sycl/handler.hh index 96c6751..668355e 100644 --- a/include/simsycl/sycl/handler.hh +++ b/include/simsycl/sycl/handler.hh @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -72,43 +73,45 @@ class handler { template void single_task(const KernelType &kernel_func) { - detail::execute_single_task(kernel_func); + detail::execute_single_task(kernel_handler(this), kernel_func); } template requires(sizeof...(Rest) > 0) void parallel_for(size_t num_work_items, Rest &&...rest) { - detail::parallel_for(range<1>(num_work_items), std::forward(rest)...); + detail::parallel_for(range<1>(num_work_items), kernel_handler(this), std::forward(rest)...); } template requires(sizeof...(Rest) > 0 && Dimensions > 0) void parallel_for(range num_work_items, Rest &&...rest) { - detail::parallel_for(num_work_items, std::forward(rest)...); + detail::parallel_for(num_work_items, kernel_handler(this), std::forward(rest)...); } template SIMSYCL_DETAIL_DEPRECATED_IN_SYCL void parallel_for( range num_work_items, id work_item_offset, KernelType &&kernel_func) { - detail::parallel_for(num_work_items, work_item_offset, kernel_func); + detail::parallel_for(num_work_items, work_item_offset, kernel_handler(this), kernel_func); } template requires(sizeof...(Rest) > 0) void parallel_for(nd_range execution_range, Rest &&...rest) { - detail::parallel_for(m_device, execution_range, m_local_memory, std::forward(rest)...); + detail::parallel_for( + m_device, execution_range, m_local_memory, kernel_handler(this), std::forward(rest)...); } template void parallel_for_work_group(range num_work_groups, const WorkgroupFunctionType &kernel_func) { - detail::parallel_for_work_group(m_device, num_work_groups, {}, m_local_memory, kernel_func); + detail::parallel_for_work_group( + m_device, num_work_groups, {}, m_local_memory, kernel_handler(this), kernel_func); } template void parallel_for_work_group(range num_work_groups, range work_group_size, const WorkgroupFunctionType &kernel_func) { detail::parallel_for_work_group( - m_device, num_work_groups, {work_group_size}, m_local_memory, kernel_func); + m_device, num_work_groups, {work_group_size}, m_local_memory, kernel_handler(this), kernel_func); } void single_task(const kernel &kernel_object) { @@ -206,10 +209,23 @@ class handler { void use_kernel_bundle(const kernel_bundle &exec_bundle); template - void set_specialization_constant(typename std::remove_reference_t::value_type value); + void set_specialization_constant(typename std::remove_reference_t::value_type value) { + static_assert(detail::is_specialization_id_v>); + if(auto existing = find_specialization_constant(this, &SpecName)) { + *existing = value; + } else { + m_specialization_constants.emplace_back(&SpecName, value); + } + } template - typename std::remove_reference_t::value_type get_specialization_constant(); + typename std::remove_reference_t::value_type get_specialization_constant() const { + static_assert(detail::is_specialization_id_v>); + if(auto existing = find_specialization_constant(this, &SpecName)) { + return std::any_cast::value_type>(*existing); + } + return detail::get_specialization_default(SpecName); + } private: friend handler simsycl::detail::make_handler(const sycl::device &device); @@ -217,10 +233,26 @@ class handler { device m_device; std::vector m_local_memory; + std::vector> m_specialization_constants; explicit handler(const device &device) : m_device(device) {} + + static auto find_specialization_constant(auto self, const void *spec_id) + -> decltype(&self->m_specialization_constants[0].second) { + if(const auto it = std::find_if(self->m_specialization_constants.begin(), + self->m_specialization_constants.end(), [&](const auto &pair) { return pair.first == spec_id; }); + it != self->m_specialization_constants.end()) { + return &it->second; + } + return nullptr; + } }; +template +typename std::remove_reference_t::value_type kernel_handler::get_specialization_constant() { + return m_cgh->get_specialization_constant(); +} + } // namespace simsycl::sycl namespace simsycl::detail { diff --git a/include/simsycl/sycl/kernel.hh b/include/simsycl/sycl/kernel.hh index 9c53da7..04fe300 100644 --- a/include/simsycl/sycl/kernel.hh +++ b/include/simsycl/sycl/kernel.hh @@ -12,6 +12,63 @@ #include +namespace simsycl::detail { + +template +struct is_specialization_id : std::false_type {}; + +template +struct is_specialization_id> : std::true_type {}; + +template +inline constexpr bool is_specialization_id_v = is_specialization_id::value; + +template +const T &get_specialization_default(const sycl::specialization_id &id) { + return id.m_default_value; +} + +} // namespace simsycl::detail + +namespace simsycl::sycl { + +template +class specialization_id { + public: + using value_type = T; + + template + explicit constexpr specialization_id(Args &&...args) + requires(std::is_constructible_v) + : m_default_value(std::forward(args)...) {} + + specialization_id(const specialization_id &rhs) = delete; + specialization_id(specialization_id &&rhs) = delete; + specialization_id &operator=(const specialization_id &rhs) = delete; + specialization_id &operator=(specialization_id &&rhs) = delete; + + private: + template + friend const U &detail::get_specialization_default(const sycl::specialization_id &id); + + value_type m_default_value; +}; + +class kernel_handler { + public: + template + typename std::remove_reference_t::value_type get_specialization_constant(); + // implemented in handler.hh + + private: + friend class handler; + explicit kernel_handler(handler *cgh) : m_cgh(cgh) {} + + handler *m_cgh; +}; + +} // namespace simsycl::sycl + namespace simsycl::detail { struct kernel_id_state { @@ -226,12 +283,6 @@ kernel kernel_bundle::get_kernel() const return get_kernel(get_kernel_id()); } -class kernel_handler { - public: - template - typename std::remove_reference_t::value_type get_specialization_constant(); -}; - template kernel_id get_kernel_id() { return detail::get_kernel_id(typeid(KernelName *)); diff --git a/include/simsycl/sycl/queue.hh b/include/simsycl/sycl/queue.hh index 0b2e6fd..19540d0 100644 --- a/include/simsycl/sycl/queue.hh +++ b/include/simsycl/sycl/queue.hh @@ -112,25 +112,19 @@ class queue final : public detail::reference_type, template event single_task(const KernelFunc &kernel_func) { - auto status = detail::event_state::submit_and_start(); - detail::execute_single_task(kernel_func); - return status.end(); + return submit([&](handler &cgh) { cgh.single_task(kernel_func); }); } template event single_task(event dep_event, const KernelType &kernel_func) { (void)dep_event; - auto status = detail::event_state::submit_and_start(); - detail::execute_single_task(kernel_func); - return status.end(); + return submit([&](handler &cgh) { cgh.single_task(kernel_func); }); } template event single_task(const std::vector &dep_events, const KernelType &kernel_func) { (void)dep_events; - auto status = detail::event_state::submit_and_start(); - detail::execute_single_task(kernel_func); - return status.end(); + return submit([&](handler &cgh) { cgh.single_task(kernel_func); }); } template, template 0), int> = 0> event simple_parallel_for(range num_work_items, Rest &&...rest) { - auto status = detail::event_state::submit_and_start(); - simsycl::detail::parallel_for(num_work_items, std::forward(rest)...); - return status.end(); + return submit([&](handler &cgh) { cgh.parallel_for(num_work_items, std::forward(rest)...); }); } template 0), int> = 0> event parallel_for_nd_range(nd_range execution_range, Rest &&...rest) { - auto status = detail::event_state::submit_and_start(); - simsycl::detail::parallel_for(get_device(), execution_range, {}, std::forward(rest)...); - return status.end(); + return submit( + [&](handler &cgh) { cgh.parallel_for(execution_range, std::forward(rest)...); }); } };