Skip to content

Commit

Permalink
Init MKL for Pytorch XPU and enable fft_c2c
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng committed Aug 26, 2024
1 parent 12d7ee7 commit 8f8096a
Show file tree
Hide file tree
Showing 13 changed files with 760 additions and 2 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ set(TORCH_XPU_OPS_ROOT ${PROJECT_SOURCE_DIR})
list(APPEND CMAKE_MODULE_PATH ${TORCH_XPU_OPS_ROOT}/cmake/Modules)

include(${TORCH_XPU_OPS_ROOT}/cmake/SYCL.cmake)
include(${TORCH_XPU_OPS_ROOT}/cmake/ONEMKL.cmake)
include(${TORCH_XPU_OPS_ROOT}/cmake/BuildFlags.cmake)

if(BUILD_TEST)
Expand Down
66 changes: 66 additions & 0 deletions cmake/Modules/FindONEMKL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
set(ONEMKL_FOUND FALSE)

set(ONEMKL_LIBRARIES)

# In order to be compatible with various situations of Pytorch development
# bundle setup, ENV{MKLROOT} and SYCL_ROOT will be checked sequentially to get
# the root directory of oneMKL.
if(DEFINED ENV{MKLROOT})
# Directly get the root directory of oneMKL if ENV{MKLROOT} exists.
set(ONEMKL_ROOT $ENV{MKLROOT})
elseif(SYCL_FOUND)
# oneMKL configuration may not be imported into the build system. Get the root
# directory of oneMKL based on the root directory of compiler relatively.
get_filename_component(ONEMKL_ROOT "${SYCL_ROOT}/../../mkl/latest" REALPATH)
endif()

if(NOT DEFINED ONEMKL_ROOT)
message(
WARNING
"Cannot find either ENV{MKLROOT} or SYCL_ROOT, please setup oneAPI environment before building!!"
)
return()
endif()

if(NOT EXISTS ${ONEMKL_ROOT})
message(
WARNING
"${ONEMKL_ROOT} not found, please setup oneAPI environment before building!!"
)
return()
endif()

find_file(
ONEMKL_INCLUDE_DIR
NAMES include
HINTS ${ONEMKL_ROOT}
NO_DEFAULT_PATH)

find_file(
ONEMKL_LIB_DIR
NAMES lib
HINTS ${ONEMKL_ROOT}
NO_DEFAULT_PATH)

if((ONEMKL_INCLUDE_DIR STREQUAL "ONEMKL_INCLUDE_DIR-NOTFOUND")
OR (ONEMKL_LIB_DIR STREQUAL "ONEMKL_LIB_DIR-NOTFOUND"))
message(WARNING "oneMKL sdk is incomplete!!")
return()
endif()

if(WIN32)
set(MKL_LIB_NAMES "mkl_intel_lp64" "mkl_intel_thread" "mkl_core" "mkl_sycl")
else()
set(MKL_LIB_NAMES "mkl_intel_lp64" "mkl_gnu_thread" "mkl_core" "mkl_sycl_dft")
endif()

foreach(LIB_NAME IN LISTS MKL_LIB_NAMES)
find_library(
${LIB_NAME}_library
NAMES ${LIB_NAME}
HINTS ${ONEMKL_LIB_DIR}
NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH)
list(APPEND ONEMKL_LIBRARIES ${${LIB_NAME}_library})
endforeach()

set(ONEMKL_FOUND TRUE)
11 changes: 11 additions & 0 deletions cmake/ONEMKL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
find_package(ONEMKL)
if(NOT ONEMKL_FOUND)
message(FATAL_ERROR "Can NOT find ONEMKL cmake helpers module!")
endif()

set(TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR ${ONEMKL_INCLUDE_DIR})

set(TORCH_XPU_OPS_ONEMKL_LIBRARIES ${ONEMKL_LIBRARIES})

list(INSERT TORCH_XPU_OPS_ONEMKL_LIBRARIES 0 "-Wl,--no-as-needed")
list(APPEND TORCH_XPU_OPS_ONEMKL_LIBRARIES "-Wl,--as-needed")
3 changes: 3 additions & 0 deletions src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# ATen XPU sources

file(GLOB xpu_cpp "xpu/*.cpp" "native/xpu/*.cpp" "native/sparse/*.cpp")
file(GLOB xpu_mkl "native/xpu/mkl/*.cpp")
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp")

list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl})
list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl})

set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE)
set(ATen_XPU_MKL_SRCS ${ATen_XPU_MKL_SRCS} PARENT_SCOPE)
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE)
43 changes: 43 additions & 0 deletions src/ATen/native/xpu/SpectralOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <ATen/ATen.h>
#include <ATen/native/Resize.h>
#include <ATen/native/xpu/mkl/SpectralOps.h>
#include <ATen/xpu/XPUNativeFunctions.h>

namespace at {

Tensor XPUNativeFunctions::_fft_c2c(
const Tensor& self,
IntArrayRef dim,
int64_t normalization,
bool forward) {
TORCH_CHECK(self.is_complex());

#if defined(__MKL_FALLBACK_TO_CPU)
Tensor out_cpu = native::_fft_c2c_mkl(
self.to(Device(at::kCPU)), dim, normalization, forward);
return out_cpu.to(Device(at::kXPU));
#else
return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward);
#endif // __MKL_FALLBACK_TO_CPU
}

Tensor& XPUNativeFunctions::_fft_c2c_out(
const Tensor& self,
IntArrayRef dim,
int64_t normalization,
bool forward,
Tensor& out) {
TORCH_CHECK(self.is_complex());

#if defined(__MKL_FALLBACK_TO_CPU)
Tensor out_cpu = out.to(Device(at::kCPU));
native::_fft_c2c_mkl_out(
self.to(Device(at::kCPU)), dim, normalization, forward, out_cpu);
out.copy_(out_cpu);
return out;
#else
return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out);
#endif // __MKL_FALLBACK_TO_CPU
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_efficient_attention_forward",
"_embedding_bag_dense_backward",
"_embedding_bag_per_sample_weights_backward",
"_fft_c2c",
"_fft_c2r",
"_fft_r2c",
"_flash_attention_forward",
Expand Down
Loading

0 comments on commit 8f8096a

Please sign in to comment.