Skip to content

Commit

Permalink
Merge branch 'mengfeil/weekly' of https://github.com/intel/torch-xpu-ops
Browse files Browse the repository at this point in the history
 into mengfeil/weekly
  • Loading branch information
mengfei25 committed Jul 29, 2024
2 parents aa09785 + d77a710 commit de5e34f
Show file tree
Hide file tree
Showing 26 changed files with 614 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ tacotron2,fail_to_run,fail_to_run,fail_to_run,fail_to_run,fail_to_run
timm_efficientdet,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load
timm_efficientnet,pass,pass,pass,pass,pass
timm_nfnet,pass,pass,pass,pass,pass
timm_regnet,pass,pass,pass,pass,pass
timm_regnet,pass,fail_accuracy,pass,pass,pass
timm_resnest,pass,pass,pass,pass,pass
timm_vision_transformer,pass,pass,pass,pass,pass
timm_vision_transformer_large,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip
Expand Down
6 changes: 3 additions & 3 deletions cmake/BuildFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "MSVC"
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-approx-func)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Wno-absolute-value)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz)
# Equivalent to build option -fpreview-breaking-changes for SYCL compiler.
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -D__INTEL_PREVIEW_BREAKING_CHANGES)
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -D_GLIBCXX_USE_CXX11_ABI=${GLIBCXX_USE_CXX11_ABI})
endif()
# TODO: Align with PyTorch and switch to ABI=0 eventually, after
# resolving incompatible implementation in SYCL runtime.
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -D_GLIBCXX_USE_CXX11_ABI=1)
set(SYCL_FLAGS ${SYCL_FLAGS} ${SYCL_KERNEL_OPTIONS})

set(TORCH_XPU_OPS_FLAGS ${SYCL_HOST_FLAGS})
Expand Down
5 changes: 2 additions & 3 deletions cmake/Modules/FindSYCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,6 @@ macro(SYCL_LINK_DEVICE_OBJECTS output_file sycl_target)
OUTPUT ${output_file}
DEPENDS ${object_files}
COMMAND ${SYCL_EXECUTABLE}
-fsycl
${SYCL_device_link_flags}
-fsycl-link ${object_files}
-Xs "\"${SYCL_OFFLINE_COMPILER_FLAGS}\""
Expand Down Expand Up @@ -471,7 +470,7 @@ macro(SYCL_ADD_LIBRARY sycl_target)
target_link_libraries(
${sycl_target}
${SYCL_LINK_LIBRARIES_KEYWORD}
${SYCL_LIBRARIES})
${SYCL_LIBRARY})

set_target_properties(${sycl_target}
PROPERTIES
Expand Down Expand Up @@ -530,7 +529,7 @@ macro(SYCL_ADD_EXECUTABLE sycl_target)
target_link_libraries(
${sycl_target}
${SYCL_LINK_LIBRARIES_KEYWORD}
${SYCL_LIBRARIES})
${SYCL_LIBRARY})

set_target_properties(${sycl_target}
PROPERTIES
Expand Down
73 changes: 18 additions & 55 deletions cmake/Modules/FindSYCLToolkit.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@ This will define the following variables:
#]=======================================================================]

set(SYCLTOOLKIT_FOUND False)
include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake)
include(${TORCH_ROOT}/cmake/Modules/FindSYCLToolkit.cmake)

set(SYCL_ROOT "")
if(DEFINED ENV{SYCL_ROOT})
set(SYCL_ROOT $ENV{SYCL_ROOT})
elseif(DEFINED ENV{CMPLR_ROOT})
set(SYCL_ROOT $ENV{CMPLR_ROOT})
if(NOT SYCL_FOUND)
set(SYCLTOOLKIT_FOUND FALSE)
return()
endif()

if(SYCLTOOLKIT_FOUND)
return()
endif()
set(SYCLTOOLKIT_FOUND TRUE)

include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake)

if(WIN32)
set(SYCL_EXECUTABLE_NAME icx)
else()
Expand Down Expand Up @@ -71,43 +75,6 @@ if(nocmplr)
set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
endif()

find_file(
SYCL_INCLUDE_DIR
NAMES include
HINTS ${SYCL_ROOT}
NO_DEFAULT_PATH
)

find_file(
SYCL_INCLUDE_SYCL_DIR
NAMES sycl
HINTS ${SYCL_ROOT}/include
NO_DEFAULT_PATH
)

list(APPEND SYCL_INCLUDE_DIR ${SYCL_INCLUDE_SYCL_DIR})

find_file(
SYCL_LIBRARY_DIR
NAMES lib lib64
HINTS ${SYCL_ROOT}
NO_DEFAULT_PATH
)

find_library(
SYCL_LIBRARY
NAMES sycl
HINTS ${SYCL_LIBRARY_DIR}
NO_DEFAULT_PATH
)

if((NOT SYCL_INCLUDE_DIR) OR (NOT SYCL_LIBRARY_DIR) OR (NOT SYCL_LIBRARY))
set(SYCLTOOLKIT_FOUND False)
set(SYCL_REASON_FAILURE "SYCL sdk is incomplete!!")
set(SYCL_NOT_FOUND_MESSAGE "${SYCL_REASON_FAILURE}")
return()
endif()

# Function to write a test case to verify SYCL features.

function(SYCL_CMPLR_TEST_WRITE src)
Expand Down Expand Up @@ -202,6 +169,13 @@ set(SYCL_FLAGS "")
set(SYCL_LINK_FLAGS "")
list(APPEND SYCL_FLAGS "-fsycl")
list(APPEND SYCL_LINK_FLAGS "-fsycl")
if(LINUX)
string(REGEX MATCH "libsycl-preview.so" is_abi_neutral ${SYCL_LIBRARY})
if(is_abi_neutral)
list(APPEND SYCL_FLAGS "-fpreview-breaking-changes")
list(APPEND SYCL_LINK_FLAGS "-fpreview-breaking-changes")
endif()
endif()

set(SYCL_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SYCL_FLAGS}")

Expand Down Expand Up @@ -249,14 +223,3 @@ endif()

message(DEBUG "The SYCL compiler is ${SYCL_COMPILER}")
message(DEBUG "The SYCL Flags are ${SYCL_FLAGS}")

# Avoid module variables conflict due to calling find_package recursively
# e.g. find_package -> add_subdirectory(entering in a sub-project) -> find_package
# find_package_handle_standard_args(
# SYCLToolkit
# FOUND_VAR SYCLTOOLKIT_FOUND
# REQUIRED_VARS SYCL_INCLUDE_DIR SYCL_LIBRARY_DIR SYCL_LIBRARY SYCL_FLAGS
# VERSION_VAR SYCL_LANGUAGE_VERSION
# REASON_FAILURE_MESSAGE "${SYCL_REASON_FAILURE}")
set(SYCLTOOLKIT_FOUND True)

18 changes: 0 additions & 18 deletions cmake/SYCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,6 @@ if(NOT SYCL_VERSION)
return()
endif()

find_library(SYCL_LIBRARIES sycl HINTS ${SYCL_LIBRARY_DIR})
# On Windows, currently there's no sycl.lib. Only sycl7.lib with version suffix,
# where the current version of the SYCL runtime is 7.
# Until oneAPI adds support to sycl.lib without the version suffix,
# sycl_runtime_version needs to be hardcoded and uplifted when SYCL runtime version uplifts.
# TODO: remove this when sycl.lib is supported on Windows
if(WIN32)
set(sycl_runtime_version 7)
find_library(
SYCL_LIBRARIES
NAMES "sycl${sycl_runtime_version}"
HINTS ${SYCL_LIBRARY_DIR}
)
if(SYCL_LIBRARIES STREQUAL "SYCL_LIBRARIES-NOTFOUND")
message(FATAL_ERROR "Cannot find a SYCL library on Windows")
endif()
endif()

set(SYCL_COMPILER_VERSION)
file(READ ${SYCL_VERSION} version_contents)
string(REGEX MATCHALL "__SYCL_COMPILER_VERSION +[0-9]+" VERSION_LINE "${version_contents}")
Expand Down
2 changes: 1 addition & 1 deletion src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ATen XPU sources

file(GLOB xpu_cpp "xpu/*.cpp", "native/xpu/*.cpp", "native/sparse/*.cpp")
file(GLOB xpu_cpp "xpu/*.cpp" "native/xpu/*.cpp" "native/sparse/*.cpp")
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp")

list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
Expand Down
195 changes: 195 additions & 0 deletions src/ATen/native/xpu/Histogram.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#include <ATen/Context.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/Resize.h>
#include <ATen/native/xpu/sycl/HistogramKernels.h>
#include <ATen/xpu/XPUNativeFunctions.h>

namespace at {

/* Checks properties of input tensors input, bins, and weight.
*/
void histogramdd_check_inputs(
const Tensor& input,
const Tensor& bins,
const std::optional<Tensor>& weight) {
if (weight.has_value()) {
TORCH_CHECK(
weight->device() == input.device(),
"weight and input need to be on the same device.")
}
auto input_dtype = input.dtype();
auto bins_dtype = bins.dtype();
TORCH_CHECK(
input_dtype == bins_dtype,
"torch.histogramdd: input tensor and bins tensors should",
" have the same dtype, but got input with dtype ",
input_dtype,
" and bins with dtype ",
bins_dtype);

const int64_t bins_dim = bins.dim();
TORCH_CHECK(
bins_dim == 1,
"torch.histogramdd: bins tensor should have one dimension,",
" but got ",
bins_dim,
" dimensions in the bin tensor");

const int64_t numel = bins.numel();
TORCH_CHECK(
numel > 0,
"torch.histogramdd: bins tensor should have at least 1 element,",
" but got ",
numel,
" elements in the bin tensor");

if (weight.has_value()) {
TORCH_CHECK(
input.dtype() == weight.value().dtype(),
"torch.histogramdd: if weight tensor is provided, ",
"input tensor and weight tensor should have the same dtype, ",
"but got input(",
input.dtype(),
")",
", and weight(",
weight.value().dtype(),
")");

auto input_sizes = input.sizes().vec();

auto weight_sizes = weight.value().sizes().vec();
if (weight_sizes.empty()) {
// correctly handle scalars
weight_sizes = {1};
}

TORCH_CHECK(
input_sizes == weight_sizes,
"torch.histogramdd: if weight tensor is provided it should have",
" the same shape as the input tensor excluding its innermost ",
"dimension, but got input with shape ",
input.sizes(),
" and weight ",
"with shape ",
weight.value().sizes());
}
}

/* Checks properties of output tensors hist and bin_edges, then resizes them.
*/
void histogramdd_prepare_out(
const Tensor& input,
int64_t bin_ct,
const Tensor& hist,
const Tensor& bin_edges) {
TORCH_CHECK(
input.dtype() == hist.dtype(),
"torch.histogram: input tensor and hist tensor should",
" have the same dtype, but got input ",
input.dtype(),
" and hist ",
hist.dtype());

TORCH_CHECK(
input.dtype() == bin_edges.dtype(),
"torch.histogram: input tensor and bin_edges tensor should",
" have the same dtype, but got input ",
input.dtype(),
" and bin_edges ",
bin_edges.dtype());

TORCH_CHECK(
bin_ct > 0, "torch.histogram(): bins must be > 0, but got ", bin_ct);

at::native::resize_output(bin_edges, {bin_ct + 1});

at::native::resize_output(hist, {bin_ct});
}

void histogramdd_prepare_out(
const Tensor& input,
const Tensor& bins,
const Tensor& hist,
const Tensor& bin_edges) {
int64_t bin_ct = bins.numel() - 1;
histogramdd_prepare_out(input, bin_ct, hist, bin_edges);
}

static Tensor& histogramdd_out(
const Tensor& self,
const Tensor& bins,
const std::optional<Tensor>& weight,
bool density,
Tensor& hist,
Tensor& bin_edges) {
histogramdd_check_inputs(self, bins, weight);
histogramdd_prepare_out(self, bins, hist, bin_edges);

bin_edges.copy_(bins);

at::native::xpu::histogramdd_kernel(self, weight, density, hist, bin_edges);
return hist;
}

std::tuple<Tensor&, Tensor&> XPUNativeFunctions::histogram_out(
const Tensor& self,
const Tensor& bins,
const std::optional<Tensor>& weight,
bool density,
Tensor& hist,
Tensor& bin_edges) {
Tensor reshaped_self = self.reshape({self.numel()});
std::optional<Tensor> reshaped_weight = weight.has_value()
? weight.value().reshape({weight.value().numel()})
: weight;

histogramdd_out(
reshaped_self, bins, reshaped_weight, density, hist, bin_edges);

return std::forward_as_tuple(hist, bin_edges);
}

std::tuple<Tensor, Tensor> XPUNativeFunctions::histogram(
const Tensor& self,
const Tensor& bins,
const std::optional<Tensor>& weight,
bool density) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
Tensor bin_edges = at::empty({0}, bins.options(), MemoryFormat::Contiguous);
return histogram_out(self, bins, weight, density, hist, bin_edges);
}

std::tuple<Tensor&, Tensor&> XPUNativeFunctions::histogram_out(
const Tensor& self,
int64_t bin_ct,
std::optional<c10::ArrayRef<double>> range,
const std::optional<Tensor>& weight,
bool density,
Tensor& hist,
Tensor& bin_edges) {
Tensor reshaped_self = self.reshape({self.numel()});
std::optional<Tensor> reshaped_weight = weight.has_value()
? weight.value().reshape({weight.value().numel()})
: weight;

histogramdd_prepare_out(reshaped_self, bin_ct, hist, bin_edges);
histogramdd_check_inputs(reshaped_self, bin_edges, reshaped_weight);

at::native::xpu::histogramdd_linear_kernel(
reshaped_self, bin_ct, range, reshaped_weight, density, hist, bin_edges);
return std::forward_as_tuple(hist, bin_edges);
}

std::tuple<Tensor, Tensor> XPUNativeFunctions::histogram(
const Tensor& self,
int64_t bin_ct,
std::optional<c10::ArrayRef<double>> range,
const std::optional<Tensor>& weight,
bool density) {
Tensor hist = at::empty({0}, self.options(), MemoryFormat::Contiguous);
Tensor bin_edges_out = at::empty({0}, self.options());
return histogram_out(
self, bin_ct, range, weight, density, hist, bin_edges_out);
}

} // namespace at
Loading

0 comments on commit de5e34f

Please sign in to comment.