Skip to content

Commit 0bc9740

Browse files
Implement pairwise distance on 2d grid
1 parent 94c2d62 commit 0bc9740

File tree

4 files changed

+34
-31
lines changed

4 files changed

+34
-31
lines changed

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
@dpex.kernel
1010
def _pairwise_distance_kernel(X1, X2, D):
11-
i = dpex.get_global_id(0)
12-
j = dpex.get_global_id(1)
11+
i = dpex.get_global_id(1)
12+
j = dpex.get_global_id(0)
1313

1414
X1_cols = X1.shape[1]
1515

@@ -21,4 +21,4 @@ def _pairwise_distance_kernel(X1, X2, D):
2121

2222

2323
def pairwise_distance(X1, X2, D):
24-
_pairwise_distance_kernel[dpex.Range(X1.shape[0], X2.shape[0])](X1, X2, D)
24+
_pairwise_distance_kernel[dpex.Range(X2.shape[0], X1.shape[0])](X1, X2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
@nb.kernel(gpu_fp64_truncate="auto")
1010
def _pairwise_distance_kernel(X1, X2, D):
11-
i = nb.get_global_id(0)
12-
j = nb.get_global_id(1)
11+
i = nb.get_global_id(1)
12+
j = nb.get_global_id(0)
1313

1414
X1_cols = X1.shape[1]
1515

@@ -22,5 +22,5 @@ def _pairwise_distance_kernel(X1, X2, D):
2222

2323
def pairwise_distance(X1, X2, D):
2424
_pairwise_distance_kernel[
25-
(X1.shape[0], X2.shape[0]), nb.DEFAULT_LOCAL_SIZE
25+
(X2.shape[0], X1.shape[0]), nb.DEFAULT_LOCAL_SIZE
2626
](X1, X2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_kernel.hpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,28 @@
66

77
using namespace sycl;
88

9-
#ifdef __DO_FLOAT__
10-
#define SQRT(x) sqrtf(x)
11-
#else
12-
#define SQRT(x) sqrt(x)
13-
#endif
9+
template <typename T> class PairwiseDistanceKernel;
1410

1511
template <typename FpTy>
1612
void pairwise_distance_impl(queue Queue,
17-
size_t npoints,
13+
size_t x1_npoints,
14+
size_t x2_npoints,
1815
size_t ndims,
1916
const FpTy *p1,
2017
const FpTy *p2,
2118
FpTy *distance_op)
2219
{
2320
Queue.submit([&](handler &h) {
24-
h.parallel_for<class PairwiseDistanceKernel>(
25-
range<1>{npoints}, [=](id<1> myID) {
26-
size_t i = myID[0];
27-
for (size_t j = 0; j < npoints; j++) {
28-
FpTy d = 0.;
29-
for (size_t k = 0; k < ndims; k++) {
30-
auto tmp = p1[i * ndims + k] - p2[j * ndims + k];
31-
d += tmp * tmp;
32-
}
33-
if (d != 0.0) {
34-
distance_op[i * npoints + j] = sqrt(d);
35-
}
21+
h.parallel_for<PairwiseDistanceKernel<FpTy>>(
22+
range<2>{x1_npoints, x2_npoints}, [=](id<2> myID) {
23+
auto i = myID[0];
24+
auto j = myID[1];
25+
FpTy d = 0.;
26+
for (size_t k = 0; k < ndims; k++) {
27+
auto tmp = p1[i * ndims + k] - p2[j * ndims + k];
28+
d += tmp * tmp;
3629
}
30+
distance_op[i * x2_npoints + j] = sycl::sqrt(d);
3731
});
3832
});
3933

dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_sycl.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,27 @@ void pairwise_distance_sync(dpctl::tensor::usm_ndarray X1,
4949
{
5050
sycl::event res_ev;
5151
auto Queue = X1.get_queue();
52-
auto ndims = 3;
53-
auto npoints = X1.get_size() / ndims;
52+
auto ndims = X1.get_shape(1);
53+
auto x1_npoints = X1.get_shape(0);
54+
auto x2_npoints = X2.get_shape(0);
5455

5556
if (!ensure_compatibility(X1, X2, D))
5657
throw std::runtime_error("Input arrays are not acceptable.");
5758

58-
if (X1.get_typenum() != UAR_DOUBLE || X2.get_typenum() != UAR_DOUBLE) {
59-
throw std::runtime_error("Expected a double precision FP array.");
59+
if (X1.get_typenum() == UAR_FLOAT) {
60+
pairwise_distance_impl(Queue, x1_npoints, x2_npoints, ndims,
61+
X1.get_data<float>(), X2.get_data<float>(),
62+
D.get_data<float>());
63+
}
64+
else if (X1.get_typenum() == UAR_DOUBLE) {
65+
pairwise_distance_impl(Queue, x1_npoints, x2_npoints, ndims,
66+
X1.get_data<double>(), X2.get_data<double>(),
67+
D.get_data<double>());
68+
}
69+
else {
70+
throw std::runtime_error(
71+
"Expected a double or single precision FP array.");
6072
}
61-
62-
pairwise_distance_impl(Queue, npoints, ndims, X1.get_data<double>(),
63-
X2.get_data<double>(), D.get_data<double>());
6473
}
6574

6675
PYBIND11_MODULE(_pairwise_distance_sycl, m)

0 commit comments

Comments
 (0)