@@ -134,21 +134,22 @@ void _mkl_dft(
134
134
checked_signal_sizes.begin () + 1 , checked_signal_sizes.end ());
135
135
136
136
std::shared_ptr<dft_config_t > desc_config (new dft_config_t );
137
- desc_config->set_value (config_param::PLACEMENT, DFTI_NOT_INPLACE);
138
- desc_config->set_value (config_param::NUMBER_OF_TRANSFORMS, batch);
137
+ // desc_config->set_value(config_param::PLACEMENT, static_cast<std::underlying_type_t<config_value>>(config_value::NOT_INPLACE));
138
+ // desc_config->set_value(config_param::PLACEMENT, static_cast<int64_t>(config_value::NOT_INPLACE));
139
+ // desc_config->set_value(config_param::NUMBER_OF_TRANSFORMS, batch);
139
140
140
141
auto istrides = input.strides ();
141
142
auto ostrides = output.strides ();
142
143
int64_t idist = istrides[0 ];
143
144
int64_t odist = ostrides[0 ];
144
145
145
- if (!inverse) {
146
- desc_config->set_value (config_param::FWD_DISTANCE, idist);
147
- desc_config->set_value (config_param::BWD_DISTANCE, odist);
148
- } else {
149
- desc_config->set_value (config_param::FWD_DISTANCE, odist);
150
- desc_config->set_value (config_param::BWD_DISTANCE, idist);
151
- }
146
+ // if (!inverse) {
147
+ // desc_config->set_value(config_param::FWD_DISTANCE, idist);
148
+ // desc_config->set_value(config_param::BWD_DISTANCE, odist);
149
+ // } else {
150
+ // desc_config->set_value(config_param::FWD_DISTANCE, odist);
151
+ // desc_config->set_value(config_param::BWD_DISTANCE, idist);
152
+ // }
152
153
153
154
std::vector<int64_t > fwd_strides (1 + signal_ndim, 0 ),
154
155
bwd_strides (1 + signal_ndim, 0 );
@@ -163,36 +164,64 @@ void _mkl_dft(
163
164
}
164
165
}
165
166
166
- desc_config->set_strides (fwd_strides, bwd_strides);
167
+ // desc_config->set_strides(fwd_strides, bwd_strides);
168
+
169
+ // if (!complex_input || !complex_output) {
170
+ // desc_config->set_value(
171
+ // config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
172
+ // }
173
+
174
+ auto desc = descriptor<prec, signal_type>(mkl_signal_sizes);
175
+ desc.set_value (config_param::PLACEMENT, config_value::NOT_INPLACE);
176
+ desc.set_value (config_param::NUMBER_OF_TRANSFORMS, batch);
177
+
178
+ if (!inverse) {
179
+ desc.set_value (config_param::FWD_DISTANCE, idist);
180
+ desc.set_value (config_param::BWD_DISTANCE, odist);
181
+ } else {
182
+ desc.set_value (config_param::FWD_DISTANCE, odist);
183
+ desc.set_value (config_param::BWD_DISTANCE, idist);
184
+ }
185
+
186
+ if (!fwd_strides.empty ()) {
187
+ desc.set_value (config_param::FWD_STRIDES, fwd_strides.data ());
188
+ }
189
+ if (!bwd_strides.empty ()) {
190
+ desc.set_value (config_param::BWD_STRIDES, bwd_strides.data ());
191
+ }
167
192
168
193
if (!complex_input || !complex_output) {
169
- desc_config-> set_value (
194
+ desc. set_value (
170
195
config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
171
196
}
172
197
173
- auto desc =
174
- dft_desc_t <prec, signal_type>(queue, mkl_signal_sizes, desc_config);
198
+ desc.set_value (
199
+ oneapi::mkl::dft::config_param::WORKSPACE,
200
+ oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
201
+ desc.commit (queue);
202
+ // auto desc =
203
+ // dft_desc_t<prec, signal_type>(queue, mkl_signal_sizes, desc_config);
175
204
176
205
// Obtain the size of workspace required after commit.
177
206
size_t workspaceSizeBytes = 0 ;
178
- desc.raw (). get_value (
207
+ desc.get_value (
179
208
oneapi::mkl::dft::config_param::WORKSPACE_BYTES, &workspaceSizeBytes);
180
209
181
210
// Allocate USM workspace and provide it to the descriptor.
182
211
Tensor workspaceBuf = at::empty (
183
212
{(long )(workspaceSizeBytes / sizeof (double ))},
184
213
input.options ().dtype (at::kDouble ),
185
214
c10::nullopt);
186
- desc.raw (). set_workspace ((double *)workspaceBuf.data_ptr ());
215
+ desc.set_workspace ((double *)workspaceBuf.data_ptr ());
187
216
188
217
auto in_data = (scalar_t *)input.data_ptr ();
189
218
auto out_data = (scalar_t *)output.data_ptr ();
190
219
191
220
sycl::event event;
192
221
if (!inverse) {
193
- event = compute_forward (desc. raw () , in_data, out_data);
222
+ event = compute_forward (desc, in_data, out_data);
194
223
} else {
195
- event = compute_backward (desc. raw () , in_data, out_data);
224
+ event = compute_backward (desc, in_data, out_data);
196
225
}
197
226
event.wait_and_throw ();
198
227
queue.throw_asynchronous ();
0 commit comments