diff --git a/src/portfft/committed_descriptor.hpp b/src/portfft/committed_descriptor.hpp new file mode 100644 index 00000000..cfa9e353 --- /dev/null +++ b/src/portfft/committed_descriptor.hpp @@ -0,0 +1,315 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMITTED_DESCRIPTOR_HPP +#define PORTFFT_COMMITTED_DESCRIPTOR_HPP + +#include + +#include +#include + +#include "enums.hpp" + +#include "committed_descriptor_impl.hpp" + +namespace portfft { + +template +class committed_descriptor : private detail::committed_descriptor_impl { + public: + /** + * Alias for `Scalar`. + */ + using scalar_type = Scalar; + + /** + * std::complex with `Scalar` scalar. + */ + using complex_type = std::complex; + + // Use base class constructor + using detail::committed_descriptor_impl::committed_descriptor_impl; + // Use base class function without this-> + using detail::committed_descriptor_impl::dispatch_direction; + + /** + * Computes in-place forward FFT, working on a buffer. + * + * @param inout buffer containing input and output data + */ + void compute_forward(sycl::buffer& inout) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // For now we can just call out-of-place implementation. + // This might need to be changed once we implement support for large sizes that work in global memory. + compute_forward(inout, inout); + } + + /** + * Computes in-place forward FFT, working on buffers. + * + * @param inout_real buffer containing real part of the input and output data + * @param inout_imag buffer containing imaginary part of the input and output data + */ + void compute_forward(sycl::buffer& inout_real, sycl::buffer& inout_imag) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // For now we can just call out-of-place implementation. + // This might need to be changed once we implement support for large sizes that work in global memory. + compute_forward(inout_real, inout_imag, inout_real, inout_imag); + } + + /** + * Computes in-place backward FFT, working on a buffer. + * + * @param inout buffer containing input and output data + */ + void compute_backward(sycl::buffer& inout) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // For now we can just call out-of-place implementation. + // This might need to be changed once we implement support for large sizes that work in global memory. + compute_backward(inout, inout); + } + + /** + * Computes in-place backward FFT, working on buffers. + * + * @param inout_real buffer containing real part of the input and output data + * @param inout_imag buffer containing imaginary part of the input and output data + */ + void compute_backward(sycl::buffer& inout_real, sycl::buffer& inout_imag) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // For now we can just call out-of-place implementation. + // This might need to be changed once we implement support for large sizes that work in global memory. + compute_backward(inout_real, inout_imag, inout_real, inout_imag); + } + + /** + * Computes out-of-place forward FFT, working on buffers. + * + * @param in buffer containing input data + * @param out buffer containing output data + */ + void compute_forward(const sycl::buffer& in, sycl::buffer& out) { + PORTFFT_LOG_FUNCTION_ENTRY(); + dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::FORWARD); + } + + /** + * Computes out-of-place forward FFT, working on buffers. + * + * @param in_real buffer containing real part of the input data + * @param in_imag buffer containing imaginary part of the input data + * @param out_real buffer containing real part of the output data + * @param out_imag buffer containing imaginary part of the output data + */ + void compute_forward(const sycl::buffer& in_real, const sycl::buffer& in_imag, + sycl::buffer& out_real, sycl::buffer& out_imag) { + PORTFFT_LOG_FUNCTION_ENTRY(); + dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::FORWARD); + } + + /** + * Computes out-of-place forward FFT, working on buffers. + * + * @param in buffer containing input data + * @param out buffer containing output data + */ + void compute_forward(const sycl::buffer& /*in*/, sycl::buffer& /*out*/) { + PORTFFT_LOG_FUNCTION_ENTRY(); + throw unsupported_configuration("Real to complex FFTs not yet implemented."); + } + + /** + * Compute out of place backward FFT, working on buffers + * + * @param in buffer containing input data + * @param out buffer containing output data + */ + void compute_backward(const sycl::buffer& in, sycl::buffer& out) { + PORTFFT_LOG_FUNCTION_ENTRY(); + dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::BACKWARD); + } + + /** + * Compute out of place backward FFT, working on buffers + * + * @param in_real buffer containing real part of the input data + * @param in_imag buffer containing imaginary part of the input data + * @param out_real buffer containing real part of the output data + * @param out_imag buffer containing imaginary part of the output data + */ + void compute_backward(const sycl::buffer& in_real, const sycl::buffer& in_imag, + sycl::buffer& out_real, sycl::buffer& out_imag) { + PORTFFT_LOG_FUNCTION_ENTRY(); + dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::BACKWARD); + } + + /** + * Computes in-place forward FFT, working on USM memory. + * + * @param inout USM pointer to memory containing input and output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_forward(complex_type* inout, const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // For now we can just call out-of-place implementation. + // This might need to be changed once we implement support for large sizes that work in global memory. + return compute_forward(inout, inout, dependencies); + } + + /** + * Computes in-place forward FFT, working on USM memory. + * + * @param inout_real USM pointer to memory containing real part of the input and output data + * @param inout_imag USM pointer to memory containing imaginary part of the input and output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_forward(scalar_type* inout_real, scalar_type* inout_imag, + const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // For now we can just call out-of-place implementation. + // This might need to be changed once we implement support for large sizes that work in global memory. + return compute_forward(inout_real, inout_imag, inout_real, inout_imag, dependencies); + } + + /** + * Computes in-place forward FFT, working on USM memory. + * + * @param inout USM pointer to memory containing input and output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_forward(Scalar* inout, const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // For now we can just call out-of-place implementation. + // This might need to be changed once we implement support for large sizes that work in global memory. + return compute_forward(inout, reinterpret_cast(inout), dependencies); + } + + /** + * Computes in-place backward FFT, working on USM memory. + * + * @param inout USM pointer to memory containing input and output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_backward(complex_type* inout, const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return compute_backward(inout, inout, dependencies); + } + + /** + * Computes in-place backward FFT, working on USM memory. + * + * @param inout_real USM pointer to memory containing real part of the input and output data + * @param inout_imag USM pointer to memory containing imaginary part of the input and output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_backward(scalar_type* inout_real, scalar_type* inout_imag, + const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return compute_backward(inout_real, inout_imag, inout_real, inout_imag, dependencies); + } + + /** + * Computes out-of-place forward FFT, working on USM memory. + * + * @param in USM pointer to memory containing input data + * @param out USM pointer to memory containing output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_forward(const complex_type* in, complex_type* out, + const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::FORWARD, dependencies); + } + + /** + * Computes out-of-place forward FFT, working on USM memory. + * + * @param in_real USM pointer to memory containing real part of the input data + * @param in_imag USM pointer to memory containing imaginary part of the input data + * @param out_real USM pointer to memory containing real part of the output data + * @param out_imag USM pointer to memory containing imaginary part of the output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_forward(const scalar_type* in_real, const scalar_type* in_imag, scalar_type* out_real, + scalar_type* out_imag, const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::FORWARD, + dependencies); + } + + /** + * Computes out-of-place forward FFT, working on USM memory. + * + * @param in USM pointer to memory containing input data + * @param out USM pointer to memory containing output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_forward(const Scalar* /*in*/, complex_type* /*out*/, + const std::vector& /*dependencies*/ = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + throw unsupported_configuration("Real to complex FFTs not yet implemented."); + return {}; + } + + /** + * Computes out-of-place backward FFT, working on USM memory. + * + * @param in USM pointer to memory containing input data + * @param out USM pointer to memory containing output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_backward(const complex_type* in, complex_type* out, + const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::BACKWARD, + dependencies); + } + + /** + * Computes out-of-place backward FFT, working on USM memory. + * + * @param in_real USM pointer to memory containing real part of the input data + * @param in_imag USM pointer to memory containing imaginary part of the input data + * @param out_real USM pointer to memory containing real part of the output data + * @param out_imag USM pointer to memory containing imaginary part of the output data + * @param dependencies events that must complete before the computation + * @return sycl::event associated with this computation + */ + sycl::event compute_backward(const scalar_type* in_real, const scalar_type* in_imag, scalar_type* out_real, + scalar_type* out_imag, const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::BACKWARD, + dependencies); + } +}; + +} // namespace portfft + +#endif diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp new file mode 100644 index 00000000..e7d66328 --- /dev/null +++ b/src/portfft/committed_descriptor_impl.hpp @@ -0,0 +1,1299 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Codeplay's portFFT + * + **************************************************************************/ + +#ifndef PORTFFT_COMMITTED_DESCRIPTOR_IMPL_HPP +#define PORTFFT_COMMITTED_DESCRIPTOR_IMPL_HPP + +#include + +#include +#include +#include +#include +#include + +#include "common/exceptions.hpp" +#include "common/subgroup.hpp" +#include "defines.hpp" +#include "enums.hpp" +#include "specialization_constant.hpp" +#include "utils.hpp" + +namespace portfft { + +template +struct descriptor; + +namespace detail { + +template +class committed_descriptor_impl; + +template +std::vector compute_level( + const typename committed_descriptor_impl::kernel_data_struct& kd_struct, TIn input, Scalar* output, + TIn input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, + IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, + IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, + Idx total_factors, complex_storage storage, const std::vector& dependencies, sycl::queue& queue); + +template +sycl::event transpose_level(const typename committed_descriptor_impl::kernel_data_struct& kd_struct, + const Scalar* input, TOut output, const IdxGlobal* factors_triple, IdxGlobal committed_size, + Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx total_factors, + IdxGlobal output_offset, sycl::queue& queue, const std::vector& events, + complex_storage storage); + +// kernel names +// TODO: Remove all templates except Scalar, Domain and Memory and SubgroupSize +template +class workitem_kernel; +template +class subgroup_kernel; +template +class workgroup_kernel; +template +class global_kernel; +template +class transpose_kernel; + +/** + * Return the default strides for a given dft size + * + * @param lengths the dimensions of the dft + */ +inline std::vector get_default_strides(const std::vector& lengths) { + PORTFFT_LOG_FUNCTION_ENTRY(); + std::vector strides(lengths.size()); + std::size_t total_size = 1; + for (std::size_t i_plus1 = lengths.size(); i_plus1 > 0; i_plus1--) { + std::size_t i = i_plus1 - 1; + strides[i] = total_size; + total_size *= lengths[i]; + } + PORTFFT_LOG_TRACE("Default strides:", strides); + return strides; +} + +/** + * Return whether the given descriptor has default strides and distance for a given direction + * + * @tparam Descriptor Descriptor type + * @param desc Descriptor to check + * @param dir Direction + */ +template +bool has_default_strides_and_distance(const Descriptor& desc, direction dir) { + const auto default_strides = get_default_strides(desc.lengths); + const auto default_distance = desc.get_flattened_length(); + return desc.get_strides(dir) == default_strides && desc.get_distance(dir) == default_distance; +} + +/** + * Return whether the given descriptor has strides and distance consistent with the batch interleaved layout + * + * @tparam Descriptor Descriptor type + * @param desc Descriptor to check + * @param dir Direction + */ +template +bool is_batch_interleaved(const Descriptor& desc, direction dir) { + return desc.lengths.size() == 1 && desc.get_distance(dir) == 1 && + desc.get_strides(dir).back() == desc.number_of_transforms; +} + +/** + * Return an enum describing the layout of the data in the descriptor + * + * @tparam Descriptor Descriptor type + * @param desc Descriptor to check + * @param dir Direction + */ +template +detail::layout get_layout(const Descriptor& desc, direction dir) { + if (has_default_strides_and_distance(desc, dir)) { + return detail::layout::PACKED; + } + if (is_batch_interleaved(desc, dir)) { + return detail::layout::BATCH_INTERLEAVED; + } + return detail::layout::UNPACKED; +} + +/* +Compute functions in the `committed_descriptor_impl` call `dispatch_kernel` and `dispatch_kernel_helper`. These two +functions ensure the kernel is run with a supported subgroup size. Next `dispatch_kernel_helper` calls `run_kernel`. The +`run_kernel` member function picks appropriate implementation and calls the static `run_kernel of that implementation`. +The implementation specific `run_kernel` handles differences between forward and backward computations, casts the memory +(USM or buffers) from complex to scalars and launches the kernel. Each function described in this doc has only one +templated overload that handles both directions of transforms and buffer and USM memory. + +Device functions make no assumptions on the size of a work group or the number of workgroups in a kernel. These numbers +can be tuned for each device. + +Implementation-specific `run_kernel` function make the size of the FFT that is handled by the individual workitems +compile time constant. The one for subgroup implementation also calls `cross_sg_dispatcher` that makes the +cross-subgroup factor of FFT size compile time constant. They do that by using a switch on the FFT size for one +workitem, before calling `workitem_impl`, `subgroup_impl` or `workgroup_impl` . The `_impl` functions take the FFT size +for one workitem as a template parameter. Only the calls that are determined to fit into available registers (depending +on the value of PORTFFT_TARGET_REGS_PER_WI macro) are actually instantiated. + +The `_impl` functions iterate over the batch of problems, loading data for each first in +local memory then from there into private one. This is done in these two steps to avoid non-coalesced global memory +accesses. `workitem_impl` loads one problem per workitem, `subgroup_impl` loads one problem per subgroup and +`workgroup_impl` loads one problem per workgroup. After doing computations by the calls to `wi_dft` for workitem, +`sg_dft` for subgroup and `wg_dft` for workgroup, the data is written out, going through local memory again. + +The computational parts of the implementations are further documented in files with their implementations +`workitem.hpp`, `subgroup.hpp` and `workgroup.hpp`. +*/ + +/** + * A committed descriptor that contains everything that is needed to run FFT. + * + * @tparam Scalar type of the scalar used for computations + * @tparam Domain domain of the FFT + */ +template +class committed_descriptor_impl { + friend struct descriptor; + template + friend std::vector detail::compute_level( + const typename committed_descriptor_impl::kernel_data_struct& kd_struct, TIn input, + Scalar1* output, TIn input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr, + const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, + IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, + IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, + const std::vector& dependencies, sycl::queue& queue); + + template + friend sycl::event detail::transpose_level( + const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const Scalar1* input, + TOut output, const IdxGlobal* factors_triple, IdxGlobal committed_size, Idx num_batches_in_l2, + IdxGlobal n_transforms, IdxGlobal batch_start, Idx total_factors, IdxGlobal output_offset, sycl::queue& queue, + const std::vector& events, complex_storage storage); + + /** + * Vector containing the sub-implementation level, kernel_ids and factors for each factor that requires a separate + * kernel. + */ + using kernel_ids_and_metadata_t = + std::vector, std::vector>>; + descriptor params; + sycl::queue queue; + sycl::device dev; + sycl::context ctx; + Idx n_compute_units; + std::vector supported_sg_sizes; + Idx local_memory_size; + IdxGlobal llc_size; + std::shared_ptr scratch_ptr_1; + std::shared_ptr scratch_ptr_2; + std::size_t scratch_space_required; + + struct kernel_data_struct { + sycl::kernel_bundle exec_bundle; + std::vector factors; + std::size_t length; + Idx used_sg_size; + Idx num_sgs_per_wg; + std::shared_ptr twiddles_forward; + detail::level level; + IdxGlobal batch_size; + std::size_t local_mem_required; + IdxGlobal global_range; + IdxGlobal local_range; + + kernel_data_struct(sycl::kernel_bundle&& exec_bundle, + const std::vector& factors, std::size_t length, Idx used_sg_size, Idx num_sgs_per_wg, + std::shared_ptr twiddles_forward, detail::level level) + : exec_bundle(std::move(exec_bundle)), + factors(factors), + length(length), + used_sg_size(used_sg_size), + num_sgs_per_wg(num_sgs_per_wg), + twiddles_forward(twiddles_forward), + level(level) {} + }; + + struct dimension_struct { + std::vector forward_kernels; + std::vector backward_kernels; + std::vector transpose_kernels; + std::shared_ptr factors_and_scan; + detail::level level; + std::size_t length; + Idx used_sg_size; + Idx num_batches_in_l2; + Idx num_factors; + + dimension_struct(std::vector forward_kernels, std::vector backward_kernels, + detail::level level, std::size_t length, Idx used_sg_size) + : forward_kernels(std::move(forward_kernels)), + backward_kernels(std::move(backward_kernels)), + level(level), + length(length), + used_sg_size(used_sg_size) {} + }; + + std::vector dimensions; + + template + auto dispatch(detail::level level, Args&&... args) { + switch (level) { + case detail::level::WORKITEM: + return Impl::template inner::execute(*this, args...); + case detail::level::SUBGROUP: + return Impl::template inner::execute(*this, args...); + case detail::level::WORKGROUP: + return Impl::template inner::execute(*this, args...); + case detail::level::GLOBAL: + return Impl::template inner::execute(*this, args...); + default: + // This should be unreachable + throw unsupported_configuration("Unimplemented"); + } + } + + template + auto dispatch(detail::level level, Args&&... args) { + switch (level) { + case detail::level::WORKITEM: + return Impl::template inner::execute(*this, args...); + case detail::level::SUBGROUP: + return Impl::template inner::execute(*this, args...); + case detail::level::WORKGROUP: + return Impl::template inner::execute(*this, args...); + case detail::level::GLOBAL: + return Impl::template inner::execute(*this, args...); + default: + // This should be unreachable + throw unsupported_configuration("Unimplemented"); + } + } + + template + auto dispatch(detail::level level, Args&&... args) { + switch (level) { + case detail::level::WORKITEM: + return Impl::template inner::execute(*this, + args...); + case detail::level::SUBGROUP: + return Impl::template inner::execute(*this, + args...); + case detail::level::WORKGROUP: + return Impl::template inner::execute( + *this, args...); + case detail::level::GLOBAL: + return Impl::template inner::execute(*this, + args...); + default: + // This should be unreachable + throw unsupported_configuration("Unimplemented"); + } + } + + /** + * Prepares the implementation for the particular problem size. That includes factorizing it and getting ids for the + * set of kernels that need to be JIT compiled. + * + * @tparam SubgroupSize size of the subgroup + * @param kernel_num the consecutive number of the kernel to prepare + * @return implementation to use for the dimension and a vector of tuples of: implementation to use for a kernel, + * vector of kernel ids, factors + */ + template + std::tuple prepare_implementation(std::size_t kernel_num) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // TODO: check and support all the parameter values + if constexpr (Domain != domain::COMPLEX) { + throw unsupported_configuration("portFFT only supports complex to complex transforms"); + } + + std::vector ids; + std::vector factors; + IdxGlobal fft_size = static_cast(params.lengths[kernel_num]); + if (detail::fits_in_wi(fft_size)) { + ids = detail::get_ids(); + PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size); + return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}}; + } + if (detail::fits_in_sg(fft_size, SubgroupSize)) { + Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); + Idx factor_wi = static_cast(fft_size) / factor_sg; + // This factorization is duplicated in the dispatch logic on the device. + // The CT and spec constant factors should match. + factors.push_back(factor_wi); + factors.push_back(factor_sg); + ids = detail::get_ids(); + PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg); + return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}}; + } + IdxGlobal n_idx_global = detail::factorize(fft_size); + if (detail::can_cast_safely(n_idx_global) && + detail::can_cast_safely(fft_size / n_idx_global)) { + if (n_idx_global == 1) { + throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported"); + } + Idx n = static_cast(n_idx_global); + Idx m = static_cast(fft_size / n_idx_global); + Idx factor_sg_n = detail::factorize_sg(n, SubgroupSize); + Idx factor_wi_n = n / factor_sg_n; + Idx factor_sg_m = detail::factorize_sg(m, SubgroupSize); + Idx factor_wi_m = m / factor_sg_m; + Idx temp_num_sgs_in_wg; + std::size_t local_memory_usage = num_scalars_in_local_mem( + detail::level::WORKGROUP, static_cast(fft_size), SubgroupSize, + {factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg) * + sizeof(Scalar); + // Checks for PACKED layout only at the moment, as the other layout will not be supported + // by the global implementation. For such sizes, only PACKED layout will be supported + if (detail::fits_in_wi(factor_wi_n) && detail::fits_in_wi(factor_wi_m) && + (local_memory_usage <= static_cast(local_memory_size))) { + factors.push_back(factor_wi_n); + factors.push_back(factor_sg_n); + factors.push_back(factor_wi_m); + factors.push_back(factor_sg_m); + // This factorization of N and M is duplicated in the dispatch logic on the device. + // The CT and spec constant factors should match. + ids = detail::get_ids(); + PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n, + " factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m); + return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}}; + } + } + PORTFFT_LOG_TRACE("Preparing global impl"); + std::vector, std::vector>> param_vec; + auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool { + if (detail::fits_in_wi(factor_size)) { + // Throughout we have assumed there would always be enough local memory for the WI implementation. + param_vec.emplace_back(detail::level::WORKITEM, + detail::get_ids(), + std::vector{static_cast(factor_size)}); + PORTFFT_LOG_TRACE("Workitem kernel for factor:", factor_size); + return true; + } + bool fits_in_local_memory_subgroup = [&]() { + Idx temp_num_sgs_in_wg; + IdxGlobal factor_sg = detail::factorize_sg(factor_size, SubgroupSize); + IdxGlobal factor_wi = factor_size / factor_sg; + if (detail::can_cast_safely(factor_sg) && detail::can_cast_safely(factor_wi)) { + if (batch_interleaved_layout) { + return (2 * + num_scalars_in_local_mem( + detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, + {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * + sizeof(Scalar) + + 2 * static_cast(factor_size) * sizeof(Scalar)) < + static_cast(local_memory_size); + } + return (num_scalars_in_local_mem( + detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, + {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * + sizeof(Scalar) + + 2 * static_cast(factor_size) * sizeof(Scalar)) < + static_cast(local_memory_size); + } + return false; + }(); + if (detail::fits_in_sg(factor_size, SubgroupSize) && fits_in_local_memory_subgroup && + !PORTFFT_SLOW_SG_SHUFFLES) { + Idx factor_sg = detail::factorize_sg(static_cast(factor_size), SubgroupSize); + Idx factor_wi = static_cast(factor_size) / factor_sg; + PORTFFT_LOG_TRACE("Subgroup kernel for factor:", factor_size, "with factor_wi:", factor_wi, + "and factor_sg:", factor_sg); + param_vec.emplace_back(detail::level::SUBGROUP, + detail::get_ids(), + std::vector{factor_sg, factor_wi}); + return true; + } + return false; + }; + detail::factorize_input(fft_size, check_and_select_target_level); + return {detail::level::GLOBAL, param_vec}; + } + + /** + * Struct for dispatching `set_spec_constants()` call. + */ + struct set_spec_constants_struct { + // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class + template + struct inner { + static void execute(committed_descriptor_impl& desc, sycl::kernel_bundle& in_bundle, + std::size_t length, const std::vector& factors, detail::level level, Idx factor_num, + Idx num_factors); + }; + }; + + /** + * Sets the implementation dependant specialization constant value + * @param top_level implementation to dispatch to + * @param in_bundle input kernel bundle to set spec constants for + * @param length length of the fft + * @param factors factors of the corresponsing length + * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation + * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation + * @param scale_factor_applied whether or not to multiply scale factor + * @param level sub implementation to run which will be set as a spec constant + * @param conjugate_on_load whether or not to take conjugate of the input + * @param conjugate_on_store whether or not to take conjugate of the output + * @param scale_factor Scale to be applied to the result + * @param factor_num factor number which is set as a spec constant + * @param num_factors total number of factors of the committed size, set as a spec constant + */ + void set_spec_constants(detail::level top_level, sycl::kernel_bundle& in_bundle, + std::size_t length, const std::vector& factors, + detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, + detail::apply_scale_factor scale_factor_applied, detail::level level, + detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store, + Scalar scale_factor, Idx factor_num = 0, Idx num_factors = 0) { + PORTFFT_LOG_FUNCTION_ENTRY(); + const Idx length_idx = static_cast(length); + // These spec constants are used in all implementations, so we set them here + PORTFFT_LOG_TRACE("Setting specialization constants:"); + PORTFFT_LOG_TRACE("SpecConstComplexStorage:", params.complex_storage); + in_bundle.template set_specialization_constant(params.complex_storage); + PORTFFT_LOG_TRACE("SpecConstNumRealsPerFFT:", 2 * length_idx); + in_bundle.template set_specialization_constant(2 * length_idx); + PORTFFT_LOG_TRACE("SpecConstWIScratchSize:", 2 * detail::wi_temps(length_idx)); + in_bundle.template set_specialization_constant(2 * detail::wi_temps(length_idx)); + PORTFFT_LOG_TRACE("SpecConstMultiplyOnLoad:", multiply_on_load); + in_bundle.template set_specialization_constant(multiply_on_load); + PORTFFT_LOG_TRACE("SpecConstMultiplyOnStore:", multiply_on_store); + in_bundle.template set_specialization_constant(multiply_on_store); + PORTFFT_LOG_TRACE("SpecConstApplyScaleFactor:", scale_factor_applied); + in_bundle.template set_specialization_constant(scale_factor_applied); + PORTFFT_LOG_TRACE("SpecConstConjugateOnLoad:", conjugate_on_load); + in_bundle.template set_specialization_constant(conjugate_on_load); + PORTFFT_LOG_TRACE("SpecConstConjugateOnStore:", conjugate_on_store); + in_bundle.template set_specialization_constant(conjugate_on_store); + PORTFFT_LOG_TRACE("get_spec_constant_scale:", scale_factor); + in_bundle.template set_specialization_constant()>(scale_factor); + + dispatch(top_level, in_bundle, length, factors, level, factor_num, num_factors); + } + + /** + * Struct for dispatching `num_scalars_in_local_mem()` call. + */ + struct num_scalars_in_local_mem_struct { + // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class + template + struct inner { + static std::size_t execute(committed_descriptor_impl& desc, std::size_t length, Idx used_sg_size, + const std::vector& factors, Idx& num_sgs_per_wg); + }; + }; + + /** + * Determine the number of scalars we need to have space for in the local memory. It may also modify `num_sgs_per_wg` + * to make the problem fit in the local memory. + * + * @param level the implementation that will be used + * @param length length of the FFT the kernel will execute + * @param used_sg_size subgroup size the kernel will use + * @param factors factorization of the FFT size the kernel will use + * @param[out] num_sgs_per_wg number of subgroups in a workgroup + * @return the number of scalars + */ + template + std::size_t num_scalars_in_local_mem(detail::level level, std::size_t length, Idx used_sg_size, + const std::vector& factors, Idx& num_sgs_per_wg) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return dispatch(level, length, used_sg_size, factors, num_sgs_per_wg); + } + + /** + * Struct for dispatching `calculate_twiddles()` call. + */ + struct calculate_twiddles_struct { + // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class + template + struct inner { + static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& dimension_data, + std::vector& kernels); + }; + }; + + /** + * Calculates twiddle factors for the implementation in use. + * @param level Implementation selected for the committed size + * @param dimension_data dimension_struct correspoding to the dimension for which twiddles are being calculated + * @param kernels vector of kernels + * @return Scalar* USM pointer to the twiddle factors + */ + Scalar* calculate_twiddles(detail::level level, dimension_struct& dimension_data, + std::vector& kernels) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return dispatch(level, dimension_data, kernels); + } + + /** + * Sets the specialization constants for all the kernel_ids contained in the vector + * returned from prepare_implementation + * @tparam SubgroupSize Subgroup size + * @param top_level selected level of implementation + * @param prepared_vec vector of tuples of: implementation to use for a kernel, + * vector of kernel ids, factors + * @param compute_direction direction of compute: forward or backward + * @param dimension_num which dimension are the kernels being built for + * @param skip_scaling whether or not to skip scaling + * @return vector of kernel_data_struct if all kernel builds are successful, std::nullopt otherwise + */ + template + std::optional> set_spec_constants_driver(detail::level top_level, + kernel_ids_and_metadata_t& prepared_vec, + direction compute_direction, + std::size_t dimension_num, + bool skip_scaling) { + Scalar scale_factor = compute_direction == direction::FORWARD ? params.forward_scale : params.backward_scale; + detail::apply_scale_factor scale_factor_applied = detail::apply_scale_factor::APPLIED; + bool is_compatible = true; + if (skip_scaling) { + scale_factor_applied = detail::apply_scale_factor::NOT_APPLIED; + } + std::size_t counter = 0; + auto conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; + auto conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; + std::vector result; + for (auto& [level, ids, factors] : prepared_vec) { + auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); + if (top_level == detail::level::GLOBAL) { + std::size_t factor_size = + static_cast(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())); + if (counter == prepared_vec.size() - 1) { + if (compute_direction == direction::BACKWARD) { + conjugate_on_store = detail::complex_conjugate::APPLIED; + } + set_spec_constants(detail::level::GLOBAL, in_bundle, factor_size, factors, + detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, + detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store, + scale_factor, static_cast(counter), static_cast(prepared_vec.size())); + // reset conjugate_on_store + conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; + } else { + if (counter == 0 && compute_direction == direction::BACKWARD) { + conjugate_on_load = detail::complex_conjugate::APPLIED; + } + set_spec_constants(detail::level::GLOBAL, in_bundle, factor_size, factors, + detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED, + detail::apply_scale_factor::NOT_APPLIED, level, conjugate_on_load, conjugate_on_store, + scale_factor, static_cast(counter), static_cast(prepared_vec.size())); + // reset conjugate_on_load + conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; + } + } else { + if (compute_direction == direction::BACKWARD) { + conjugate_on_load = detail::complex_conjugate::APPLIED; + conjugate_on_store = detail::complex_conjugate::APPLIED; + } + set_spec_constants(level, in_bundle, params.lengths[dimension_num], factors, + detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, + scale_factor_applied, level, conjugate_on_load, conjugate_on_store, scale_factor); + } + try { + PORTFFT_LOG_TRACE("Building kernel bundle with subgroup size", SubgroupSize); + result.emplace_back(sycl::build(in_bundle), factors, params.lengths[dimension_num], SubgroupSize, + PORTFFT_SGS_IN_WG, std::shared_ptr(), level); + PORTFFT_LOG_TRACE("Kernel bundle build complete."); + } catch (std::exception& e) { + PORTFFT_LOG_WARNING("Build for subgroup size", SubgroupSize, "failed with message:\n", e.what()); + is_compatible = false; + break; + } + counter++; + } + if (is_compatible) { + return result; + } + return std::nullopt; + } + + /** + * Builds the kernel bundles with appropriate values of specialization constants for the first supported subgroup + * size. + * + * @tparam SubgroupSize first subgroup size + * @tparam OtherSGSizes other subgroup sizes + * @param dimension_num The dimension for which the kernels are being built + * @param skip_scaling whether or not to skip scaling + * @return `dimension_struct` for the newly built kernels + */ + template + dimension_struct build_w_spec_const(std::size_t dimension_num, bool skip_scaling) { + PORTFFT_LOG_FUNCTION_ENTRY(); + if (std::count(supported_sg_sizes.begin(), supported_sg_sizes.end(), SubgroupSize)) { + auto [top_level, prepared_vec] = prepare_implementation(dimension_num); + bool is_compatible = true; + for (auto [level, ids, factors] : prepared_vec) { + is_compatible = is_compatible && sycl::is_compatible(ids, dev); + if (!is_compatible) { + break; + } + } + + if (is_compatible) { + auto forward_kernels = set_spec_constants_driver(top_level, prepared_vec, direction::FORWARD, + dimension_num, skip_scaling); + auto backward_kernels = set_spec_constants_driver(top_level, prepared_vec, direction::BACKWARD, + dimension_num, skip_scaling); + if (forward_kernels.has_value() && backward_kernels.has_value()) { + return {forward_kernels.value(), backward_kernels.value(), top_level, params.lengths[dimension_num], + SubgroupSize}; + } + } + } + if constexpr (sizeof...(OtherSGSizes) == 0) { + throw invalid_configuration("None of the compiled subgroup sizes are supported by the device"); + } else { + return build_w_spec_const(dimension_num, skip_scaling); + } + } + + /** + * Function which calculates the amount of scratch space required, and also pre computes the necessary scans required. + * @param num_global_level_dimensions number of global level dimensions in the committed size + */ + void allocate_scratch_and_precompute_scan(Idx num_global_level_dimensions) { + PORTFFT_LOG_FUNCTION_ENTRY(); + std::size_t n_kernels = params.lengths.size(); + if (num_global_level_dimensions == 1) { + std::size_t global_dimension = 0; + for (std::size_t i = 0; i < n_kernels; i++) { + if (dimensions.at(i).level == detail::level::GLOBAL) { + global_dimension = i; + break; + } + } + std::vector factors; + std::vector sub_batches; + std::vector inclusive_scan; + std::size_t cache_required_for_twiddles = 0; + for (const auto& kernel_data : dimensions.at(global_dimension).forward_kernels) { + IdxGlobal factor_size = static_cast( + std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); + cache_required_for_twiddles += + static_cast(2 * factor_size * kernel_data.batch_size) * sizeof(Scalar); + factors.push_back(factor_size); + sub_batches.push_back(kernel_data.batch_size); + } + dimensions.at(global_dimension).num_factors = static_cast(factors.size()); + std::size_t cache_space_left_for_batches = static_cast(llc_size) - cache_required_for_twiddles; + // TODO: In case of multi-dim (single dim global sized), this should be batches corresponding to that dim + dimensions.at(global_dimension).num_batches_in_l2 = static_cast(std::min( + static_cast(PORTFFT_MAX_CONCURRENT_KERNELS), + std::min(params.number_of_transforms, + std::max(std::size_t(1), cache_space_left_for_batches / + (2 * dimensions.at(global_dimension).length * sizeof(Scalar)))))); + scratch_space_required = 2 * dimensions.at(global_dimension).length * + static_cast(dimensions.at(global_dimension).num_batches_in_l2); + PORTFFT_LOG_TRACE("Allocating 2 scratch arrays of size", scratch_space_required, "scalars in global memory"); + scratch_ptr_1 = detail::make_shared(scratch_space_required, queue); + scratch_ptr_2 = detail::make_shared(scratch_space_required, queue); + inclusive_scan.push_back(factors.at(0)); + for (std::size_t i = 1; i < factors.size(); i++) { + inclusive_scan.push_back(inclusive_scan.at(i - 1) * factors.at(i)); + } + PORTFFT_LOG_TRACE("Dimension:", global_dimension, + "num_batches_in_l2:", dimensions.at(global_dimension).num_batches_in_l2, + "scan:", inclusive_scan); + dimensions.at(global_dimension).factors_and_scan = + detail::make_shared(factors.size() + sub_batches.size() + inclusive_scan.size(), queue); + queue.copy(factors.data(), dimensions.at(global_dimension).factors_and_scan.get(), factors.size()); + queue.copy(sub_batches.data(), dimensions.at(global_dimension).factors_and_scan.get() + factors.size(), + sub_batches.size()); + queue.copy(inclusive_scan.data(), + dimensions.at(global_dimension).factors_and_scan.get() + factors.size() + sub_batches.size(), + inclusive_scan.size()); + queue.wait(); + // build transpose kernels + std::size_t num_transposes_required = factors.size() - 1; + for (std::size_t i = 0; i < num_transposes_required; i++) { + std::vector ids; + auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), + detail::get_transpose_kernel_ids()); + PORTFFT_LOG_TRACE("Setting specialization constants for transpose kernel", i); + PORTFFT_LOG_TRACE("SpecConstComplexStorage:", params.complex_storage); + in_bundle.template set_specialization_constant(params.complex_storage); + PORTFFT_LOG_TRACE("GlobalSpecConstLevelNum:", i); + in_bundle.template set_specialization_constant(static_cast(i)); + PORTFFT_LOG_TRACE("GlobalSpecConstNumFactors:", factors.size()); + in_bundle.template set_specialization_constant( + static_cast(factors.size())); + dimensions.at(global_dimension) + .transpose_kernels.emplace_back( + sycl::build(in_bundle), + std::vector{static_cast(factors.at(i)), static_cast(sub_batches.at(i))}, 1, 1, 1, + std::shared_ptr(), detail::level::GLOBAL); + } + } else { + std::size_t max_encountered_global_size = 0; + for (std::size_t i = 0; i < n_kernels; i++) { + if (dimensions.at(i).level == detail::level::GLOBAL) { + max_encountered_global_size = max_encountered_global_size > dimensions.at(i).length + ? max_encountered_global_size + : dimensions.at(i).length; + } + } + // TODO: max_scratch_size should be max(global_size_1 * corresponding_batches_in_l2, global_size_1 * + // corresponding_batches_in_l2), in the case of multi-dim global FFTs. + scratch_space_required = 2 * max_encountered_global_size * params.number_of_transforms; + scratch_ptr_1 = detail::make_shared(scratch_space_required, queue); + scratch_ptr_2 = detail::make_shared(scratch_space_required, queue); + for (std::size_t i = 0; i < n_kernels; i++) { + if (dimensions.at(i).level == detail::level::GLOBAL) { + std::vector factors; + std::vector sub_batches; + std::vector inclusive_scan; + for (const auto& kernel_data : dimensions.at(i).forward_kernels) { + IdxGlobal factor_size = static_cast( + std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); + factors.push_back(factor_size); + sub_batches.push_back(kernel_data.batch_size); + } + inclusive_scan.push_back(factors.at(0)); + for (std::size_t j = 1; j < factors.size(); j++) { + inclusive_scan.push_back(inclusive_scan.at(j - 1) * factors.at(j)); + } + dimensions.at(i).num_factors = static_cast(factors.size()); + dimensions.at(i).factors_and_scan = + detail::make_shared(factors.size() + sub_batches.size() + inclusive_scan.size(), queue); + queue.copy(factors.data(), dimensions.at(i).factors_and_scan.get(), factors.size()); + queue.copy(sub_batches.data(), dimensions.at(i).factors_and_scan.get() + factors.size(), sub_batches.size()); + queue.copy(inclusive_scan.data(), + dimensions.at(i).factors_and_scan.get() + factors.size() + sub_batches.size(), + inclusive_scan.size()); + queue.wait(); + // build transpose kernels + std::size_t num_transposes_required = factors.size() - 1; + for (std::size_t j = 0; j < num_transposes_required; j++) { + auto in_bundle = sycl::get_kernel_bundle( + queue.get_context(), detail::get_transpose_kernel_ids()); + PORTFFT_LOG_TRACE("Setting specilization constants for transpose kernel", j); + PORTFFT_LOG_TRACE("GlobalSpecConstLevelNum:", i); + in_bundle.template set_specialization_constant(static_cast(i)); + PORTFFT_LOG_TRACE("GlobalSpecConstNumFactors:", factors.size()); + in_bundle.template set_specialization_constant( + static_cast(factors.size())); + dimensions.at(i).transpose_kernels.emplace_back( + sycl::build(in_bundle), + std::vector{static_cast(factors.at(j)), static_cast(sub_batches.at(j))}, 1, 1, 1, + std::shared_ptr(), detail::level::GLOBAL); + } + } + } + } + } + + /** + * Constructor. + * + * @param params descriptor this is created from + * @param queue queue to use when enqueueing device work + */ + committed_descriptor_impl(const descriptor& params, sycl::queue& queue) + : params(params), + queue(queue), + dev(queue.get_device()), + ctx(queue.get_context()), + // get some properties we will use for tunning + n_compute_units(static_cast(dev.get_info())), + supported_sg_sizes(dev.get_info()), + local_memory_size(static_cast(queue.get_device().get_info())), + llc_size(static_cast(queue.get_device().get_info())) { + PORTFFT_LOG_FUNCTION_ENTRY(); + PORTFFT_LOG_TRACE("Device info:"); + PORTFFT_LOG_TRACE("n_compute_units:", n_compute_units); + PORTFFT_LOG_TRACE("supported_sg_sizes:", supported_sg_sizes); + PORTFFT_LOG_TRACE("local_memory_size:", local_memory_size); + PORTFFT_LOG_TRACE("llc_size:", llc_size); + + // check it's suitable to run + const auto forward_layout = detail::get_layout(params, direction::FORWARD); + const auto backward_layout = detail::get_layout(params, direction::BACKWARD); + if (params.lengths.size() > 1) { + const bool supported_layout = + forward_layout == detail::layout::PACKED && backward_layout == detail::layout::PACKED; + if (!supported_layout) { + throw unsupported_configuration("Multi-dimensional transforms are only supported with default data layout"); + } + } else { + const bool supported_layout = + (forward_layout == detail::layout::PACKED || forward_layout == detail::layout::BATCH_INTERLEAVED) && + (backward_layout == detail::layout::PACKED || backward_layout == detail::layout::BATCH_INTERLEAVED); + if (!supported_layout) { + throw unsupported_configuration("Arbitary strides are not supported"); + } + } + + // compile the kernels and precalculate twiddles + std::size_t n_kernels = params.lengths.size(); + for (std::size_t i = 0; i < n_kernels; i++) { + bool skip_scaling = true; + if (i == n_kernels - 1) { + skip_scaling = false; + } + dimensions.emplace_back(build_w_spec_const(i, skip_scaling)); + dimensions.back().forward_kernels.at(0).twiddles_forward = std::shared_ptr( + calculate_twiddles(dimensions.back().level, dimensions.at(i), dimensions.back().forward_kernels), + [queue](Scalar* ptr) { + if (ptr != nullptr) { + sycl::free(ptr, queue); + } + }); + // TODO: refactor multi-dimensional fft's such that they can use a single pointer for twiddles. + dimensions.back().backward_kernels.at(0).twiddles_forward = std::shared_ptr( + calculate_twiddles(dimensions.back().level, dimensions.at(i), dimensions.back().backward_kernels), + [queue](Scalar* ptr) { + if (ptr != nullptr) { + PORTFFT_LOG_TRACE("Freeing the array for twiddle factors"); + sycl::free(ptr, queue); + } + }); + } + + bool is_scratch_required = false; + Idx num_global_level_dimensions = 0; + for (std::size_t i = 0; i < n_kernels; i++) { + if (dimensions.at(i).level == detail::level::GLOBAL) { + is_scratch_required = true; + num_global_level_dimensions++; + } + } + if (num_global_level_dimensions != 0) { + if (params.lengths.size() > 1) { + throw unsupported_configuration("Only 1D FFTs that do not fit in local memory are supported"); + } + if (params.get_distance(direction::FORWARD) != params.lengths[0] || + params.get_distance(direction::BACKWARD) != params.lengths[0]) { + throw unsupported_configuration("Large FFTs are currently only supported in non-strided format"); + } + } + + if (is_scratch_required) { + allocate_scratch_and_precompute_scan(num_global_level_dimensions); + } + } + + /** + * Utility function for copy constructor and copy assignment operator + * @param desc `committed_descriptor_impl` of which the copy is to be made + */ + void create_copy(const committed_descriptor_impl& desc) { + PORTFFT_LOG_FUNCTION_ENTRY(); +#define PORTFFT_COPY(x) this->x = desc.x; + PORTFFT_COPY(params) + PORTFFT_COPY(queue) + PORTFFT_COPY(dev) + PORTFFT_COPY(ctx) + PORTFFT_COPY(n_compute_units) + PORTFFT_COPY(supported_sg_sizes) + PORTFFT_COPY(local_memory_size) + PORTFFT_COPY(dimensions) + PORTFFT_COPY(scratch_space_required) + PORTFFT_COPY(llc_size) +#undef PORTFFT_COPY + + bool is_scratch_required = false; + for (std::size_t i = 0; i < desc.dimensions.size(); i++) { + if (desc.dimensions.at(i).level == detail::level::GLOBAL) { + is_scratch_required = true; + break; + } + } + if (is_scratch_required) { + PORTFFT_LOG_TRACE("Allocating 2 scratch arrays of size", desc.scratch_space_required, "Scalars in global memory"); + this->scratch_ptr_1 = + detail::make_shared(static_cast(desc.scratch_space_required), this->queue); + this->scratch_ptr_2 = + detail::make_shared(static_cast(desc.scratch_space_required), this->queue); + } + } + + public: + committed_descriptor_impl(const committed_descriptor_impl& desc) : params(desc.params) { // TODO params copied twice + PORTFFT_LOG_FUNCTION_ENTRY(); + create_copy(desc); + } + + committed_descriptor_impl& operator=(const committed_descriptor_impl& desc) { + PORTFFT_LOG_FUNCTION_ENTRY(); + if (this != &desc) { + create_copy(desc); + } + return *this; + } + + static_assert(std::is_same_v || std::is_same_v, + "Scalar must be either float or double!"); + + /** + * Destructor + */ + ~committed_descriptor_impl() { + PORTFFT_LOG_FUNCTION_ENTRY(); + queue.wait(); + } + + // default construction is not appropriate + committed_descriptor_impl() = delete; + + protected: + /** + * Dispatches to the implementation for the appropriate direction. + * + * @tparam TIn Type of the input buffer or USM pointer + * @tparam TOut Type of the output buffer or USM pointer + * @param in buffer or USM pointer to memory containing input data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param out buffer or USM pointer to memory containing output data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param used_storage how components of a complex value are stored - either split or interleaved + * @param compute_direction direction of compute, forward / backward + * @param dependencies events that must complete before the computation + * @return sycl::event + */ + template + sycl::event dispatch_direction(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, + complex_storage used_storage, direction compute_direction, + const std::vector& dependencies = {}) { + PORTFFT_LOG_FUNCTION_ENTRY(); +#ifndef PORTFFT_ENABLE_BUFFER_BUILDS + if constexpr (!std::is_pointer_v || !std::is_pointer_v) { + throw invalid_configuration("Buffer interface can not be called when buffer builds are disabled."); + } +#endif + if (used_storage != params.complex_storage) { + if (used_storage == complex_storage::SPLIT_COMPLEX) { + throw invalid_configuration( + "To use interface with split real and imaginary memory, descriptor.complex_storage must be set to " + "SPLIT_COMPLEX."); + } + throw invalid_configuration( + "To use interface with interleaved real and imaginary values, descriptor.complex_storage must be set to " + "INTERLEAVED_COMPLEX."); + } + if (compute_direction == direction::FORWARD) { + return dispatch_dimensions(in, out, in_imag, out_imag, dependencies, params.forward_strides, + params.backward_strides, params.forward_distance, params.backward_distance, + params.forward_offset, params.backward_offset, compute_direction); + } + return dispatch_dimensions(in, out, in_imag, out_imag, dependencies, params.backward_strides, + params.forward_strides, params.backward_distance, params.forward_distance, + params.backward_offset, params.forward_offset, compute_direction); + } + + /** + * Dispatches to the implementation for the appropriate number of dimensions. + * + * @tparam TIn Type of the input buffer or USM pointer + * @tparam TOut Type of the output buffer or USM pointer + * @param in buffer or USM pointer to memory containing input data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param out buffer or USM pointer to memory containing output data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param dependencies events that must complete before the computation + * @param input_strides strides between input elements for each dimension of one FFT + * @param output_strides strides between output elements for each dimension of one FFT + * @param input_distance distance between the starts of input data for two consecutive FFTs + * @param output_distance distance between the starts of output data for two consecutive FFTs + * @param input_offset offset into input allocation where the data for FFTs start + * @param output_offset offset into output allocation where the data for FFTs start + * @param compute_direction direction of compute, forward / backward + * @return sycl::event + */ + template + sycl::event dispatch_dimensions(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, + const std::vector& dependencies, + const std::vector& input_strides, + const std::vector& output_strides, std::size_t input_distance, + std::size_t output_distance, std::size_t input_offset, std::size_t output_offset, + direction compute_direction) { + PORTFFT_LOG_FUNCTION_ENTRY(); + using TOutConst = std::conditional_t, const std::remove_pointer_t*, const TOut>; + std::size_t n_dimensions = params.lengths.size(); + std::size_t total_size = params.get_flattened_length(); + + const auto forward_layout = detail::get_layout(params, direction::FORWARD); + const auto backward_layout = detail::get_layout(params, direction::BACKWARD); + + // currently multi-dimensional transforms are implemented just for default (PACKED) data layout + const bool multi_dim_supported = + forward_layout == detail::layout::PACKED && backward_layout == detail::layout::PACKED; + if (n_dimensions != 1 && !multi_dim_supported) { + throw internal_error("Only default layout is supported for multi-dimensional transforms."); + } + + // product of sizes of all dimension inner relative to the one we are currently working on + std::size_t inner_size = 1; + // product of sizes of all dimension outer relative to the one we are currently working on + std::size_t outer_size = total_size / params.lengths.back(); + std::size_t input_stride_0 = input_strides.back(); + std::size_t output_stride_0 = output_strides.back(); + // distances are currently used just in the first dimension - these changes are meant for that one + // TODO fix this to support non-default layouts + if (input_stride_0 < input_distance) { // for example: batch interleaved input + input_distance = params.lengths.back(); + } + if (output_stride_0 < output_distance) { // for example: batch interleaved output + output_distance = params.lengths.back(); + } + + PORTFFT_LOG_TRACE("Dispatching the kernel for the last dimension"); + sycl::event previous_event = + dispatch_kernel_1d(in, out, in_imag, out_imag, dependencies, params.number_of_transforms * outer_size, + input_stride_0, output_stride_0, input_distance, output_distance, input_offset, + output_offset, dimensions.back(), compute_direction); + if (n_dimensions == 1) { + return previous_event; + } + std::vector previous_events{previous_event}; + std::vector next_events; + inner_size *= params.lengths.back(); + for (std::size_t i = n_dimensions - 2; i != static_cast(-1); i--) { + outer_size /= params.lengths[i]; + // TODO do everything from the next loop in a single kernel once we support more than one distance in the + // kernels. + std::size_t stride_between_kernels = inner_size * params.lengths[i]; + PORTFFT_LOG_TRACE("Dispatching the kernels for the dimension", i); + for (std::size_t j = 0; j < params.number_of_transforms * outer_size; j++) { + sycl::event e = dispatch_kernel_1d( + out, out, out_imag, out_imag, previous_events, inner_size, inner_size, inner_size, 1, 1, + output_offset + j * stride_between_kernels, output_offset + j * stride_between_kernels, dimensions[i], + compute_direction); + next_events.push_back(e); + } + inner_size *= params.lengths[i]; + std::swap(previous_events, next_events); + next_events.clear(); + } + return queue.single_task(previous_events, []() {}); // just to get an event that depends on all previous ones + } + + /** + * Dispatches the kernel with the first subgroup size that is supported by the device. + * + * @tparam TIn Type of the input buffer or USM pointer + * @tparam TOut Type of the output buffer or USM pointer + * @param in buffer or USM pointer to memory containing input data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param out buffer or USM pointer to memory containing output data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param dependencies events that must complete before the computation + * @param n_transforms number of FT transforms to do in one call + * @param input_stride stride between input elements of one FFT + * @param output_stride stride between output elements of one FFT + * @param input_distance distance between the starts of input data for two consecutive FFTs + * @param output_distance distance between the starts of output data for two consecutive FFTs + * @param input_offset offset into input allocation where the data for FFTs start + * @param output_offset offset into output allocation where the data for FFTs start + * @param dimension_data data for the dimension this call will work on + * @param compute_direction direction of compute, forward / backward + * @return sycl::event + */ + template + sycl::event dispatch_kernel_1d(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, + const std::vector& dependencies, std::size_t n_transforms, + std::size_t input_stride, std::size_t output_stride, std::size_t input_distance, + std::size_t output_distance, std::size_t input_offset, std::size_t output_offset, + dimension_struct& dimension_data, direction compute_direction) { + PORTFFT_LOG_FUNCTION_ENTRY(); + return dispatch_kernel_1d_helper( + in, out, in_imag, out_imag, dependencies, n_transforms, input_stride, output_stride, input_distance, + output_distance, input_offset, output_offset, dimension_data, compute_direction); + } + + /** + * Helper for dispatching the kernel with the first subgroup size that is supported by the device. + * + * @tparam TIn Type of the input buffer or USM pointer + * @tparam TOut Type of the output buffer or USM pointer + * @tparam SubgroupSize first subgroup size + * @tparam OtherSGSizes other subgroup sizes + * @param in buffer or USM pointer to memory containing input data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param out buffer or USM pointer to memory containing output data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param dependencies events that must complete before the computation + * @param n_transforms number of FT transforms to do in one call + * @param input_stride stride between input elements of one FFT + * @param output_stride stride between output elements of one FFT + * @param input_distance distance between the starts of input data for two consecutive FFTs + * @param output_distance distance between the starts of output data for two consecutive FFTs + * @param input_offset offset into input allocation where the data for FFTs start + * @param output_offset offset into output allocation where the data for FFTs start + * @param dimension_data data for the dimension this call will work on + * @param compute_direction direction of compute, forward / backward + * @return sycl::event + */ + template + sycl::event dispatch_kernel_1d_helper(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, + const std::vector& dependencies, std::size_t n_transforms, + std::size_t input_stride, std::size_t output_stride, std::size_t input_distance, + std::size_t output_distance, std::size_t input_offset, + std::size_t output_offset, dimension_struct& dimension_data, + direction compute_direction) { + PORTFFT_LOG_FUNCTION_ENTRY(); + if (SubgroupSize == dimension_data.used_sg_size) { + const bool input_packed = input_distance == dimension_data.length && input_stride == 1; + const bool output_packed = output_distance == dimension_data.length && output_stride == 1; + const bool input_batch_interleaved = input_distance == 1 && input_stride == n_transforms; + const bool output_batch_interleaved = output_distance == 1 && output_stride == n_transforms; + for (kernel_data_struct kernel_data : dimension_data.forward_kernels) { + std::size_t minimum_local_mem_required; + if (input_batch_interleaved) { + minimum_local_mem_required = num_scalars_in_local_mem( + kernel_data.level, kernel_data.length, SubgroupSize, kernel_data.factors, + kernel_data.num_sgs_per_wg) * + sizeof(Scalar); + PORTFFT_LOG_TRACE("Local mem required:", minimum_local_mem_required, "B. Available: ", local_memory_size, + "B."); + if (static_cast(minimum_local_mem_required) > local_memory_size) { + throw out_of_local_memory_error( + "Insufficient amount of local memory available: " + std::to_string(local_memory_size) + + "B. Required: " + std::to_string(minimum_local_mem_required) + "B."); + } + } + } + if (input_packed && output_packed) { + return run_kernel( + in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, + compute_direction); + } + if (input_batch_interleaved && output_packed && in != out) { + return run_kernel( + in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, + compute_direction); + } + if (input_packed && output_batch_interleaved && in != out) { + return run_kernel( + in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, + compute_direction); + } + if (input_batch_interleaved && output_batch_interleaved) { + return run_kernel( + in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, + compute_direction); + } + throw unsupported_configuration("Only PACKED or BATCH_INTERLEAVED transforms are supported"); + } + if constexpr (sizeof...(OtherSGSizes) == 0) { + throw invalid_configuration("None of the compiled subgroup sizes are supported by the device!"); + } else { + return dispatch_kernel_1d_helper( + in, out, in_imag, out_imag, dependencies, n_transforms, input_stride, output_stride, input_distance, + output_distance, input_offset, output_offset, dimension_data, compute_direction); + } + } + + /** + * Struct for dispatching `run_kernel()` call. + * + * @tparam LayoutIn Input Layout + * @tparam LayoutOut Output Layout + * @tparam SubgroupSize size of the subgroup + * @tparam TIn Type of the input USM pointer or buffer + * @tparam TOut Type of the output USM pointer or buffer + */ + template + struct run_kernel_struct { + // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class + template + struct inner { + static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, + TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, + std::size_t forward_offset, std::size_t backward_offset, + dimension_struct& dimension_data, direction compute_direction); + }; + }; + + /** + * Common interface to run the kernel called by compute_forward and compute_backward + * + * @tparam LayoutIn Input Layout + * @tparam LayoutOut Output Layout + * @tparam SubgroupSize size of the subgroup + * @tparam TIn Type of the input USM pointer or buffer + * @tparam TOut Type of the output USM pointer or buffer + * @param in buffer or USM pointer to memory containing input data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param out buffer or USM pointer to memory containing output data. Real part of input data if + * `descriptor.complex_storage` is split. + * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if + * `descriptor.complex_storage` is interleaved. + * @param dependencies events that must complete before the computation + * @param n_transforms number of FT transforms to do in one call + * @param input_offset offset into input allocation where the data for FFTs start + * @param output_offset offset into output allocation where the data for FFTs start + * @param dimension_data data for the dimension this call will work on + * @param compute_direction direction of fft, forward / backward + * @return sycl::event + */ + template + sycl::event run_kernel(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, + const std::vector& dependencies, std::size_t n_transforms, + std::size_t input_offset, std::size_t output_offset, dimension_struct& dimension_data, + direction compute_direction) { + PORTFFT_LOG_FUNCTION_ENTRY(); + // mixing const and non-const inputs leads to hard-to-debug linking errors, as both use the same kernel name, but + // are called from different template instantiations. + static_assert(!std::is_pointer_v || std::is_const_v>, + "We do not differentiate kernel names between kernels with const and non-const USM inputs, so all " + "should be const"); + // kernel names currently assume both are the same. Mixing them without adding TOut to kernel names would lead to + // hard-to-debug linking errors + static_assert(std::is_pointer_v == std::is_pointer_v, + "Both input and output to the kernels should be the same - either buffers or USM"); + using TInReinterpret = decltype(detail::reinterpret(in)); + using TOutReinterpret = decltype(detail::reinterpret(out)); + std::size_t vec_multiplier = params.complex_storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; + return dispatch>( + dimension_data.level, detail::reinterpret(in), detail::reinterpret(out), + detail::reinterpret(in_imag), detail::reinterpret(out_imag), dependencies, + static_cast(n_transforms), static_cast(vec_multiplier * input_offset), + static_cast(vec_multiplier * output_offset), dimension_data, compute_direction); + } +}; + +} // namespace detail +} // namespace portfft + +#endif diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index bc40e55b..7798a789 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -431,7 +431,7 @@ static void dispatch_transpose_kernel_impl(const Scalar* input, Scalar* output, * @return sycl::event */ template -sycl::event transpose_level(const typename committed_descriptor::kernel_data_struct& kd_struct, +sycl::event transpose_level(const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const Scalar* input, TOut output, const IdxGlobal* factors_triple, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx total_factors, IdxGlobal output_offset, sycl::queue& queue, const std::vector& events, @@ -502,11 +502,12 @@ sycl::event transpose_level(const typename committed_descriptor: template std::vector compute_level( - const typename committed_descriptor::kernel_data_struct& kd_struct, const TIn input, Scalar* output, - const TIn input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, - IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, - IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, - Idx total_factors, complex_storage storage, const std::vector& dependencies, sycl::queue& queue) { + const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn input, + Scalar* output, const TIn input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, + const IdxGlobal* factors_triple, IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, + IdxGlobal input_global_offset, IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, + IdxGlobal batch_start, Idx factor_id, Idx total_factors, complex_storage storage, + const std::vector& dependencies, sycl::queue& queue) { PORTFFT_LOG_FUNCTION_ENTRY(); IdxGlobal local_range = kd_struct.local_range; IdxGlobal global_range = kd_struct.global_range; diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index ab41cc33..030a2a0d 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -24,1542 +24,15 @@ #include #include -#include -#include #include #include -#include "common/exceptions.hpp" -#include "common/subgroup.hpp" #include "defines.hpp" #include "enums.hpp" -#include "specialization_constant.hpp" -#include "utils.hpp" -namespace portfft { -template -class committed_descriptor; -namespace detail { -template -std::vector compute_level( - const typename committed_descriptor::kernel_data_struct& kd_struct, TIn input, Scalar* output, - TIn input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, const IdxGlobal* factors_triple, - IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, - IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, - Idx total_factors, complex_storage storage, const std::vector& dependencies, sycl::queue& queue); - -template -sycl::event transpose_level(const typename committed_descriptor::kernel_data_struct& kd_struct, - const Scalar* input, TOut output, const IdxGlobal* factors_triple, IdxGlobal committed_size, - Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx total_factors, - IdxGlobal output_offset, sycl::queue& queue, const std::vector& events, - complex_storage storage); - -// kernel names -// TODO: Remove all templates except Scalar, Domain and Memory and SubgroupSize -template -class workitem_kernel; -template -class subgroup_kernel; -template -class workgroup_kernel; -template -class global_kernel; -template -class transpose_kernel; - -/** - * Return the default strides for a given dft size - * - * @param lengths the dimensions of the dft - */ -inline std::vector get_default_strides(const std::vector& lengths) { - PORTFFT_LOG_FUNCTION_ENTRY(); - std::vector strides(lengths.size()); - std::size_t total_size = 1; - for (std::size_t i_plus1 = lengths.size(); i_plus1 > 0; i_plus1--) { - std::size_t i = i_plus1 - 1; - strides[i] = total_size; - total_size *= lengths[i]; - } - PORTFFT_LOG_TRACE("Default strides:", strides); - return strides; -} - -/** - * Return whether the given descriptor has default strides and distance for a given direction - * - * @tparam Descriptor Descriptor type - * @param desc Descriptor to check - * @param dir Direction - */ -template -bool has_default_strides_and_distance(const Descriptor& desc, direction dir) { - const auto default_strides = get_default_strides(desc.lengths); - const auto default_distance = desc.get_flattened_length(); - return desc.get_strides(dir) == default_strides && desc.get_distance(dir) == default_distance; -} - -/** - * Return whether the given descriptor has strides and distance consistent with the batch interleaved layout - * - * @tparam Descriptor Descriptor type - * @param desc Descriptor to check - * @param dir Direction - */ -template -bool is_batch_interleaved(const Descriptor& desc, direction dir) { - return desc.lengths.size() == 1 && desc.get_distance(dir) == 1 && - desc.get_strides(dir).back() == desc.number_of_transforms; -} - -/** - * Return an enum describing the layout of the data in the descriptor - * - * @tparam Descriptor Descriptor type - * @param desc Descriptor to check - * @param dir Direction - */ -template -detail::layout get_layout(const Descriptor& desc, direction dir) { - if (has_default_strides_and_distance(desc, dir)) { - return detail::layout::PACKED; - } - if (is_batch_interleaved(desc, dir)) { - return detail::layout::BATCH_INTERLEAVED; - } - return detail::layout::UNPACKED; -} - -} // namespace detail - -// forward declaration -template -struct descriptor; - -/* -Compute functions in the `committed_descriptor` call `dispatch_kernel` and `dispatch_kernel_helper`. These two functions -ensure the kernel is run with a supported subgroup size. Next `dispatch_kernel_helper` calls `run_kernel`. The -`run_kernel` member function picks appropriate implementation and calls the static `run_kernel of that implementation`. -The implementation specific `run_kernel` handles differences between forward and backward computations, casts the memory -(USM or buffers) from complex to scalars and launches the kernel. Each function described in this doc has only one -templated overload that handles both directions of transforms and buffer and USM memory. - -Device functions make no assumptions on the size of a work group or the number of workgroups in a kernel. These numbers -can be tuned for each device. - -Implementation-specific `run_kernel` function make the size of the FFT that is handled by the individual workitems -compile time constant. The one for subgroup implementation also calls `cross_sg_dispatcher` that makes the -cross-subgroup factor of FFT size compile time constant. They do that by using a switch on the FFT size for one -workitem, before calling `workitem_impl`, `subgroup_impl` or `workgroup_impl` . The `_impl` functions take the FFT size -for one workitem as a template parameter. Only the calls that are determined to fit into available registers (depending -on the value of PORTFFT_TARGET_REGS_PER_WI macro) are actually instantiated. - -The `_impl` functions iterate over the batch of problems, loading data for each first in -local memory then from there into private one. This is done in these two steps to avoid non-coalesced global memory -accesses. `workitem_impl` loads one problem per workitem, `subgroup_impl` loads one problem per subgroup and -`workgroup_impl` loads one problem per workgroup. After doing computations by the calls to `wi_dft` for workitem, -`sg_dft` for subgroup and `wg_dft` for workgroup, the data is written out, going through local memory again. - -The computational parts of the implementations are further documented in files with their implementations -`workitem.hpp`, `subgroup.hpp` and `workgroup.hpp`. -*/ - -/** - * A committed descriptor that contains everything that is needed to run FFT. - * - * @tparam Scalar type of the scalar used for computations - * @tparam Domain domain of the FFT - */ -template -class committed_descriptor { - using complex_type = std::complex; - - friend struct descriptor; - template - friend std::vector detail::compute_level( - const typename committed_descriptor::kernel_data_struct& kd_struct, TIn input, Scalar1* output, - TIn input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr, const IdxGlobal* factors_triple, - IdxGlobal intermediate_twiddle_offset, IdxGlobal subimpl_twiddle_offset, IdxGlobal input_global_offset, - IdxGlobal committed_size, Idx num_batches_in_l2, IdxGlobal n_transforms, IdxGlobal batch_start, Idx factor_id, - Idx total_factors, complex_storage storage, const std::vector& dependencies, sycl::queue& queue); - - template - friend sycl::event detail::transpose_level( - const typename committed_descriptor::kernel_data_struct& kd_struct, const Scalar1* input, - TOut output, const IdxGlobal* factors_triple, IdxGlobal committed_size, Idx num_batches_in_l2, - IdxGlobal n_transforms, IdxGlobal batch_start, Idx total_factors, IdxGlobal output_offset, sycl::queue& queue, - const std::vector& events, complex_storage storage); - - /** - * Vector containing the sub-implementation level, kernel_ids and factors for each factor that requires a separate - * kernel. - */ - using kernel_ids_and_metadata_t = - std::vector, std::vector>>; - descriptor params; - sycl::queue queue; - sycl::device dev; - sycl::context ctx; - Idx n_compute_units; - std::vector supported_sg_sizes; - Idx local_memory_size; - IdxGlobal llc_size; - std::shared_ptr scratch_ptr_1; - std::shared_ptr scratch_ptr_2; - std::size_t scratch_space_required; - - struct kernel_data_struct { - sycl::kernel_bundle exec_bundle; - std::vector factors; - std::size_t length; - Idx used_sg_size; - Idx num_sgs_per_wg; - std::shared_ptr twiddles_forward; - detail::level level; - IdxGlobal batch_size; - std::size_t local_mem_required; - IdxGlobal global_range; - IdxGlobal local_range; - - kernel_data_struct(sycl::kernel_bundle&& exec_bundle, - const std::vector& factors, std::size_t length, Idx used_sg_size, Idx num_sgs_per_wg, - std::shared_ptr twiddles_forward, detail::level level) - : exec_bundle(std::move(exec_bundle)), - factors(factors), - length(length), - used_sg_size(used_sg_size), - num_sgs_per_wg(num_sgs_per_wg), - twiddles_forward(twiddles_forward), - level(level) {} - }; +#include "committed_descriptor.hpp" - struct dimension_struct { - std::vector forward_kernels; - std::vector backward_kernels; - std::vector transpose_kernels; - std::shared_ptr factors_and_scan; - detail::level level; - std::size_t length; - Idx used_sg_size; - Idx num_batches_in_l2; - Idx num_factors; - - dimension_struct(std::vector forward_kernels, std::vector backward_kernels, - detail::level level, std::size_t length, Idx used_sg_size) - : forward_kernels(std::move(forward_kernels)), - backward_kernels(std::move(backward_kernels)), - level(level), - length(length), - used_sg_size(used_sg_size) {} - }; - - std::vector dimensions; - - template - auto dispatch(detail::level level, Args&&... args) { - switch (level) { - case detail::level::WORKITEM: - return Impl::template inner::execute(*this, args...); - case detail::level::SUBGROUP: - return Impl::template inner::execute(*this, args...); - case detail::level::WORKGROUP: - return Impl::template inner::execute(*this, args...); - case detail::level::GLOBAL: - return Impl::template inner::execute(*this, args...); - default: - // This should be unreachable - throw unsupported_configuration("Unimplemented"); - } - } - - template - auto dispatch(detail::level level, Args&&... args) { - switch (level) { - case detail::level::WORKITEM: - return Impl::template inner::execute(*this, args...); - case detail::level::SUBGROUP: - return Impl::template inner::execute(*this, args...); - case detail::level::WORKGROUP: - return Impl::template inner::execute(*this, args...); - case detail::level::GLOBAL: - return Impl::template inner::execute(*this, args...); - default: - // This should be unreachable - throw unsupported_configuration("Unimplemented"); - } - } - - template - auto dispatch(detail::level level, Args&&... args) { - switch (level) { - case detail::level::WORKITEM: - return Impl::template inner::execute(*this, - args...); - case detail::level::SUBGROUP: - return Impl::template inner::execute(*this, - args...); - case detail::level::WORKGROUP: - return Impl::template inner::execute( - *this, args...); - case detail::level::GLOBAL: - return Impl::template inner::execute(*this, - args...); - default: - // This should be unreachable - throw unsupported_configuration("Unimplemented"); - } - } - - /** - * Prepares the implementation for the particular problem size. That includes factorizing it and getting ids for the - * set of kernels that need to be JIT compiled. - * - * @tparam SubgroupSize size of the subgroup - * @param kernel_num the consecutive number of the kernel to prepare - * @return implementation to use for the dimension and a vector of tuples of: implementation to use for a kernel, - * vector of kernel ids, factors - */ - template - std::tuple prepare_implementation(std::size_t kernel_num) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // TODO: check and support all the parameter values - if constexpr (Domain != domain::COMPLEX) { - throw unsupported_configuration("portFFT only supports complex to complex transforms"); - } - - std::vector ids; - std::vector factors; - IdxGlobal fft_size = static_cast(params.lengths[kernel_num]); - if (detail::fits_in_wi(fft_size)) { - ids = detail::get_ids(); - PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size); - return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}}; - } - if (detail::fits_in_sg(fft_size, SubgroupSize)) { - Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); - Idx factor_wi = static_cast(fft_size) / factor_sg; - // This factorization is duplicated in the dispatch logic on the device. - // The CT and spec constant factors should match. - factors.push_back(factor_wi); - factors.push_back(factor_sg); - ids = detail::get_ids(); - PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg); - return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}}; - } - IdxGlobal n_idx_global = detail::factorize(fft_size); - if (detail::can_cast_safely(n_idx_global) && - detail::can_cast_safely(fft_size / n_idx_global)) { - if (n_idx_global == 1) { - throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported"); - } - Idx n = static_cast(n_idx_global); - Idx m = static_cast(fft_size / n_idx_global); - Idx factor_sg_n = detail::factorize_sg(n, SubgroupSize); - Idx factor_wi_n = n / factor_sg_n; - Idx factor_sg_m = detail::factorize_sg(m, SubgroupSize); - Idx factor_wi_m = m / factor_sg_m; - Idx temp_num_sgs_in_wg; - std::size_t local_memory_usage = num_scalars_in_local_mem( - detail::level::WORKGROUP, static_cast(fft_size), SubgroupSize, - {factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg) * - sizeof(Scalar); - // Checks for PACKED layout only at the moment, as the other layout will not be supported - // by the global implementation. For such sizes, only PACKED layout will be supported - if (detail::fits_in_wi(factor_wi_n) && detail::fits_in_wi(factor_wi_m) && - (local_memory_usage <= static_cast(local_memory_size))) { - factors.push_back(factor_wi_n); - factors.push_back(factor_sg_n); - factors.push_back(factor_wi_m); - factors.push_back(factor_sg_m); - // This factorization of N and M is duplicated in the dispatch logic on the device. - // The CT and spec constant factors should match. - ids = detail::get_ids(); - PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n, - " factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m); - return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}}; - } - } - PORTFFT_LOG_TRACE("Preparing global impl"); - std::vector, std::vector>> param_vec; - auto check_and_select_target_level = [&](IdxGlobal factor_size, bool batch_interleaved_layout = true) -> bool { - if (detail::fits_in_wi(factor_size)) { - // Throughout we have assumed there would always be enough local memory for the WI implementation. - param_vec.emplace_back(detail::level::WORKITEM, - detail::get_ids(), - std::vector{static_cast(factor_size)}); - PORTFFT_LOG_TRACE("Workitem kernel for factor:", factor_size); - return true; - } - bool fits_in_local_memory_subgroup = [&]() { - Idx temp_num_sgs_in_wg; - IdxGlobal factor_sg = detail::factorize_sg(factor_size, SubgroupSize); - IdxGlobal factor_wi = factor_size / factor_sg; - if (detail::can_cast_safely(factor_sg) && detail::can_cast_safely(factor_wi)) { - if (batch_interleaved_layout) { - return (2 * - num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * - sizeof(Scalar) + - 2 * static_cast(factor_size) * sizeof(Scalar)) < - static_cast(local_memory_size); - } - return (num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * - sizeof(Scalar) + - 2 * static_cast(factor_size) * sizeof(Scalar)) < - static_cast(local_memory_size); - } - return false; - }(); - if (detail::fits_in_sg(factor_size, SubgroupSize) && fits_in_local_memory_subgroup && - !PORTFFT_SLOW_SG_SHUFFLES) { - Idx factor_sg = detail::factorize_sg(static_cast(factor_size), SubgroupSize); - Idx factor_wi = static_cast(factor_size) / factor_sg; - PORTFFT_LOG_TRACE("Subgroup kernel for factor:", factor_size, "with factor_wi:", factor_wi, - "and factor_sg:", factor_sg); - param_vec.emplace_back(detail::level::SUBGROUP, - detail::get_ids(), - std::vector{factor_sg, factor_wi}); - return true; - } - return false; - }; - detail::factorize_input(fft_size, check_and_select_target_level); - return {detail::level::GLOBAL, param_vec}; - } - - /** - * Struct for dispatching `set_spec_constants()` call. - */ - struct set_spec_constants_struct { - // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class - template - struct inner { - static void execute(committed_descriptor& desc, sycl::kernel_bundle& in_bundle, - std::size_t length, const std::vector& factors, detail::level level, Idx factor_num, - Idx num_factors); - }; - }; - - /** - * Sets the implementation dependant specialization constant value - * @param top_level implementation to dispatch to - * @param in_bundle input kernel bundle to set spec constants for - * @param length length of the fft - * @param factors factors of the corresponsing length - * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation - * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation - * @param scale_factor_applied whether or not to multiply scale factor - * @param level sub implementation to run which will be set as a spec constant - * @param conjugate_on_load whether or not to take conjugate of the input - * @param conjugate_on_store whether or not to take conjugate of the output - * @param scale_factor Scale to be applied to the result - * @param factor_num factor number which is set as a spec constant - * @param num_factors total number of factors of the committed size, set as a spec constant - */ - void set_spec_constants(detail::level top_level, sycl::kernel_bundle& in_bundle, - std::size_t length, const std::vector& factors, - detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, - detail::apply_scale_factor scale_factor_applied, detail::level level, - detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store, - Scalar scale_factor, Idx factor_num = 0, Idx num_factors = 0) { - PORTFFT_LOG_FUNCTION_ENTRY(); - const Idx length_idx = static_cast(length); - // These spec constants are used in all implementations, so we set them here - PORTFFT_LOG_TRACE("Setting specialization constants:"); - PORTFFT_LOG_TRACE("SpecConstComplexStorage:", params.complex_storage); - in_bundle.template set_specialization_constant(params.complex_storage); - PORTFFT_LOG_TRACE("SpecConstNumRealsPerFFT:", 2 * length_idx); - in_bundle.template set_specialization_constant(2 * length_idx); - PORTFFT_LOG_TRACE("SpecConstWIScratchSize:", 2 * detail::wi_temps(length_idx)); - in_bundle.template set_specialization_constant(2 * detail::wi_temps(length_idx)); - PORTFFT_LOG_TRACE("SpecConstMultiplyOnLoad:", multiply_on_load); - in_bundle.template set_specialization_constant(multiply_on_load); - PORTFFT_LOG_TRACE("SpecConstMultiplyOnStore:", multiply_on_store); - in_bundle.template set_specialization_constant(multiply_on_store); - PORTFFT_LOG_TRACE("SpecConstApplyScaleFactor:", scale_factor_applied); - in_bundle.template set_specialization_constant(scale_factor_applied); - PORTFFT_LOG_TRACE("SpecConstConjugateOnLoad:", conjugate_on_load); - in_bundle.template set_specialization_constant(conjugate_on_load); - PORTFFT_LOG_TRACE("SpecConstConjugateOnStore:", conjugate_on_store); - in_bundle.template set_specialization_constant(conjugate_on_store); - PORTFFT_LOG_TRACE("get_spec_constant_scale:", scale_factor); - in_bundle.template set_specialization_constant()>(scale_factor); - - dispatch(top_level, in_bundle, length, factors, level, factor_num, num_factors); - } - - /** - * Struct for dispatching `num_scalars_in_local_mem()` call. - */ - struct num_scalars_in_local_mem_struct { - // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class - template - struct inner { - static std::size_t execute(committed_descriptor& desc, std::size_t length, Idx used_sg_size, - const std::vector& factors, Idx& num_sgs_per_wg); - }; - }; - - /** - * Determine the number of scalars we need to have space for in the local memory. It may also modify `num_sgs_per_wg` - * to make the problem fit in the local memory. - * - * @param level the implementation that will be used - * @param length length of the FFT the kernel will execute - * @param used_sg_size subgroup size the kernel will use - * @param factors factorization of the FFT size the kernel will use - * @param[out] num_sgs_per_wg number of subgroups in a workgroup - * @return the number of scalars - */ - template - std::size_t num_scalars_in_local_mem(detail::level level, std::size_t length, Idx used_sg_size, - const std::vector& factors, Idx& num_sgs_per_wg) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return dispatch(level, length, used_sg_size, factors, num_sgs_per_wg); - } - - /** - * Struct for dispatching `calculate_twiddles()` call. - */ - struct calculate_twiddles_struct { - // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class - template - struct inner { - static Scalar* execute(committed_descriptor& desc, dimension_struct& dimension_data, - std::vector& kernels); - }; - }; - - /** - * Calculates twiddle factors for the implementation in use. - * @param level Implementation selected for the committed size - * @param dimension_data dimension_struct correspoding to the dimension for which twiddles are being calculated - * @param kernels vector of kernels - * @return Scalar* USM pointer to the twiddle factors - */ - Scalar* calculate_twiddles(detail::level level, dimension_struct& dimension_data, - std::vector& kernels) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return dispatch(level, dimension_data, kernels); - } - - /** - * Sets the specialization constants for all the kernel_ids contained in the vector - * returned from prepare_implementation - * @tparam SubgroupSize Subgroup size - * @param top_level selected level of implementation - * @param prepared_vec vector of tuples of: implementation to use for a kernel, - * vector of kernel ids, factors - * @param compute_direction direction of compute: forward or backward - * @param dimension_num which dimension are the kernels being built for - * @param skip_scaling whether or not to skip scaling - * @return vector of kernel_data_struct if all kernel builds are successful, std::nullopt otherwise - */ - template - std::optional> set_spec_constants_driver(detail::level top_level, - kernel_ids_and_metadata_t& prepared_vec, - direction compute_direction, - std::size_t dimension_num, - bool skip_scaling) { - Scalar scale_factor = compute_direction == direction::FORWARD ? params.forward_scale : params.backward_scale; - detail::apply_scale_factor scale_factor_applied = detail::apply_scale_factor::APPLIED; - bool is_compatible = true; - if (skip_scaling) { - scale_factor_applied = detail::apply_scale_factor::NOT_APPLIED; - } - std::size_t counter = 0; - auto conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; - auto conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; - std::vector result; - for (auto& [level, ids, factors] : prepared_vec) { - auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); - if (top_level == detail::level::GLOBAL) { - std::size_t factor_size = - static_cast(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())); - if (counter == prepared_vec.size() - 1) { - if (compute_direction == direction::BACKWARD) { - conjugate_on_store = detail::complex_conjugate::APPLIED; - } - set_spec_constants(detail::level::GLOBAL, in_bundle, factor_size, factors, - detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, - detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store, - scale_factor, static_cast(counter), static_cast(prepared_vec.size())); - // reset conjugate_on_store - conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; - } else { - if (counter == 0 && compute_direction == direction::BACKWARD) { - conjugate_on_load = detail::complex_conjugate::APPLIED; - } - set_spec_constants(detail::level::GLOBAL, in_bundle, factor_size, factors, - detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED, - detail::apply_scale_factor::NOT_APPLIED, level, conjugate_on_load, conjugate_on_store, - scale_factor, static_cast(counter), static_cast(prepared_vec.size())); - // reset conjugate_on_load - conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; - } - } else { - if (compute_direction == direction::BACKWARD) { - conjugate_on_load = detail::complex_conjugate::APPLIED; - conjugate_on_store = detail::complex_conjugate::APPLIED; - } - set_spec_constants(level, in_bundle, params.lengths[dimension_num], factors, - detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, - scale_factor_applied, level, conjugate_on_load, conjugate_on_store, scale_factor); - } - try { - PORTFFT_LOG_TRACE("Building kernel bundle with subgroup size", SubgroupSize); - result.emplace_back(sycl::build(in_bundle), factors, params.lengths[dimension_num], SubgroupSize, - PORTFFT_SGS_IN_WG, std::shared_ptr(), level); - PORTFFT_LOG_TRACE("Kernel bundle build complete."); - } catch (std::exception& e) { - PORTFFT_LOG_WARNING("Build for subgroup size", SubgroupSize, "failed with message:\n", e.what()); - is_compatible = false; - break; - } - counter++; - } - if (is_compatible) { - return result; - } - return std::nullopt; - } - - /** - * Builds the kernel bundles with appropriate values of specialization constants for the first supported subgroup - * size. - * - * @tparam SubgroupSize first subgroup size - * @tparam OtherSGSizes other subgroup sizes - * @param dimension_num The dimension for which the kernels are being built - * @param skip_scaling whether or not to skip scaling - * @return `dimension_struct` for the newly built kernels - */ - template - dimension_struct build_w_spec_const(std::size_t dimension_num, bool skip_scaling) { - PORTFFT_LOG_FUNCTION_ENTRY(); - if (std::count(supported_sg_sizes.begin(), supported_sg_sizes.end(), SubgroupSize)) { - auto [top_level, prepared_vec] = prepare_implementation(dimension_num); - bool is_compatible = true; - for (auto [level, ids, factors] : prepared_vec) { - is_compatible = is_compatible && sycl::is_compatible(ids, dev); - if (!is_compatible) { - break; - } - } - - if (is_compatible) { - auto forward_kernels = set_spec_constants_driver(top_level, prepared_vec, direction::FORWARD, - dimension_num, skip_scaling); - auto backward_kernels = set_spec_constants_driver(top_level, prepared_vec, direction::BACKWARD, - dimension_num, skip_scaling); - if (forward_kernels.has_value() && backward_kernels.has_value()) { - return {forward_kernels.value(), backward_kernels.value(), top_level, params.lengths[dimension_num], - SubgroupSize}; - } - } - } - if constexpr (sizeof...(OtherSGSizes) == 0) { - throw invalid_configuration("None of the compiled subgroup sizes are supported by the device"); - } else { - return build_w_spec_const(dimension_num, skip_scaling); - } - } - - /** - * Function which calculates the amount of scratch space required, and also pre computes the necessary scans required. - * @param num_global_level_dimensions number of global level dimensions in the committed size - */ - void allocate_scratch_and_precompute_scan(Idx num_global_level_dimensions) { - PORTFFT_LOG_FUNCTION_ENTRY(); - std::size_t n_kernels = params.lengths.size(); - if (num_global_level_dimensions == 1) { - std::size_t global_dimension = 0; - for (std::size_t i = 0; i < n_kernels; i++) { - if (dimensions.at(i).level == detail::level::GLOBAL) { - global_dimension = i; - break; - } - } - std::vector factors; - std::vector sub_batches; - std::vector inclusive_scan; - std::size_t cache_required_for_twiddles = 0; - for (const auto& kernel_data : dimensions.at(global_dimension).forward_kernels) { - IdxGlobal factor_size = static_cast( - std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); - cache_required_for_twiddles += - static_cast(2 * factor_size * kernel_data.batch_size) * sizeof(Scalar); - factors.push_back(factor_size); - sub_batches.push_back(kernel_data.batch_size); - } - dimensions.at(global_dimension).num_factors = static_cast(factors.size()); - std::size_t cache_space_left_for_batches = static_cast(llc_size) - cache_required_for_twiddles; - // TODO: In case of multi-dim (single dim global sized), this should be batches corresponding to that dim - dimensions.at(global_dimension).num_batches_in_l2 = static_cast(std::min( - static_cast(PORTFFT_MAX_CONCURRENT_KERNELS), - std::min(params.number_of_transforms, - std::max(std::size_t(1), cache_space_left_for_batches / - (2 * dimensions.at(global_dimension).length * sizeof(Scalar)))))); - scratch_space_required = 2 * dimensions.at(global_dimension).length * - static_cast(dimensions.at(global_dimension).num_batches_in_l2); - PORTFFT_LOG_TRACE("Allocating 2 scratch arrays of size", scratch_space_required, "scalars in global memory"); - scratch_ptr_1 = detail::make_shared(scratch_space_required, queue); - scratch_ptr_2 = detail::make_shared(scratch_space_required, queue); - inclusive_scan.push_back(factors.at(0)); - for (std::size_t i = 1; i < factors.size(); i++) { - inclusive_scan.push_back(inclusive_scan.at(i - 1) * factors.at(i)); - } - PORTFFT_LOG_TRACE("Dimension:", global_dimension, - "num_batches_in_l2:", dimensions.at(global_dimension).num_batches_in_l2, - "scan:", inclusive_scan); - dimensions.at(global_dimension).factors_and_scan = - detail::make_shared(factors.size() + sub_batches.size() + inclusive_scan.size(), queue); - queue.copy(factors.data(), dimensions.at(global_dimension).factors_and_scan.get(), factors.size()); - queue.copy(sub_batches.data(), dimensions.at(global_dimension).factors_and_scan.get() + factors.size(), - sub_batches.size()); - queue.copy(inclusive_scan.data(), - dimensions.at(global_dimension).factors_and_scan.get() + factors.size() + sub_batches.size(), - inclusive_scan.size()); - queue.wait(); - // build transpose kernels - std::size_t num_transposes_required = factors.size() - 1; - for (std::size_t i = 0; i < num_transposes_required; i++) { - std::vector ids; - auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), - detail::get_transpose_kernel_ids()); - PORTFFT_LOG_TRACE("Setting specialization constants for transpose kernel", i); - PORTFFT_LOG_TRACE("SpecConstComplexStorage:", params.complex_storage); - in_bundle.template set_specialization_constant(params.complex_storage); - PORTFFT_LOG_TRACE("GlobalSpecConstLevelNum:", i); - in_bundle.template set_specialization_constant(static_cast(i)); - PORTFFT_LOG_TRACE("GlobalSpecConstNumFactors:", factors.size()); - in_bundle.template set_specialization_constant( - static_cast(factors.size())); - dimensions.at(global_dimension) - .transpose_kernels.emplace_back( - sycl::build(in_bundle), - std::vector{static_cast(factors.at(i)), static_cast(sub_batches.at(i))}, 1, 1, 1, - std::shared_ptr(), detail::level::GLOBAL); - } - } else { - std::size_t max_encountered_global_size = 0; - for (std::size_t i = 0; i < n_kernels; i++) { - if (dimensions.at(i).level == detail::level::GLOBAL) { - max_encountered_global_size = max_encountered_global_size > dimensions.at(i).length - ? max_encountered_global_size - : dimensions.at(i).length; - } - } - // TODO: max_scratch_size should be max(global_size_1 * corresponding_batches_in_l2, global_size_1 * - // corresponding_batches_in_l2), in the case of multi-dim global FFTs. - scratch_space_required = 2 * max_encountered_global_size * params.number_of_transforms; - scratch_ptr_1 = detail::make_shared(scratch_space_required, queue); - scratch_ptr_2 = detail::make_shared(scratch_space_required, queue); - for (std::size_t i = 0; i < n_kernels; i++) { - if (dimensions.at(i).level == detail::level::GLOBAL) { - std::vector factors; - std::vector sub_batches; - std::vector inclusive_scan; - for (const auto& kernel_data : dimensions.at(i).forward_kernels) { - IdxGlobal factor_size = static_cast( - std::accumulate(kernel_data.factors.begin(), kernel_data.factors.end(), 1, std::multiplies())); - factors.push_back(factor_size); - sub_batches.push_back(kernel_data.batch_size); - } - inclusive_scan.push_back(factors.at(0)); - for (std::size_t j = 1; j < factors.size(); j++) { - inclusive_scan.push_back(inclusive_scan.at(j - 1) * factors.at(j)); - } - dimensions.at(i).num_factors = static_cast(factors.size()); - dimensions.at(i).factors_and_scan = - detail::make_shared(factors.size() + sub_batches.size() + inclusive_scan.size(), queue); - queue.copy(factors.data(), dimensions.at(i).factors_and_scan.get(), factors.size()); - queue.copy(sub_batches.data(), dimensions.at(i).factors_and_scan.get() + factors.size(), sub_batches.size()); - queue.copy(inclusive_scan.data(), - dimensions.at(i).factors_and_scan.get() + factors.size() + sub_batches.size(), - inclusive_scan.size()); - queue.wait(); - // build transpose kernels - std::size_t num_transposes_required = factors.size() - 1; - for (std::size_t j = 0; j < num_transposes_required; j++) { - auto in_bundle = sycl::get_kernel_bundle( - queue.get_context(), detail::get_transpose_kernel_ids()); - PORTFFT_LOG_TRACE("Setting specilization constants for transpose kernel", j); - PORTFFT_LOG_TRACE("GlobalSpecConstLevelNum:", i); - in_bundle.template set_specialization_constant(static_cast(i)); - PORTFFT_LOG_TRACE("GlobalSpecConstNumFactors:", factors.size()); - in_bundle.template set_specialization_constant( - static_cast(factors.size())); - dimensions.at(i).transpose_kernels.emplace_back( - sycl::build(in_bundle), - std::vector{static_cast(factors.at(j)), static_cast(sub_batches.at(j))}, 1, 1, 1, - std::shared_ptr(), detail::level::GLOBAL); - } - } - } - } - } - - /** - * Constructor. - * - * @param params descriptor this is created from - * @param queue queue to use when enqueueing device work - */ - committed_descriptor(const descriptor& params, sycl::queue& queue) - : params(params), - queue(queue), - dev(queue.get_device()), - ctx(queue.get_context()), - // get some properties we will use for tunning - n_compute_units(static_cast(dev.get_info())), - supported_sg_sizes(dev.get_info()), - local_memory_size(static_cast(queue.get_device().get_info())), - llc_size(static_cast(queue.get_device().get_info())) { - PORTFFT_LOG_FUNCTION_ENTRY(); - PORTFFT_LOG_TRACE("Device info:"); - PORTFFT_LOG_TRACE("n_compute_units:", n_compute_units); - PORTFFT_LOG_TRACE("supported_sg_sizes:", supported_sg_sizes); - PORTFFT_LOG_TRACE("local_memory_size:", local_memory_size); - PORTFFT_LOG_TRACE("llc_size:", llc_size); - - // check it's suitable to run - const auto forward_layout = detail::get_layout(params, direction::FORWARD); - const auto backward_layout = detail::get_layout(params, direction::BACKWARD); - if (params.lengths.size() > 1) { - const bool supported_layout = - forward_layout == detail::layout::PACKED && backward_layout == detail::layout::PACKED; - if (!supported_layout) { - throw unsupported_configuration("Multi-dimensional transforms are only supported with default data layout"); - } - } else { - const bool supported_layout = - (forward_layout == detail::layout::PACKED || forward_layout == detail::layout::BATCH_INTERLEAVED) && - (backward_layout == detail::layout::PACKED || backward_layout == detail::layout::BATCH_INTERLEAVED); - if (!supported_layout) { - throw unsupported_configuration("Arbitary strides are not supported"); - } - } - - // compile the kernels and precalculate twiddles - std::size_t n_kernels = params.lengths.size(); - for (std::size_t i = 0; i < n_kernels; i++) { - bool skip_scaling = true; - if (i == n_kernels - 1) { - skip_scaling = false; - } - dimensions.emplace_back(build_w_spec_const(i, skip_scaling)); - dimensions.back().forward_kernels.at(0).twiddles_forward = std::shared_ptr( - calculate_twiddles(dimensions.back().level, dimensions.at(i), dimensions.back().forward_kernels), - [queue](Scalar* ptr) { - if (ptr != nullptr) { - sycl::free(ptr, queue); - } - }); - // TODO: refactor multi-dimensional fft's such that they can use a single pointer for twiddles. - dimensions.back().backward_kernels.at(0).twiddles_forward = std::shared_ptr( - calculate_twiddles(dimensions.back().level, dimensions.at(i), dimensions.back().backward_kernels), - [queue](Scalar* ptr) { - if (ptr != nullptr) { - PORTFFT_LOG_TRACE("Freeing the array for twiddle factors"); - sycl::free(ptr, queue); - } - }); - } - - bool is_scratch_required = false; - Idx num_global_level_dimensions = 0; - for (std::size_t i = 0; i < n_kernels; i++) { - if (dimensions.at(i).level == detail::level::GLOBAL) { - is_scratch_required = true; - num_global_level_dimensions++; - } - } - if (num_global_level_dimensions != 0) { - if (params.lengths.size() > 1) { - throw unsupported_configuration("Only 1D FFTs that do not fit in local memory are supported"); - } - if (params.get_distance(direction::FORWARD) != params.lengths[0] || - params.get_distance(direction::BACKWARD) != params.lengths[0]) { - throw unsupported_configuration("Large FFTs are currently only supported in non-strided format"); - } - } - - if (is_scratch_required) { - allocate_scratch_and_precompute_scan(num_global_level_dimensions); - } - } - - /** - * Utility function fo copy constructor and copy assignment operator - * @param desc committed_descriptor of which the copy is to be made - */ - void create_copy(const committed_descriptor& desc) { - PORTFFT_LOG_FUNCTION_ENTRY(); -#define PORTFFT_COPY(x) this->x = desc.x; - PORTFFT_COPY(params) - PORTFFT_COPY(queue) - PORTFFT_COPY(dev) - PORTFFT_COPY(ctx) - PORTFFT_COPY(n_compute_units) - PORTFFT_COPY(supported_sg_sizes) - PORTFFT_COPY(local_memory_size) - PORTFFT_COPY(dimensions) - PORTFFT_COPY(scratch_space_required) - PORTFFT_COPY(llc_size) - -#undef PORTFFT_COPY - bool is_scratch_required = false; - for (std::size_t i = 0; i < desc.dimensions.size(); i++) { - if (desc.dimensions.at(i).level == detail::level::GLOBAL) { - is_scratch_required = true; - break; - } - } - if (is_scratch_required) { - PORTFFT_LOG_TRACE("Allocating 2 scratch arrays of size", desc.scratch_space_required, "Scalars in global memory"); - this->scratch_ptr_1 = - detail::make_shared(static_cast(desc.scratch_space_required), this->queue); - this->scratch_ptr_2 = - detail::make_shared(static_cast(desc.scratch_space_required), this->queue); - } - } - - public: - committed_descriptor(const committed_descriptor& desc) : params(desc.params) { - PORTFFT_LOG_FUNCTION_ENTRY(); - create_copy(desc); - } - - committed_descriptor& operator=(const committed_descriptor& desc) { - PORTFFT_LOG_FUNCTION_ENTRY(); - if (this != &desc) { - create_copy(desc); - } - return *this; - } - - static_assert(std::is_same_v || std::is_same_v, - "Scalar must be either float or double!"); - /** - * Alias for `Scalar`. - */ - using scalar_type = Scalar; - /** - * Alias for `Domain`. - */ - static constexpr domain DomainValue = Domain; - - /** - * Destructor - */ - ~committed_descriptor() { - PORTFFT_LOG_FUNCTION_ENTRY(); - queue.wait(); - } - - // default construction is not appropriate - committed_descriptor() = delete; - - /** - * Computes in-place forward FFT, working on a buffer. - * - * @param inout buffer containing input and output data - */ - void compute_forward(sycl::buffer& inout) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // For now we can just call out-of-place implementation. - // This might need to be changed once we implement support for large sizes that work in global memory. - compute_forward(inout, inout); - } - - /** - * Computes in-place forward FFT, working on buffers. - * - * @param inout_real buffer containing real part of the input and output data - * @param inout_imag buffer containing imaginary part of the input and output data - */ - void compute_forward(sycl::buffer& inout_real, sycl::buffer& inout_imag) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // For now we can just call out-of-place implementation. - // This might need to be changed once we implement support for large sizes that work in global memory. - compute_forward(inout_real, inout_imag, inout_real, inout_imag); - } - - /** - * Computes in-place backward FFT, working on a buffer. - * - * @param inout buffer containing input and output data - */ - void compute_backward(sycl::buffer& inout) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // For now we can just call out-of-place implementation. - // This might need to be changed once we implement support for large sizes that work in global memory. - compute_backward(inout, inout); - } - - /** - * Computes in-place backward FFT, working on buffers. - * - * @param inout_real buffer containing real part of the input and output data - * @param inout_imag buffer containing imaginary part of the input and output data - */ - void compute_backward(sycl::buffer& inout_real, sycl::buffer& inout_imag) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // For now we can just call out-of-place implementation. - // This might need to be changed once we implement support for large sizes that work in global memory. - compute_backward(inout_real, inout_imag, inout_real, inout_imag); - } - - /** - * Computes out-of-place forward FFT, working on buffers. - * - * @param in buffer containing input data - * @param out buffer containing output data - */ - void compute_forward(const sycl::buffer& in, sycl::buffer& out) { - PORTFFT_LOG_FUNCTION_ENTRY(); - dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::FORWARD); - } - - /** - * Computes out-of-place forward FFT, working on buffers. - * - * @param in_real buffer containing real part of the input data - * @param in_imag buffer containing imaginary part of the input data - * @param out_real buffer containing real part of the output data - * @param out_imag buffer containing imaginary part of the output data - */ - void compute_forward(const sycl::buffer& in_real, const sycl::buffer& in_imag, - sycl::buffer& out_real, sycl::buffer& out_imag) { - PORTFFT_LOG_FUNCTION_ENTRY(); - dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::FORWARD); - } - - /** - * Computes out-of-place forward FFT, working on buffers. - * - * @param in buffer containing input data - * @param out buffer containing output data - */ - void compute_forward(const sycl::buffer& /*in*/, sycl::buffer& /*out*/) { - PORTFFT_LOG_FUNCTION_ENTRY(); - throw unsupported_configuration("Real to complex FFTs not yet implemented."); - } - - /** - * Compute out of place backward FFT, working on buffers - * - * @param in buffer containing input data - * @param out buffer containing output data - */ - void compute_backward(const sycl::buffer& in, sycl::buffer& out) { - PORTFFT_LOG_FUNCTION_ENTRY(); - dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::BACKWARD); - } - - /** - * Compute out of place backward FFT, working on buffers - * - * @param in_real buffer containing real part of the input data - * @param in_imag buffer containing imaginary part of the input data - * @param out_real buffer containing real part of the output data - * @param out_imag buffer containing imaginary part of the output data - */ - void compute_backward(const sycl::buffer& in_real, const sycl::buffer& in_imag, - sycl::buffer& out_real, sycl::buffer& out_imag) { - PORTFFT_LOG_FUNCTION_ENTRY(); - dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::BACKWARD); - } - - /** - * Computes in-place forward FFT, working on USM memory. - * - * @param inout USM pointer to memory containing input and output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_forward(complex_type* inout, const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // For now we can just call out-of-place implementation. - // This might need to be changed once we implement support for large sizes that work in global memory. - return compute_forward(inout, inout, dependencies); - } - - /** - * Computes in-place forward FFT, working on USM memory. - * - * @param inout_real USM pointer to memory containing real part of the input and output data - * @param inout_imag USM pointer to memory containing imaginary part of the input and output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_forward(scalar_type* inout_real, scalar_type* inout_imag, - const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // For now we can just call out-of-place implementation. - // This might need to be changed once we implement support for large sizes that work in global memory. - return compute_forward(inout_real, inout_imag, inout_real, inout_imag, dependencies); - } - - /** - * Computes in-place forward FFT, working on USM memory. - * - * @param inout USM pointer to memory containing input and output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_forward(Scalar* inout, const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // For now we can just call out-of-place implementation. - // This might need to be changed once we implement support for large sizes that work in global memory. - return compute_forward(inout, reinterpret_cast(inout), dependencies); - } - - /** - * Computes in-place backward FFT, working on USM memory. - * - * @param inout USM pointer to memory containing input and output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_backward(complex_type* inout, const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return compute_backward(inout, inout, dependencies); - } - - /** - * Computes in-place backward FFT, working on USM memory. - * - * @param inout_real USM pointer to memory containing real part of the input and output data - * @param inout_imag USM pointer to memory containing imaginary part of the input and output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_backward(scalar_type* inout_real, scalar_type* inout_imag, - const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return compute_backward(inout_real, inout_imag, inout_real, inout_imag, dependencies); - } - - /** - * Computes out-of-place forward FFT, working on USM memory. - * - * @param in USM pointer to memory containing input data - * @param out USM pointer to memory containing output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_forward(const complex_type* in, complex_type* out, - const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::FORWARD, dependencies); - } - - /** - * Computes out-of-place forward FFT, working on USM memory. - * - * @param in_real USM pointer to memory containing real part of the input data - * @param in_imag USM pointer to memory containing imaginary part of the input data - * @param out_real USM pointer to memory containing real part of the output data - * @param out_imag USM pointer to memory containing imaginary part of the output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_forward(const scalar_type* in_real, const scalar_type* in_imag, scalar_type* out_real, - scalar_type* out_imag, const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::FORWARD, - dependencies); - } - - /** - * Computes out-of-place forward FFT, working on USM memory. - * - * @param in USM pointer to memory containing input data - * @param out USM pointer to memory containing output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_forward(const Scalar* /*in*/, complex_type* /*out*/, - const std::vector& /*dependencies*/ = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - throw unsupported_configuration("Real to complex FFTs not yet implemented."); - return {}; - } - - /** - * Computes out-of-place backward FFT, working on USM memory. - * - * @param in USM pointer to memory containing input data - * @param out USM pointer to memory containing output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_backward(const complex_type* in, complex_type* out, - const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::BACKWARD, - dependencies); - } - - /** - * Computes out-of-place backward FFT, working on USM memory. - * - * @param in_real USM pointer to memory containing real part of the input data - * @param in_imag USM pointer to memory containing imaginary part of the input data - * @param out_real USM pointer to memory containing real part of the output data - * @param out_imag USM pointer to memory containing imaginary part of the output data - * @param dependencies events that must complete before the computation - * @return sycl::event associated with this computation - */ - sycl::event compute_backward(const scalar_type* in_real, const scalar_type* in_imag, scalar_type* out_real, - scalar_type* out_imag, const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::BACKWARD, - dependencies); - } - - private: - /** - * Dispatches to the implementation for the appropriate direction. - * - * @tparam TIn Type of the input buffer or USM pointer - * @tparam TOut Type of the output buffer or USM pointer - * @param in buffer or USM pointer to memory containing input data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param out buffer or USM pointer to memory containing output data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param used_storage how components of a complex value are stored - either split or interleaved - * @param compute_direction direction of compute, forward / backward - * @param dependencies events that must complete before the computation - * @return sycl::event - */ - template - sycl::event dispatch_direction(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - complex_storage used_storage, direction compute_direction, - const std::vector& dependencies = {}) { - PORTFFT_LOG_FUNCTION_ENTRY(); -#ifndef PORTFFT_ENABLE_BUFFER_BUILDS - if constexpr (!std::is_pointer_v || !std::is_pointer_v) { - throw invalid_configuration("Buffer interface can not be called when buffer builds are disabled."); - } -#endif - if (used_storage != params.complex_storage) { - if (used_storage == complex_storage::SPLIT_COMPLEX) { - throw invalid_configuration( - "To use interface with split real and imaginary memory, descriptor.complex_storage must be set to " - "SPLIT_COMPLEX."); - } - throw invalid_configuration( - "To use interface with interleaved real and imaginary values, descriptor.complex_storage must be set to " - "INTERLEAVED_COMPLEX."); - } - if (compute_direction == direction::FORWARD) { - return dispatch_dimensions(in, out, in_imag, out_imag, dependencies, params.forward_strides, - params.backward_strides, params.forward_distance, params.backward_distance, - params.forward_offset, params.backward_offset, compute_direction); - } - return dispatch_dimensions(in, out, in_imag, out_imag, dependencies, params.backward_strides, - params.forward_strides, params.backward_distance, params.forward_distance, - params.backward_offset, params.forward_offset, compute_direction); - } - - /** - * Dispatches to the implementation for the appropriate number of dimensions. - * - * @tparam TIn Type of the input buffer or USM pointer - * @tparam TOut Type of the output buffer or USM pointer - * @param in buffer or USM pointer to memory containing input data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param out buffer or USM pointer to memory containing output data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param dependencies events that must complete before the computation - * @param input_strides strides between input elements for each dimension of one FFT - * @param output_strides strides between output elements for each dimension of one FFT - * @param input_distance distance between the starts of input data for two consecutive FFTs - * @param output_distance distance between the starts of output data for two consecutive FFTs - * @param input_offset offset into input allocation where the data for FFTs start - * @param output_offset offset into output allocation where the data for FFTs start - * @param compute_direction direction of compute, forward / backward - * @return sycl::event - */ - template - sycl::event dispatch_dimensions(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - const std::vector& dependencies, - const std::vector& input_strides, - const std::vector& output_strides, std::size_t input_distance, - std::size_t output_distance, std::size_t input_offset, std::size_t output_offset, - direction compute_direction) { - PORTFFT_LOG_FUNCTION_ENTRY(); - using TOutConst = std::conditional_t, const std::remove_pointer_t*, const TOut>; - std::size_t n_dimensions = params.lengths.size(); - std::size_t total_size = params.get_flattened_length(); - - const auto forward_layout = detail::get_layout(params, direction::FORWARD); - const auto backward_layout = detail::get_layout(params, direction::BACKWARD); - - // currently multi-dimensional transforms are implemented just for default (PACKED) data layout - const bool multi_dim_supported = - forward_layout == detail::layout::PACKED && backward_layout == detail::layout::PACKED; - if (n_dimensions != 1 && !multi_dim_supported) { - throw internal_error("Only default layout is supported for multi-dimensional transforms."); - } - - // product of sizes of all dimension inner relative to the one we are currently working on - std::size_t inner_size = 1; - // product of sizes of all dimension outer relative to the one we are currently working on - std::size_t outer_size = total_size / params.lengths.back(); - std::size_t input_stride_0 = input_strides.back(); - std::size_t output_stride_0 = output_strides.back(); - // distances are currently used just in the first dimension - these changes are meant for that one - // TODO fix this to support non-default layouts - if (input_stride_0 < input_distance) { // for example: batch interleaved input - input_distance = params.lengths.back(); - } - if (output_stride_0 < output_distance) { // for example: batch interleaved output - output_distance = params.lengths.back(); - } - - PORTFFT_LOG_TRACE("Dispatching the kernel for the last dimension"); - sycl::event previous_event = - dispatch_kernel_1d(in, out, in_imag, out_imag, dependencies, params.number_of_transforms * outer_size, - input_stride_0, output_stride_0, input_distance, output_distance, input_offset, - output_offset, dimensions.back(), compute_direction); - if (n_dimensions == 1) { - return previous_event; - } - std::vector previous_events{previous_event}; - std::vector next_events; - inner_size *= params.lengths.back(); - for (std::size_t i = n_dimensions - 2; i != static_cast(-1); i--) { - outer_size /= params.lengths[i]; - // TODO do everything from the next loop in a single kernel once we support more than one distance in the - // kernels. - std::size_t stride_between_kernels = inner_size * params.lengths[i]; - PORTFFT_LOG_TRACE("Dispatching the kernels for the dimension", i); - for (std::size_t j = 0; j < params.number_of_transforms * outer_size; j++) { - sycl::event e = dispatch_kernel_1d( - out, out, out_imag, out_imag, previous_events, inner_size, inner_size, inner_size, 1, 1, - output_offset + j * stride_between_kernels, output_offset + j * stride_between_kernels, dimensions[i], - compute_direction); - next_events.push_back(e); - } - inner_size *= params.lengths[i]; - std::swap(previous_events, next_events); - next_events.clear(); - } - return queue.single_task(previous_events, []() {}); // just to get an event that depends on all previous ones - } - - /** - * Dispatches the kernel with the first subgroup size that is supported by the device. - * - * @tparam TIn Type of the input buffer or USM pointer - * @tparam TOut Type of the output buffer or USM pointer - * @param in buffer or USM pointer to memory containing input data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param out buffer or USM pointer to memory containing output data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param dependencies events that must complete before the computation - * @param n_transforms number of FT transforms to do in one call - * @param input_stride stride between input elements of one FFT - * @param output_stride stride between output elements of one FFT - * @param input_distance distance between the starts of input data for two consecutive FFTs - * @param output_distance distance between the starts of output data for two consecutive FFTs - * @param input_offset offset into input allocation where the data for FFTs start - * @param output_offset offset into output allocation where the data for FFTs start - * @param dimension_data data for the dimension this call will work on - * @param compute_direction direction of compute, forward / backward - * @return sycl::event - */ - template - sycl::event dispatch_kernel_1d(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - const std::vector& dependencies, std::size_t n_transforms, - std::size_t input_stride, std::size_t output_stride, std::size_t input_distance, - std::size_t output_distance, std::size_t input_offset, std::size_t output_offset, - dimension_struct& dimension_data, direction compute_direction) { - PORTFFT_LOG_FUNCTION_ENTRY(); - return dispatch_kernel_1d_helper( - in, out, in_imag, out_imag, dependencies, n_transforms, input_stride, output_stride, input_distance, - output_distance, input_offset, output_offset, dimension_data, compute_direction); - } - - /** - * Helper for dispatching the kernel with the first subgroup size that is supported by the device. - * - * @tparam TIn Type of the input buffer or USM pointer - * @tparam TOut Type of the output buffer or USM pointer - * @tparam SubgroupSize first subgroup size - * @tparam OtherSGSizes other subgroup sizes - * @param in buffer or USM pointer to memory containing input data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param out buffer or USM pointer to memory containing output data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param dependencies events that must complete before the computation - * @param n_transforms number of FT transforms to do in one call - * @param input_stride stride between input elements of one FFT - * @param output_stride stride between output elements of one FFT - * @param input_distance distance between the starts of input data for two consecutive FFTs - * @param output_distance distance between the starts of output data for two consecutive FFTs - * @param input_offset offset into input allocation where the data for FFTs start - * @param output_offset offset into output allocation where the data for FFTs start - * @param dimension_data data for the dimension this call will work on - * @param compute_direction direction of compute, forward / backward - * @return sycl::event - */ - template - sycl::event dispatch_kernel_1d_helper(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - const std::vector& dependencies, std::size_t n_transforms, - std::size_t input_stride, std::size_t output_stride, std::size_t input_distance, - std::size_t output_distance, std::size_t input_offset, - std::size_t output_offset, dimension_struct& dimension_data, - direction compute_direction) { - PORTFFT_LOG_FUNCTION_ENTRY(); - if (SubgroupSize == dimension_data.used_sg_size) { - const bool input_packed = input_distance == dimension_data.length && input_stride == 1; - const bool output_packed = output_distance == dimension_data.length && output_stride == 1; - const bool input_batch_interleaved = input_distance == 1 && input_stride == n_transforms; - const bool output_batch_interleaved = output_distance == 1 && output_stride == n_transforms; - for (kernel_data_struct kernel_data : dimension_data.forward_kernels) { - std::size_t minimum_local_mem_required; - if (input_batch_interleaved) { - minimum_local_mem_required = num_scalars_in_local_mem( - kernel_data.level, kernel_data.length, SubgroupSize, kernel_data.factors, - kernel_data.num_sgs_per_wg) * - sizeof(Scalar); - PORTFFT_LOG_TRACE("Local mem required:", minimum_local_mem_required, "B. Available: ", local_memory_size, - "B."); - if (static_cast(minimum_local_mem_required) > local_memory_size) { - throw out_of_local_memory_error( - "Insufficient amount of local memory available: " + std::to_string(local_memory_size) + - "B. Required: " + std::to_string(minimum_local_mem_required) + "B."); - } - } - } - if (input_packed && output_packed) { - return run_kernel( - in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, - compute_direction); - } - if (input_batch_interleaved && output_packed && in != out) { - return run_kernel( - in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, - compute_direction); - } - if (input_packed && output_batch_interleaved && in != out) { - return run_kernel( - in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, - compute_direction); - } - if (input_batch_interleaved && output_batch_interleaved) { - return run_kernel( - in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, - compute_direction); - } - throw unsupported_configuration("Only PACKED or BATCH_INTERLEAVED transforms are supported"); - } - if constexpr (sizeof...(OtherSGSizes) == 0) { - throw invalid_configuration("None of the compiled subgroup sizes are supported by the device!"); - } else { - return dispatch_kernel_1d_helper( - in, out, in_imag, out_imag, dependencies, n_transforms, input_stride, output_stride, input_distance, - output_distance, input_offset, output_offset, dimension_data, compute_direction); - } - } - - /** - * Struct for dispatching `run_kernel()` call. - * - * @tparam LayoutIn Input Layout - * @tparam LayoutOut Output Layout - * @tparam SubgroupSize size of the subgroup - * @tparam TIn Type of the input USM pointer or buffer - * @tparam TOut Type of the output USM pointer or buffer - */ - template - struct run_kernel_struct { - // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class - template - struct inner { - static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, - TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, - std::size_t forward_offset, std::size_t backward_offset, - dimension_struct& dimension_data, direction compute_direction); - }; - }; - - /** - * Common interface to run the kernel called by compute_forward and compute_backward - * - * @tparam LayoutIn Input Layout - * @tparam LayoutOut Output Layout - * @tparam SubgroupSize size of the subgroup - * @tparam TIn Type of the input USM pointer or buffer - * @tparam TOut Type of the output USM pointer or buffer - * @param in buffer or USM pointer to memory containing input data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param out buffer or USM pointer to memory containing output data. Real part of input data if - * `descriptor.complex_storage` is split. - * @param in_imag buffer or USM pointer to memory containing imaginary part of the input data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param out_imag buffer or USM pointer to memory containing imaginary part of the output data. Ignored if - * `descriptor.complex_storage` is interleaved. - * @param dependencies events that must complete before the computation - * @param n_transforms number of FT transforms to do in one call - * @param input_offset offset into input allocation where the data for FFTs start - * @param output_offset offset into output allocation where the data for FFTs start - * @param dimension_data data for the dimension this call will work on - * @param compute_direction direction of fft, forward / backward - * @return sycl::event - */ - template - sycl::event run_kernel(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - const std::vector& dependencies, std::size_t n_transforms, - std::size_t input_offset, std::size_t output_offset, dimension_struct& dimension_data, - direction compute_direction) { - PORTFFT_LOG_FUNCTION_ENTRY(); - // mixing const and non-const inputs leads to hard-to-debug linking errors, as both use the same kernel name, but - // are called from different template instantiations. - static_assert(!std::is_pointer_v || std::is_const_v>, - "We do not differentiate kernel names between kernels with const and non-const USM inputs, so all " - "should be const"); - // kernel names currently assume both are the same. Mixing them without adding TOut to kernel names would lead to - // hard-to-debug linking errors - static_assert(std::is_pointer_v == std::is_pointer_v, - "Both input and output to the kernels should be the same - either buffers or USM"); - using TInReinterpret = decltype(detail::reinterpret(in)); - using TOutReinterpret = decltype(detail::reinterpret(out)); - std::size_t vec_multiplier = params.complex_storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; - return dispatch>( - dimension_data.level, detail::reinterpret(in), detail::reinterpret(out), - detail::reinterpret(in_imag), detail::reinterpret(out_imag), dependencies, - static_cast(n_transforms), static_cast(vec_multiplier * input_offset), - static_cast(vec_multiplier * output_offset), dimension_data, compute_direction); - } -}; +namespace portfft { /** * A descriptor containing FFT problem parameters. diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 2c476912..f0d009d3 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -104,12 +104,10 @@ inline IdxGlobal increment_twiddle_offset(detail::level level, Idx factor_size) return 0; } -} // namespace detail - template template -struct committed_descriptor::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor& desc, dimension_struct& /*dimension_data*/, +struct committed_descriptor_impl::calculate_twiddles_struct::inner { + static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& /*dimension_data*/, std::vector& kernels) { PORTFFT_LOG_FUNCTION_ENTRY(); std::vector factors_idx_global; @@ -256,8 +254,8 @@ struct committed_descriptor::calculate_twiddles_struct::inner template -struct committed_descriptor::set_spec_constants_struct::inner { - static void execute(committed_descriptor& /*desc*/, sycl::kernel_bundle& in_bundle, +struct committed_descriptor_impl::set_spec_constants_struct::inner { + static void execute(committed_descriptor_impl& /*desc*/, sycl::kernel_bundle& in_bundle, std::size_t length, const std::vector& factors, detail::level level, Idx factor_num, Idx num_factors) { PORTFFT_LOG_FUNCTION_ENTRY(); @@ -282,9 +280,9 @@ struct committed_descriptor::set_spec_constants_struct::inner template -struct committed_descriptor::num_scalars_in_local_mem_struct::inner { - static std::size_t execute(committed_descriptor& /*desc*/, std::size_t /*length*/, Idx /*used_sg_size*/, +struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { + static std::size_t execute(committed_descriptor_impl& /*desc*/, std::size_t /*length*/, Idx /*used_sg_size*/, const std::vector& /*factors*/, Idx& /*num_sgs_per_wg*/) { PORTFFT_LOG_FUNCTION_ENTRY(); // No work required as all work done in calculate_twiddles; @@ -295,10 +293,10 @@ struct committed_descriptor::num_scalars_in_local_mem_struct::in template template template -struct committed_descriptor::run_kernel_struct::inner { - static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - const std::vector& dependencies, IdxGlobal n_transforms, +struct committed_descriptor_impl::run_kernel_struct::inner { + static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, + TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data, direction compute_direction) { PORTFFT_LOG_FUNCTION_ENTRY(); @@ -404,6 +402,7 @@ struct committed_descriptor::run_kernel_struct template -struct committed_descriptor::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor& desc, dimension_struct& /*dimension_data*/, +struct committed_descriptor_impl::calculate_twiddles_struct::inner { + static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& /*dimension_data*/, std::vector& kernels) { PORTFFT_LOG_FUNCTION_ENTRY(); const auto& kernel_data = kernels.at(0); @@ -617,10 +616,10 @@ struct committed_descriptor::calculate_twiddles_struct::inner template template -struct committed_descriptor::run_kernel_struct::inner { - static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - const std::vector& dependencies, IdxGlobal n_transforms, +struct committed_descriptor_impl::run_kernel_struct::inner { + static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, + TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data, direction compute_direction) { PORTFFT_LOG_FUNCTION_ENTRY(); @@ -676,8 +675,8 @@ struct committed_descriptor::run_kernel_struct template -struct committed_descriptor::set_spec_constants_struct::inner { - static void execute(committed_descriptor& /*desc*/, sycl::kernel_bundle& in_bundle, +struct committed_descriptor_impl::set_spec_constants_struct::inner { + static void execute(committed_descriptor_impl& /*desc*/, sycl::kernel_bundle& in_bundle, std::size_t /*length*/, const std::vector& factors, detail::level /*level*/, Idx /*factor_num*/, Idx /*num_factors*/) { PORTFFT_LOG_FUNCTION_ENTRY(); @@ -690,9 +689,9 @@ struct committed_descriptor::set_spec_constants_struct::inner template -struct committed_descriptor::num_scalars_in_local_mem_struct::inner { - static std::size_t execute(committed_descriptor& desc, std::size_t length, Idx used_sg_size, +struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { + static std::size_t execute(committed_descriptor_impl& desc, std::size_t length, Idx used_sg_size, const std::vector& factors, Idx& num_sgs_per_wg) { PORTFFT_LOG_FUNCTION_ENTRY(); Idx dft_length = static_cast(length); @@ -718,6 +717,7 @@ struct committed_descriptor::num_scalars_in_local_mem_struct::in } }; +} // namespace detail } // namespace portfft #endif // PORTFFT_DISPATCHER_SUBGROUP_DISPATCHER_HPP diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index dc047550..a0a65dd6 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -272,15 +272,14 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima } global_data.log_message_global(__func__, "exited"); } -} // namespace detail template template template -struct committed_descriptor::run_kernel_struct::inner { - static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - const std::vector& dependencies, IdxGlobal n_transforms, +struct committed_descriptor_impl::run_kernel_struct::inner { + static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, + TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data, direction compute_direction) { PORTFFT_LOG_FUNCTION_ENTRY(); @@ -343,8 +342,8 @@ struct committed_descriptor::run_kernel_struct template -struct committed_descriptor::set_spec_constants_struct::inner { - static void execute(committed_descriptor& /*desc*/, sycl::kernel_bundle& in_bundle, +struct committed_descriptor_impl::set_spec_constants_struct::inner { + static void execute(committed_descriptor_impl& /*desc*/, sycl::kernel_bundle& in_bundle, std::size_t length, const std::vector& /*factors*/, detail::level /*level*/, Idx /*factor_num*/, Idx /*num_factors*/) { PORTFFT_LOG_FUNCTION_ENTRY(); @@ -356,9 +355,9 @@ struct committed_descriptor::set_spec_constants_struct::inner template -struct committed_descriptor::num_scalars_in_local_mem_struct::inner { - static std::size_t execute(committed_descriptor& /*desc*/, std::size_t length, Idx used_sg_size, +struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { + static std::size_t execute(committed_descriptor_impl& /*desc*/, std::size_t length, Idx used_sg_size, const std::vector& factors, Idx& /*num_sgs_per_wg*/) { PORTFFT_LOG_FUNCTION_ENTRY(); std::size_t n = static_cast(factors[0]) * static_cast(factors[1]); @@ -374,8 +373,8 @@ struct committed_descriptor::num_scalars_in_local_mem_struct::in template template -struct committed_descriptor::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor& desc, dimension_struct& /*dimension_data*/, +struct committed_descriptor_impl::calculate_twiddles_struct::inner { + static Scalar* execute(committed_descriptor_impl& desc, dimension_struct& /*dimension_data*/, std::vector& kernels) { PORTFFT_LOG_FUNCTION_ENTRY(); const auto& kernel_data = kernels.at(0); @@ -435,6 +434,7 @@ struct committed_descriptor::calculate_twiddles_struct::inner template template -struct committed_descriptor::run_kernel_struct::inner { - static sycl::event execute(committed_descriptor& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, - const std::vector& dependencies, IdxGlobal n_transforms, +struct committed_descriptor_impl::run_kernel_struct::inner { + static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, + TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data, direction compute_direction) { PORTFFT_LOG_FUNCTION_ENTRY(); @@ -333,8 +332,8 @@ struct committed_descriptor::run_kernel_struct template -struct committed_descriptor::set_spec_constants_struct::inner { - static void execute(committed_descriptor& /*desc*/, sycl::kernel_bundle& in_bundle, +struct committed_descriptor_impl::set_spec_constants_struct::inner { + static void execute(committed_descriptor_impl& /*desc*/, sycl::kernel_bundle& in_bundle, std::size_t length, const std::vector& /*factors*/, detail::level /*level*/, Idx /*factor_num*/, Idx /*num_factors*/) { PORTFFT_LOG_FUNCTION_ENTRY(); @@ -346,9 +345,9 @@ struct committed_descriptor::set_spec_constants_struct::inner template -struct committed_descriptor::num_scalars_in_local_mem_struct::inner { - static std::size_t execute(committed_descriptor& desc, std::size_t length, Idx used_sg_size, +struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { + static std::size_t execute(committed_descriptor_impl& desc, std::size_t length, Idx used_sg_size, const std::vector& /*factors*/, Idx& num_sgs_per_wg) { PORTFFT_LOG_FUNCTION_ENTRY(); Idx num_scalars_per_sg = detail::pad_local(2 * static_cast(length) * used_sg_size, 1); @@ -361,14 +360,15 @@ struct committed_descriptor::num_scalars_in_local_mem_struct::in template template -struct committed_descriptor::calculate_twiddles_struct::inner { - static Scalar* execute(committed_descriptor& /*desc*/, dimension_struct& /*dimension_data*/, +struct committed_descriptor_impl::calculate_twiddles_struct::inner { + static Scalar* execute(committed_descriptor_impl& /*desc*/, dimension_struct& /*dimension_data*/, std::vector& /*kernels*/) { PORTFFT_LOG_FUNCTION_ENTRY(); return nullptr; } }; +} // namespace detail } // namespace portfft #endif // PORTFFT_DISPATCHER_WORKITEM_DISPATCHER_HPP