Skip to content

Commit

Permalink
Update FFT integration code
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng committed Dec 20, 2024
1 parent 09a64d9 commit d7e546b
Showing 1 changed file with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions src/ATen/native/xpu/mkl/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,22 @@ void _mkl_dft(
checked_signal_sizes.begin() + 1, checked_signal_sizes.end());

std::shared_ptr<dft_config_t> desc_config(new dft_config_t);
desc_config->set_value(config_param::PLACEMENT, DFTI_NOT_INPLACE);
desc_config->set_value(config_param::NUMBER_OF_TRANSFORMS, batch);
//desc_config->set_value(config_param::PLACEMENT, static_cast<std::underlying_type_t<config_value>>(config_value::NOT_INPLACE));
//desc_config->set_value(config_param::PLACEMENT, static_cast<int64_t>(config_value::NOT_INPLACE));
//desc_config->set_value(config_param::NUMBER_OF_TRANSFORMS, batch);

auto istrides = input.strides();
auto ostrides = output.strides();
int64_t idist = istrides[0];
int64_t odist = ostrides[0];

if (!inverse) {
desc_config->set_value(config_param::FWD_DISTANCE, idist);
desc_config->set_value(config_param::BWD_DISTANCE, odist);
} else {
desc_config->set_value(config_param::FWD_DISTANCE, odist);
desc_config->set_value(config_param::BWD_DISTANCE, idist);
}
//if (!inverse) {
// desc_config->set_value(config_param::FWD_DISTANCE, idist);
// desc_config->set_value(config_param::BWD_DISTANCE, odist);
//} else {
// desc_config->set_value(config_param::FWD_DISTANCE, odist);
// desc_config->set_value(config_param::BWD_DISTANCE, idist);
//}

std::vector<int64_t> fwd_strides(1 + signal_ndim, 0),
bwd_strides(1 + signal_ndim, 0);
Expand All @@ -163,36 +164,64 @@ void _mkl_dft(
}
}

desc_config->set_strides(fwd_strides, bwd_strides);
//desc_config->set_strides(fwd_strides, bwd_strides);

//if (!complex_input || !complex_output) {
// desc_config->set_value(
// config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
//}

auto desc = descriptor<prec, signal_type>(mkl_signal_sizes);
desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE);
desc.set_value(config_param::NUMBER_OF_TRANSFORMS, batch);

if (!inverse) {
desc.set_value(config_param::FWD_DISTANCE, idist);
desc.set_value(config_param::BWD_DISTANCE, odist);
} else {
desc.set_value(config_param::FWD_DISTANCE, odist);
desc.set_value(config_param::BWD_DISTANCE, idist);
}

if (!fwd_strides.empty()) {
desc.set_value(config_param::FWD_STRIDES, fwd_strides.data());
}
if (!bwd_strides.empty()) {
desc.set_value(config_param::BWD_STRIDES, bwd_strides.data());
}

if (!complex_input || !complex_output) {
desc_config->set_value(
desc.set_value(
config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
}

auto desc =
dft_desc_t<prec, signal_type>(queue, mkl_signal_sizes, desc_config);
desc.set_value(
oneapi::mkl::dft::config_param::WORKSPACE,
oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);
desc.commit(queue);
//auto desc =
// dft_desc_t<prec, signal_type>(queue, mkl_signal_sizes, desc_config);

// Obtain the size of workspace required after commit.
size_t workspaceSizeBytes = 0;
desc.raw().get_value(
desc.get_value(
oneapi::mkl::dft::config_param::WORKSPACE_BYTES, &workspaceSizeBytes);

// Allocate USM workspace and provide it to the descriptor.
Tensor workspaceBuf = at::empty(
{(long)(workspaceSizeBytes / sizeof(double))},
input.options().dtype(at::kDouble),
c10::nullopt);
desc.raw().set_workspace((double*)workspaceBuf.data_ptr());
desc.set_workspace((double*)workspaceBuf.data_ptr());

auto in_data = (scalar_t*)input.data_ptr();
auto out_data = (scalar_t*)output.data_ptr();

sycl::event event;
if (!inverse) {
event = compute_forward(desc.raw(), in_data, out_data);
event = compute_forward(desc, in_data, out_data);
} else {
event = compute_backward(desc.raw(), in_data, out_data);
event = compute_backward(desc, in_data, out_data);
}
event.wait_and_throw();
queue.throw_asynchronous();
Expand Down

0 comments on commit d7e546b

Please sign in to comment.