Skip to content

Commit 14892b8

Browse files
authored
Refactor descriptor (#137)
* logging * use env variables * added warnings * format * addressed review comments * addressed more comments * format * split commited_descriptor class * one more trace * split files * formatq * address Hugh's comments * format
1 parent 9732dd1 commit 14892b8

File tree

8 files changed

+1671
-1584
lines changed

8 files changed

+1671
-1584
lines changed

src/portfft/committed_descriptor.hpp

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
/***************************************************************************
2+
*
3+
* Copyright (C) Codeplay Software Ltd.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
* Codeplay's portFFT
18+
*
19+
**************************************************************************/
20+
21+
#ifndef PORTFFT_COMMITTED_DESCRIPTOR_HPP
22+
#define PORTFFT_COMMITTED_DESCRIPTOR_HPP
23+
24+
#include <sycl/sycl.hpp>
25+
26+
#include <complex>
27+
#include <vector>
28+
29+
#include "enums.hpp"
30+
31+
#include "committed_descriptor_impl.hpp"
32+
33+
namespace portfft {
34+
35+
template <typename Scalar, domain Domain>
36+
class committed_descriptor : private detail::committed_descriptor_impl<Scalar, Domain> {
37+
public:
38+
/**
39+
* Alias for `Scalar`.
40+
*/
41+
using scalar_type = Scalar;
42+
43+
/**
44+
* std::complex with `Scalar` scalar.
45+
*/
46+
using complex_type = std::complex<Scalar>;
47+
48+
// Use base class constructor
49+
using detail::committed_descriptor_impl<Scalar, Domain>::committed_descriptor_impl;
50+
// Use base class function without this->
51+
using detail::committed_descriptor_impl<Scalar, Domain>::dispatch_direction;
52+
53+
/**
54+
* Computes in-place forward FFT, working on a buffer.
55+
*
56+
* @param inout buffer containing input and output data
57+
*/
58+
void compute_forward(sycl::buffer<complex_type, 1>& inout) {
59+
PORTFFT_LOG_FUNCTION_ENTRY();
60+
// For now we can just call out-of-place implementation.
61+
// This might need to be changed once we implement support for large sizes that work in global memory.
62+
compute_forward(inout, inout);
63+
}
64+
65+
/**
66+
* Computes in-place forward FFT, working on buffers.
67+
*
68+
* @param inout_real buffer containing real part of the input and output data
69+
* @param inout_imag buffer containing imaginary part of the input and output data
70+
*/
71+
void compute_forward(sycl::buffer<scalar_type, 1>& inout_real, sycl::buffer<scalar_type, 1>& inout_imag) {
72+
PORTFFT_LOG_FUNCTION_ENTRY();
73+
// For now we can just call out-of-place implementation.
74+
// This might need to be changed once we implement support for large sizes that work in global memory.
75+
compute_forward(inout_real, inout_imag, inout_real, inout_imag);
76+
}
77+
78+
/**
79+
* Computes in-place backward FFT, working on a buffer.
80+
*
81+
* @param inout buffer containing input and output data
82+
*/
83+
void compute_backward(sycl::buffer<complex_type, 1>& inout) {
84+
PORTFFT_LOG_FUNCTION_ENTRY();
85+
// For now we can just call out-of-place implementation.
86+
// This might need to be changed once we implement support for large sizes that work in global memory.
87+
compute_backward(inout, inout);
88+
}
89+
90+
/**
91+
* Computes in-place backward FFT, working on buffers.
92+
*
93+
* @param inout_real buffer containing real part of the input and output data
94+
* @param inout_imag buffer containing imaginary part of the input and output data
95+
*/
96+
void compute_backward(sycl::buffer<scalar_type, 1>& inout_real, sycl::buffer<scalar_type, 1>& inout_imag) {
97+
PORTFFT_LOG_FUNCTION_ENTRY();
98+
// For now we can just call out-of-place implementation.
99+
// This might need to be changed once we implement support for large sizes that work in global memory.
100+
compute_backward(inout_real, inout_imag, inout_real, inout_imag);
101+
}
102+
103+
/**
104+
* Computes out-of-place forward FFT, working on buffers.
105+
*
106+
* @param in buffer containing input data
107+
* @param out buffer containing output data
108+
*/
109+
void compute_forward(const sycl::buffer<complex_type, 1>& in, sycl::buffer<complex_type, 1>& out) {
110+
PORTFFT_LOG_FUNCTION_ENTRY();
111+
dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::FORWARD);
112+
}
113+
114+
/**
115+
* Computes out-of-place forward FFT, working on buffers.
116+
*
117+
* @param in_real buffer containing real part of the input data
118+
* @param in_imag buffer containing imaginary part of the input data
119+
* @param out_real buffer containing real part of the output data
120+
* @param out_imag buffer containing imaginary part of the output data
121+
*/
122+
void compute_forward(const sycl::buffer<scalar_type, 1>& in_real, const sycl::buffer<scalar_type, 1>& in_imag,
123+
sycl::buffer<scalar_type, 1>& out_real, sycl::buffer<scalar_type, 1>& out_imag) {
124+
PORTFFT_LOG_FUNCTION_ENTRY();
125+
dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::FORWARD);
126+
}
127+
128+
/**
129+
* Computes out-of-place forward FFT, working on buffers.
130+
*
131+
* @param in buffer containing input data
132+
* @param out buffer containing output data
133+
*/
134+
void compute_forward(const sycl::buffer<Scalar, 1>& /*in*/, sycl::buffer<complex_type, 1>& /*out*/) {
135+
PORTFFT_LOG_FUNCTION_ENTRY();
136+
throw unsupported_configuration("Real to complex FFTs not yet implemented.");
137+
}
138+
139+
/**
140+
* Compute out of place backward FFT, working on buffers
141+
*
142+
* @param in buffer containing input data
143+
* @param out buffer containing output data
144+
*/
145+
void compute_backward(const sycl::buffer<complex_type, 1>& in, sycl::buffer<complex_type, 1>& out) {
146+
PORTFFT_LOG_FUNCTION_ENTRY();
147+
dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::BACKWARD);
148+
}
149+
150+
/**
151+
* Compute out of place backward FFT, working on buffers
152+
*
153+
* @param in_real buffer containing real part of the input data
154+
* @param in_imag buffer containing imaginary part of the input data
155+
* @param out_real buffer containing real part of the output data
156+
* @param out_imag buffer containing imaginary part of the output data
157+
*/
158+
void compute_backward(const sycl::buffer<scalar_type, 1>& in_real, const sycl::buffer<scalar_type, 1>& in_imag,
159+
sycl::buffer<scalar_type, 1>& out_real, sycl::buffer<scalar_type, 1>& out_imag) {
160+
PORTFFT_LOG_FUNCTION_ENTRY();
161+
dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::BACKWARD);
162+
}
163+
164+
/**
165+
* Computes in-place forward FFT, working on USM memory.
166+
*
167+
* @param inout USM pointer to memory containing input and output data
168+
* @param dependencies events that must complete before the computation
169+
* @return sycl::event associated with this computation
170+
*/
171+
sycl::event compute_forward(complex_type* inout, const std::vector<sycl::event>& dependencies = {}) {
172+
PORTFFT_LOG_FUNCTION_ENTRY();
173+
// For now we can just call out-of-place implementation.
174+
// This might need to be changed once we implement support for large sizes that work in global memory.
175+
return compute_forward(inout, inout, dependencies);
176+
}
177+
178+
/**
179+
* Computes in-place forward FFT, working on USM memory.
180+
*
181+
* @param inout_real USM pointer to memory containing real part of the input and output data
182+
* @param inout_imag USM pointer to memory containing imaginary part of the input and output data
183+
* @param dependencies events that must complete before the computation
184+
* @return sycl::event associated with this computation
185+
*/
186+
sycl::event compute_forward(scalar_type* inout_real, scalar_type* inout_imag,
187+
const std::vector<sycl::event>& dependencies = {}) {
188+
PORTFFT_LOG_FUNCTION_ENTRY();
189+
// For now we can just call out-of-place implementation.
190+
// This might need to be changed once we implement support for large sizes that work in global memory.
191+
return compute_forward(inout_real, inout_imag, inout_real, inout_imag, dependencies);
192+
}
193+
194+
/**
195+
* Computes in-place forward FFT, working on USM memory.
196+
*
197+
* @param inout USM pointer to memory containing input and output data
198+
* @param dependencies events that must complete before the computation
199+
* @return sycl::event associated with this computation
200+
*/
201+
sycl::event compute_forward(Scalar* inout, const std::vector<sycl::event>& dependencies = {}) {
202+
PORTFFT_LOG_FUNCTION_ENTRY();
203+
// For now we can just call out-of-place implementation.
204+
// This might need to be changed once we implement support for large sizes that work in global memory.
205+
return compute_forward(inout, reinterpret_cast<complex_type*>(inout), dependencies);
206+
}
207+
208+
/**
209+
* Computes in-place backward FFT, working on USM memory.
210+
*
211+
* @param inout USM pointer to memory containing input and output data
212+
* @param dependencies events that must complete before the computation
213+
* @return sycl::event associated with this computation
214+
*/
215+
sycl::event compute_backward(complex_type* inout, const std::vector<sycl::event>& dependencies = {}) {
216+
PORTFFT_LOG_FUNCTION_ENTRY();
217+
return compute_backward(inout, inout, dependencies);
218+
}
219+
220+
/**
221+
* Computes in-place backward FFT, working on USM memory.
222+
*
223+
* @param inout_real USM pointer to memory containing real part of the input and output data
224+
* @param inout_imag USM pointer to memory containing imaginary part of the input and output data
225+
* @param dependencies events that must complete before the computation
226+
* @return sycl::event associated with this computation
227+
*/
228+
sycl::event compute_backward(scalar_type* inout_real, scalar_type* inout_imag,
229+
const std::vector<sycl::event>& dependencies = {}) {
230+
PORTFFT_LOG_FUNCTION_ENTRY();
231+
return compute_backward(inout_real, inout_imag, inout_real, inout_imag, dependencies);
232+
}
233+
234+
/**
235+
* Computes out-of-place forward FFT, working on USM memory.
236+
*
237+
* @param in USM pointer to memory containing input data
238+
* @param out USM pointer to memory containing output data
239+
* @param dependencies events that must complete before the computation
240+
* @return sycl::event associated with this computation
241+
*/
242+
sycl::event compute_forward(const complex_type* in, complex_type* out,
243+
const std::vector<sycl::event>& dependencies = {}) {
244+
PORTFFT_LOG_FUNCTION_ENTRY();
245+
return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::FORWARD, dependencies);
246+
}
247+
248+
/**
249+
* Computes out-of-place forward FFT, working on USM memory.
250+
*
251+
* @param in_real USM pointer to memory containing real part of the input data
252+
* @param in_imag USM pointer to memory containing imaginary part of the input data
253+
* @param out_real USM pointer to memory containing real part of the output data
254+
* @param out_imag USM pointer to memory containing imaginary part of the output data
255+
* @param dependencies events that must complete before the computation
256+
* @return sycl::event associated with this computation
257+
*/
258+
sycl::event compute_forward(const scalar_type* in_real, const scalar_type* in_imag, scalar_type* out_real,
259+
scalar_type* out_imag, const std::vector<sycl::event>& dependencies = {}) {
260+
PORTFFT_LOG_FUNCTION_ENTRY();
261+
return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::FORWARD,
262+
dependencies);
263+
}
264+
265+
/**
266+
* Computes out-of-place forward FFT, working on USM memory.
267+
*
268+
* @param in USM pointer to memory containing input data
269+
* @param out USM pointer to memory containing output data
270+
* @param dependencies events that must complete before the computation
271+
* @return sycl::event associated with this computation
272+
*/
273+
sycl::event compute_forward(const Scalar* /*in*/, complex_type* /*out*/,
274+
const std::vector<sycl::event>& /*dependencies*/ = {}) {
275+
PORTFFT_LOG_FUNCTION_ENTRY();
276+
throw unsupported_configuration("Real to complex FFTs not yet implemented.");
277+
return {};
278+
}
279+
280+
/**
281+
* Computes out-of-place backward FFT, working on USM memory.
282+
*
283+
* @param in USM pointer to memory containing input data
284+
* @param out USM pointer to memory containing output data
285+
* @param dependencies events that must complete before the computation
286+
* @return sycl::event associated with this computation
287+
*/
288+
sycl::event compute_backward(const complex_type* in, complex_type* out,
289+
const std::vector<sycl::event>& dependencies = {}) {
290+
PORTFFT_LOG_FUNCTION_ENTRY();
291+
return dispatch_direction(in, out, in, out, complex_storage::INTERLEAVED_COMPLEX, direction::BACKWARD,
292+
dependencies);
293+
}
294+
295+
/**
296+
* Computes out-of-place backward FFT, working on USM memory.
297+
*
298+
* @param in_real USM pointer to memory containing real part of the input data
299+
* @param in_imag USM pointer to memory containing imaginary part of the input data
300+
* @param out_real USM pointer to memory containing real part of the output data
301+
* @param out_imag USM pointer to memory containing imaginary part of the output data
302+
* @param dependencies events that must complete before the computation
303+
* @return sycl::event associated with this computation
304+
*/
305+
sycl::event compute_backward(const scalar_type* in_real, const scalar_type* in_imag, scalar_type* out_real,
306+
scalar_type* out_imag, const std::vector<sycl::event>& dependencies = {}) {
307+
PORTFFT_LOG_FUNCTION_ENTRY();
308+
return dispatch_direction(in_real, out_real, in_imag, out_imag, complex_storage::SPLIT_COMPLEX, direction::BACKWARD,
309+
dependencies);
310+
}
311+
};
312+
313+
} // namespace portfft
314+
315+
#endif

0 commit comments

Comments
 (0)