Skip to content

Commit d7e546b

Browse files
committed
Update FFT integration code
1 parent 09a64d9 commit d7e546b

File tree

1 file changed

+46
-17
lines changed

1 file changed

+46
-17
lines changed

src/ATen/native/xpu/mkl/SpectralOps.cpp

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,22 @@ void _mkl_dft(
134134
checked_signal_sizes.begin() + 1, checked_signal_sizes.end());
135135

136136
std::shared_ptr<dft_config_t> desc_config(new dft_config_t);
137-
desc_config->set_value(config_param::PLACEMENT, DFTI_NOT_INPLACE);
138-
desc_config->set_value(config_param::NUMBER_OF_TRANSFORMS, batch);
137+
//desc_config->set_value(config_param::PLACEMENT, static_cast<std::underlying_type_t<config_value>>(config_value::NOT_INPLACE));
138+
//desc_config->set_value(config_param::PLACEMENT, static_cast<int64_t>(config_value::NOT_INPLACE));
139+
//desc_config->set_value(config_param::NUMBER_OF_TRANSFORMS, batch);
139140

140141
auto istrides = input.strides();
141142
auto ostrides = output.strides();
142143
int64_t idist = istrides[0];
143144
int64_t odist = ostrides[0];
144145

145-
if (!inverse) {
146-
desc_config->set_value(config_param::FWD_DISTANCE, idist);
147-
desc_config->set_value(config_param::BWD_DISTANCE, odist);
148-
} else {
149-
desc_config->set_value(config_param::FWD_DISTANCE, odist);
150-
desc_config->set_value(config_param::BWD_DISTANCE, idist);
151-
}
146+
//if (!inverse) {
147+
// desc_config->set_value(config_param::FWD_DISTANCE, idist);
148+
// desc_config->set_value(config_param::BWD_DISTANCE, odist);
149+
//} else {
150+
// desc_config->set_value(config_param::FWD_DISTANCE, odist);
151+
// desc_config->set_value(config_param::BWD_DISTANCE, idist);
152+
//}
152153

153154
std::vector<int64_t> fwd_strides(1 + signal_ndim, 0),
154155
bwd_strides(1 + signal_ndim, 0);
@@ -163,36 +164,64 @@ void _mkl_dft(
163164
}
164165
}
165166

166-
desc_config->set_strides(fwd_strides, bwd_strides);
167+
//desc_config->set_strides(fwd_strides, bwd_strides);
168+
169+
//if (!complex_input || !complex_output) {
170+
// desc_config->set_value(
171+
// config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
172+
//}
173+
174+
auto desc = descriptor<prec, signal_type>(mkl_signal_sizes);
175+
desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE);
176+
desc.set_value(config_param::NUMBER_OF_TRANSFORMS, batch);
177+
178+
if (!inverse) {
179+
desc.set_value(config_param::FWD_DISTANCE, idist);
180+
desc.set_value(config_param::BWD_DISTANCE, odist);
181+
} else {
182+
desc.set_value(config_param::FWD_DISTANCE, odist);
183+
desc.set_value(config_param::BWD_DISTANCE, idist);
184+
}
185+
186+
if (!fwd_strides.empty()) {
187+
desc.set_value(config_param::FWD_STRIDES, fwd_strides.data());
188+
}
189+
if (!bwd_strides.empty()) {
190+
desc.set_value(config_param::BWD_STRIDES, bwd_strides.data());
191+
}
167192

168193
if (!complex_input || !complex_output) {
169-
desc_config->set_value(
194+
desc.set_value(
170195
config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
171196
}
172197

173-
auto desc =
174-
dft_desc_t<prec, signal_type>(queue, mkl_signal_sizes, desc_config);
198+
desc.set_value(
199+
oneapi::mkl::dft::config_param::WORKSPACE,
200+
oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
201+
desc.commit(queue);
202+
//auto desc =
203+
// dft_desc_t<prec, signal_type>(queue, mkl_signal_sizes, desc_config);
175204

176205
// Obtain the size of workspace required after commit.
177206
size_t workspaceSizeBytes = 0;
178-
desc.raw().get_value(
207+
desc.get_value(
179208
oneapi::mkl::dft::config_param::WORKSPACE_BYTES, &workspaceSizeBytes);
180209

181210
// Allocate USM workspace and provide it to the descriptor.
182211
Tensor workspaceBuf = at::empty(
183212
{(long)(workspaceSizeBytes / sizeof(double))},
184213
input.options().dtype(at::kDouble),
185214
c10::nullopt);
186-
desc.raw().set_workspace((double*)workspaceBuf.data_ptr());
215+
desc.set_workspace((double*)workspaceBuf.data_ptr());
187216

188217
auto in_data = (scalar_t*)input.data_ptr();
189218
auto out_data = (scalar_t*)output.data_ptr();
190219

191220
sycl::event event;
192221
if (!inverse) {
193-
event = compute_forward(desc.raw(), in_data, out_data);
222+
event = compute_forward(desc, in_data, out_data);
194223
} else {
195-
event = compute_backward(desc.raw(), in_data, out_data);
224+
event = compute_backward(desc, in_data, out_data);
196225
}
197226
event.wait_and_throw();
198227
queue.throw_asynchronous();

0 commit comments

Comments
 (0)