From 0bc97409458fd22021c01fcdfef114c80a5988a9 Mon Sep 17 00:00:00 2001 From: Alexander Kalistratov Date: Sat, 21 Oct 2023 02:29:06 +0200 Subject: [PATCH] Implement pairwise distance on 2d grid --- .../pairwise_distance_numba_dpex_k.py | 6 ++-- .../pairwise_distance_numba_mlir_k.py | 6 ++-- .../_pairwise_distance_kernel.hpp | 30 ++++++++----------- .../_pairwise_distance_sycl.cpp | 23 +++++++++----- 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py b/dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py index 057b8298..e0fa9982 100644 --- a/dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py +++ b/dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py @@ -8,8 +8,8 @@ @dpex.kernel def _pairwise_distance_kernel(X1, X2, D): - i = dpex.get_global_id(0) - j = dpex.get_global_id(1) + i = dpex.get_global_id(1) + j = dpex.get_global_id(0) X1_cols = X1.shape[1] @@ -21,4 +21,4 @@ def _pairwise_distance_kernel(X1, X2, D): def pairwise_distance(X1, X2, D): - _pairwise_distance_kernel[dpex.Range(X1.shape[0], X2.shape[0])](X1, X2, D) + _pairwise_distance_kernel[dpex.Range(X2.shape[0], X1.shape[0])](X1, X2, D) diff --git a/dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py b/dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py index d2e68272..02477177 100644 --- a/dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py +++ b/dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py @@ -8,8 +8,8 @@ @nb.kernel(gpu_fp64_truncate="auto") def _pairwise_distance_kernel(X1, X2, D): - i = nb.get_global_id(0) - j = nb.get_global_id(1) + i = nb.get_global_id(1) + j = nb.get_global_id(0) X1_cols = X1.shape[1] @@ -22,5 +22,5 @@ def _pairwise_distance_kernel(X1, X2, D): def pairwise_distance(X1, X2, D): _pairwise_distance_kernel[ - (X1.shape[0], X2.shape[0]), nb.DEFAULT_LOCAL_SIZE + (X2.shape[0], X1.shape[0]), nb.DEFAULT_LOCAL_SIZE ](X1, X2, D) diff --git a/dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_kernel.hpp b/dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_kernel.hpp index 6e15959c..6ecc3b12 100644 --- a/dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_kernel.hpp +++ b/dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_kernel.hpp @@ -6,34 +6,28 @@ using namespace sycl; -#ifdef __DO_FLOAT__ -#define SQRT(x) sqrtf(x) -#else -#define SQRT(x) sqrt(x) -#endif +template class PairwiseDistanceKernel; template void pairwise_distance_impl(queue Queue, - size_t npoints, + size_t x1_npoints, + size_t x2_npoints, size_t ndims, const FpTy *p1, const FpTy *p2, FpTy *distance_op) { Queue.submit([&](handler &h) { - h.parallel_for( - range<1>{npoints}, [=](id<1> myID) { - size_t i = myID[0]; - for (size_t j = 0; j < npoints; j++) { - FpTy d = 0.; - for (size_t k = 0; k < ndims; k++) { - auto tmp = p1[i * ndims + k] - p2[j * ndims + k]; - d += tmp * tmp; - } - if (d != 0.0) { - distance_op[i * npoints + j] = sqrt(d); - } + h.parallel_for>( + range<2>{x1_npoints, x2_npoints}, [=](id<2> myID) { + auto i = myID[0]; + auto j = myID[1]; + FpTy d = 0.; + for (size_t k = 0; k < ndims; k++) { + auto tmp = p1[i * ndims + k] - p2[j * ndims + k]; + d += tmp * tmp; } + distance_op[i * x2_npoints + j] = sycl::sqrt(d); }); }); diff --git a/dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_sycl.cpp b/dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_sycl.cpp index 3a5dc725..2d89606a 100644 --- a/dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_sycl.cpp +++ b/dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_sycl.cpp @@ -49,18 +49,27 @@ void pairwise_distance_sync(dpctl::tensor::usm_ndarray X1, { sycl::event res_ev; auto Queue = X1.get_queue(); - auto ndims = 3; - auto npoints = X1.get_size() / ndims; + auto ndims = X1.get_shape(1); + auto x1_npoints = X1.get_shape(0); + auto x2_npoints = X2.get_shape(0); if (!ensure_compatibility(X1, X2, D)) throw std::runtime_error("Input arrays are not acceptable."); - if (X1.get_typenum() != UAR_DOUBLE || X2.get_typenum() != UAR_DOUBLE) { - throw std::runtime_error("Expected a double precision FP array."); + if (X1.get_typenum() == UAR_FLOAT) { + pairwise_distance_impl(Queue, x1_npoints, x2_npoints, ndims, + X1.get_data(), X2.get_data(), + D.get_data()); + } + else if (X1.get_typenum() == UAR_DOUBLE) { + pairwise_distance_impl(Queue, x1_npoints, x2_npoints, ndims, + X1.get_data(), X2.get_data(), + D.get_data()); + } + else { + throw std::runtime_error( + "Expected a double or single precision FP array."); } - - pairwise_distance_impl(Queue, npoints, ndims, X1.get_data(), - X2.get_data(), D.get_data()); } PYBIND11_MODULE(_pairwise_distance_sycl, m)